# Advanced Automatic Differentiation in JAX#

*Authors: Vlatimir Mikulik & Matteo Hessel*

Computing gradients is a critical part of modern machine learning methods. This section considers a few advanced topics in the areas of automatic differentiation as it relates to modern machine learning.

While understanding how automatic differentiation works under the hood isnâ€™t crucial for using JAX in most contexts, we encourage the reader to check out this quite accessible video to get a deeper sense of whatâ€™s going on.

The Autodiff Cookbook is a more advanced and more detailed explanation of how these ideas are implemented in the JAX backend. Itâ€™s not necessary to understand this to do most things in JAX. However, some features (like defining custom derivatives) depend on understanding this, so itâ€™s worth knowing this explanation exists if you ever need to use them.

## Higher-order derivatives#

JAXâ€™s autodiff makes it easy to compute higher-order derivatives, because the functions that compute derivatives are themselves differentiable. Thus, higher-order derivatives are as easy as stacking transformations.

We illustrate this in the single-variable case:

The derivative of \(f(x) = x^3 + 2x^2 - 3x + 1\) can be computed as:

```
import jax
f = lambda x: x**3 + 2*x**2 - 3*x + 1
dfdx = jax.grad(f)
```

The higher-order derivatives of \(f\) are:

Computing any of these in JAX is as easy as chaining the `grad`

function:

```
d2fdx = jax.grad(dfdx)
d3fdx = jax.grad(d2fdx)
d4fdx = jax.grad(d3fdx)
```

Evaluating the above in \(x=1\) would give us:

Using JAX:

```
print(dfdx(1.))
print(d2fdx(1.))
print(d3fdx(1.))
print(d4fdx(1.))
```

```
4.0
10.0
6.0
0.0
```

In the multivariable case, higher-order derivatives are more complicated. The second-order derivative of a function is represented by its Hessian matrix, defined according to

The Hessian of a real-valued function of several variables, \(f: \mathbb R^n\to\mathbb R\), can be identified with the Jacobian of its gradient. JAX provides two transformations for computing the Jacobian of a function, `jax.jacfwd`

and `jax.jacrev`

, corresponding to forward- and reverse-mode autodiff. They give the same answer, but one can be more efficient than the other in different circumstances â€“ see the video about autodiff linked above for an explanation.

```
def hessian(f):
return jax.jacfwd(jax.grad(f))
```

Letâ€™s double check this is correct on the dot-product \(f: \mathbf{x} \mapsto \mathbf{x} ^\top \mathbf{x}\).

if \(i=j\), \(\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 2\). Otherwise, \(\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 0\).

```
import jax.numpy as jnp
def f(x):
return jnp.dot(x, x)
hessian(f)(jnp.array([1., 2., 3.]))
```

```
DeviceArray([[2., 0., 0.],
[0., 2., 0.],
[0., 0., 2.]], dtype=float32)
```

Often, however, we arenâ€™t interested in computing the full Hessian itself, and doing so can be very inefficient. The Autodiff Cookbook explains some tricks, like the Hessian-vector product, that allow to use it without materialising the whole matrix.

If you plan to work with higher-order derivatives in JAX, we strongly recommend reading the Autodiff Cookbook.

## Higher order optimization#

Some meta-learning techniques, such as Model-Agnostic Meta-Learning (MAML), require differentiating through gradient updates. In other frameworks this can be quite cumbersome, but in JAX itâ€™s much easier:

```
def meta_loss_fn(params, data):
"""Computes the loss after one step of SGD."""
grads = jax.grad(loss_fn)(params, data)
return loss_fn(params - lr * grads, data)
meta_grads = jax.grad(meta_loss_fn)(params, data)
```

## Stopping gradients#

Auto-diff enables automatic computation of the gradient of a function with respect to its inputs. Sometimes, however, we might want some additional control: for instance, we might want to avoid back-propagating gradients through some subset of the computational graph.

Consider for instance the TD(0) (temporal difference) reinforcement learning update. This is used to learn to estimate the *value* of a state in an environment from experience of interacting with the environment. Letâ€™s assume the value estimate \(v_{\theta}(s_{t-1}\)) in a state \(s_{t-1}\) is parameterised by a linear function.

```
# Value function and initial parameters
value_fn = lambda theta, state: jnp.dot(theta, state)
theta = jnp.array([0.1, -0.1, 0.])
```

Consider a transition from a state \(s_{t-1}\) to a state \(s_t\) during which we observed the reward \(r_t\)

```
# An example transition.
s_tm1 = jnp.array([1., 2., -1.])
r_t = jnp.array(1.)
s_t = jnp.array([2., 1., 0.])
```

The TD(0) update to the network parameters is:

This update is not the gradient of any loss function.

However, it can be **written** as the gradient of the pseudo loss function

if the dependency of the target \(r_t + v_{\theta}(s_t)\) on the parameter \(\theta\) is ignored.

How can we implement this in JAX? If we write the pseudo loss naively we get:

```
def td_loss(theta, s_tm1, r_t, s_t):
v_tm1 = value_fn(theta, s_tm1)
target = r_t + value_fn(theta, s_t)
return (target - v_tm1) ** 2
td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)
delta_theta
```

```
DeviceArray([ 2.4, -2.4, 2.4], dtype=float32)
```

But `td_update`

