TransformationsÂ¶

At its core, JAX is an extensible system for transforming numerical functions. This section will discuss four that are of primary interest: `grad()`, `jit()`, `vmap()`, and `pmap()`.

Automatic differentiation with `grad`Â¶

JAX has roughly the same API as Autograd. The most popular function is `jax.grad()` for reverse-mode gradients:

```from jax import grad
import jax.numpy as jnp

def tanh(x):  # Define a function
y = jnp.exp(-2.0 * x)
return (1.0 - y) / (1.0 + y)

print(grad_tanh(1.0))   # Evaluate it at x = 1.0
```
```0.4199743
```

You can differentiate to any order with `grad()`.

```print(grad(grad(grad(tanh)))(1.0))
```
```0.6216266
```

For more advanced autodiff, you can use `jax.vjp()` for reverse-mode vector-Jacobian products and `jax.jvp()` for forward-mode Jacobian-vector products. The two can be composed arbitrarily with one another, and with other JAX transformations. Hereâ€™s one way to compose those to make a function that efficiently computes full Hessian matrices:

```from jax import jit, jacfwd, jacrev

def hessian(fun):
return jit(jacfwd(jacrev(fun)))
```

As with Autograd, youâ€™re free to use differentiation with Python control structures:

```def abs_val(x):
if x > 0:
return x
else:
return -x

```
```1.0
```
```print(abs_val_grad(-1.0))
```
```-1.0
```

See the reference docs on automatic differentiation and the JAX Autodiff Cookbook for more.

Compilation with `jit`Â¶

You can use XLA to compile your functions end-to-end with `jax.jit()` used either as an `@jit` decorator or as a higher-order function.

```import jax.numpy as jnp
from jax import jit

def slow_f(x):
# Element-wise ops see a large benefit from fusion
return x * x + x * 2.0

x = jnp.ones((5000, 5000))
```
```135 ms Â± 1.19 ms per loop (mean Â± std. dev. of 7 runs, 1 loop each)
```
```fast_f = jit(slow_f)

# Results are the same
assert jnp.allclose(slow_f(x), fast_f(x))

```
```45.2 ms Â± 415 Âµs per loop (mean Â± std. dev. of 7 runs, 10 loops each)
```

You can mix `jit()` and `grad()` and any other JAX transformation however you like.

Using `jit()` puts constraints on the kind of Python control flow the function can use; see ðŸ”ª JAX - The Sharp Bits ðŸ”ª for more.

Auto-vectorization with `vmap`Â¶

`jax.vmap()` is the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a functionâ€™s primitive operations for better performance.

Using `vmap()` can save you from having to carry around batch dimensions in your code. For example, consider this simple unbatched neural network prediction function:

```def predict(params, input_vec):
assert input_vec.ndim == 1
activations = input_vec
for W, b in params:
outputs = jnp.dot(W, activations) + b  # `input_vec` on the right-hand side!
activations = jnp.tanh(outputs)
return outputs
```

We often instead write `jnp.dot(inputs, W)` to allow for a batch dimension on the left side of `inputs`, but weâ€™ve written this particular prediction function to apply only to single input vectors. If we wanted to apply this function to a batch of inputs at once, semantically we could just write

```# Create some sample inputs & parameters
import numpy as np
k, N = 10, 5
input_batch = np.random.rand(k, N)
params = [
(np.random.rand(N, N), np.random.rand(N)),
(np.random.rand(N, N), np.random.rand(N)),
]
```
```from functools import partial
predictions = jnp.stack(list(map(partial(predict, params), input_batch)))
```

But pushing one example through the network at a time would be slow! Itâ€™s better to vectorize the computation, so that at every layer weâ€™re doing matrix-matrix multiplication rather than matrix-vector multiplication.

The `vmap()` function does that transformation for us. That is, if we write:

```from jax import vmap
predictions = vmap(partial(predict, params))(input_batch)
# or, alternatively
predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)
```

then the `vmap()` function will push the outer loop inside the function, and our machine will end up executing matrix-matrix multiplications exactly as if weâ€™d done the batching by hand.

Itâ€™s easy enough to manually batch a simple neural network without `vmap()`, but in other cases manual vectorization can be impractical or impossible. Take the problem of efficiently computing per-example gradients: that is, for a fixed set of parameters, we want to compute the gradient of our loss function evaluated separately at each example in a batch. With `vmap()`, itâ€™s easy:

```# create a sample loss function & inputs
def loss(params, x, y0):
y = predict(params, x)
return jnp.sum((y - y0) ** 2)

inputs = np.random.rand(k, N)
targets = np.random.rand(k, N)
```
```per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets)
```

Of course, `vmap()` can be arbitrarily composed with `jit()`, `grad()`, and any other JAX transformation! We use `vmap()` with both forward- and reverse-mode automatic differentiation for fast Jacobian and Hessian matrix calculations in `jax.jacfwd()`, `jax.jacrev()`, and `jax.hessian()`.

SPMD programming with `pmap`Â¶

For parallel programming of multiple accelerators, like multiple GPUs, use `jax.pmap()`. With `pmap()` you write single-program multiple-data (SPMD) programs, including fast parallel collective communication operations. Applying `pmap()` will mean that the function you write is compiled by XLA (similarly to `jit()`), then replicated and executed in parallel across devices.

Hereâ€™s an example on an 8-core machine:

```from jax import random, pmap
import jax.numpy as jnp

# Create 8 random 5000 x 6000 matrices, one per core
keys = random.split(random.PRNGKey(0), 8)
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)

# Run a local matmul on each device in parallel (no data transfer)
result = pmap(lambda x: jnp.dot(x, x.T))(mats)  # result.shape is (8, 5000, 5000)

# Compute the mean on each device in parallel and print the result
print(pmap(jnp.mean)(result))
```
```[1.1566514 1.1805066 1.2052315 1.2045699 1.1876893 1.2037915 1.2322364
1.2015213]
```

In addition to expressing pure maps, you can use fast Parallel operators between devices:

```from functools import partial
from jax import lax

@partial(pmap, axis_name='i')
def normalize(x):
return x / lax.psum(x, 'i')

print(normalize(jnp.arange(4.)))
```
```[0.         0.16666667 0.33333334 0.5       ]
```

You can even nest `pmap` functions for more sophisticated communication patterns.

It all composes, so youâ€™re free to differentiate through parallel computations:

```from jax import grad

@pmap
def f(x):
y = jnp.sin(x)
@pmap
def g(z):
return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()

x = jnp.arange(8.0).reshape(2, 4)
print(f(x))
```
```[[  0.          6.883216    7.438035    1.1543589]
[-12.774312  -16.18599    -4.7163434  11.089487 ]]
```
```print(grad(lambda x: f(x).sum())(x))
```
```[[-37.80803  -21.112206  18.502668  43.11701 ]
[-76.852104  33.34725  112.88069   88.631256]]
```

When reverse-mode differentiating a `pmap()` function (e.g. with `grad()`), the backward pass of the computation is parallelized just like the forward pass.

See the SPMD Cookbook and the SPMD MNIST classifier from scratch example for more.