Advanced Automatic Differentiation in JAX

Open in Colab

Authors: Vlatimir Mikulik & Matteo Hessel

Computing gradients is a critical part of modern machine learning methods. This section considers a few advanced topics in the areas of automatic differentiation as it relates to modern machine learning.

While understanding how automatic differentiation works under the hood isn’t crucial for using JAX in most contexts, we encourage the reader to check out this quite accessible video to get a deeper sense of what’s going on.

The Autodiff Cookbook is a more advanced and more detailed explanation of how these ideas are implemented in the JAX backend. It’s not necessary to understand this to do most things in JAX. However, some features (like defining custom derivatives) depend on understanding this, so it’s worth knowing this explanation exists if you ever need to use them.

Higher-order derivatives

JAX’s autodiff makes it easy to compute higher-order derivatives, because the functions that compute derivatives are themselves differentiable. Thus, higher-order derivatives are as easy as stacking transformations.

We illustrate this in the single-variable case:

The derivative of \(f(x) = x^3 + 2x^2 - 3x + 1\) can be computed as:

import jax

f = lambda x: x**3 + 2*x**2 - 3*x + 1

dfdx = jax.grad(f)

The higher-order derivatives of \(f\) are:

\[\begin{split} \begin{array}{l} f'(x) = 3x^2 + 4x -3\\ f''(x) = 6x + 4\\ f'''(x) = 6\\ f^{iv}(x) = 0 \end{array} \end{split}\]

Computing any of these in JAX is as easy as chaining the grad function:

d2fdx = jax.grad(dfdx)
d3fdx = jax.grad(d2fdx)
d4fdx = jax.grad(d3fdx)

Evaluating the above in \(x=1\) would give us:

\[\begin{split} \begin{array}{l} f'(1) = 4\\ f''(1) = 10\\ f'''(1) = 6\\ f^{iv}(1) = 0 \end{array} \end{split}\]

Using JAX:


In the multivariable case, higher-order derivatives are more complicated. The second-order derivative of a function is represented by its Hessian matrix, defined according to

\[(\mathbf{H}f)_{i,j} = \frac{\partial^2 f}{\partial_i\partial_j}.\]

The Hessian of a real-valued function of several variables, \(f: \mathbb R^n\to\mathbb R\), can be identified with the Jacobian of its gradient. JAX provides two transformations for computing the Jacobian of a function, jax.jacfwd and jax.jacrev, corresponding to forward- and reverse-mode autodiff. They give the same answer, but one can be more efficient than the other in different circumstances – see the video about autodiff linked above for an explanation.

def hessian(f):
  return jax.jacfwd(jax.grad(f))

Let’s double check this is correct on the dot-product \(f: \mathbf{x} \mapsto \mathbf{x} ^\top \mathbf{x}\).

if \(i=j\), \(\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 2\). Otherwise, \(\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 0\).

import jax.numpy as jnp

def f(x):
  return, x)

hessian(f)(jnp.array([1., 2., 3.]))
DeviceArray([[2., 0., 0.],
             [0., 2., 0.],
             [0., 0., 2.]], dtype=float32)

Often, however, we aren’t interested in computing the full Hessian itself, and doing so can be very inefficient. The Autodiff Cookbook explains some tricks, like the Hessian-vector product, that allow to use it without materialising the whole matrix.

If you plan to work with higher-order derivatives in JAX, we strongly recommend reading the Autodiff Cookbook.

Higher order optimization

Some meta-learning techniques, such as Model-Agnostic Meta-Learning (MAML), require differentiating through gradient updates. In other frameworks this can be quite cumbersome, but in JAX it’s much easier:

def meta_loss_fn(params, data):
  """Computes the loss after one step of SGD."""
  grads = jax.grad(loss_fn)(params, data)
  return loss_fn(params - lr * grads, data)

meta_grads = jax.grad(meta_loss_fn)(params, data)

Stopping gradients

Auto-diff enables automatic computation of the gradient of a function with respect to its inputs. Sometimes, however, we might want some additional control: for instance, we might want to avoid back-propagating gradients through some subset of the computational graph.

Consider for instance the TD(0) (temporal difference) reinforcement learning update. This is used to learn to estimate the value of a state in an environment from experience of interacting with the environment. Let’s assume the value estimate \(v_{\theta}(s_{t-1}\)) in a state \(s_{t-1}\) is parameterised by a linear function.

# Value function and initial parameters
value_fn = lambda theta, state:, state)
theta = jnp.array([0.1, -0.1, 0.])

Consider a transition from a state \(s_{t-1}\) to a state \(s_t\) during which we observed the reward \(r_t\)

# An example transition.
s_tm1 = jnp.array([1., 2., -1.])
r_t = jnp.array(1.)
s_t = jnp.array([2., 1., 0.])

The TD(0) update to the network parameters is:

\[ \Delta \theta = (r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})) \nabla v_{\theta}(s_{t-1}) \]

This update is not the gradient of any loss function.

However it can be written as the gradient of the pseudo loss function

\[ L(\theta) = [r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})]^2 \]

if the dependency of the target \(r_t + v_{\theta}(s_t)\) on the parameter \(\theta\) is ignored.

