#star/use #star/great # JAX: Composable transformations of Python+NumPy programs [[maclaurin_autograd_2023|Autograd]] plus [[accelerated linear algebra]]. **composable function transformations** Yes finally! Functional programming in machine learning!!! `jax.numpy` is a mostly drop-in replacement for [[NumPy]]. Biggest difference is that arrays are **immutable** (yay purity) `jax.lax` is a lower-level API, stricter but more powerful. basically a wrapper around [[accelerated linear algebra|xla]] ## [[jit compiler]] Jax is sequential by default, but one can use JIT compilation: optimize a block and run at once - requires array shapes to be static and known at compile time - it traces functions to determine effect on a given shape and dtype for the input. this **doesn't know about values** - will be recompiled when different shapes / dtypes are passed as input (kinda like [[Julia]]'s JIT) - so eg no control flow can depend on *values* of traced variables. instead, can use `lax` alternatives - also follows that shape can't depend on values (no dynamically-sized arrays). instead use things like zeroing out via jnp.where - to avoid tracing an argument, mark it as *static*. this must be **hashable** (i.e. no arrays). then the fn will get recompiled for each new *value* of that arg - **operations can also be static.** these get executed at compile time and should be written with numpy instead of jax (since running with jnp will give traced variables that don't have values at compile time). - I forget what this means - after compilation: the sequence of operations is encoded as **jaxpr** (jax expression) - see [[COMPSCI 2520R]] for interesting discussion of jaxpr usage - also [[memo programming language]] see [[functional programming#scan]] ## sharp bits - **use pure functions** - don't use python iterators, will manually unroll, slow compilation - no external mutable state. careful about caching behaviour - **out-of-bounds** indexing is **undefined** behaviour (well, it usually clamps, but will be weird for autodiff) - must pass *ndarrays*, no lists or tuples (causes poor performance cuz the array isn't recognized as a "single unit"), pytrees - [[pseudorandom]]ness: numpy randomness is unintuitive. jax uses 8-byte keys that you need to pass to functions that need randomness - instead: each time *split* the key. can generate multiple subkeys at once - debugging **NaNs**: - env `JAX_DEBUG_NANS=True` - `from jax.config import config; config.update("jax_debug_nans", True)` - or import config and run `config.parse_flags_with_absl()` and pass `--jax_debug_nans=True` on cli - when profiling, make sure to add `.block_until_ready()` since by default jax uses async dispatch - `AttributeError: module 'jax.interpreters.xla' has no attribute 'DeviceArray'` - Try upgrading [[Chex]] - np linalg inv doesn't check for singular matrix? See also [[autodiff#sources]] # docs Some cool [[reinforcement learning software]] stuff made with Jax ![[library.base#jax reinforcement learning]] - main docs have ecosystem links - [GitHub - n2cholas/awesome-jax: JAX - A curated list of resources](https://github.com/n2cholas/awesome-jax) - https://github.com/jax-ml (collection of projects) - [Learning JAX as a PyTorch developer · Patrick Kidger](https://kidger.site/thoughts/torch2jax/) [[Patrick Kidger]] has built lots of cool JAX software actual docs - JAX on Metal: - https://developer.apple.com/metal/jax/ - https://github.com/jax-ml/jax/issues?q=is%3Aissue+is%3Aopen+metal - https://github.com/google/jax#installation - [How to Think in JAX — JAX documentation](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html) - [[2020BabuschkinEtAlDeepMindJAXEcosystem|The DeepMind JAX Ecosystem]] - [Autodidax: JAX core from scratch — JAX documentation](https://jax.readthedocs.io/en/latest/autodidax.html) - [[deep learning programming framework]] - Very interesting walkthrough!