# TransformationsÂ¶

At its core, JAX is an extensible system for transforming numerical functions.
This section will discuss four that are of primary interest: `grad()`

, `jit()`

, `vmap()`

, and `pmap()`

.

## Automatic differentiation with `grad`

Â¶

JAX has roughly the same API as Autograd.
The most popular function is `jax.grad()`

for reverse-mode gradients:

```
from jax import grad
import jax.numpy as jnp
def tanh(x): # Define a function
y = jnp.exp(-2.0 * x)
return (1.0 - y) / (1.0 + y)
grad_tanh = grad(tanh) # Obtain its gradient function
print(grad_tanh(1.0)) # Evaluate it at x = 1.0
```

```
0.4199743
```

You can differentiate to any order with `grad()`

.

```
print(grad(grad(grad(tanh)))(1.0))
```

```
0.6216266
```

For more advanced autodiff, you can use `jax.vjp()`

for
reverse-mode vector-Jacobian products and `jax.jvp()`

for
forward-mode Jacobian-vector products. The two can be composed arbitrarily with
one another, and with other JAX transformations. Hereâ€™s one way to compose those
to make a function that efficiently computes full Hessian
matrices:

```
from jax import jit, jacfwd, jacrev
def hessian(fun):
return jit(jacfwd(jacrev(fun)))
```

As with Autograd, youâ€™re free to use differentiation with Python control structures:

```
def abs_val(x):
if x > 0:
return x
else:
return -x
abs_val_grad = grad(abs_val)
print(abs_val_grad(1.0))
```

```
1.0
```

```
print(abs_val_grad(-1.0))
```

```
-1.0
```

See the reference docs on automatic differentiation and the JAX Autodiff Cookbook for more.

## Compilation with `jit`

Â¶

You can use XLA to compile your functions end-to-end with `jax.jit()`

used either as an `@jit`

decorator or as a higher-order function.

```
import jax.numpy as jnp
from jax import jit
def slow_f(x):
# Element-wise ops see a large benefit from fusion
return x * x + x * 2.0
x = jnp.ones((5000, 5000))
%timeit slow_f(x).block_until_ready()
```

```
138 ms Â± 544 Âµs per loop (mean Â± std. dev. of 7 runs, 10 loops each)
```

```
fast_f = jit(slow_f)
# Results are the same
assert jnp.allclose(slow_f(x), fast_f(x))
%timeit fast_f(x).block_until_ready()
```

```
43.4 ms Â± 280 Âµs per loop (mean Â± std. dev. of 7 runs, 10 loops each)
```

You can mix `jit()`

and `grad()`

and any other JAX transformation however you like.

Using `jit()`

puts constraints on the kind of Python control flow the function can use; see
ðŸ”ª JAX - The Sharp Bits ðŸ”ª for more.

## Auto-vectorization with `vmap`

Â¶

`jax.vmap()`

is the vectorizing map.
It has the familiar semantics of mapping a function along array axes, but
instead of keeping the loop on the outside, it pushes the loop down into a
functionâ€™s primitive operations for better performance.

Using `vmap()`

can save you from having to carry around batch dimensions in your
code. For example, consider this simple *unbatched* neural network prediction
function:

```
def predict(params, input_vec):
assert input_vec.ndim == 1
activations = input_vec
for W, b in params:
outputs = jnp.dot(W, activations) + b # `input_vec` on the right-hand side!
activations = jnp.tanh(outputs)
return outputs
```

We often instead write `jnp.dot(inputs, W)`

to allow for a batch dimension on the
left side of `inputs`

, but weâ€™ve written this particular prediction function to
apply only to single input vectors. If we wanted to apply this function to a
batch of inputs at once, semantically we could just write

```
# Create some sample inputs & parameters
import numpy as np
k, N = 10, 5
input_batch = np.random.rand(k, N)
params = [
(np.random.rand(N, N), np.random.rand(N)),
(np.random.rand(N, N), np.random.rand(N)),
]
```