How can we implement this in JAX? If we write the pseudo loss naively we get:

def td_loss(theta, s_tm1, r_t, s_t):
  v_tm1 = value_fn(theta, s_tm1)
  target = r_t + value_fn(theta, s_t)
  return (target - v_tm1) ** 2

td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)

DeviceArray([ 2.4, -2.4,  2.4], dtype=float32)

But td_update will not compute a TD(0) update, because the gradient computation will include the dependency of target on \(\theta\).

We can use jax.lax.stop_gradient to force JAX to ignore the dependency of the target on \(\theta\):

def td_loss(theta, s_tm1, r_t, s_t):
  v_tm1 = value_fn(theta, s_tm1)
  target = r_t + value_fn(theta, s_t)
  return (jax.lax.stop_gradient(target) - v_tm1) ** 2

td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)

DeviceArray([-2.4, -4.8,  2.4], dtype=float32)

This will treat target as if it did not depend on the parameters \(\theta\) and compute the correct update to the parameters.

The jax.lax.stop_gradient may also be useful in other settings, for instance if you want the gradient from some loss to only affect a subset of the parameters of the neural network (because, for instance, the other parameters are trained using a different loss).

Straight-through estimator using stop_gradient

The straight-through estimator is a trick for defining a ‘gradient’ of a function that is otherwise non-differentiable. Given a non-differentiable function \(f : \mathbb{R}^n \to \mathbb{R}^n\) that is used as part of a larger function that we wish to find a gradient of, we simply pretend during the backward pass that \(f\) is the identity function. This can be implemented neatly using jax.lax.stop_gradient:

def f(x):
  return jnp.round(x)  # non-differentiable

def straight_through_f(x):
  return x + jax.lax.stop_gradient(f(x) - x)

print("f(x): ", f(3.2))
print("straight_through_f(x):", straight_through_f(3.2))

print("grad(f)(x):", jax.grad(f)(3.2))
print("grad(straight_through_f)(x):", jax.grad(straight_through_f)(3.2))
f(x):  3.0
straight_through_f(x): 3.0
grad(f)(x): 0.0
grad(straight_through_f)(x): 1.0

Per-example gradients

While most ML systems compute gradients and updates from batches of data, for reasons of computational efficiency and/or variance reduction, it is sometimes necessary to have access to the gradient/update associated with each specific sample in the batch.

For instance, this is needed to prioritise data based on gradient magnitude, or to apply clipping / normalisations on a sample by sample basis.

In many frameworks (PyTorch, TF, Theano) it is often not trivial to compute per-example gradients, because the library directly accumulates the gradient over the batch. Naive workarounds, such as computing a separate loss per example and then aggregating the resulting gradients are typically very inefficient.

In JAX we can define the code to compute the gradient per-sample in an easy but efficient way.

Just combine the jit, vmap and grad transformations together:

perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0)))

# Test it:
batched_s_tm1 = jnp.stack([s_tm1, s_tm1])
batched_r_t = jnp.stack([r_t, r_t])
batched_s_t = jnp.stack([s_t, s_t])

perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
DeviceArray([[-2.4, -4.8,  2.4],
             [-2.4, -4.8,  2.4]], dtype=float32)

Let’s walk through this one transformation at a time.

First, we apply jax.grad to td_loss to obtain a function that computes the gradient of the loss w.r.t. the parameters on single (unbatched) inputs:

dtdloss_dtheta = jax.grad(td_loss)

dtdloss_dtheta(theta, s_tm1, r_t, s_t)
DeviceArray([-2.4, -4.8,  2.4], dtype=float32)

This function computes one row of the array above.

Then, we vectorise this function using jax.vmap. This adds a batch dimension to all inputs and outputs. Now, given a batch of inputs, we produce a batch of outputs – each output in the batch corresponds to the gradient for the corresponding member of the input batch.

almost_perex_grads = jax.vmap(dtdloss_dtheta)

batched_theta = jnp.stack([theta, theta])
almost_perex_grads(batched_theta, batched_s_tm1, batched_r_t, batched_s_t)
DeviceArray([[-2.4, -4.8,  2.4],
             [-2.4, -4.8,  2.4]], dtype=float32)

This isn’t quite what we want, because we have to manually feed this function a batch of thetas, whereas we actually want to use a single theta. We fix this by adding in_axes to the jax.vmap, specifying theta as None, and the other args as 0. This makes the resulting function add an extra axis only to the other arguments, leaving theta unbatched, as we want:

inefficient_perex_grads = jax.vmap(dtdloss_dtheta, in_axes=(None, 0, 0, 0))

inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
DeviceArray([[-2.4, -4.8,  2.4],
             [-2.4, -4.8,  2.4]], dtype=float32)

Almost there! This does what we want, but is slower than it has to be. Now, we wrap the whole thing in a jax.jit to get the compiled, efficient version of the same function:

perex_grads = jax.jit(inefficient_perex_grads)

perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
DeviceArray([[-2.4, -4.8,  2.4],
             [-2.4, -4.8,  2.4]], dtype=float32)
%timeit inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
%timeit perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
100 loops, best of 5: 7.74 ms per loop
10000 loops, best of 5: 86.2 µs per loop