JAX: High-Performance Array Computing#

JAX is Autograd and XLA, brought together for high-performance numerical computing.

Familiar API

JAX provides a familiar NumPy-style API for ease of adoption by researchers and engineers.


JAX includes composable function transformations for compilation, batching, automatic differentiation, and parallelization.

Run Anywhere

The same code executes on multiple backends, including CPU, GPU, & TPU


JAX 0.4.1 introduces new parallelism APIs, including breaking changes to jax.experimental.pjit() and a new unified jax.Array type. Please see Distributed arrays and automatic parallelization tutorial and the jax.Array migration guide for more information.

Getting Started
User Guides
Developer Docs


pip install "jax[cpu]"
pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

For more information about supported accelerators and platforms, and for other installation options, see the Install Guide in the project README.