Interactive online version:

# The Autodiff Cookbook¶

alexbw@, mattjj@

JAX has a pretty general automatic differentiation system. In this notebook, we’ll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics.

[1]:

import jax.numpy as np
from jax import grad, jit, vmap
from jax import random

key = random.PRNGKey(0)

/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/test-docs/lib/python3.7/site-packages/jax/lib/xla_bridge.py:123: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')


### Starting with grad¶

You can differentiate a function with grad:

[2]:

grad_tanh = grad(np.tanh)

0.070650816


grad takes a function and returns a function. If you have a Python function f that evaluates the mathematical function $$f$$, then grad(f) is a Python function that evaluates the mathematical function $$\nabla f$$. That means grad(f)(x) represents the value $$\nabla f(x)$$.

Since grad operates on functions, you can apply it to its own output to differentiate as many times as you like:

[3]:

print(grad(grad(np.tanh))(2.0))

-0.13621867
0.25265405


Let’s look at computing gradients with grad in a linear logistic regression model. First, the setup:

[4]:

def sigmoid(x):
return 0.5 * (np.tanh(x / 2) + 1)

# Outputs probability of a label being true.
def predict(W, b, inputs):
return sigmoid(np.dot(inputs, W) + b)

# Build a toy dataset.
inputs = np.array([[0.52, 1.12,  0.77],
[0.88, -1.08, 0.15],
[0.52, 0.06, -1.30],
[0.74, -2.49, 1.39]])
targets = np.array([True, True, False, True])

# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
preds = predict(W, b, inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -np.sum(np.log(label_probs))

# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())


Use the grad function with its argnums argument to differentiate a function with respect to positional arguments.

[5]:

# Differentiate loss with respect to the first positional argument:

# Since argnums=0 is the default, this does the same thing:

# But we can choose different values too, and drop the keyword:

# Including tuple values

W_grad [-0.16965581 -0.8774648  -1.4901345 ]


This grad API has a direct correspondence to the excellent notation in Spivak’s classic Calculus on Manifolds (1965), also used in Sussman and Wisdom’s *Structure and Interpretation of Classical Mechanics* (2015) and their *Functional Differential Geometry* (2013). Both books are open-access. See in particular the “Prologue” section of Functional Differential Geometry for a defense of this notation.

Essentially, when using the argnums argument, if f is a Python function for evaluating the mathematical function $$f$$, then the Python expression grad(f, i) evaluates to a Python function for evaluating $$\partial_i f$$.

### Differentiating with respect to nested lists, tuples, and dicts¶

Differentiating with respect to standard Python containers just works, so use tuples, lists, and dicts (and arbitrary nesting) however you like.

[6]:

def loss2(params_dict):
preds = predict(params_dict['W'], params_dict['b'], inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -np.sum(np.log(label_probs))


{'W': DeviceArray([-0.16965581, -0.8774648 , -1.4901345 ], dtype=float32), 'b': DeviceArray(-0.29227245, dtype=float32)}


You can register your own container types to work with not just grad but all the JAX transformations (jit, vmap, etc.).

### Evaluate a function and its gradient using value_and_grad¶

Another convenient function is value_and_grad for efficiently computing both a function’s value as well as its gradient’s value:

[7]:

from jax import value_and_grad
print('loss value', loss_value)
print('loss value', loss(W, b))

loss value 3.0519395
loss value 3.0519395


### Checking against numerical differences¶

A great thing about derivatives is that they’re straightforward to check with finite differences:

[8]:

# Set a step size for finite differences calculations
eps = 1e-4

# Check b_grad with scalar finite differences
b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps

# Check W_grad with finite differences in a random direction
key, subkey = random.split(key)
vec = random.normal(subkey, W.shape)
unitvec = vec / np.sqrt(np.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps

b_grad_numerical -0.29325485
W_dirderiv_numerical -0.19788742
W_dirderiv_autodiff -0.19909072


JAX provides a simple convenience function that does essentially the same thing, but checks up to any order of differentiation that you like:

[9]:

from jax.test_util import check_grads
check_grads(loss, (W, b), order=2)  # check up to 2nd order derivatives


### Hessian-vector products with grad-of-grad¶

One thing we can do with higher-order grad is build a Hessian-vector product function. (Later on we’ll write an even more efficient implementation that mixes both forward- and reverse-mode, but this one will use pure reverse-mode.)

A Hessian-vector product function can be useful in a truncated Newton Conjugate-Gradient algorithm for minimizing smooth convex functions, or for studying the curvature of neural network training objectives (e.g. 1, 2, 3, 4).

For a scalar-valued function $$f : \mathbb{R}^n \to \mathbb{R}$$, the Hessian at a point $$x \in \mathbb{R}^n$$ is written as $$\partial^2 f(x)$$. A Hessian-vector product function is then able to evaluate

$$\qquad v \mapsto \partial^2 f(x) \cdot v$$

for any $$v \in \mathbb{R}^n$$.

The trick is not to instantiate the full Hessian matrix: if $$n$$ is large, perhaps in the millions or billions in the context of neural networks, then that might be impossible to store.

Luckily, grad already gives us a way to write an efficient Hessian-vector product function. We just have to use the identity

$$\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)$$,

where $$g(x) = \partial f(x) \cdot v$$ is a new scalar-valued function that dots the gradient of $$f$$ at $$x$$ with the vector $$v$$. Notice that we’re only ever differentiating scalar-valued functions of vector-valued arguments, which is exactly where we know grad is efficient.

In JAX code, we can just write this:

[10]:

def hvp(f, x, v):


This example shows that you can freely use lexical closure, and JAX will never get perturbed or confused.

We’ll check this implementation a few cells down, once we see how to compute dense Hessian matrices. We’ll also write an even better version that uses both forward-mode and reverse-mode.

### Jacobians and Hessians using jacfwd and jacrev¶

You can compute full Jacobian matrices using the jacfwd and jacrev functions:

[11]:

from jax import jacfwd, jacrev

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

J = jacfwd(f)(W)
print("jacfwd result, with shape", J.shape)
print(J)

J = jacrev(f)(W)
print("jacrev result, with shape", J.shape)
print(J)

jacfwd result, with shape (4, 3)
[[ 0.05981752  0.12883773  0.08857594]
[ 0.04015911 -0.04928619  0.0068453 ]
[ 0.12188289  0.01406341 -0.3047072 ]
[ 0.00140426 -0.00472514  0.00263773]]
jacrev result, with shape (4, 3)
[[ 0.05981752  0.12883773  0.08857594]
[ 0.04015911 -0.04928619  0.0068453 ]
[ 0.12188289  0.01406341 -0.3047072 ]
[ 0.00140426 -0.00472514  0.00263773]]


These two functions compute the same values (up to machine numerics), but differ in their implementation: jacfwd uses forward-mode automatic differentiation, which is more efficient for “tall” Jacobian matrices, while jacrev uses reverse-mode, which is more efficient for “wide” Jacobian matrices. For matrices that are near-square, jacfwd probably has an edge over jacrev.

You can also use jacfwd and jacrev with container types:

[12]:

def predict_dict(params, inputs):
return predict(params['W'], params['b'], inputs)

J_dict = jacrev(predict_dict)({'W': W, 'b': b}, inputs)
for k, v in J_dict.items():
print("Jacobian from {} to logits is".format(k))
print(v)

Jacobian from W to logits is
[[ 0.05981752  0.12883773  0.08857594]
[ 0.04015911 -0.04928619  0.0068453 ]
[ 0.12188289  0.01406341 -0.3047072 ]
[ 0.00140426 -0.00472514  0.00263773]]
Jacobian from b to logits is
[0.11503369 0.04563536 0.23439017 0.00189765]


For more details on forward- and reverse-mode, as well as how to implement jacfwd and jacrev as efficiently as possible, read on!

Using a composition of two of these functions gives us a way to compute dense Hessian matrices:

[13]:

def hessian(f):
return jacfwd(jacrev(f))

H = hessian(f)(W)
print("hessian, with shape", H.shape)
print(H)

hessian, with shape (4, 3, 3)
[[[ 0.02285464  0.04922538  0.03384245]
[ 0.04922538  0.10602391  0.07289143]
[ 0.03384245  0.07289144  0.05011286]]

[[-0.03195212  0.03921397 -0.00544638]
[ 0.03921397 -0.04812624  0.0066842 ]
[-0.00544638  0.0066842  -0.00092836]]

[[-0.01583708 -0.00182736  0.0395927 ]
[-0.00182736 -0.00021085  0.00456839]
[ 0.0395927   0.00456839 -0.09898175]]

[[-0.0010352   0.00348331 -0.0019445 ]
[ 0.00348331 -0.01172087  0.00654297]
[-0.0019445   0.00654297 -0.0036525 ]]]


This shape makes sense: if we start with a function $$f : \mathbb{R}^n \to \mathbb{R}^m$$, then at a point $$x \in \mathbb{R}^n$$ we expect to get the shapes

• $$f(x) \in \mathbb{R}^m$$, the value of $$f$$ at $$x$$,
• $$\partial f(x) \in \mathbb{R}^{m \times n}$$, the Jacobian matrix at $$x$$,
• $$\partial^2 f(x) \in \mathbb{R}^{m \times n \times n}$$, the Hessian at $$x$$,

and so on.

To implement hessian, we could have used jacrev(jacrev(f)) or jacrev(jacfwd(f)) or any other composition of the two. But forward-over-reverse is typically the most efficient. That’s because in the inner Jacobian computation we’re often differentiating a function wide Jacobian (maybe like a loss function $$f : \mathbb{R}^n \to \mathbb{R}$$), while in the outer Jacobian computation we’re differentiating a function with a square Jacobian (since $$\nabla f : \mathbb{R}^n \to \mathbb{R}^n$$), which is where forward-mode wins out.

## How it’s made: two foundational autodiff functions¶

### Jacobian-Vector products (JVPs, aka forward-mode autodiff)¶

JAX includes efficient and general implementations of both forward- and reverse-mode automatic differentiation. The familiar grad function is built on reverse-mode, but to explain the difference in the two modes, and when each can be useful, we need a bit of math background.

#### JVPs in math¶

Mathematically, given a function $$f : \mathbb{R}^n \to \mathbb{R}^m$$, the Jacobian of $$f$$ evaluated at an input point $$x \in \mathbb{R}^n$$, denoted $$\partial f(x)$$, is often thought of as a matrix in $$\mathbb{R}^m \times \mathbb{R}^n$$:

$$\qquad \partial f(x) \in \mathbb{R}^{m \times n}$$.

But we can also think of $$\partial f(x)$$ as a linear map, which maps the tangent space of the domain of $$f$$ at the point $$x$$ (which is just another copy of $$\mathbb{R}^n$$) to the tangent space of the codomain of $$f$$ at the point $$f(x)$$ (a copy of $$\mathbb{R}^m$$):

$$\qquad \partial f(x) : \mathbb{R}^n \to \mathbb{R}^m$$.

This map is called the pushforward map of $$f$$ at $$x$$. The Jacobian matrix is just the matrix for this linear map in a standard basis.

If we don’t commit to one specific input point $$x$$, then we can think of the function $$\partial f$$ as first taking an input point and returning the Jacobian linear map at that input point:

$$\qquad \partial f : \mathbb{R}^n \to \mathbb{R}^n \to \mathbb{R}^m$$.

In particular, we can uncurry things so that given input point $$x \in \mathbb{R}^n$$ and a tangent vector $$v \in \mathbb{R}^n$$, we get back an output tangent vector in $$\mathbb{R}^m$$. We call that mapping, from $$(x, v)$$ pairs to output tangent vectors, the Jacobian-vector product, and write it as

$$\qquad (x, v) \mapsto \partial f(x) v$$

#### JVPs in JAX code¶

Back in Python code, JAX’s jvp function models this transformation. Given a Python function that evaluates $$f$$, JAX’s jvp is a way to get a Python function for evaluating $$(x, v) \mapsto (f(x), \partial f(x) v)$$.

[14]:

from jax import jvp

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

key, subkey = random.split(key)
v = random.normal(subkey, W.shape)

# Push forward the vector v along f evaluated at W
y, u = jvp(f, (W,), (v,))


In terms of Haskell-like type signatures, we could write

jvp :: (a -> b) -> a -> T a -> (b, T b)


where we use T a to denote the type of the tangent space for a. In words, jvp takes as arguments a function of type a -> b, a value of type a, and a tangent vector value of type T a. It gives back a pair consisting of a value of type b and an output tangent vector of type T b.

The jvp-transformed function is evaluated much like the original function, but paired up with each primal value of type a it pushes along tangent values of type T a. For each primitive numerical operation that the original function would have applied, the jvp-transformed function executes a “JVP rule” for that primitive that both evaluates the primitive on the primals and applies the primitive’s JVP at those primal values.

That evaluation strategy has some immediate implications about computational complexity: since we evaluate JVPs as we go, we don’t need to store anything for later, and so the memory cost is independent of the depth of the computation. In addition, the FLOP cost of the jvp-transformed function is about 3x the cost of just evaluating the function (one unit of work for evaluating the original function, for example sin(x); one unit for linearizing, like cos(x); and one unit for applying the linearized function to a vector, like cos_x * v). Put another way, for a fixed primal point $$x$$, we can evaluate $$v \mapsto \partial f(x) \cdot v$$ for about the same marginal cost as evaluating $$f$$.

That memory complexity sounds pretty compelling! So why don’t we see forward-mode very often in machine learning?

To answer that, first think about how you could use a JVP to build a full Jacobian matrix. If we apply a JVP to a one-hot tangent vector, it reveals one column of the Jacobian matrix, corresponding to the nonzero entry we fed in. So we can build a full Jacobian one column at a time, and to get each column costs about the same as one function evaluation. That will be efficient for functions with “tall” Jacobians, but inefficient for “wide” Jacobians.

If you’re doing gradient-based optimization in machine learning, you probably want to minimize a loss function from parameters in $$\mathbb{R}^n$$ to a scalar loss value in $$\mathbb{R}$$. That means the Jacobian of this function is a very wide matrix: $$\partial f(x) \in \mathbb{R}^{1 \times n}$$, which we often identify with the Gradient vector $$\nabla f(x) \in \mathbb{R}^n$$. Building that matrix one column at a time, with each call taking a similar number of FLOPs to evaluating the original function, sure seems inefficient! In particular, for training neural networks, where $$f$$ is a training loss function and $$n$$ can be in the millions or billions, this approach just won’t scale.

To do better for functions like this, we just need to use reverse-mode.

### Vector-Jacobian products (VJPs, aka reverse-mode autodiff)¶

Where forward-mode gives us back a function for evaluating Jacobian-vector products, which we can then use to build Jacobian matrices one column at a time, reverse-mode is a way to get back a function for evaluating vector-Jacobian products (equivalently Jacobian-transpose-vector products), which we can use to build Jacobian matrices one row at a time.

#### VJPs in math¶

Let’s again consider a function $$f : \mathbb{R}^n \to \mathbb{R}^m$$. Starting from our notation for JVPs, the notation for VJPs is pretty simple:

$$\qquad (x, v) \mapsto v \partial f(x)$$,

where $$v$$ is an element of the cotangent space of $$f$$ at $$x$$ (isomorphic to another copy of $$\mathbb{R}^m$$). When being rigorous, we should think of $$v$$ as a linear map $$v : \mathbb{R}^m \to \mathbb{R}$$, and when we write $$v \partial f(x)$$ we mean function composition $$v \circ \partial f(x)$$, where the types work out because $$\partial f(x) : \mathbb{R}^n \to \mathbb{R}^m$$. But in the common case we can identify $$v$$ with a vector in $$\mathbb{R}^m$$ and use the two almost interchageably, just like we might sometimes flip between “column vectors” and “row vectors” without much comment.

With that identification, we can alternatively think of the linear part of a VJP as the transpose (or adjoint conjugate) of the linear part of a JVP:

$$\qquad (x, v) \mapsto \partial f(x)^\mathsf{T} v$$.

For a given point $$x$$, we can write the signature as

$$\qquad \partial f(x)^\mathsf{T} : \mathbb{R}^m \to \mathbb{R}^n$$.

The corresponding map on cotangent spaces is often called the pullback of $$f$$ at $$x$$. The key for our purposes is that it goes from something that looks like the output of $$f$$ to something that looks like the input of $$f$$, just like we might expect from a transposed linear function.

#### VJPs in JAX code¶

Switching from math back to Python, the JAX function vjp can take a Python function for evaluating $$f$$ and give us back a Python function for evaluating the VJP $$(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$$.

[15]:

from jax import vjp

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

y, vjp_fun = vjp(f, W)

key, subkey = random.split(key)
u = random.normal(subkey, y.shape)

# Pull back the covector u along f evaluated at W
v = vjp_fun(u)


In terms of Haskell-like type signatures, we could write

vjp :: (a -> b) -> a -> (b, CT b -> CT a)


where we use CT a to denote the type for the cotangent space for a. In words, vjp takes as arguments a function of type a -> b and a point of type a, and gives back a pair consisting of a value of type b and a linear map of type CT b -> CT a.

This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating $$(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$$ is only about three times the cost of evaluating $$f$$. In particular, if we want the gradient of a function $$f : \mathbb{R}^n \to \mathbb{R}$$, we can do it in just one call. That’s how grad is efficient for gradient-based optimization, even for objectives like neural network training loss functions on millions or billions of parameters.

There’s a cost, though: though the FLOPs are friendly, memory scales with the depth of the computation. Also, the implementation is traditionally more complex than that of forward-mode, though JAX has some tricks up its sleeve (that’s a story for a future notebook!).

For more on how reverse-mode works, see this tutorial video from the Deep Learning Summer School in 2017.

### Hessian-vector products using both forward- and reverse-mode¶

In a previous section, we implemented a Hessian-vector product function just using reverse-mode:

[16]:

def hvp(f, x, v):


That’s efficient, but we can do even better and save some memory by using forward-mode together with reverse-mode.

Mathematically, given a function $$f : \mathbb{R}^n \to \mathbb{R}$$ to differentiate, a point $$x \in \mathbb{R}^n$$ at which to linearize the function, and a vector $$v \in \mathbb{R}^n$$, the Hessian-vector product function we want is

$$(x, v) \mapsto \partial^2 f(x) v$$

Consider the helper function $$g : \mathbb{R}^n \to \mathbb{R}^n$$ defined to be the derivative (or gradient) of $$f$$, namely $$g(x) = \partial f(x)$$. All we need is its JVP, since that will give us

$$(x, v) \mapsto \partial g(x) v = \partial^2 f(x) v$$.

We can translate that almost directly into code:

[17]:

from jax import jvp, grad

# forward-over-reverse
def hvp(f, primals, tangents):


Even better, since we didn’t have to call np.dot directly, this hvp function works with arrays of any shape and with arbitrary container types (like vectors stored as nested lists/dicts/tuples), and doesn’t even have a dependence on jax.numpy.

Here’s an example of how to use it:

[18]:

def f(X):
return np.sum(np.tanh(X)**2)

key, subkey1, subkey2 = random.split(key, 3)
X = random.normal(subkey1, (30, 40))
V = random.normal(subkey2, (30, 40))

ans1 = hvp(f, (X,), (V,))
ans2 = np.tensordot(hessian(f)(X), V, 2)

print(np.allclose(ans1, ans2, 1e-4, 1e-4))

True


Another way you might consider writing this is using reverse-over-forward:

[19]:

# reverse-over-forward
def hvp_revfwd(f, primals, tangents):
g = lambda primals: jvp(f, primals, tangents)[1]


That’s not quite as good, though, because forward-mode has less overhead than reverse-mode, and since the outer differentiation operator here has to differentiate a larger computation than the inner one, keeping forward-mode on the outside works best:

[20]:

# reverse-over-reverse, only works for single arguments
def hvp_revrev(f, primals, tangents):
x, = primals
v, = tangents

print("Forward over reverse")
%timeit -n10 -r3 hvp(f, (X,), (V,))
print("Reverse over forward")
%timeit -n10 -r3 hvp_revfwd(f, (X,), (V,))
print("Reverse over reverse")
%timeit -n10 -r3 hvp_revrev(f, (X,), (V,))

print("Naive full Hessian materialization")
%timeit -n10 -r3 np.tensordot(hessian(f)(X), V, 2)

Forward over reverse
9.82 ms ± 232 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over forward
13.6 ms ± 2.47 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over reverse
16.6 ms ± 2.02 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Naive full Hessian materialization
27.3 ms ± 1.8 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)


## Composing VJPs, JVPs, and vmap¶

### Jacobian-Matrix and Matrix-Jacobian products¶

Now that we have jvp and vjp transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX’s vmap transformation <https://github.com/google/jax#auto-vectorization-with-vmap>__ to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products.

[21]:

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

# Pull back the covectors m_i along f, evaluated at W, for all i.
# First, use a list comprehension to loop over rows in the matrix M.
def loop_mjp(f, x, M):
y, vjp_fun = vjp(f, x)
return np.vstack([vjp_fun(mi) for mi in M])

# Now, use vmap to build a computation that does a single fast matrix-matrix
# multiply, rather than an outer loop over vector-matrix multiplies.
def vmap_mjp(f, x, M):
y, vjp_fun = vjp(f, x)
outs, = vmap(vjp_fun)(M)
return outs

key = random.PRNGKey(0)
num_covecs = 128
U = random.normal(key, (num_covecs,) + y.shape)

loop_vs = loop_mjp(f, W, M=U)
print('Non-vmapped Matrix-Jacobian product')
%timeit -n10 -r3 loop_mjp(f, W, M=U)

print('\nVmapped Matrix-Jacobian product')
vmap_vs = vmap_mjp(f, W, M=U)
%timeit -n10 -r3 vmap_mjp(f, W, M=U)

assert np.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical'

Non-vmapped Matrix-Jacobian product
133 ms ± 1.74 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)

Vmapped Matrix-Jacobian product
5.53 ms ± 202 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)

[22]:

def loop_jmp(f, x, M):
# jvp immediately returns the primal and tangent values as a tuple,
# so we'll compute and select the tangents in a list comprehension
return np.vstack([jvp(f, (W,), (mi,))[1] for mi in M])

def vmap_jmp(f, x, M):
_jvp = lambda s: jvp(f, (W,), (s,))[1]
return vmap(_jvp)(M)

num_vecs = 128
S = random.normal(key, (num_vecs,) + W.shape)

loop_vs = loop_jmp(f, W, M=S)
print('Non-vmapped Jacobian-Matrix product')
%timeit -n10 -r3 loop_jmp(f, W, M=S)
vmap_vs = vmap_jmp(f, W, M=S)
print('\nVmapped Jacobian-Matrix product')
%timeit -n10 -r3 vmap_jmp(f, W, M=S)

assert np.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical'

Non-vmapped Jacobian-Matrix product
441 ms ± 8.68 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)