```
from functools import partial
predictions = jnp.stack(list(map(partial(predict, params), input_batch)))
```

But pushing one example through the network at a time would be slow! Itâ€™s better to vectorize the computation, so that at every layer weâ€™re doing matrix-matrix multiplication rather than matrix-vector multiplication.

The `vmap()`

function does that transformation for us. That is, if we write:

```
from jax import vmap
predictions = vmap(partial(predict, params))(input_batch)
# or, alternatively
predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)
```

then the `vmap()`

function will push the outer loop inside the function, and our
machine will end up executing matrix-matrix multiplications exactly as if weâ€™d
done the batching by hand.

Itâ€™s easy enough to manually batch a simple neural network without `vmap()`

, but
in other cases manual vectorization can be impractical or impossible. Take the
problem of efficiently computing per-example gradients: that is, for a fixed set
of parameters, we want to compute the gradient of our loss function evaluated
separately at each example in a batch. With `vmap()`

, itâ€™s easy:

```
# create a sample loss function & inputs
def loss(params, x, y0):
y = predict(params, x)
return jnp.sum((y - y0) ** 2)
inputs = np.random.rand(k, N)
targets = np.random.rand(k, N)
```

```
per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets)
```

Of course, `vmap()`

can be arbitrarily composed with `jit()`

, `grad()`

,
and any other JAX transformation! We use `vmap()`

with both forward- and reverse-mode automatic
differentiation for fast Jacobian and Hessian matrix calculations in
`jax.jacfwd()`

, `jax.jacrev()`

, and `jax.hessian()`

.

## SPMD programming with `pmap`

Â¶

For parallel programming of multiple accelerators, like multiple GPUs, use
`jax.pmap()`

.
With `pmap()`

you write single-program multiple-data (SPMD) programs, including
fast parallel collective communication operations. Applying `pmap()`

will mean
that the function you write is compiled by XLA (similarly to `jit()`

), then
replicated and executed in parallel across devices.

Hereâ€™s an example on an 8-core machine:

```
from jax import random, pmap
import jax.numpy as jnp
# Create 8 random 5000 x 6000 matrices, one per core
keys = random.split(random.PRNGKey(0), 8)
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)
# Run a local matmul on each device in parallel (no data transfer)
result = pmap(lambda x: jnp.dot(x, x.T))(mats) # result.shape is (8, 5000, 5000)
# Compute the mean on each device in parallel and print the result
print(pmap(jnp.mean)(result))
```

```
[1.1566514 1.1805066 1.2052315 1.2045699 1.1876893 1.2037915 1.2322364
1.2015213]
```

In addition to expressing pure maps, you can use fast Parallel operators between devices:

```
from functools import partial
from jax import lax
@partial(pmap, axis_name='i')
def normalize(x):
return x / lax.psum(x, 'i')
print(normalize(jnp.arange(4.)))
```

```
[0. 0.16666667 0.33333334 0.5 ]
```

You can even nest `pmap`

functions for more sophisticated communication patterns.

It all composes, so youâ€™re free to differentiate through parallel computations:

```
from jax import grad
@pmap
def f(x):
y = jnp.sin(x)
@pmap
def g(z):
return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
return grad(lambda w: jnp.sum(g(w)))(x)
x = jnp.arange(8.0).reshape(2, 4)
print(f(x))
```

```
[[ 0. 6.883216 7.438035 1.1543589]
[-12.774312 -16.18599 -4.7163434 11.089487 ]]
```

```
print(grad(lambda x: f(x).sum())(x))
```

```
[[-37.80803 -21.112206 18.502668 43.11701 ]
[-76.852104 33.34725 112.88069 88.631256]]
```

When reverse-mode differentiating a `pmap()`

function (e.g. with `grad()`

), the
backward pass of the computation is parallelized just like the forward pass.

See the SPMD Cookbook and the SPMD MNIST classifier from scratch example for more.