Transformations

At its core, JAX is an extensible system for transforming numerical functions. This section will discuss four that are of primary interest: grad(), jit(), vmap(), and pmap().

Automatic differentiation with grad

JAX has roughly the same API as Autograd. The most popular function is jax.grad() for reverse-mode gradients:

from jax import grad
import jax.numpy as jnp

def tanh(x):  # Define a function
  y = jnp.exp(-2.0 * x)
  return (1.0 - y) / (1.0 + y)

grad_tanh = grad(tanh)  # Obtain its gradient function
print(grad_tanh(1.0))   # Evaluate it at x = 1.0
0.4199743

You can differentiate to any order with grad().

print(grad(grad(grad(tanh)))(1.0))
0.6216266

For more advanced autodiff, you can use jax.vjp() for reverse-mode vector-Jacobian products and jax.jvp() for forward-mode Jacobian-vector products. The two can be composed arbitrarily with one another, and with other JAX transformations. Here’s one way to compose those to make a function that efficiently computes full Hessian matrices:

from jax import jit, jacfwd, jacrev

def hessian(fun):
  return jit(jacfwd(jacrev(fun)))

As with Autograd, you’re free to use differentiation with Python control structures:

def abs_val(x):
  if x > 0:
    return x
  else:
    return -x

abs_val_grad = grad(abs_val)
print(abs_val_grad(1.0))
1.0
print(abs_val_grad(-1.0))
-1.0

See the reference docs on automatic differentiation and the JAX Autodiff Cookbook for more.

Compilation with jit

You can use XLA to compile your functions end-to-end with jax.jit() used either as an @jit decorator or as a higher-order function.

import jax.numpy as jnp
from jax import jit

def slow_f(x):
  # Element-wise ops see a large benefit from fusion
  return x * x + x * 2.0

x = jnp.ones((5000, 5000))
%timeit slow_f(x).block_until_ready()
138 ms ± 544 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
fast_f = jit(slow_f)

# Results are the same
assert jnp.allclose(slow_f(x), fast_f(x))

%timeit fast_f(x).block_until_ready()
43.4 ms ± 280 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

You can mix jit() and grad() and any other JAX transformation however you like.

Using jit() puts constraints on the kind of Python control flow the function can use; see 🔪 JAX - The Sharp Bits 🔪 for more.

Auto-vectorization with vmap

jax.vmap() is the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a function’s primitive operations for better performance.

Using vmap() can save you from having to carry around batch dimensions in your code. For example, consider this simple unbatched neural network prediction function:

def predict(params, input_vec):
  assert input_vec.ndim == 1
  activations = input_vec
  for W, b in params:
    outputs = jnp.dot(W, activations) + b  # `input_vec` on the right-hand side!
    activations = jnp.tanh(outputs)
  return outputs

We often instead write jnp.dot(inputs, W) to allow for a batch dimension on the left side of inputs, but we’ve written this particular prediction function to apply only to single input vectors. If we wanted to apply this function to a batch of inputs at once, semantically we could just write

# Create some sample inputs & parameters
import numpy as np
k, N = 10, 5
input_batch = np.random.rand(k, N)
params = [
  (np.random.rand(N, N), np.random.rand(N)),
  (np.random.rand(N, N), np.random.rand(N)),
]
from functools import partial
predictions = jnp.stack(list(map(partial(predict, params), input_batch)))

But pushing one example through the network at a time would be slow! It’s better to vectorize the computation, so that at every layer we’re doing matrix-matrix multiplication rather than matrix-vector multiplication.

The vmap() function does that transformation for us. That is, if we write:

from jax import vmap
predictions = vmap(partial(predict, params))(input_batch)
# or, alternatively
predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)

then the vmap() function will push the outer loop inside the function, and our machine will end up executing matrix-matrix multiplications exactly as if we’d done the batching by hand.

It’s easy enough to manually batch a simple neural network without vmap(), but in other cases manual vectorization can be impractical or impossible. Take the problem of efficiently computing per-example gradients: that is, for a fixed set of parameters, we want to compute the gradient of our loss function evaluated separately at each example in a batch. With vmap(), it’s easy:

# create a sample loss function & inputs
def loss(params, x, y0):
  y = predict(params, x)
  return jnp.sum((y - y0) ** 2)

inputs = np.random.rand(k, N)
targets = np.random.rand(k, N)
per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets)

Of course, vmap() can be arbitrarily composed with jit(), grad(), and any other JAX transformation! We use vmap() with both forward- and reverse-mode automatic differentiation for fast Jacobian and Hessian matrix calculations in jax.jacfwd(), jax.jacrev(), and jax.hessian().

SPMD programming with pmap

For parallel programming of multiple accelerators, like multiple GPUs, use jax.pmap(). With pmap() you write single-program multiple-data (SPMD) programs, including fast parallel collective communication operations. Applying pmap() will mean that the function you write is compiled by XLA (similarly to jit()), then replicated and executed in parallel across devices.

Here’s an example on an 8-core machine:

from jax import random, pmap
import jax.numpy as jnp

# Create 8 random 5000 x 6000 matrices, one per core
keys = random.split(random.PRNGKey(0), 8)
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)

# Run a local matmul on each device in parallel (no data transfer)
result = pmap(lambda x: jnp.dot(x, x.T))(mats)  # result.shape is (8, 5000, 5000)

# Compute the mean on each device in parallel and print the result
print(pmap(jnp.mean)(result))
[1.1566514 1.1805066 1.2052315 1.2045699 1.1876893 1.2037915 1.2322364
 1.2015213]

In addition to expressing pure maps, you can use fast Parallel operators between devices:

from functools import partial
from jax import lax

@partial(pmap, axis_name='i')
def normalize(x):
  return x / lax.psum(x, 'i')

print(normalize(jnp.arange(4.)))
[0.         0.16666667 0.33333334 0.5       ]

You can even nest pmap functions for more sophisticated communication patterns.

It all composes, so you’re free to differentiate through parallel computations:

from jax import grad

@pmap
def f(x):
  y = jnp.sin(x)
  @pmap
  def g(z):
    return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
  return grad(lambda w: jnp.sum(g(w)))(x)

x = jnp.arange(8.0).reshape(2, 4)
print(f(x))
[[  0.          6.883216    7.438035    1.1543589]
 [-12.774312  -16.18599    -4.7163434  11.089487 ]]
print(grad(lambda x: f(x).sum())(x))
[[-37.80803  -21.112206  18.502668  43.11701 ]
 [-76.852104  33.34725  112.88069   88.631256]]

When reverse-mode differentiating a pmap() function (e.g. with grad()), the backward pass of the computation is parallelized just like the forward pass.

See the SPMD Cookbook and the SPMD MNIST classifier from scratch example for more.