Vmapped Jacobian-Matrix product
4.52 ms ± 132 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)


### The implementation of jacfwd and jacrev¶

Now that we’ve seen fast Jacobian-matrix and matrix-Jacobian products, it’s not hard to guess how to write jacfwd and jacrev. We just use the same technique to push-forward or pull-back an entire standard basis (isomorphic to an identity matrix) at once.

[23]:

from jax import jacrev as builtin_jacrev

def our_jacrev(f):
def jacfun(x):
y, vjp_fun = vjp(f, x)
# Use vmap to do a matrix-Jacobian product.
# Here, the matrix is the Euclidean basis, so we get all
# entries in the Jacobian at once.
J, = vmap(vjp_fun, in_axes=0)(np.eye(len(y)))
return J
return jacfun

assert np.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!'

[24]:

from jax import jacfwd as builtin_jacfwd

def our_jacfwd(f):
def jacfun(x):
_jvp = lambda s: jvp(f, (x,), (s,))[1]
Jt =vmap(_jvp, in_axes=1)(np.eye(len(x)))
return np.transpose(Jt)
return jacfun

assert np.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!'


Interestingly, Autograd couldn’t do this. Our implementation of reverse-mode jacobian in Autograd <https://github.com/HIPS/autograd/blob/96a03f44da43cd7044c61ac945c483955deba957/autograd/differential_operators.py#L60>__ had to pull back one vector at a time with an outer-loop map. Pushing one vector at a time through the computation is much less efficient than batching it all together with vmap.