will **not** compute a TD(0) update, because the gradient computation will include the dependency of `target`

on \(\theta\).

We can use `jax.lax.stop_gradient`

to force JAX to ignore the dependency of the target on \(\theta\):

```
def td_loss(theta, s_tm1, r_t, s_t):
v_tm1 = value_fn(theta, s_tm1)
target = r_t + value_fn(theta, s_t)
return (jax.lax.stop_gradient(target) - v_tm1) ** 2
td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)
delta_theta
```

```
DeviceArray([-2.4, -4.8, 2.4], dtype=float32)
```

This will treat `target`

as if it did **not** depend on the parameters \(\theta\) and compute the correct update to the parameters.

The `jax.lax.stop_gradient`

may also be useful in other settings, for instance if you want the gradient from some loss to only affect a subset of the parameters of the neural network (because, for instance, the other parameters are trained using a different loss).

## Straight-through estimator using `stop_gradient`

#

The straight-through estimator is a trick for defining a â€˜gradientâ€™ of a function that is otherwise non-differentiable. Given a non-differentiable function \(f : \mathbb{R}^n \to \mathbb{R}^n\) that is used as part of a larger function that we wish to find a gradient of, we simply pretend during the backward pass that \(f\) is the identity function. This can be implemented neatly using `jax.lax.stop_gradient`

:

```
def f(x):
return jnp.round(x) # non-differentiable
def straight_through_f(x):
# Create an exactly-zero expression with Sterbenz lemma that has
# an exactly-one gradient.
zero = x - jax.lax.stop_gradient(x)
return zero + jax.lax.stop_gradient(f(x))
print("f(x): ", f(3.2))
print("straight_through_f(x):", straight_through_f(3.2))
print("grad(f)(x):", jax.grad(f)(3.2))
print("grad(straight_through_f)(x):", jax.grad(straight_through_f)(3.2))
```

```
f(x): 3.0
straight_through_f(x): 3.0
grad(f)(x): 0.0
grad(straight_through_f)(x): 1.0
```

## Per-example gradients#

While most ML systems compute gradients and updates from batches of data, for reasons of computational efficiency and/or variance reduction, it is sometimes necessary to have access to the gradient/update associated with each specific sample in the batch.

For instance, this is needed to prioritise data based on gradient magnitude, or to apply clipping / normalisations on a sample by sample basis.

In many frameworks (PyTorch, TF, Theano) it is often not trivial to compute per-example gradients, because the library directly accumulates the gradient over the batch. Naive workarounds, such as computing a separate loss per example and then aggregating the resulting gradients are typically very inefficient.

In JAX we can define the code to compute the gradient per-sample in an easy but efficient way.

Just combine the `jit`

, `vmap`

and `grad`

transformations together:

```
perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0)))
# Test it:
batched_s_tm1 = jnp.stack([s_tm1, s_tm1])
batched_r_t = jnp.stack([r_t, r_t])
batched_s_t = jnp.stack([s_t, s_t])
perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
```

```
DeviceArray([[-2.4, -4.8, 2.4],
[-2.4, -4.8, 2.4]], dtype=float32)
```

Letâ€™s walk through this one transformation at a time.

First, we apply `jax.grad`

to `td_loss`

to obtain a function that computes the gradient of the loss w.r.t. the parameters on single (unbatched) inputs:

```
dtdloss_dtheta = jax.grad(td_loss)
dtdloss_dtheta(theta, s_tm1, r_t, s_t)
```

```
DeviceArray([-2.4, -4.8, 2.4], dtype=float32)
```

This function computes one row of the array above.

Then, we vectorise this function using `jax.vmap`

. This adds a batch dimension to all inputs and outputs. Now, given a batch of inputs, we produce a batch of outputs â€“ each output in the batch corresponds to the gradient for the corresponding member of the input batch.

```
almost_perex_grads = jax.vmap(dtdloss_dtheta)
batched_theta = jnp.stack([theta, theta])
almost_perex_grads(batched_theta, batched_s_tm1, batched_r_t, batched_s_t)
```

```
DeviceArray([[-2.4, -4.8, 2.4],
[-2.4, -4.8, 2.4]], dtype=float32)
```

This isnâ€™t quite what we want, because we have to manually feed this function a batch of `theta`

s, whereas we actually want to use a single `theta`

. We fix this by adding `in_axes`

to the `jax.vmap`

, specifying theta as `None`

, and the other args as `0`

. This makes the resulting function add an extra axis only to the other arguments, leaving `theta`

unbatched, as we want:

```
inefficient_perex_grads = jax.vmap(dtdloss_dtheta, in_axes=(None, 0, 0, 0))
inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
```

```
DeviceArray([[-2.4, -4.8, 2.4],
[-2.4, -4.8, 2.4]], dtype=float32)
```

Almost there! This does what we want, but is slower than it has to be. Now, we wrap the whole thing in a `jax.jit`

to get the compiled, efficient version of the same function:

```
perex_grads = jax.jit(inefficient_perex_grads)
perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
```

```
DeviceArray([[-2.4, -4.8, 2.4],
[-2.4, -4.8, 2.4]], dtype=float32)
```

```
%timeit inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
%timeit perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
```

```
100 loops, best of 5: 7.74 ms per loop
10000 loops, best of 5: 86.2 Âµs per loop
```