JAX reference documentation#

JAX is Autograd and XLA, brought together for high-performance numerical computing and machine learning research. It provides composable transformations of Python+NumPy programs: differentiate, vectorize, parallelize, Just-In-Time compile to GPU/TPU, and more.


JAX 0.4.0 introduces new parallelism APIs, including breaking changes to jax.experimental.pjit() and a new unified jax.Array type. Please see Parallelism with JAX tutorial and the jax.Array migration guide for more information.

API documentation

Indices and tables#