JAX: High-Performance Array Computing#
JAX is Autograd and XLA, brought together for high-performance numerical computing.
Familiar API
JAX provides a familiar NumPy-style API for ease of adoption by researchers and engineers.
Transformations
JAX includes composable function transformations for compilation, batching, automatic differentiation, and parallelization.
Run Anywhere
The same code executes on multiple backends, including CPU, GPU, & TPU
Installation#
pip install "jax[cpu]"
pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
For more information about supported accelerators and platforms, and for other installation options, see the Install Guide in the project README.