JAX: High-Performance Array Computing

JAX: High-Performance Array Computing#

JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning.

If you’re looking to train neural networks, use Flax and start with its documentation. Some associated tools are Optax and Orbax. For an end-to-end transformer library built on JAX, see MaxText.

Familiar API

JAX provides a familiar NumPy-style API for ease of adoption by researchers and engineers.

Transformations

JAX includes composable function transformations for compilation, batching, automatic differentiation, and parallelization.

Run Anywhere

The same code executes on multiple backends, including CPU, GPU, & TPU

Getting Started
User Guides
Developer Docs