Another thing that Autograd couldn’t do is jit. Interestingly, no matter how much Python dynamism you use in your function to be differentiated, we could always use jit on the linear part of the computation. For example:

[25]:

def f(x):
try:
if x < 3:
return 2 * x ** 3
else:
raise ValueError
except ValueError:
return np.pi * x

y, f_vjp = vjp(f, 4.)
print(jit(f_vjp)(1.))

(DeviceArray(3.1415927, dtype=float32),)


## Complex numbers and differentiation¶

JAX is great at complex numbers and differentiation. To support both holomorphic and non-holomorphic differentiation, JAX follows Autograd’s convention for encoding complex derivatives.

Consider a complex-to-complex function $$f: \mathbb{C} \to \mathbb{C}$$ that we break down into its component real-to-real functions:

[26]:

def f(z):
x, y = real(z), imag(z)
return u(x, y), v(x, y) * 1j


That is, we’ve decomposed $$f(z) = u(x, y) + v(x, y) i$$ where $$z = x + y i$$. We define grad(f) to correspond to

[27]:

def grad_f(z):
x, y = real(z), imag(z)


In math symbols, that means we define $$\partial f(z) \triangleq \partial_0 u(x, y) - \partial_1 u(x, y) i$$. So we throw out $$v$$, ignoring the complex component function of $$f$$ entirely!

