#star/use #star/great
# JAX: Composable transformations of Python+NumPy programs
[[maclaurin_autograd_2023|Autograd]] plus [[accelerated linear algebra]].
**composable function transformations**
Yes finally! Functional programming in machine learning!!!
`jax.numpy` is a mostly drop-in replacement for [[NumPy]]. Biggest difference is that arrays are **immutable** (yay purity)
`jax.lax` is a lower-level API, stricter but more powerful. basically a wrapper around [[accelerated linear algebra|xla]]
## [[jit compiler]]
Jax is sequential by default, but one can use JIT compilation: optimize a block and run at once
- requires array shapes to be static and known at compile time
- it traces functions to determine effect on a given shape and dtype for the input. this **doesn't know about values**
- will be recompiled when different shapes / dtypes are passed as input (kinda like [[Julia]]'s JIT)
- so eg no control flow can depend on *values* of traced variables. instead, can use `lax` alternatives
- also follows that shape can't depend on values (no dynamically-sized arrays). instead use things like zeroing out via jnp.where
- to avoid tracing an argument, mark it as *static*. this must be **hashable** (i.e. no arrays). then the fn will get recompiled for each new *value* of that arg
- **operations can also be static.** these get executed at compile time and should be written with numpy instead of jax (since running with jnp will give traced variables that don't have values at compile time).
- I forget what this means
- after compilation: the sequence of operations is encoded as **jaxpr** (jax expression)
- see [[COMPSCI 2520R]] for interesting discussion of jaxpr usage
- also [[memo programming language]]
see [[functional programming#scan]]
## sharp bits
- **use pure functions**
- don't use python iterators, will manually unroll, slow compilation
- no external mutable state. careful about caching behaviour
- **out-of-bounds** indexing is **undefined** behaviour (well, it usually clamps, but will be weird for autodiff)
- must pass *ndarrays*, no lists or tuples (causes poor performance cuz the array isn't recognized as a "single unit"), pytrees
- [[pseudorandom]]ness: numpy randomness is unintuitive. jax uses 8-byte keys that you need to pass to functions that need randomness
- instead: each time *split* the key. can generate multiple subkeys at once
- debugging **NaNs**:
- env `JAX_DEBUG_NANS=True`
- `from jax.config import config; config.update("jax_debug_nans", True)`
- or import config and run `config.parse_flags_with_absl()` and pass `--jax_debug_nans=True` on cli
- when profiling, make sure to add `.block_until_ready()` since by default jax uses async dispatch
- `AttributeError: module 'jax.interpreters.xla' has no attribute 'DeviceArray'`
- Try upgrading [[Chex]]
- np linalg inv doesn't check for singular matrix?
See also [[autodiff#sources]]
# docs
Some cool [[reinforcement learning software]] stuff made with Jax
![[library.base#jax reinforcement learning]]
- main docs have ecosystem links
- [GitHub - n2cholas/awesome-jax: JAX - A curated list of resources](https://github.com/n2cholas/awesome-jax)
- https://github.com/jax-ml (collection of projects)
- [Learning JAX as a PyTorch developer · Patrick Kidger](https://kidger.site/thoughts/torch2jax/)
[[Patrick Kidger]] has built lots of cool JAX software
actual docs
- JAX on Metal:
- https://developer.apple.com/metal/jax/
- https://github.com/jax-ml/jax/issues?q=is%3Aissue+is%3Aopen+metal
- https://github.com/google/jax#installation
- [How to Think in JAX — JAX documentation](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html)
- [[2020BabuschkinEtAlDeepMindJAXEcosystem|The DeepMind JAX Ecosystem]]
- [Autodidax: JAX core from scratch — JAX documentation](https://jax.readthedocs.io/en/latest/autodidax.html)
- [[deep learning programming framework]]
- Very interesting walkthrough!