Quickstart#

JAX a library for array-oriented numerical computation (à la NumPy), with automatic differentiation and JIT compilation to enable high-performance machine learning research.

This document provides a quick overview of essential JAX features, so you can get started with JAX quickly:

  • JAX provides a unified NumPy-like interface to computations that run on CPU, GPU, or TPU, in local or distributed settings.

  • JAX features built-in Just-In-Time (JIT) compilation via Open XLA, an open-source machine learning compiler ecosystem.

  • JAX functions support efficient evaluation of gradients via its automatic differentiation transformations.

  • JAX functions can be automatically vectorized to efficiently map them over arrays representing batches of inputs.

Installation#

JAX can be installed for CPU on Linux, Windows, and macOS directly from the Python Package Index:

pip install "jax[cpu]"

or, for NVIDIA GPU:

pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

For more detailed platform-specific installation information, check out Installing JAX.

JAX as NumPy#

Most JAX usage is through the familiar jax.numpy API, which is typically imported under the jnp alias:

import jax.numpy as jnp

With this import, you can immediately use JAX in a similar manner to typical NumPy programs, including using NumPy-style array creation functions, Python functions and operators, and array attributes and methods:

def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(5.0)
print(selu(x))
[0.        1.05      2.1       3.1499999 4.2      ]

You’ll find a few differences between JAX arrays and NumPy arrays once you begin digging-in; these are explored in 🔪 JAX - The Sharp Bits 🔪.

Just-in-time compilation with jax.jit()#

JAX runs transparently on the GPU or TPU (falling back to CPU if you don’t have one). However, in the above example, JAX is dispatching kernels to the chip one operation at a time. If we have a sequence of operations, we can use the jax.jit() function to compile this sequence of operations together using XLA.

We can use IPython’s %timeit to quickly benchmark our selu function, using block_until_ready() to account for JAX’s dynamic dispatch (See Asynchronous dispatch):

from jax import random

key = random.key(1701)
x = random.normal(key, (1_000_000,))
%timeit selu(x).block_until_ready()
2.81 ms ± 23.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

(notice we’ve used jax.random to generate some random numbers; for details on how to generate random numbers in JAX, check out Pseudorandom numbers).

We can speed the execution of this function with the jax.jit() transformation, which will jit-compile the first time selu is called and will be cached thereafter.

from jax import jit

selu_jit = jit(selu)
_ = selu_jit(x)  # compiles on first call
%timeit selu_jit(x).block_until_ready()
848 µs ± 5.34 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

The above timing represent execution on CPU, but the same code can be run on GPU or TPU, typically for an even greater speedup.

For more on JIT compilation in JAX, check out Just-in-time compilation.

Taking derivatives with jax.grad()#

In addition to transforming functions via JIT compilation, JAX also provides other transformations. One such transformation is jax.grad(), which performs automatic differentiation (autodiff):

from jax import grad

def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))
[0.25       0.19661197 0.10499357]

Let’s verify with finite differences that our result is correct.

def first_finite_differences(f, x, eps=1E-3):
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])

print(first_finite_differences(sum_logistic, x_small))
[0.24998187 0.1965761  0.10502338]

The grad() and jit() transformations compose and can be mixed arbitrarily. In the above example we jitted sum_logistic and then took its derivative. We can go further:

print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))
-0.0353256

Beyond scalar-valued functions, the jax.jacobian() transformation can be used to compute the full Jacobian matrix for vector-valued functions:

from jax import jacobian
print(jacobian(jnp.exp)(x_small))
[[1.        0.        0.       ]
 [0.        2.7182817 0.       ]
 [0.        0.        7.389056 ]]

For more advanced autodiff operations, you can use jax.vjp() for reverse-mode vector-Jacobian products, and jax.jvp() and jax.linearize() for forward-mode Jacobian-vector products. The two can be composed arbitrarily with one another, and with other JAX transformations. For example, jax.jvp() and jax.vjp() are used to define the forward-mode jax.jacfwd() and reverse-mode jax.jacrev() for computing Jacobians in forward- and reverse-mode, respectively. Here’s one way to compose them to make a function that efficiently computes full Hessian matrices:

from jax import jacfwd, jacrev
def hessian(fun):
  return jit(jacfwd(jacrev(fun)))
print(hessian(sum_logistic)(x_small))
[[-0.         -0.         -0.        ]
 [-0.         -0.09085776 -0.        ]
 [-0.         -0.         -0.07996249]]

This kind of composition produces efficient code in practice; this is more-or-less how JAX’s built-in jax.hessian() function is implemented.

For more on automatic differentiation in JAX, check out Automatic differentiation.

Auto-vectorization with jax.vmap()#

Another useful transformation is vmap(), the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of explicitly looping over function calls, it transforms the function into a natively vectorized version for better performance. When composed with jit(), it can be just as performant as manually rewriting your function operate over an extra batch dimension.

We’re going to work with a simple example, and promote matrix-vector products into matrix-matrix products using vmap(). Although this is easy to do by hand in this specific case, the same technique can apply to more complicated functions.

key1, key2 = random.split(key)
mat = random.normal(key1, (150, 100))
batched_x = random.normal(key2, (10, 100))

def apply_matrix(x):
  return jnp.dot(mat, x)

The apply_matrix function maps a vector to a vector, but we may want to apply it row-wise across a matrix. We could do this by looping over the batch dimension in Python, but this usually results in poor performance.

def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
Naively batched
970 µs ± 2.84 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

A programmer familiar with the the jnp.dot function might recognize that apply_matrix can be rewritten to avoid explicit looping, using the built-in batching semantics of jnp.dot:

import numpy as np

@jit
def batched_apply_matrix(batched_x):
  return jnp.dot(batched_x, mat.T)

np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
                           batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()
Manually batched
13.7 µs ± 133 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

However, as functions become more complicated, this kind of manual batching becomes more difficult and error-prone. The vmap() transformation is designed to automatically transform a function into a batch-aware version:

from jax import vmap

@jit
def vmap_batched_apply_matrix(batched_x):
  return vmap(apply_matrix)(batched_x)

np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
                           vmap_batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
Auto-vectorized with vmap
20.6 µs ± 89.7 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

As you would expect, vmap() can be arbitrarily composed with jit(), grad(), and any other JAX transformation.

For more on automatic vectorization in JAX, check out Automatic vectorization.

This is just a taste of what JAX can do. We’re really excited to see what you do with it!