This convention covers three important cases:

1. If f evaluates a holomorphic function, then we get the usual complex derivative, since $$\partial_0 u = \partial_1 v$$ and $$\partial_1 u = - \partial_0 v$$.
2. If f is evaluates the real-valued loss function of a complex parameter x, then we get a result that we can use in gradient-based optimization by taking steps in the direction of the conjugate of grad(f)(x).
3. If f evaluates a real-to-real function, but its implementation uses complex primitives internally (some of which must be non-holomorphic, e.g. FFTs used in convolutions) then we get the same result that an implementation that only used real primitives would have given.

By throwing away v entirely, this convention does not handle the case where f evaluates a non-holomorphic function and you want to evaluate all of $$\partial_0 u$$, $$\partial_1 u$$, $$\partial_0 v$$, and $$\partial_1 v$$ at once. But in that case the answer would have to contain four real values, and so there’s no way to express it as a single complex number.

You should expect complex numbers to work everywhere in JAX. Here’s differentiating through a Cholesky decomposition of a complex matrix:

[28]:

A = np.array([[5.,    2.+3j,    5j],
[2.-3j,   7.,  1.+7j],
[-5j,  1.-7j,    12.]])

def f(X):
L = np.linalg.cholesky(X)
return np.sum((L - np.sin(L))**2)


/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/test-docs/lib/python3.7/site-packages/jax/lax/lax.py:1976: ComplexWarning: Casting complex values to real discards the imaginary part
lambda t, new_dtype, old_dtype: [convert_element_type(t, old_dtype)])

[28]:

DeviceArray([[-0.7534186  +0.j       , -3.0509028 -10.940544j ,
5.9896836  +3.5423026j],
[-3.0509028 +10.940544j , -8.9044895  +0.j       ,
-5.1351523  -6.559372j ],
[ 5.9896836  -3.5423026j, -5.1351523  +6.559372j ,
0.01320427 +0.j       ]], dtype=complex64)


For primitives’ JVP rules, writing the primals as $$z = a + bi$$ and the tangents as $$t = c + di$$, we define the Jacobian-vector product $$t \mapsto \partial f(z) \cdot t$$ as

$$t \mapsto \begin{matrix} \begin{bmatrix} 1 & 1 \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(a, b) & -\partial_0 v(a, b) \\ - \partial_1 u(a, b) i & \partial_1 v(a, b) i \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}$$.

See Chapter 4 of Dougal’s PhD thesis for more details.