JAX: High-Performance Array Computing
JAX: High-Performance Array Computing#
JAX provides a familiar NumPy-style API for ease of adoption by researchers and engineers.
JAX includes composable function transformations for compilation, batching, automatic differentiation, and parallelization.
The same code executes on multiple backends, including CPU, GPU, & TPU
JAX 0.4.1 introduces new parallelism APIs, including breaking changes to
jax.experimental.pjit() and a new unified
Please see Distributed arrays and automatic parallelization tutorial and the jax.Array migration
guide for more information.
pip install "jax[cpu]"
pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
For more information about supported accelerators and platforms, and for other installation options, see the Install Guide in the project README.