JAX reference documentation¶
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more.
For an introduction to JAX, start at the JAX GitHub page.
JAX Quickstart¶
JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for highperformance machine learning research.
With its updated version of Autograd, JAX can automatically differentiate native Python and NumPy code. It can differentiate through a large subset of Python’s features, including loops, ifs, recursion, and closures, and it can even take derivatives of derivatives of derivatives. It supports reversemode as well as forwardmode differentiation, and the two can be composed arbitrarily to any order.
What’s new is that JAX uses XLA to compile and run your NumPy code on accelerators, like GPUs and TPUs. Compilation happens under the hood by default, with library calls getting justintime compiled and executed. But JAX even lets you justintime compile your own Python functions into XLAoptimized kernels using a onefunction API. Compilation and automatic differentiation can be composed arbitrarily, so you can express sophisticated algorithms and get maximal performance without having to leave Python.
[1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
Multiplying Matrices¶
We’ll be generating random data in the following examples. One big difference between NumPy and JAX is how you generate random numbers. For more details, see Common Gotchas in JAX.
[2]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[0.372111 0.2642311 0.18252774 0.7368198 0.44030386 0.15214427
0.6713536 0.59086424 0.73168874 0.56730247]
Let’s dive right in and multiply two big matrices.
[3]:
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready() # runs on the GPU
474 ms ± 14.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
We added that block_until_ready
because JAX uses asynchronous execution by default.
JAX NumPy functions work on regular NumPy arrays.
[4]:
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()
484 ms ± 13.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
That’s slower because it has to transfer data to the GPU every time. You can ensure that an NDArray is backed by device memory using device_put
.
[5]:
from jax import device_put
x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()
441 ms ± 9.61 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
The output of device_put
still acts like an NDArray, but it only copies values back to the CPU when they’re needed for printing, plotting, saving to disk, branching, etc. The behavior of device_put
is equivalent to the function jit(lambda x: x)
, but it’s faster.
If you have a GPU (or TPU!) these calls run on the accelerator and have the potential to be much faster than on CPU.
[6]:
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit np.dot(x, x.T)
224 ms ± 7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX is much more than just a GPUbacked NumPy. It also comes with a few program transformations that are useful when writing numerical code. For now, there’s three main ones:
jit
, for speeding up your codegrad
, for taking derivativesvmap
, for automatic vectorization or batching.
Let’s go over these, onebyone. We’ll also end up composing these in interesting ways.
Using jit
to speed up functions¶
JAX runs transparently on the GPU (or CPU, if you don’t have one, and TPU coming soon!). However, in the above example, JAX is dispatching kernels to the GPU one operation at a time. If we have a sequence of operations, we can use the @jit
decorator to compile multiple operations together using XLA. Let’s try that.
[7]:
def selu(x, alpha=1.67, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x)  alpha)
x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()
5.55 ms ± 68.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
We can speed it up with @jit
, which will jitcompile the first time selu
is called and will be cached thereafter.
[8]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()
991 µs ± 4.91 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Taking derivatives with grad
¶
In addition to evaluating numerical functions, we also want to transform them. One transformation is automatic differentiation. In JAX, just like in Autograd, you can compute gradients with the grad
function.
[9]:
def sum_logistic(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(x)))
x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))
[0.25 0.19661197 0.10499357]
Let’s verify with finite differences that our result is correct.
[10]:
def first_finite_differences(f, x):
eps = 1e3
return jnp.array([(f(x + eps * v)  f(x  eps * v)) / (2 * eps)
for v in jnp.eye(len(x))])
print(first_finite_differences(sum_logistic, x_small))
[0.24998187 0.1964569 0.10502338]
Taking derivatives is as easy as calling grad
. grad
and jit
compose and can be mixed arbitrarily. In the above example we jitted sum_logistic
and then took its derivative. We can go further:
[11]:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))
0.035325594
For more advanced autodiff, you can use jax.vjp
for reversemode vectorJacobian products and jax.jvp
for forwardmode Jacobianvector products. The two can be composed arbitrarily with one another, and with other JAX transformations. Here’s one way to compose them to make a function that efficiently computes full Hessian matrices:
[12]:
from jax import jacfwd, jacrev
def hessian(fun):
return jit(jacfwd(jacrev(fun)))
Autovectorization with vmap
¶
JAX has one more transformation in its API that you might find useful: vmap
, the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a function’s primitive operations for better performance. When composed with jit
, it can be just as fast as adding the batch dimensions by hand.
We’re going to work with a simple example, and promote matrixvector products into matrixmatrix products using vmap
. Although this is easy to do by hand in this specific case, the same technique can apply to more complicated functions.
[13]:
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))
def apply_matrix(v):
return jnp.dot(mat, v)
Given a function such as apply_matrix
, we can loop over a batch dimension in Python, but usually the performance of doing so is poor.
[14]:
def naively_batched_apply_matrix(v_batched):
return jnp.stack([apply_matrix(v) for v in v_batched])
print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
Naively batched
4.71 ms ± 90.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
We know how to batch this operation manually. In this case, jnp.dot
handles extra batch dimensions transparently.
[15]:
@jit
def batched_apply_matrix(v_batched):
return jnp.dot(v_batched, mat.T)
print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()
Manually batched
96.3 µs ± 1.19 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
However, suppose we had a more complicated function without batching support. We can use vmap
to add batching support automatically.
[16]:
@jit
def vmap_batched_apply_matrix(v_batched):
return vmap(apply_matrix)(v_batched)
print('Autovectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
Autovectorized with vmap
115 µs ± 1.75 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Of course, vmap
can be arbitrarily composed with jit
, grad
, and any other JAX transformation.
This is just a taste of what JAX can do. We’re really excited to see what you do with it!
The Autodiff Cookbook¶
alexbw@, mattjj@
JAX has a pretty general automatic differentiation system. In this notebook, we’ll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics.
[1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
key = random.PRNGKey(0)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Gradients¶
Starting with grad
¶
You can differentiate a function with grad
:
[2]:
grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))
0.070650816
grad
takes a function and returns a function. If you have a Python function f
that evaluates the mathematical function \(f\), then grad(f)
is a Python function that evaluates the mathematical function \(\nabla f\). That means grad(f)(x)
represents the value \(\nabla f(x)\).
Since grad
operates on functions, you can apply it to its own output to differentiate as many times as you like:
[3]:
print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0))
0.13621867
0.25265405
Let’s look at computing gradients with grad
in a linear logistic regression model. First, the setup:
[4]:
def sigmoid(x):
return 0.5 * (jnp.tanh(x / 2) + 1)
# Outputs probability of a label being true.
def predict(W, b, inputs):
return sigmoid(jnp.dot(inputs, W) + b)
# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12, 0.77],
[0.88, 1.08, 0.15],
[0.52, 0.06, 1.30],
[0.74, 2.49, 1.39]])
targets = jnp.array([True, True, False, True])
# Training loss is the negative loglikelihood of the training examples.
def loss(W, b):
preds = predict(W, b, inputs)
label_probs = preds * targets + (1  preds) * (1  targets)
return jnp.sum(jnp.log(label_probs))
# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())
Use the grad
function with its argnums
argument to differentiate a function with respect to positional arguments.
[5]:
# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print('W_grad', W_grad)
# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print('W_grad', W_grad)
# But we can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print('b_grad', b_grad)
# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print('W_grad', W_grad)
print('b_grad', b_grad)
W_grad [0.16965581 0.8774648 1.4901345 ]
W_grad [0.16965581 0.8774648 1.4901345 ]
b_grad 0.29227245
W_grad [0.16965581 0.8774648 1.4901345 ]
b_grad 0.29227245
This grad
API has a direct correspondence to the excellent notation in Spivak’s classic Calculus on Manifolds (1965), also used in Sussman and Wisdom’s *Structure and Interpretation of Classical Mechanics* (2015) and their *Functional Differential Geometry* (2013). Both books are openaccess. See in particular the “Prologue” section
of Functional Differential Geometry for a defense of this notation.
Essentially, when using the argnums
argument, if f
is a Python function for evaluating the mathematical function \(f\), then the Python expression grad(f, i)
evaluates to a Python function for evaluating \(\partial_i f\).
Differentiating with respect to nested lists, tuples, and dicts¶
Differentiating with respect to standard Python containers just works, so use tuples, lists, and dicts (and arbitrary nesting) however you like.
[6]:
def loss2(params_dict):
preds = predict(params_dict['W'], params_dict['b'], inputs)
label_probs = preds * targets + (1  preds) * (1  targets)
return jnp.sum(jnp.log(label_probs))
print(grad(loss2)({'W': W, 'b': b}))
{'W': DeviceArray([0.16965581, 0.8774648 , 1.4901345 ], dtype=float32), 'b': DeviceArray(0.29227245, dtype=float32)}
You can register your own container types to work with not just grad
but all the JAX transformations (jit
, vmap
, etc.).
Evaluate a function and its gradient using value_and_grad
¶
Another convenient function is value_and_grad
for efficiently computing both a function’s value as well as its gradient’s value:
[7]:
from jax import value_and_grad
loss_value, Wb_grad = value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b))
loss value 3.0519395
loss value 3.0519395
Checking against numerical differences¶
A great thing about derivatives is that they’re straightforward to check with finite differences:
[8]:
# Set a step size for finite differences calculations
eps = 1e4
# Check b_grad with scalar finite differences
b_grad_numerical = (loss(W, b + eps / 2.)  loss(W, b  eps / 2.)) / eps
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', grad(loss, 1)(W, b))
# Check W_grad with finite differences in a random direction
key, subkey = random.split(key)
vec = random.normal(subkey, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b)  loss(W  eps / 2. * unitvec, b)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(W, b), unitvec))
b_grad_numerical 0.29325485
b_grad_autodiff 0.29227245
W_dirderiv_numerical 0.19788742
W_dirderiv_autodiff 0.19909072
JAX provides a simple convenience function that does essentially the same thing, but checks up to any order of differentiation that you like:
[9]:
from jax.test_util import check_grads
check_grads(loss, (W, b), order=2) # check up to 2nd order derivatives
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Hessianvector products with grad
ofgrad
¶
One thing we can do with higherorder grad
is build a Hessianvector product function. (Later on we’ll write an even more efficient implementation that mixes both forward and reversemode, but this one will use pure reversemode.)
A Hessianvector product function can be useful in a truncated Newton ConjugateGradient algorithm for minimizing smooth convex functions, or for studying the curvature of neural network training objectives (e.g. 1, 2, 3, 4).
For a scalarvalued function \(f : \mathbb{R}^n \to \mathbb{R}\) with continuous second derivatives (so that the Hessian matrix is symmetric), the Hessian at a point \(x \in \mathbb{R}^n\) is written as \(\partial^2 f(x)\). A Hessianvector product function is then able to evaluate
\(\qquad v \mapsto \partial^2 f(x) \cdot v\)
for any \(v \in \mathbb{R}^n\).
The trick is not to instantiate the full Hessian matrix: if \(n\) is large, perhaps in the millions or billions in the context of neural networks, then that might be impossible to store.
Luckily, grad
already gives us a way to write an efficient Hessianvector product function. We just have to use the identity
\(\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)\),
where \(g(x) = \partial f(x) \cdot v\) is a new scalarvalued function that dots the gradient of \(f\) at \(x\) with the vector \(v\). Notice that we’re only ever differentiating scalarvalued functions of vectorvalued arguments, which is exactly where we know grad
is efficient.
In JAX code, we can just write this:
[10]:
def hvp(f, x, v):
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
This example shows that you can freely use lexical closure, and JAX will never get perturbed or confused.
We’ll check this implementation a few cells down, once we see how to compute dense Hessian matrices. We’ll also write an even better version that uses both forwardmode and reversemode.
Jacobians and Hessians using jacfwd
and jacrev
¶
You can compute full Jacobian matrices using the jacfwd
and jacrev
functions:
[11]:
from jax import jacfwd, jacrev
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
J = jacfwd(f)(W)
print("jacfwd result, with shape", J.shape)
print(J)
J = jacrev(f)(W)
print("jacrev result, with shape", J.shape)
print(J)
jacfwd result, with shape (4, 3)
[[ 0.05981752 0.12883773 0.08857594]
[ 0.04015911 0.04928619 0.0068453 ]
[ 0.12188289 0.01406341 0.3047072 ]
[ 0.00140426 0.00472514 0.00263773]]
jacrev result, with shape (4, 3)
[[ 0.05981752 0.12883773 0.08857594]
[ 0.04015911 0.04928619 0.0068453 ]
[ 0.12188289 0.01406341 0.3047072 ]
[ 0.00140426 0.00472514 0.00263773]]
These two functions compute the same values (up to machine numerics), but differ in their implementation: jacfwd
uses forwardmode automatic differentiation, which is more efficient for “tall” Jacobian matrices, while jacrev
uses reversemode, which is more efficient for “wide” Jacobian matrices. For matrices that are nearsquare, jacfwd
probably has an edge over jacrev
.
You can also use jacfwd
and jacrev
with container types:
[12]:
def predict_dict(params, inputs):
return predict(params['W'], params['b'], inputs)
J_dict = jacrev(predict_dict)({'W': W, 'b': b}, inputs)
for k, v in J_dict.items():
print("Jacobian from {} to logits is".format(k))
print(v)
Jacobian from W to logits is
[[ 0.05981752 0.12883773 0.08857594]
[ 0.04015911 0.04928619 0.0068453 ]
[ 0.12188289 0.01406341 0.3047072 ]
[ 0.00140426 0.00472514 0.00263773]]
Jacobian from b to logits is
[0.11503369 0.04563536 0.23439017 0.00189765]
For more details on forward and reversemode, as well as how to implement jacfwd
and jacrev
as efficiently as possible, read on!
Using a composition of two of these functions gives us a way to compute dense Hessian matrices:
[13]:
def hessian(f):
return jacfwd(jacrev(f))
H = hessian(f)(W)
print("hessian, with shape", H.shape)
print(H)
hessian, with shape (4, 3, 3)
[[[ 0.02285464 0.04922538 0.03384245]
[ 0.04922538 0.10602391 0.07289143]
[ 0.03384245 0.07289144 0.05011286]]
[[0.03195212 0.03921397 0.00544638]
[ 0.03921397 0.04812624 0.0066842 ]
[0.00544638 0.0066842 0.00092836]]
[[0.01583708 0.00182736 0.0395927 ]
[0.00182736 0.00021085 0.00456839]
[ 0.0395927 0.00456839 0.09898175]]
[[0.0010352 0.00348331 0.0019445 ]
[ 0.00348331 0.01172087 0.00654297]
[0.0019445 0.00654297 0.0036525 ]]]
This shape makes sense: if we start with a function \(f : \mathbb{R}^n \to \mathbb{R}^m\), then at a point \(x \in \mathbb{R}^n\) we expect to get the shapes
\(f(x) \in \mathbb{R}^m\), the value of \(f\) at \(x\),
\(\partial f(x) \in \mathbb{R}^{m \times n}\), the Jacobian matrix at \(x\),
\(\partial^2 f(x) \in \mathbb{R}^{m \times n \times n}\), the Hessian at \(x\),
and so on.
To implement hessian
, we could have used jacfwd(jacrev(f))
or jacrev(jacfwd(f))
or any other composition of the two. But forwardoverreverse is typically the most efficient. That’s because in the inner Jacobian computation we’re often differentiating a function wide Jacobian (maybe like a loss function \(f : \mathbb{R}^n \to \mathbb{R}\)), while in the outer Jacobian computation we’re differentiating a function with a square Jacobian (since
\(\nabla f : \mathbb{R}^n \to \mathbb{R}^n\)), which is where forwardmode wins out.
How it’s made: two foundational autodiff functions¶
JacobianVector products (JVPs, aka forwardmode autodiff)¶
JAX includes efficient and general implementations of both forward and reversemode automatic differentiation. The familiar grad
function is built on reversemode, but to explain the difference in the two modes, and when each can be useful, we need a bit of math background.
JVPs in math¶
Mathematically, given a function \(f : \mathbb{R}^n \to \mathbb{R}^m\), the Jacobian of \(f\) evaluated at an input point \(x \in \mathbb{R}^n\), denoted \(\partial f(x)\), is often thought of as a matrix in \(\mathbb{R}^m \times \mathbb{R}^n\):
\(\qquad \partial f(x) \in \mathbb{R}^{m \times n}\).
But we can also think of \(\partial f(x)\) as a linear map, which maps the tangent space of the domain of \(f\) at the point \(x\) (which is just another copy of \(\mathbb{R}^n\)) to the tangent space of the codomain of \(f\) at the point \(f(x)\) (a copy of \(\mathbb{R}^m\)):
\(\qquad \partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\).
This map is called the pushforward map of \(f\) at \(x\). The Jacobian matrix is just the matrix for this linear map in a standard basis.
If we don’t commit to one specific input point \(x\), then we can think of the function \(\partial f\) as first taking an input point and returning the Jacobian linear map at that input point:
\(\qquad \partial f : \mathbb{R}^n \to \mathbb{R}^n \to \mathbb{R}^m\).
In particular, we can uncurry things so that given input point \(x \in \mathbb{R}^n\) and a tangent vector \(v \in \mathbb{R}^n\), we get back an output tangent vector in \(\mathbb{R}^m\). We call that mapping, from \((x, v)\) pairs to output tangent vectors, the Jacobianvector product, and write it as
\(\qquad (x, v) \mapsto \partial f(x) v\)
JVPs in JAX code¶
Back in Python code, JAX’s jvp
function models this transformation. Given a Python function that evaluates \(f\), JAX’s jvp
is a way to get a Python function for evaluating \((x, v) \mapsto (f(x), \partial f(x) v)\).
[14]:
from jax import jvp
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
key, subkey = random.split(key)
v = random.normal(subkey, W.shape)
# Push forward the vector `v` along `f` evaluated at `W`
y, u = jvp(f, (W,), (v,))
In terms of Haskelllike type signatures, we could write
jvp :: (a > b) > a > T a > (b, T b)
where we use T a
to denote the type of the tangent space for a
. In words, jvp
takes as arguments a function of type a > b
, a value of type a
, and a tangent vector value of type T a
. It gives back a pair consisting of a value of type b
and an output tangent vector of type T b
.
The jvp
transformed function is evaluated much like the original function, but paired up with each primal value of type a
it pushes along tangent values of type T a
. For each primitive numerical operation that the original function would have applied, the jvp
transformed function executes a “JVP rule” for that primitive that both evaluates the primitive on the primals and applies the primitive’s JVP at those primal values.
That evaluation strategy has some immediate implications about computational complexity: since we evaluate JVPs as we go, we don’t need to store anything for later, and so the memory cost is independent of the depth of the computation. In addition, the FLOP cost of the jvp
transformed function is about 3x the cost of just evaluating the function (one unit of work for evaluating the original function, for example sin(x)
; one unit for linearizing, like cos(x)
; and one unit for applying
the linearized function to a vector, like cos_x * v
). Put another way, for a fixed primal point \(x\), we can evaluate \(v \mapsto \partial f(x) \cdot v\) for about the same marginal cost as evaluating \(f\).
That memory complexity sounds pretty compelling! So why don’t we see forwardmode very often in machine learning?
To answer that, first think about how you could use a JVP to build a full Jacobian matrix. If we apply a JVP to a onehot tangent vector, it reveals one column of the Jacobian matrix, corresponding to the nonzero entry we fed in. So we can build a full Jacobian one column at a time, and to get each column costs about the same as one function evaluation. That will be efficient for functions with “tall” Jacobians, but inefficient for “wide” Jacobians.
If you’re doing gradientbased optimization in machine learning, you probably want to minimize a loss function from parameters in \(\mathbb{R}^n\) to a scalar loss value in \(\mathbb{R}\). That means the Jacobian of this function is a very wide matrix: \(\partial f(x) \in \mathbb{R}^{1 \times n}\), which we often identify with the Gradient vector \(\nabla f(x) \in \mathbb{R}^n\). Building that matrix one column at a time, with each call taking a similar number of FLOPs to evaluating the original function, sure seems inefficient! In particular, for training neural networks, where \(f\) is a training loss function and \(n\) can be in the millions or billions, this approach just won’t scale.
To do better for functions like this, we just need to use reversemode.
VectorJacobian products (VJPs, aka reversemode autodiff)¶
Where forwardmode gives us back a function for evaluating Jacobianvector products, which we can then use to build Jacobian matrices one column at a time, reversemode is a way to get back a function for evaluating vectorJacobian products (equivalently Jacobiantransposevector products), which we can use to build Jacobian matrices one row at a time.
VJPs in math¶
Let’s again consider a function \(f : \mathbb{R}^n \to \mathbb{R}^m\). Starting from our notation for JVPs, the notation for VJPs is pretty simple:
\(\qquad (x, v) \mapsto v \partial f(x)\),
where \(v\) is an element of the cotangent space of \(f\) at \(x\) (isomorphic to another copy of \(\mathbb{R}^m\)). When being rigorous, we should think of \(v\) as a linear map \(v : \mathbb{R}^m \to \mathbb{R}\), and when we write \(v \partial f(x)\) we mean function composition \(v \circ \partial f(x)\), where the types work out because \(\partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\). But in the common case we can identify \(v\) with a vector in \(\mathbb{R}^m\) and use the two almost interchageably, just like we might sometimes flip between “column vectors” and “row vectors” without much comment.
With that identification, we can alternatively think of the linear part of a VJP as the transpose (or adjoint conjugate) of the linear part of a JVP:
\(\qquad (x, v) \mapsto \partial f(x)^\mathsf{T} v\).
For a given point \(x\), we can write the signature as
\(\qquad \partial f(x)^\mathsf{T} : \mathbb{R}^m \to \mathbb{R}^n\).
The corresponding map on cotangent spaces is often called the pullback of \(f\) at \(x\). The key for our purposes is that it goes from something that looks like the output of \(f\) to something that looks like the input of \(f\), just like we might expect from a transposed linear function.
VJPs in JAX code¶
Switching from math back to Python, the JAX function vjp
can take a Python function for evaluating \(f\) and give us back a Python function for evaluating the VJP \((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\).
[15]:
from jax import vjp
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
y, vjp_fun = vjp(f, W)
key, subkey = random.split(key)
u = random.normal(subkey, y.shape)
# Pull back the covector `u` along `f` evaluated at `W`
v = vjp_fun(u)
In terms of Haskelllike type signatures, we could write
vjp :: (a > b) > a > (b, CT b > CT a)
where we use CT a
to denote the type for the cotangent space for a
. In words, vjp
takes as arguments a function of type a > b
and a point of type a
, and gives back a pair consisting of a value of type b
and a linear map of type CT b > CT a
.
This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating \((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\) is only about three times the cost of evaluating \(f\). In particular, if we want the gradient of a function \(f : \mathbb{R}^n \to \mathbb{R}\), we can do it in just one call. That’s how grad
is efficient for gradientbased optimization, even for objectives like neural network training loss functions on millions or
billions of parameters.
There’s a cost, though: though the FLOPs are friendly, memory scales with the depth of the computation. Also, the implementation is traditionally more complex than that of forwardmode, though JAX has some tricks up its sleeve (that’s a story for a future notebook!).
For more on how reversemode works, see this tutorial video from the Deep Learning Summer School in 2017.
Hessianvector products using both forward and reversemode¶
In a previous section, we implemented a Hessianvector product function just using reversemode (assuming continuous second derivatives):
[16]:
def hvp(f, x, v):
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
That’s efficient, but we can do even better and save some memory by using forwardmode together with reversemode.
Mathematically, given a function \(f : \mathbb{R}^n \to \mathbb{R}\) to differentiate, a point \(x \in \mathbb{R}^n\) at which to linearize the function, and a vector \(v \in \mathbb{R}^n\), the Hessianvector product function we want is
\((x, v) \mapsto \partial^2 f(x) v\)
Consider the helper function \(g : \mathbb{R}^n \to \mathbb{R}^n\) defined to be the derivative (or gradient) of \(f\), namely \(g(x) = \partial f(x)\). All we need is its JVP, since that will give us
\((x, v) \mapsto \partial g(x) v = \partial^2 f(x) v\).
We can translate that almost directly into code:
[17]:
from jax import jvp, grad
# forwardoverreverse
def hvp(f, primals, tangents):
return jvp(grad(f), primals, tangents)[1]
Even better, since we didn’t have to call jnp.dot
directly, this hvp
function works with arrays of any shape and with arbitrary container types (like vectors stored as nested lists/dicts/tuples), and doesn’t even have a dependence on jax.numpy
.
Here’s an example of how to use it:
[18]:
def f(X):
return jnp.sum(jnp.tanh(X)**2)
key, subkey1, subkey2 = random.split(key, 3)
X = random.normal(subkey1, (30, 40))
V = random.normal(subkey2, (30, 40))
ans1 = hvp(f, (X,), (V,))
ans2 = jnp.tensordot(hessian(f)(X), V, 2)
print(jnp.allclose(ans1, ans2, 1e4, 1e4))
True
Another way you might consider writing this is using reverseoverforward:
[19]:
# reverseoverforward
def hvp_revfwd(f, primals, tangents):
g = lambda primals: jvp(f, primals, tangents)[1]
return grad(g)(primals)
That’s not quite as good, though, because forwardmode has less overhead than reversemode, and since the outer differentiation operator here has to differentiate a larger computation than the inner one, keeping forwardmode on the outside works best:
[20]:
# reverseoverreverse, only works for single arguments
def hvp_revrev(f, primals, tangents):
x, = primals
v, = tangents
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
print("Forward over reverse")
%timeit n10 r3 hvp(f, (X,), (V,))
print("Reverse over forward")
%timeit n10 r3 hvp_revfwd(f, (X,), (V,))
print("Reverse over reverse")
%timeit n10 r3 hvp_revrev(f, (X,), (V,))
print("Naive full Hessian materialization")
%timeit n10 r3 jnp.tensordot(hessian(f)(X), V, 2)
Forward over reverse
6.43 ms ± 338 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over forward
9.55 ms ± 2.55 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over reverse
12.2 ms ± 1.85 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Naive full Hessian materialization
53.9 ms ± 2.06 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Composing VJPs, JVPs, and vmap
¶
JacobianMatrix and MatrixJacobian products¶
Now that we have jvp
and vjp
transformations that give us functions to pushforward or pullback single vectors at a time, we can use JAX’s vmap
transformation to push and pull entire bases at once. In particular, we can use that to write fast matrixJacobian and Jacobianmatrix products.
[21]:
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
# Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`.
# First, use a list comprehension to loop over rows in the matrix M.
def loop_mjp(f, x, M):
y, vjp_fun = vjp(f, x)
return jnp.vstack([vjp_fun(mi) for mi in M])
# Now, use vmap to build a computation that does a single fast matrixmatrix
# multiply, rather than an outer loop over vectormatrix multiplies.
def vmap_mjp(f, x, M):
y, vjp_fun = vjp(f, x)
outs, = vmap(vjp_fun)(M)
return outs
key = random.PRNGKey(0)
num_covecs = 128
U = random.normal(key, (num_covecs,) + y.shape)
loop_vs = loop_mjp(f, W, M=U)
print('Nonvmapped MatrixJacobian product')
%timeit n10 r3 loop_mjp(f, W, M=U)
print('\nVmapped MatrixJacobian product')
vmap_vs = vmap_mjp(f, W, M=U)
%timeit n10 r3 vmap_mjp(f, W, M=U)
assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and nonvmapped MatrixJacobian Products should be identical'
Nonvmapped MatrixJacobian product
159 ms ± 2.83 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Vmapped MatrixJacobian product
5.95 ms ± 199 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)
[22]:
def loop_jmp(f, W, M):
# jvp immediately returns the primal and tangent values as a tuple,
# so we'll compute and select the tangents in a list comprehension
return jnp.vstack([jvp(f, (W,), (mi,))[1] for mi in M])
def vmap_jmp(f, W, M):
_jvp = lambda s: jvp(f, (W,), (s,))[1]
return vmap(_jvp)(M)
num_vecs = 128
S = random.normal(key, (num_vecs,) + W.shape)
loop_vs = loop_jmp(f, W, M=S)
print('Nonvmapped JacobianMatrix product')
%timeit n10 r3 loop_jmp(f, W, M=S)
vmap_vs = vmap_jmp(f, W, M=S)
print('\nVmapped JacobianMatrix product')
%timeit n10 r3 vmap_jmp(f, W, M=S)
assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and nonvmapped JacobianMatrix products should be identical'
Nonvmapped JacobianMatrix product
503 ms ± 12.2 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Vmapped JacobianMatrix product
5.04 ms ± 39.2 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)
The implementation of jacfwd
and jacrev
¶
Now that we’ve seen fast Jacobianmatrix and matrixJacobian products, it’s not hard to guess how to write jacfwd
and jacrev
. We just use the same technique to pushforward or pullback an entire standard basis (isomorphic to an identity matrix) at once.
[23]:
from jax import jacrev as builtin_jacrev
def our_jacrev(f):
def jacfun(x):
y, vjp_fun = vjp(f, x)
# Use vmap to do a matrixJacobian product.
# Here, the matrix is the Euclidean basis, so we get all
# entries in the Jacobian at once.
J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))
return J
return jacfun
assert jnp.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reversemode Jacobian results!'
[24]:
from jax import jacfwd as builtin_jacfwd
def our_jacfwd(f):
def jacfun(x):
_jvp = lambda s: jvp(f, (x,), (s,))[1]
Jt =vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
return jnp.transpose(Jt)
return jacfun
assert jnp.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forwardmode Jacobian results!'
Interestingly, Autograd couldn’t do this. Our implementation of reversemode jacobian
in Autograd had to pull back one vector at a time with an outerloop map
. Pushing one vector at a time through the computation is much less efficient than batching it all together with vmap
.
Another thing that Autograd couldn’t do is jit
. Interestingly, no matter how much Python dynamism you use in your function to be differentiated, we could always use jit
on the linear part of the computation. For example:
[25]:
def f(x):
try:
if x < 3:
return 2 * x ** 3
else:
raise ValueError
except ValueError:
return jnp.pi * x
y, f_vjp = vjp(f, 4.)
print(jit(f_vjp)(1.))
(DeviceArray(3.1415927, dtype=float32),)
Complex numbers and differentiation¶
JAX is great at complex numbers and differentiation. To support both holomorphic and nonholomorphic differentiation, it helps to think in terms of JVPs and VJPs.
Consider a complextocomplex function \(f: \mathbb{C} \to \mathbb{C}\) and identify it with a corresponding function \(g: \mathbb{R}^2 \to \mathbb{R}^2\),
[26]:
def f(z):
x, y = jnp.real(z), jnp.imag(z)
return u(x, y) + v(x, y) * 1j
def g(x, y):
return (u(x, y), v(x, y))
That is, we’ve decomposed \(f(z) = u(x, y) + v(x, y) i\) where \(z = x + y i\), and identified \(\mathbb{C}\) with \(\mathbb{R}^2\) to get \(g\).
Since \(g\) only involves real inputs and outputs, we already know how to write a Jacobianvector product for it, say given a tangent vector \((c, d) \in \mathbb{R}^2\), namely
\(\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}\).
To get a JVP for the original function \(f\) applied to a tangent vector \(c + di \in \mathbb{C}\), we just use the same definition and identify the result as another complex number,
\(\partial f(x + y i)(c + d i) = \begin{matrix} \begin{bmatrix} 1 & i \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}\).
That’s our definition of the JVP of a \(\mathbb{C} \to \mathbb{C}\) function! Notice it doesn’t matter whether or not \(f\) is holomorphic: the JVP is unambiguous.
Here’s a check:
[27]:
def check(seed):
key = random.PRNGKey(seed)
# random coeffs for u and v
key, subkey = random.split(key)
a, b, c, d = random.uniform(subkey, (4,))
def fun(z):
x, y = jnp.real(z), jnp.imag(z)
return u(x, y) + v(x, y) * 1j
def u(x, y):
return a * x + b * y
def v(x, y):
return c * x + d * y
# primal point
key, subkey = random.split(key)
x, y = random.uniform(subkey, (2,))
z = x + y * 1j
# tangent vector
key, subkey = random.split(key)
c, d = random.uniform(subkey, (2,))
z_dot = c + d * 1j
# check jvp
_, ans = jvp(fun, (z,), (z_dot,))
expected = (grad(u, 0)(x, y) * c +
grad(u, 1)(x, y) * d +
grad(v, 0)(x, y) * c * 1j+
grad(v, 1)(x, y) * d * 1j)
print(jnp.allclose(ans, expected))
[28]:
check(0)
check(1)
check(2)
True
True
True
What about VJPs? We do something pretty similar: for a cotangent vector \(c + di \in \mathbb{C}\) we define the VJP of \(f\) as
\((c + di)^* \; \partial f(x + y i) = \begin{matrix} \begin{bmatrix} c & d \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} 1 \\ i \end{bmatrix}\).
What’s with the negatives? They’re just to take care of complex conjugation, and the fact that we’re working with covectors.
Here’s a check of the VJP rules:
[29]:
def check(seed):
key = random.PRNGKey(seed)
# random coeffs for u and v
key, subkey = random.split(key)
a, b, c, d = random.uniform(subkey, (4,))
def fun(z):
x, y = jnp.real(z), jnp.imag(z)
return u(x, y) + v(x, y) * 1j
def u(x, y):
return a * x + b * y
def v(x, y):
return c * x + d * y
# primal point
key, subkey = random.split(key)
x, y = random.uniform(subkey, (2,))
z = x + y * 1j
# cotangent vector
key, subkey = random.split(key)
c, d = random.uniform(subkey, (2,))
z_bar = jnp.array(c + d * 1j) # for dtype control
# check vjp
_, fun_vjp = vjp(fun, z)
ans, = fun_vjp(z_bar)
expected = (grad(u, 0)(x, y) * c +
grad(v, 0)(x, y) * (d) +
grad(u, 1)(x, y) * c * (1j) +
grad(v, 1)(x, y) * (d) * (1j))
assert jnp.allclose(ans, expected, atol=1e5, rtol=1e5)
[30]:
check(0)
check(1)
check(2)
What about convenience wrappers like grad
, jacfwd
, and jacrev
?
For \(\mathbb{R} \to \mathbb{R}\) functions, recall we defined grad(f)(x)
as being vjp(f, x)[1](1.0)
, which works because applying a VJP to a 1.0
value reveals the gradient (i.e. Jacobian, or derivative). We can do the same thing for \(\mathbb{C} \to \mathbb{R}\) functions: we can still use 1.0
as the cotangent vector, and we just get out a complex number result summarizing the full Jacobian:
[31]:
def f(z):
x, y = jnp.real(z), jnp.imag(z)
return x**2 + y**2
z = 3. + 4j
grad(f)(z)
[31]:
DeviceArray(6.8.j, dtype=complex64)
For geneneral \(\mathbb{C} \to \mathbb{C}\) functions, the Jacobian has 4 realvalued degrees of freedom (as in the 2x2 Jacobian matrices above), so we can’t hope to represent all of them with in a complex number. But we can for holomorphic functions! A holomorphic function is precisely a \(\mathbb{C} \to \mathbb{C}\) function with the special property that its derivative can be represented as a single complex number. (The CauchyRiemann
equations ensure that the above 2x2 Jacobians have the special form of a scaleandrotate matrix in the complex plane, i.e. the action of a single complex number under multiplication.) And we can reveal that one complex number using a single call to vjp
with a covector of 1.0
.
Because this only works for holomorphic functions, to use this trick we need to promise JAX that our function is holomorphic; otherwise, JAX will raise an error when grad
is used for a complexoutput function:
[32]:
def f(z):
return jnp.sin(z)
z = 3. + 4j
grad(f, holomorphic=True)(z)
[32]:
DeviceArray(27.0349463.8511534j, dtype=complex64)
All the holomorphic=True
promise does is disable the error when the output is complexvalued. We can still write holomorphic=True
when the function isn’t holomorphic, but the answer we get out won’t represent the full Jacobian. Instead, it’ll be the Jacobian of the function where we just discard the imaginary part of the output:
[33]:
def f(z):
return jnp.conjugate(z)
z = 3. + 4j
grad(f, holomorphic=True)(z) # f is not actually holomorphic!
[33]:
DeviceArray(1.0.j, dtype=complex64)
There are some useful upshots for how grad
works here:
We can use
grad
on holomorphic \(\mathbb{C} \to \mathbb{C}\) functions.We can use
grad
to optimize \(f : \mathbb{C} \to \mathbb{R}\) functions, like realvalued loss functions of complex parametersx
, by taking steps in the dierction of the conjugate ofgrad(f)(x)
.If we have an \(\mathbb{R} \to \mathbb{R}\) function that just happens to use some complexvalued operations internally (some of which must be nonholomorphic, e.g. FFTs used in covolutions) then
grad
still works and we get the same result that an implementation using only real values would have given.
In any case, JVPs and VJPs are always unambiguous. And if we wanted to compute the full Jacobian matrix of a nonholomorphic \(\mathbb{C} \to \mathbb{C}\) function, we can do it with JVPs or VJPs!
You should expect complex numbers to work everywhere in JAX. Here’s differentiating through a Cholesky decomposition of a complex matrix:
[34]:
A = jnp.array([[5., 2.+3j, 5j],
[2.3j, 7., 1.+7j],
[5j, 1.7j, 12.]])
def f(X):
L = jnp.linalg.cholesky(X)
return jnp.sum((L  jnp.sin(L))**2)
grad(f, holomorphic=True)(A)
[34]:
DeviceArray([[0.7534186 +0.j , 3.0509028 10.940545j,
5.9896846 +3.542303j],
[3.0509028 +10.940545j, 8.904491 +0.j ,
5.1351523 6.559373j],
[ 5.9896846 3.542303j, 5.1351523 +6.559373j,
0.01320427 +0.j ]], dtype=complex64)
More advanced autodiff¶
In this notebook, we worked through some easy, and then progressively more complicated, applications of automatic differentiation in JAX. We hope you now feel that taking derivatives in JAX is easy and powerful.
There’s a whole world of other autodiff tricks and functionality out there. Topics we didn’t cover, but hope to in a “Advanced Autodiff Cookbook” include:
GaussNewton Vector Products, linearizing once
Custom VJPs and JVPs
Efficient derivatives at fixedpoints
Estimating the trace of a Hessian using random Hessianvector products.
Forwardmode autodiff using only reversemode autodiff.
Taking derivatives with respect to custom data types.
Checkpointing (binomial checkpointing for efficient reversemode, not model snapshotting).
Optimizing VJPs with Jacobian preaccumulation.
Autobatching logdensities example¶
This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.
Inspired by a notebook by @davmre.
[1]:
import functools
import itertools
import re
import sys
import time
from matplotlib.pyplot import *
import jax
from jax import lax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import random
import numpy as np
import scipy as sp
Generate a fake binary classification dataset¶
[2]:
np.random.seed(10009)
num_features = 10
num_points = 100
true_beta = np.random.randn(num_features).astype(jnp.float32)
all_x = np.random.randn(num_points, num_features).astype(jnp.float32)
y = (np.random.rand(num_points) < sp.special.expit(all_x.dot(true_beta))).astype(jnp.int32)
[3]:
y
[3]:
array([0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0,
1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0,
1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1,
1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0], dtype=int32)
Write the logjoint function for the model¶
We’ll write a nonbatched version, a manually batched version, and an autobatched version.
Nonbatched¶
[4]:
def log_joint(beta):
result = 0.
# Note that no `axis` parameter is provided to `jnp.sum`.
result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=1.))
result = result + jnp.sum(jnp.log(1 + jnp.exp((2*y1) * jnp.dot(all_x, beta))))
return result
[5]:
log_joint(np.random.randn(num_features))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[5]:
DeviceArray(213.23558, dtype=float32)
[6]:
# This doesn't work, because we didn't write `log_prob()` to handle batching.
try:
batch_size = 10
batched_test_beta = np.random.randn(batch_size, num_features)
log_joint(np.random.randn(batch_size, num_features))
except ValueError as e:
print("Caught expected exception " + str(e))
Caught expected exception Incompatible shapes for broadcasting: ((100, 10), (1, 100))
Manually batched¶
[7]:
def batched_log_joint(beta):
result = 0.
# Here (and below) `sum` needs an `axis` parameter. At best, forgetting to set axis
# or setting it incorrectly yields an error; at worst, it silently changes the
# semantics of the model.
result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=1.),
axis=1)
# Note the multiple transposes. Getting this right is not rocket science,
# but it's also not totally mindless. (I didn't get it right on the first
# try.)
result = result + jnp.sum(jnp.log(1 + jnp.exp((2*y1) * jnp.dot(all_x, beta.T).T)),
axis=1)
return result
[8]:
batch_size = 10
batched_test_beta = np.random.randn(batch_size, num_features)
batched_log_joint(batched_test_beta)
[8]:
DeviceArray([147.84032, 207.02205, 109.26076, 243.80833, 163.02908,
143.84848, 160.28772, 113.7717 , 126.60544, 190.81989], dtype=float32)
Autobatched with vmap¶
It just works.
[9]:
vmap_batched_log_joint = jax.vmap(log_joint)
vmap_batched_log_joint(batched_test_beta)
[9]:
DeviceArray([147.84032, 207.02205, 109.26076, 243.80833, 163.02908,
143.84848, 160.28772, 113.7717 , 126.60544, 190.81989], dtype=float32)
Selfcontained variational inference example¶
A little code is copied from above.
Set up the (batched) logjoint function¶
[10]:
@jax.jit
def log_joint(beta):
result = 0.
# Note that no `axis` parameter is provided to `jnp.sum`.
result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=10.))
result = result + jnp.sum(jnp.log(1 + jnp.exp((2*y1) * jnp.dot(all_x, beta))))
return result
batched_log_joint = jax.jit(jax.vmap(log_joint))
Define the ELBO and its gradient¶
[11]:
def elbo(beta_loc, beta_log_scale, epsilon):
beta_sample = beta_loc + jnp.exp(beta_log_scale) * epsilon
return jnp.mean(batched_log_joint(beta_sample), 0) + jnp.sum(beta_log_scale  0.5 * np.log(2*np.pi))
elbo = jax.jit(elbo)
elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))
Optimize the ELBO using SGD¶
[12]:
def normal_sample(key, shape):
"""Convenience function for quasistateful RNG."""
new_key, sub_key = random.split(key)
return new_key, random.normal(sub_key, shape)
normal_sample = jax.jit(normal_sample, static_argnums=(1,))
key = random.PRNGKey(10003)
beta_loc = jnp.zeros(num_features, jnp.float32)
beta_log_scale = jnp.zeros(num_features, jnp.float32)
step_size = 0.01
batch_size = 128
epsilon_shape = (batch_size, num_features)
for i in range(1000):
key, epsilon = normal_sample(key, epsilon_shape)
elbo_val, (beta_loc_grad, beta_log_scale_grad) = elbo_val_and_grad(
beta_loc, beta_log_scale, epsilon)
beta_loc += step_size * beta_loc_grad
beta_log_scale += step_size * beta_log_scale_grad
if i % 10 == 0:
print('{}\t{}'.format(i, elbo_val))
0 180.8538818359375
10 113.06045532226562
20 102.73725891113281
30 99.787353515625
40 98.90898132324219
50 98.29745483398438
60 98.18630981445312
70 97.5797348022461
80 97.28600311279297
90 97.469970703125
100 97.4771728515625
110 97.58067321777344
120 97.49435424804688
130 97.50271606445312
140 96.86395263671875
150 97.44197082519531
160 97.06939697265625
170 96.84028625488281
180 97.21336364746094
190 97.56502532958984
200 97.26398468017578
210 97.11979675292969
220 97.39593505859375
230 97.16830444335938
240 97.118408203125
250 97.24345397949219
260 97.2978744506836
270 96.69285583496094
280 96.9643783569336
290 97.30055236816406
300 96.63594055175781
310 97.03518676757812
320 97.52909851074219
330 97.28812408447266
340 97.0732192993164
350 97.15620422363281
360 97.25882720947266
370 97.19515228271484
380 97.13092041015625
390 97.11727905273438
400 96.93873596191406
410 97.26676940917969
420 97.35324096679688
430 97.21007537841797
440 97.28434753417969
450 97.16310119628906
460 97.2612533569336
470 97.21343994140625
480 97.23997497558594
490 97.14913177490234
500 97.23528289794922
510 96.9342041015625
520 97.21209716796875
530 96.82577514648438
540 97.01286315917969
550 96.94176483154297
560 97.16522216796875
570 97.29165649414062
580 97.42939758300781
590 97.24371337890625
600 97.15219116210938
610 97.49844360351562
620 96.99070739746094
630 96.88957977294922
640 96.89970397949219
650 97.13794708251953
660 97.43707275390625
670 96.99235534667969
680 97.15623474121094
690 97.18690490722656
700 97.11160278320312
710 97.78105163574219
720 97.23226165771484
730 97.16206359863281
740 96.99581909179688
750 96.66722869873047
760 97.16796112060547
770 97.51435089111328
780 97.28901672363281
790 96.91226196289062
800 97.1709976196289
810 97.29047393798828
820 97.16242980957031
830 97.1910629272461
840 97.56382751464844
850 97.00193786621094
860 96.86555480957031
870 96.76337432861328
880 96.83661651611328
890 97.12179565429688
900 97.09554290771484
910 97.0682373046875
920 97.11947631835938
930 96.8792953491211
940 97.45625305175781
950 96.69280242919922
960 97.29376220703125
970 97.3353042602539
980 97.34962463378906
990 97.09674835205078
Display the results¶
Coverage isn’t quite as good as we might like, but it’s not bad, and nobody said variational inference was exact.
[13]:
figure(figsize=(7, 7))
plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')
plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\sigma$ Error Bars')
plot(true_beta, beta_loc  2*jnp.exp(beta_log_scale), 'r.')
plot_scale = 3
plot([plot_scale, plot_scale], [plot_scale, plot_scale], 'k')
xlabel('True beta')
ylabel('Estimated beta')
legend(loc='best')
[13]:
<matplotlib.legend.Legend at 0x7fe7687b2a10>
[ ]:
🔪 JAX  The Sharp Bits 🔪¶
levskaya@ mattjj@
When walking about the countryside of Italy, the people will not hesitate to tell you that JAX has “una anima di pura programmazione funzionale”.
JAX is a language for expressing and composing transformations of numerical programs. JAX is also able to compile numerical programs for CPU or accelerators (GPU/TPU). JAX works great for many numerical and scientific programs, but only if they are written with certain constraints that we describe below.
[1]:
import numpy as np
from jax import grad, jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib import rcParams
rcParams['image.interpolation'] = 'nearest'
rcParams['image.cmap'] = 'viridis'
rcParams['axes.grid'] = False
🔪 Pure functions¶
JAX transformation and compilation are designed to work only on Python functions that are functionally pure: all the input data is passed through the function parameters, all the results are output through the function results. A pure function will always return the same result if invoked with the same inputs.
Here are some examples of functions that are not functially pure for which JAX behaves differently than the Python interpreter. Note that these behaviors are not guaranteed by the JAX system; the proper way to use JAX is to use it only on functionally pure Python functions.
[2]:
def impure_print_side_effect(x):
print("Executing function") # This is a sideeffect
return x
# The sideeffects appear during the first run
print ("First call: ", jit(impure_print_side_effect)(4.))
# Subsequent runs with parameters of same type and shape may not show the sideeffect
# This is because JAX now invokes a cached compilation of the function
print ("Second call: ", jit(impure_print_side_effect)(5.))
# JAX reruns the Python function when the type or shape of the argument changes
print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.])))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Executing function
First call: 4.0
Second call: 5.0
Executing function
Third call, different type: [5.]
[3]:
g = 0.
def impure_uses_globals(x):
return x + g
# JAX captures the value of the global during the first run
print ("First call: ", jit(impure_uses_globals)(4.))
g = 10. # Update the global
# Subsequent runs may silently use the cached value of the globals
print ("Second call: ", jit(impure_uses_globals)(5.))
# JAX reruns the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))
First call: 4.0
Second call: 5.0
Third call, different type: [14.]
[4]:
g = 0.
def impure_saves_global(x):
global g
g = x
return x
# JAX runs once the transformed function with special Traced values for arguments
print ("First call: ", jit(impure_saves_global)(4.))
print ("Saved global: ", g) # Saved global has an internal JAX value
First call: 4.0
Saved global: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
A Python function can be functionally pure even if it actually uses stateful objects internally, as long as it does not read or write external state:
[5]:
def pure_uses_internal_state(x):
state = dict(even=0, odd=0)
for i in range(10):
state['even' if i % 2 == 0 else 'odd'] += x
return state['even'] + state['odd']
print(jit(pure_uses_internal_state)(5.))
50.0
It is not recommended to use iterators in any JAX function you want to jit
or in any controlflow primitive. The reason is that an iterator is a python object which introduces state to retrieve the next element. Therefore, it is incompatible with JAX functional programming model. In the code below, there are some examples of incorrect attempts to use iterators with JAX. Most of them return an error, but some give unexpected results.
[6]:
import jax.numpy as jnp
import jax.lax as lax
from jax import make_jaxpr
# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0
# lax.scan
def func11(arr, extra):
ones = jnp.ones(arr.shape)
def body(carry, aelems):
ae1, ae2 = aelems
return (carry + ae1 * ae2 + extra, carry)
return lax.scan(body, 0., (arr, ones))
make_jaxpr(func11)(jnp.arange(16), 5.)
# make_jaxpr(func11)(iter(range(16)), 5.) # throws error
# lax.cond
array_operand = jnp.array([0.])
lax.cond(True, array_operand, lambda x: x+1, array_operand, lambda x: x1)
iter_operand = iter(range(10))
# lax.cond(True, iter_operand, lambda x: next(x)+1, iter_operand, lambda x: next(x)1) # throws error
45
0
🔪 InPlace Updates¶
In Numpy you’re used to doing this:
[7]:
numpy_array = np.zeros((3,3), dtype=np.float32)
print("original array:")
print(numpy_array)
# In place, mutating update
numpy_array[1, :] = 1.0
print("updated array:")
print(numpy_array)
original array:
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
updated array:
[[0. 0. 0.]
[1. 1. 1.]
[0. 0. 0.]]
If we try to update a JAX device array inplace, however, we get an error! (☉_☉)
[8]:
jax_array = jnp.zeros((3,3), dtype=jnp.float32)
# In place update of JAX's array will yield an error!
try:
jax_array[1, :] = 1.0
except Exception as e:
print("Exception {}".format(e))
Exception '<class 'jax.interpreters.xla.DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?
What gives?!
Allowing mutation of variables inplace makes program analysis and transformation very difficult. JAX requires a pure functional expression of a numerical program.
Instead, JAX offers the functional update functions: **index_update**, **index_add**, **index_min**, **index_max**, and the **index** helper.
️⚠️ inside jit
’d code and lax.while_loop
or lax.fori_loop
the size of slices can’t be functions of argument values but only functions of argument shapes – the slice start indices have no such restriction. See the below Control Flow Section for more information on this limitation.
[9]:
from jax.ops import index, index_add, index_update
index_update¶
If the input values of index_update aren’t reused, jitcompiled code will perform these operations inplace.
[10]:
jax_array = jnp.zeros((3, 3))
print("original array:")
print(jax_array)
new_jax_array = index_update(jax_array, index[1, :], 1.)
print("old array unchanged:")
print(jax_array)
print("new array:")
print(new_jax_array)
original array:
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
old array unchanged:
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
new array:
[[0. 0. 0.]
[1. 1. 1.]
[0. 0. 0.]]
index_add¶
If the input values of index_update aren’t reused, jitcompiled code will perform these operations inplace.
[11]:
print("original array:")
jax_array = jnp.ones((5, 6))
print(jax_array)
new_jax_array = index_add(jax_array, index[::2, 3:], 7.)
print("new array postaddition:")
print(new_jax_array)
original array:
[[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]]
new array postaddition:
[[1. 1. 1. 8. 8. 8.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 8. 8. 8.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 8. 8. 8.]]
🔪 OutofBounds Indexing¶
In Numpy, you are used to errors being thrown when you index an array outside of its bounds, like this:
[12]:
try:
np.arange(10)[11]
except Exception as e:
print("Exception {}".format(e))
Exception index 11 is out of bounds for axis 0 with size 10
However, raising an error on other accelerators can be more difficult. Therefore, JAX does not raise an error, instead the index is clamped to the bounds of the array, meaning that for this example the last value of the array will be returned.
[13]:
jnp.arange(10)[11]
[13]:
DeviceArray(9, dtype=int32)
Note that due to this behavior jnp.nanargmin and jnp.nanargmax return 1 for slices consisting of NaNs whereas Numpy would throw an error.
🔪 Random Numbers¶
If all scientific papers whose results are in doubt because of bad ``rand()``s were to disappear from library shelves, there would be a gap on each shelf about as big as your fist.  Numerical Recipes
RNGs and State¶
You’re used to stateful pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness:
[14]:
print(np.random.random())
print(np.random.random())
print(np.random.random())
0.22201696529622084
0.11493328212975451
0.02865958507133981
Underneath the hood, numpy uses the Mersenne Twister PRNG to power its pseudorandom functions. The PRNG has a period of \(2^{19937}1\) and at any point can be described by 624 32bit unsigned ints and a position indicating how much of this “entropy” has been used up.
[15]:
np.random.seed(0)
rng_state = np.random.get_state()
#print(rng_state)
# > ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,
# 2481403966, 4042607538, 337614300, ... 614 more numbers...,
# 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)
This pseudorandom state vector is automagically updated behind the scenes every time a random number is needed, “consuming” 2 of the uint32s in the Mersenne twister state vector:
[16]:
_ = np.random.uniform()
rng_state = np.random.get_state()
#print(rng_state)
# > ('MT19937', array([2443250962, 1093594115, 1878467924,
# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)
# Let's exhaust the entropy in this PRNG statevector
for i in range(311):
_ = np.random.uniform()
rng_state = np.random.get_state()
#print(rng_state)
# > ('MT19937', array([2443250962, 1093594115, 1878467924,
# ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)
# Next call iterates the RNG state for a new batch of fake "entropy".
_ = np.random.uniform()
rng_state = np.random.get_state()
# print(rng_state)
# > ('MT19937', array([1499117434, 2949980591, 2242547484,
# 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)
The problem with magic PRNG state is that it’s hard to reason about how it’s being used and updated across different threads, processes, and devices, and it’s very easy to screw up when the details of entropy production and consumption are hidden from the end user.
The Mersenne Twister PRNG is also known to have a number of problems, it has a large 2.5Kb state size, which leads to problematic initialization issues. It fails modern BigCrush tests, and is generally slow.
JAX PRNG¶
JAX instead implements an explicit PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern Threefry counterbased PRNG that’s splittable. That is, its design allows us to fork the PRNG state into new PRNGs for use with parallel stochastic generation.
The random state is described by two unsignedint32s that we call a key:
[17]:
from jax import random
key = random.PRNGKey(0)
key
[17]:
DeviceArray([0, 0], dtype=uint32)
JAX’s random functions produce pseudorandom numbers from the PRNG state, but do not change the state!
Reusing the same state will cause sadness and monotony, depriving the enduser of lifegiving chaos:
[18]:
print(random.normal(key, shape=(1,)))
print(key)
# No no no!
print(random.normal(key, shape=(1,)))
print(key)
[0.20584235]
[0 0]
[0.20584235]
[0 0]
Instead, we split the PRNG to get usable subkeys every time we need a new pseudorandom number:
[19]:
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print(" \SPLIT > new key ", key)
print(" \> new subkey", subkey, "> normal", normal_pseudorandom)
old key [0 0]
\SPLIT > new key [4146024105 967050713]
\> new subkey [2718843009 1272950319] > normal [1.2515389]
We propagate the key and make new subkeys whenever we need a new random number:
[20]:
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print(" \SPLIT > new key ", key)
print(" \> new subkey", subkey, "> normal", normal_pseudorandom)
old key [4146024105 967050713]
\SPLIT > new key [2384771982 3928867769]
\> new subkey [1278412471 2182328957] > normal [0.58665067]
We can generate more than one subkey at a time:
[21]:
key, *subkeys = random.split(key, 4)
for subkey in subkeys:
print(random.normal(subkey, shape=(1,)))
[0.37533444]
[0.9864503]
[0.1455319]
🔪 Control Flow¶
✔ python control_flow + autodiff ✔¶
If you just want to apply grad
to your python functions, you can use regular python controlflow constructs with no problems, as if you were using Autograd (or Pytorch or TF Eager).
[22]:
def f(x):
if x < 3:
return 3. * x ** 2
else:
return 4 * x
print(grad(f)(2.)) # ok!
print(grad(f)(4.)) # ok!
12.0
4.0
python control flow + JIT¶
Using control flow with jit
is more complicated, and by default it has more constraints.
This works:
[23]:
@jit
def f(x):
for i in range(3):
x = 2 * x
return x
print(f(3))
24
So does this:
[24]:
@jit
def g(x):
y = 0.
for i in range(x.shape[0]):
y = y + x[i]
return y
print(g(jnp.array([1., 2., 3.])))
6.0
But this doesn’t, at least by default:
[25]:
@jit
def f(x):
if x < 3:
return 3. * x ** 2
else:
return 4 * x
# This will fail!
try:
f(2)
except Exception as e:
print("Exception {}".format(e))
Exception Abstract tracer value encountered where concrete value is expected.
The problem arose with the `bool` function.
While tracing the function f at <ipythoninput1b42e45c0293f>:1, this concrete value was not available in Python because it depends on the value of the arguments to f at <ipythoninput1b42e45c0293f>:1 at flattened positions [0], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).
You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions, though at the cost of more recompiles.
See https://jax.readthedocs.io/en/latest/faq.html#abstracttracervalueencounteredwhereconcretevalueisexpectederror for more information.
Encountered tracer value: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>
What gives!?
When we jit
compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don’t have to recompile on each function evaluation.
For example, if we evaluate an @jit
function on the array jnp.array([1., 2., 3.], jnp.float32)
, we might want to compile code that we can reuse to evaluate the function on jnp.array([4., 5., 6.], jnp.float32)
to save on compile time.
To get a view of your Python code that is valid for many different argument values, JAX traces it on abstract values that represent sets of possible inputs. There are multiple different levels of abstraction, and different transformations use different abstraction levels.
By default, jit
traces your code on the ShapedArray
abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value ShapedArray((3,), jnp.float32)
, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.
But there’s a tradeoff here: if we trace a Python function on a ShapedArray((), jnp.float32)
that isn’t committed to a specific concrete value, when we hit a line like if x < 3
, the expression x < 3
evaluates to an abstract ShapedArray((), jnp.bool_)
that represents the set {True, False}
. When Python attempts to coerce that to a concrete True
or False
, we get an error: we don’t know which branch to take, and can’t continue tracing! The tradeoff is that with higher
levels of abstraction we gain a more general view of the Python code (and thus save on recompilations), but we require more constraints on the Python code to complete the trace.
The good news is that you can control this tradeoff yourself. By having jit
trace on more refined abstract values, you can relax the traceability constraints. For example, using the static_argnums
argument to jit
, we can specify to trace on concrete values of some arguments. Here’s that example function again:
[26]:
def f(x):
if x < 3:
return 3. * x ** 2
else:
return 4 * x
f = jit(f, static_argnums=(0,))
print(f(2.))
12.0
Here’s another example, this time involving a loop:
[27]:
def f(x, n):
y = 0.
for i in range(n):
y = y + x[i]
return y
f = jit(f, static_argnums=(1,))
f(jnp.array([2., 3., 4.]), 2)
[27]:
DeviceArray(5., dtype=float32)
In effect, the loop gets statically unrolled. JAX can also trace at higher levels of abstraction, like Unshaped
, but that’s not currently the default for any transformation
️⚠️ functions with argument**value dependent shapes**
These controlflow issues also come up in a more subtle way: numerical functions we want to jit can’t specialize the shapes of internal arrays on argument values (specializing on argument shapes is ok). As a trivial example, let’s make a function whose output happens to depend on the input variable length
.
[28]:
def example_fun(length, val):
return jnp.ones((length,)) * val
# unjit'd works fine
print(example_fun(5, 4))
bad_example_jit = jit(example_fun)
# this will fail:
try:
print(bad_example_jit(10, 4))
except Exception as e:
print("Exception {}".format(e))
# static_argnums tells JAX to recompile on changes at these argument positions:
good_example_jit = jit(example_fun, static_argnums=(0,))
# first compile
print(good_example_jit(10, 4))
# recompiles
print(good_example_jit(5, 4))
[4. 4. 4. 4. 4.]
Exception Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]
[4. 4. 4. 4. 4.]
static_argnums
can be handy if length
in our example rarely changes, but it would be disastrous if it changed a lot!
Lastly, if your function has global sideeffects, JAX’s tracer can cause weird things to happen. A common gotcha is trying to print arrays inside jit’d functions:
[29]:
@jit
def f(x):
print(x)
y = 2 * x
print(y)
return y
f(2)
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>
[29]:
DeviceArray(4, dtype=int32)
Structured control flow primitives¶
There are more options for control flow in JAX. Say you want to avoid recompilations but still want to use control flow that’s traceable, and that avoids unrolling large loops. Then you can use these 4 structured control flow primitives:
lax.cond
differentiablelax.while_loop
fwdmodedifferentiablelax.fori_loop
fwdmodedifferentiablelax.scan
differentiable
cond¶
python equivalent:
def cond(pred, true_operand, true_fun, false_operand, false_fun):
if pred:
return true_fun(true_operand)
else:
return false_fun(false_operand)
[30]:
from jax import lax
operand = jnp.array([0.])
lax.cond(True, operand, lambda x: x+1, operand, lambda x: x1)
# > array([1.], dtype=float32)
lax.cond(False, operand, lambda x: x+1, operand, lambda x: x1)
# > array([1.], dtype=float32)
[30]:
DeviceArray([1.], dtype=float32)
while_loop¶
python equivalent:
def while_loop(cond_fun, body_fun, init_val):
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
[31]:
init_val = 0
cond_fun = lambda x: x<10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)
# > array(10, dtype=int32)
[31]:
DeviceArray(10, dtype=int32)
fori_loop¶
python equivalent:
def fori_loop(start, stop, body_fun, init_val):
val = init_val
for i in range(start, stop):
val = body_fun(i, val)
return val
[32]:
init_val = 0
start = 0
stop = 10
body_fun = lambda i,x: x+i
lax.fori_loop(start, stop, body_fun, init_val)
# > array(45, dtype=int32)
[32]:
DeviceArray(45, dtype=int32)
Summary¶
\(\ast\) = argumentvalueindependent loop condition  unrolls the loop
🔪 Convolutions¶
JAX and XLA offer the very general Ndimensional conv_general_dilated function, but it’s not very obvious how to use it. We’ll give some examples of the common usecases.
For the most common kinds of convolutions, see also the convenience functions lax.conv and lax.conv_general_padding, as well as jax.numpy.convolve and jax.scipy.signal.convolve/jax.scipy.signal.convolve2d for an interface similar to that of the numpy and scipy packages.
A survey of the family of convolutional operators, a guide to convolutional arithmetic is highly recommended reading!
Let’s define a simple diagonal edge kernel:
[33]:
# 2D kernel  HWIO layout
kernel = np.zeros((3, 3, 3, 3), dtype=jnp.float32)
kernel += np.array([[1, 1, 0],
[1, 0,1],
[0,1,1]])[:, :, np.newaxis, np.newaxis]
print("Edge Conv kernel:")
plt.imshow(kernel[:, :, 0, 0]);
Edge Conv kernel:
[33]:
<matplotlib.image.AxesImage at 0x7f73fc7d7b50>
And we’ll make a simple synthetic image:
[34]:
# NHWC layout
img = np.zeros((1, 200, 198, 3), dtype=jnp.float32)
for k in range(3):
x = 30 + 60*k
y = 20 + 60*k
img[0, x:x+10, y:y+10, k] = 1.0
print("Original Image:")
plt.imshow(img[0]);
Original Image:
[34]:
<matplotlib.image.AxesImage at 0x7f73fc6ced90>
lax.conv and lax.conv_with_general_padding¶
These are the simple convenience functions for convolutions
️⚠️ The convenience lax.conv
, lax.conv_with_general_padding
helper function assume NCHW images and OIHW kernels.
[35]:
out = lax.conv(jnp.transpose(img,[0,3,1,2]), # lhs = NCHW image tensor
jnp.transpose(kernel,[3,2,0,1]), # rhs = OIHW conv kernel tensor
(1, 1), # window strides
'SAME') # padding mode
print("out shape: ", out.shape)
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(np.array(out)[0,0,:,:]);
out shape: (1, 3, 200, 198)
First output channel:
[35]:
<matplotlib.image.AxesImage at 0x7f73fc14dcd0>
[36]:
out = lax.conv_with_general_padding(
jnp.transpose(img,[0,3,1,2]), # lhs = NCHW image tensor
jnp.transpose(kernel,[2,3,0,1]), # rhs = IOHW conv kernel tensor
(1, 1), # window strides
((2,2),(2,2)), # general padding 2x2
(1,1), # lhs/image dilation
(1,1)) # rhs/kernel dilation
print("out shape: ", out.shape)
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(np.array(out)[0,0,:,:]);
out shape: (1, 3, 202, 200)
First output channel:
[36]:
<matplotlib.image.AxesImage at 0x7f73fc0bbad0>
Dimension Numbers define dimensional layout for conv_general_dilated¶
The important argument is the 3tuple of axis layout arguments: (Input Layout, Kernel Layout, Output Layout)  N  batch dimension  H  spatial height  W  spatial height  C  channel dimension  I  kernel input channel dimension  O  kernel output channel dimension
⚠️ To demonstrate the flexibility of dimension numbers we choose a NHWC image and HWIO kernel convention for lax.conv_general_dilated
below.
[37]:
dn = lax.conv_dimension_numbers(img.shape, # only ndim matters, not shape
kernel.shape, # only ndim matters, not shape
('NHWC', 'HWIO', 'NHWC')) # the important bit
print(dn)
ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))
SAME padding, no stride, no dilation¶
[38]:
out = lax.conv_general_dilated(img, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(1,1), # window strides
'SAME', # padding mode
(1,1), # lhs/image dilation
(1,1), # rhs/kernel dilation
dn) # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape)
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(np.array(out)[0,:,:,0]);
out shape: (1, 200, 198, 3)
First output channel:
[38]:
<matplotlib.image.AxesImage at 0x7f73fc04c450>
VALID padding, no stride, no dilation¶
[39]:
out = lax.conv_general_dilated(img, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(1,1), # window strides
'VALID', # padding mode
(1,1), # lhs/image dilation
(1,1), # rhs/kernel dilation
dn) # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, "DIFFERENT from above!")
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(np.array(out)[0,:,:,0]);
out shape: (1, 198, 196, 3) DIFFERENT from above!
First output channel:
[39]:
<matplotlib.image.AxesImage at 0x7f73f47c8150>
SAME padding, 2,2 stride, no dilation¶
[40]:
out = lax.conv_general_dilated(img, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(2,2), # window strides
'SAME', # padding mode
(1,1), # lhs/image dilation
(1,1), # rhs/kernel dilation
dn) # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, " < half the size of above")
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(np.array(out)[0,:,:,0]);
out shape: (1, 100, 99, 3) < half the size of above
First output channel:
[40]:
<matplotlib.image.AxesImage at 0x7f73fc7517d0>
VALID padding, no stride, rhs kernel dilation ~ Atrous convolution (excessive to illustrate)¶
[41]:
out = lax.conv_general_dilated(img, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(1,1), # window strides
'VALID', # padding mode
(1,1), # lhs/image dilation
(12,12), # rhs/kernel dilation
dn) # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape)
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(np.array(out)[0,:,:,0]);
out shape: (1, 176, 174, 3)
First output channel:
[41]:
<matplotlib.image.AxesImage at 0x7f73f477e890>
VALID padding, no stride, lhs=input dilation ~ Transposed Convolution¶
[42]:
out = lax.conv_general_dilated(img, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(1,1), # window strides
((0, 0), (0, 0)), # padding mode
(2,2), # lhs/image dilation
(1,1), # rhs/kernel dilation
dn) # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, "< larger than original!")
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(np.array(out)[0,:,:,0]);
out shape: (1, 397, 393, 3) < larger than original!
First output channel:
[42]:
<matplotlib.image.AxesImage at 0x7f73f4700810>
We can use the last to, for instance, implement transposed convolutions:
[43]:
# The following is equivalent to tensorflow:
# N,H,W,C = img.shape
# out = tf.nn.conv2d_transpose(img, kernel, (N,2*H,2*W,C), (1,2,2,1))
# transposed conv = 180deg kernel roation plus LHS dilation
# rotate kernel 180deg:
kernel_rot = jnp.rot90(jnp.rot90(kernel, axes=(0,1)), axes=(0,1))
# need a custom output padding:
padding = ((2, 1), (2, 1))
out = lax.conv_general_dilated(img, # lhs = image tensor
kernel_rot, # rhs = conv kernel tensor
(1,1), # window strides
padding, # padding mode
(2,2), # lhs/image dilation
(1,1), # rhs/kernel dilation
dn) # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, "< transposed_conv")
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(np.array(out)[0,:,:,0]);
out shape: (1, 400, 396, 3) < transposed_conv
First output channel:
[43]:
<matplotlib.image.AxesImage at 0x7f73f4682f90>
1D Convolutions¶
You aren’t limited to 2D convolutions, a simple 1D demo is below:
[44]:
# 1D kernel  WIO layout
kernel = np.array([[[1, 0, 1], [1, 0, 1]],
[[1, 1, 1], [1, 1, 1]]],
dtype=jnp.float32).transpose([2,1,0])
# 1D data  NWC layout
data = np.zeros((1, 200, 2), dtype=jnp.float32)
for i in range(2):
for k in range(2):
x = 35*i + 30 + 60*k
data[0, x:x+30, k] = 1.0
print("in shapes:", data.shape, kernel.shape)
plt.figure(figsize=(10,5))
plt.plot(data[0]);
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
('NWC', 'WIO', 'NWC'))
print(dn)
out = lax.conv_general_dilated(data, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(1,), # window strides
'SAME', # padding mode
(1,), # lhs/image dilation
(1,), # rhs/kernel dilation
dn) # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape)
plt.figure(figsize=(10,5))
plt.plot(out[0]);
in shapes: (1, 200, 2) (3, 2, 2)
ConvDimensionNumbers(lhs_spec=(0, 2, 1), rhs_spec=(2, 1, 0), out_spec=(0, 2, 1))
out shape: (1, 200, 2)
[44]:
[<matplotlib.lines.Line2D at 0x7f73f46cf3d0>,
<matplotlib.lines.Line2D at 0x7f73f45b99d0>]
3D Convolutions¶
[45]:
# Random 3D kernel  HWDIO layout
kernel = np.array([
[[0, 0, 0], [0, 1, 0], [0, 0, 0]],
[[0, 1, 0], [1, 0, 1], [0, 1, 0]],
[[0, 0, 0], [0, 1, 0], [0, 0, 0]]],
dtype=jnp.float32)[:, :, :, np.newaxis, np.newaxis]
# 3D data  NHWDC layout
data = np.zeros((1, 30, 30, 30, 1), dtype=jnp.float32)
x, y, z = np.mgrid[0:1:30j, 0:1:30j, 0:1:30j]
data += (np.sin(2*x*jnp.pi)*np.cos(2*y*jnp.pi)*np.cos(2*z*jnp.pi))[None,:,:,:,None]
print("in shapes:", data.shape, kernel.shape)
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
('NHWDC', 'HWDIO', 'NHWDC'))
print(dn)
out = lax.conv_general_dilated(data, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(1,1,1), # window strides
'SAME', # padding mode
(1,1,1), # lhs/image dilation
(1,1,1), # rhs/kernel dilation
dn) # dimension_numbers
print("out shape: ", out.shape)
# Make some simple 3d density plots:
from mpl_toolkits.mplot3d import Axes3D
def make_alpha(cmap):
my_cmap = cmap(jnp.arange(cmap.N))
my_cmap[:,1] = jnp.linspace(0, 1, cmap.N)**3
return mpl.colors.ListedColormap(my_cmap)
my_cmap = make_alpha(plt.cm.viridis)
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=data.ravel(), cmap=my_cmap)
ax.axis('off')
ax.set_title('input')
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=out.ravel(), cmap=my_cmap)
ax.axis('off')
ax.set_title('3D conv output');
in shapes: (1, 30, 30, 30, 1) (3, 3, 3, 1, 1)
ConvDimensionNumbers(lhs_spec=(0, 4, 1, 2, 3), rhs_spec=(4, 3, 0, 1, 2), out_spec=(0, 4, 1, 2, 3))
out shape: (1, 30, 30, 30, 1)
[45]:
Text(0.5, 0.92, '3D conv output')
🔪 NaNs¶
Debugging NaNs¶
If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaNchecker by:
setting the
JAX_DEBUG_NANS=True
environment variable;adding
from jax.config import config
andconfig.update("jax_debug_nans", True)
near the top of your main file;adding
from jax.config import config
andconfig.parse_flags_with_absl()
to your main file, then set the option using a commandline flag likejax_debug_nans=True
;
This will cause computations to errorout immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an @jit
. For code under an @jit
, the output of every @jit
function is checked and if a nan is present it will rerun the function in deoptimized opbyop mode, effectively removing one level of @jit
at
a time.
There could be tricky situations that arise, like nans that only occur under a @jit
but don’t get produced in deoptimized mode. In that case you’ll see a warning message print out but your code will continue to execute.
If the nans are being produced in the backward pass of a gradient evaluation, when an exception is raised several frames up in the stack trace you will be in the backward_pass function, which is essentially a simple jaxpr interpreter that walks the sequence of primitive operations in reverse. In the example below, we started an ipython repl with the command line env JAX_DEBUG_NANS=True ipython
, then ran this:
In [1]: import jax.numpy as jnp
In [2]: jnp.divide(0., 0.)

FloatingPointError Traceback (most recent call last)
<ipythoninput2f2e2c413b437> in <module>()
> 1 jnp.divide(0., 0.)
.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
343 return floor_divide(x1, x2)
344 else:
> 345 return true_divide(x1, x2)
346
347
.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
332 x1, x2 = _promote_shapes(x1, x2)
333 return lax.div(lax.convert_element_type(x1, result_dtype),
> 334 lax.convert_element_type(x2, result_dtype))
335
336
.../jax/jax/lax.pyc in div(x, y)
244 def div(x, y):
245 r"""Elementwise division: :math:`x \over y`."""
> 246 return div_p.bind(x, y)
247
248 def rem(x, y):
... stack trace ...
.../jax/jax/interpreters/xla.pyc in handle_result(device_buffer)
103 py_val = device_buffer.to_py()
104 if np.any(np.isnan(py_val)):
> 105 raise FloatingPointError("invalid value")
106 else:
107 return DeviceArray(device_buffer, *result_shape)
FloatingPointError: invalid value
The nan generated was caught. By running %debug
, we can get a postmortem debugger. This also works with functions under @jit
, as the example below shows.
In [4]: from jax import jit
In [5]: @jit
...: def f(x, y):
...: a = x * y
...: b = (x + y) / (x  y)
...: c = a + 2
...: return a + b * c
...:
In [6]: x = jnp.array([2., 0.])
In [7]: y = jnp.array([3., 0.])
In [8]: f(x, y)
Invalid value encountered in the output of a jit function. Calling the deoptimized version.

FloatingPointError Traceback (most recent call last)
<ipythoninput8811b7ddb3300> in <module>()
> 1 f(x, y)
... stack trace ...
<ipythoninput5619b39acbaac> in f(x, y)
2 def f(x, y):
3 a = x * y
> 4 b = (x + y) / (x  y)
5 c = a + 2
6 return a + b * c
.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
343 return floor_divide(x1, x2)
344 else:
> 345 return true_divide(x1, x2)
346
347
.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
332 x1, x2 = _promote_shapes(x1, x2)
333 return lax.div(lax.convert_element_type(x1, result_dtype),
> 334 lax.convert_element_type(x2, result_dtype))
335
336
.../jax/jax/lax.pyc in div(x, y)
244 def div(x, y):
245 r"""Elementwise division: :math:`x \over y`."""
> 246 return div_p.bind(x, y)
247
248 def rem(x, y):
... stack trace ...
When this code sees a nan in the output of an @jit
function, it calls into the deoptimized code, so we still get a clear stack trace. And we can run a postmortem debugger with %debug
to inspect all the values to figure out the error.
⚠️ You shouldn’t have the NaNchecker on if you’re not debugging, as it can introduce lots of devicehost roundtrips and performance regressions!
Double (64bit) precision¶
At the moment, JAX by default enforces singleprecision numbers to mitigate the Numpy API’s tendency to aggressively promote operands to double
. This is the desired behavior for many machinelearning applications, but it may catch you by surprise!
[46]:
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
x.dtype
[46]:
dtype('float32')
To use doubleprecision numbers, you need to set the jax_enable_x64
configuration variable at startup.
There are a few ways to do this:
You can enable 64bit mode by setting the environment variable
JAX_ENABLE_X64=True
.You can manually set the
jax_enable_x64
configuration flag at startup:
# again, this only works on startup!
from jax.config import config
config.update("jax_enable_x64", True)
You can parse commandline flags with
absl.app.run(main)
from jax.config import config
config.config_with_absl()
If you want JAX to run absl parsing for you, i.e. you don’t want to do
absl.app.run(main)
, you can instead use
from jax.config import config
if __name__ == '__main__':
# calls config.config_with_absl() *and* runs absl parsing
config.parse_flags_with_absl()
Note that #2#4 work for any of JAX’s configuration options.
We can then confirm that x64
mode is enabled:
[47]:
import jax.numpy as jnp
from jax import random
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
x.dtype # > dtype('float64')
[47]:
dtype('float32')
Caveats¶
⚠️ XLA doesn’t support 64bit convolutions on all backends!
Fin.¶
If something’s not covered here that has caused you weeping and gnashing of teeth, please let us know and we’ll extend these introductory advisos!
Custom derivative rules for JAXtransformable Python functions¶
mattjj@ Mar 19 2020, last updated Oct 14 2020
There are two ways to define differentiation rules in JAX:
using
jax.custom_jvp
andjax.custom_vjp
to define custom differentiation rules for Python functions that are already JAXtransformable; anddefining new
core.Primitive
instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems.
This notebook is about #1. To read instead about #2, see the notebook on adding primitives.
For an introduction to JAX’s automatic differentiation API, see The Autodiff Cookbook. This notebook assumes some familiarity with jax.jvp and jax.grad, and the mathematical meaning of JVPs and VJPs.
TL;DR¶
Custom JVPs with jax.custom_jvp
¶
[1]:
import jax.numpy as jnp
from jax import custom_jvp
@custom_jvp
def f(x, y):
return jnp.sin(x) * y
@f.defjvp
def f_jvp(primals, tangents):
x, y = primals
x_dot, y_dot = tangents
primal_out = f(x, y)
tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
return primal_out, tangent_out
[2]:
from jax import jvp, grad
print(f(2., 3.))
y, y_dot = jvp(f, (2., 3.), (1., 0.))
print(y)
print(y_dot)
print(grad(f)(2., 3.))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
2.7278922
2.7278922
1.2484405
1.2484405
[3]:
# Equivalent alternative using the defjvps convenience wrapper
@custom_jvp
def f(x, y):
return jnp.sin(x) * y
f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y,
lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot)
[4]:
print(f(2., 3.))
y, y_dot = jvp(f, (2., 3.), (1., 0.))
print(y)
print(y_dot)
print(grad(f)(2., 3.))
2.7278922
2.7278922
1.2484405
1.2484405
Custom VJPs with jax.custom_vjp
¶
[5]:
from jax import custom_vjp
@custom_vjp
def f(x, y):
return jnp.sin(x) * y
def f_fwd(x, y):
# Returns primal output and residuals to be used in backward pass by f_bwd.
return f(x, y), (jnp.cos(x), jnp.sin(x), y)
def f_bwd(res, g):
cos_x, sin_x, y = res # Gets residuals computed in f_fwd
return (cos_x * g * y, sin_x * g)
f.defvjp(f_fwd, f_bwd)
[6]:
print(grad(f)(2., 3.))
1.2484405
Example problems¶
To get an idea of what problems jax.custom_jvp
and jax.custom_vjp
are meant to solve, let’s go over a few examples. A more thorough introduction to the jax.custom_jvp
and jax.custom_vjp
APIs is in the next section.
Numerical stability¶
One application of jax.custom_jvp
is to improve the numerical stability of differentiation.
Say we want to write a function called log1pexp
, which computes \(x \mapsto \log ( 1 + e^x )\). We can write that using jax.numpy
:
[7]:
import jax.numpy as jnp
def log1pexp(x):
return jnp.log(1. + jnp.exp(x))
log1pexp(3.)
[7]:
DeviceArray(3.0485873, dtype=float32)
Since it’s written in terms of jax.numpy
, it’s JAXtransformable:
[8]:
from jax import jit, grad, vmap
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
3.0485873
0.95257413
[0.5 0.7310586 0.8807971]
But there’s a numerical stability problem lurking here:
[9]:
print(grad(log1pexp)(100.))
nan
That doesn’t seem right! After all, the derivative of \(x \mapsto \log (1 + e^x)\) is \(x \mapsto \frac{e^x}{1 + e^x}\), and so for large values of \(x\) we’d expect the value to be about 1.
We can get a bit more insight into what’s going on by looking at the jaxpr for the gradient computation:
[10]:
from jax import make_jaxpr
make_jaxpr(grad(log1pexp))(100.)
[10]:
{ lambda ; a.
let b = exp a
c = add b 1.0
_ = log c
d = div 1.0 c
e = mul d b
in (e,) }
Stepping through how the jaxpr would be evaluated, we can see that the last line would involve multiplying values that floating point math will round to 0 and \(\infty\), respectively, which is never a good idea. That is, we’re effectively evaluating lambda x: (1 / (1 + jnp.exp(x))) * jnp.exp(x)
for large x
, which effectively turns into 0. * jnp.inf
.
Instead of generating such large and small values, hoping for a cancellation that floats can’t always provide, we’d rather just express the derivative function as a more numerically stable program. In particular, we can write a program that more closely evaluates the equal mathematical expression \(1  \frac{1}{1 + e^x}\), with no cancellation in sight.
This problem is interesting because even though our definition of log1pexp
could already be JAXdifferentiated (and transformed with jit
, vmap
, …), we’re not happy with the result of applying standard autodiff rules to the primitives comprising log1pexp
and composing the result. Instead, we’d like to specify how the whole function log1pexp
should be differentiated, as a unit, and thus arrange those exponentials better.
This is one application of custom derivative rules for Python functions that are already JAX transformable: specifying how a composite function should be differentiated, while still using its original Python definition for other transformations (like jit
, vmap
, …).
Here’s a solution using jax.custom_jvp
:
[11]:
from jax import custom_jvp
@custom_jvp
def log1pexp(x):
return jnp.log(1. + jnp.exp(x))
@log1pexp.defjvp
def log1pexp_jvp(primals, tangents):
x, = primals
x_dot, = tangents
ans = log1pexp(x)
ans_dot = (1  1/(1 + jnp.exp(x))) * x_dot
return ans, ans_dot
[12]:
print(grad(log1pexp)(100.))
1.0
[13]:
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
3.0485873
0.95257413
[0.5 0.7310586 0.8807971]
Here’s a defjvps
convenience wrapper to express the same thing:
[14]:
@custom_jvp
def log1pexp(x):
return jnp.log(1. + jnp.exp(x))
log1pexp.defjvps(lambda t, ans, x: (1  1/(1 + jnp.exp(x))) * t)
[15]:
print(grad(log1pexp)(100.))
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
1.0
3.0485873
0.95257413
[0.5 0.7310586 0.8807971]
Enforcing a differentiation convention¶
A related application is to enforce a differentiation convention, perhaps at a boundary.
Consider the function \(f : \mathbb{R}_+ \mapsto \mathbb{R}_+\) with \(f(x) = \frac{x}{1 + \sqrt{x}}\), where we take \(\mathbb{R}_+ = [0, \infty)\). We might implement \(f\) as a program like this:
[16]:
def f(x):
return x / (1 + jnp.sqrt(x))
As a mathematical function on \(\mathbb{R}\) (the full real line), \(f\) is not differentiable at zero (because the limit defining the derivative doesn’t exist from the left). Correspondingly, autodiff produces a nan
value:
[17]:
print(grad(f)(0.))
nan
But mathematically if we think of \(f\) as a function on \(\mathbb{R}_+\) then it is differentiable at 0 [Rudin’s Principles of Mathematical Analysis Definition 5.1, or Tao’s Analysis I 3rd ed. Definition 10.1.1 and Example 10.1.6]. Alternatively, we might say as a convention we want to consider the directional derivative from the right. So there is a sensible value for the Python function grad(f)
to return at 0.0
, namely 1.0
. By default, JAX’s machinery for differentiation
assumes all functions are defined over \(\mathbb{R}\) and thus doesn’t produce 1.0
here.
We can use a custom JVP rule! In particular, we can define the JVP rule in terms of the derivative function \(x \mapsto \frac{\sqrt{x} + 2}{2(\sqrt{x} + 1)^2}\) on \(\mathbb{R}_+\),
[18]:
@custom_jvp
def f(x):
return x / (1 + jnp.sqrt(x))
@f.defjvp
def f_jvp(primals, tangents):
x, = primals
x_dot, = tangents
ans = f(x)
ans_dot = ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * x_dot
return ans, ans_dot
[19]:
print(grad(f)(0.))
1.0
Here’s the convenience wrapper version:
[20]:
@custom_jvp
def f(x):
return x / (1 + jnp.sqrt(x))
f.defjvps(lambda t, ans, x: ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * t)
[21]:
print(grad(f)(0.))
1.0
Gradient clipping¶
While in some cases we want to express a mathematical differentiation computation, in other cases we may even want to take a step away from mathematics to adjust the computation autodiff performs. One canonical example is reversemode gradient clipping.
For gradient clipping, we can use jnp.clip
together with a jax.custom_vjp
reversemodeonly rule:
[22]:
from functools import partial
from jax import custom_vjp
@custom_vjp
def clip_gradient(lo, hi, x):
return x # identity function
def clip_gradient_fwd(lo, hi, x):
return x, (lo, hi) # save bounds as residuals
def clip_gradient_bwd(res, g):
lo, hi = res
return (None, None, jnp.clip(g, lo, hi)) # use None to indicate zero cotangents for lo and hi
clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
[23]:
import matplotlib.pyplot as plt
from jax import vmap
t = jnp.linspace(0, 10, 1000)
plt.plot(jnp.sin(t))
plt.plot(vmap(grad(jnp.sin))(t))
[23]:
[<matplotlib.lines.Line2D at 0x7fe954047310>]
[24]:
def clip_sin(x):
x = clip_gradient(0.75, 0.75, x)
return jnp.sin(x)
plt.plot(clip_sin(t))
plt.plot(vmap(grad(clip_sin))(t))
[24]:
[<matplotlib.lines.Line2D at 0x7fe951f5cd10>]
Python debugging¶
Another application that is motivated by development workflow rather than numerics is to set a pdb
debugger trace in the backward pass of reversemode autodiff.
When trying to track down the source of a nan
runtime error, or just examine carefully the cotangent (gradient) values being propagated, it can be useful to insert a debugger at a point in the backward pass that corresponds to a specific point in the primal computation. You can do that with jax.custom_vjp
.
We’ll defer an example until the next section.
Implicit function differentiation of iterative implementations¶
This example gets pretty deep in the mathematical weeds!
Another application for jax.custom_vjp
is reversemode differentiation of functions that are JAXtransformable (by jit
, vmap
, …) but not efficiently JAXdifferentiable for some reason, perhaps because they involve lax.while_loop
. (It’s not possible to produce an XLA HLO program that efficiently computes the reversemode derivative of an XLA HLO While loop because that would require a program with unbounded memory use, which isn’t possible to express in XLA HLO, at least without
sideeffecting interactions through infeed/outfeed.)
For example, consider this fixed_point
routine which computes a fixed point by iteratively applying a function in a while_loop
:
[25]:
from jax.lax import while_loop
def fixed_point(f, a, x_guess):
def cond_fun(carry):
x_prev, x = carry
return jnp.abs(x_prev  x) > 1e6
def body_fun(carry):
_, x = carry
return x, f(a, x)
_, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess)))
return x_star
This is an iterative procedure for numerically solving the equation \(x = f(a, x)\) for \(x\), by iterating \(x_{t+1} = f(a, x_t)\) until \(x_{t+1}\) is sufficiently close to \(x_t\). The result \(x^*\) depends on the parameters \(a\), and so we can think of there being a function \(a \mapsto x^*(a)\) that is implicity defined by equation \(x = f(a, x)\).
We can use fixed_point
to run iterative procedures to convergence, for example running Newton’s method to calculate square roots while only executing adds, multiplies, and divides:
[26]:
def newton_sqrt(a):
update = lambda a, x: 0.5 * (x + a / x)
return fixed_point(update, a, a)
[27]:
print(newton_sqrt(2.))
1.4142135
We can vmap
or jit
the function as well:
[28]:
print(jit(vmap(newton_sqrt))(jnp.array([1., 2., 3., 4.])))
[1. 1.4142135 1.7320509 2. ]
We can’t apply reversemode automatic differentiation because of the while_loop
, but it turns out we wouldn’t want to anyway: instead of differentiating through the implementation of fixed_point
and all its iterations, we can exploit the mathematical structure to do something that is much more memoryefficient (and FLOPefficient in this case, too!). We can instead use the implicit function theorem [Prop A.25 of Bertsekas’s Nonlinear Programming, 2nd ed.], which guarantees (under some
conditions) the existence of the mathematical objects we’re about to use. In essence, we linearize at the solution and solve those linear equations iteratively to compute the derivatives we want.
Consider again the equation \(x = f(a, x)\) and the function \(x^*\). We want to evaluate vectorJacobian products like \(v^\mathsf{T} \mapsto v^\mathsf{T} \partial x^*(a_0)\).
At least in an open neighborhood around the point \(a_0\) at which we want to differentiate, let’s assume that the equation \(x^*(a) = f(a, x^*(a))\) holds for all \(a\). Since the two sides are equal as functions of \(a\), their derivatives must be equal as well, so let’s differentiate both sides:
\(\qquad \partial x^*(a) = \partial_0 f(a, x^*(a)) + \partial_1 f(a, x^*(a)) \partial x^*(a)\).
Setting \(A = \partial_1 f(a_0, x^*(a_0))\) and \(B = \partial_0 f(a_0, x^*(a_0))\), we can write the quantity we’re after more simply as
\(\qquad \partial x^*(a_0) = B + A \partial x^*(a_0)\),
or, by rearranging,
\(\qquad \partial x^*(a_0) = (I  A)^{1} B\).
That means we can evaluate vectorJacobian products like
\(\qquad v^\mathsf{T} \partial x^*(a_0) = v^\mathsf{T} (I  A)^{1} B = w^\mathsf{T} B\),
where \(w^\mathsf{T} = v^\mathsf{T} (I  A)^{1}\), or equivalently \(w^\mathsf{T} = v^\mathsf{T} + w^\mathsf{T} A\), or equivalently \(w^\mathsf{T}\) is the fixed point of the map \(u^\mathsf{T} \mapsto v^\mathsf{T} + u^\mathsf{T} A\). That last characterization gives us a way to write the VJP for fixed_point
in terms of a call to fixed_point
! Moreover, after expanding \(A\) and \(B\) back out, we can see we need only to evaluate VJPs of \(f\) at
\((a_0, x^*(a_0))\).
Here’s the upshot:
[29]:
from jax import vjp
@partial(custom_vjp, nondiff_argnums=(0,))
def fixed_point(f, a, x_guess):
def cond_fun(carry):
x_prev, x = carry
return jnp.abs(x_prev  x) > 1e6
def body_fun(carry):
_, x = carry
return x, f(a, x)
_, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess)))
return x_star
def fixed_point_fwd(f, a, x_init):
x_star = fixed_point(f, a, x_init)
return x_star, (a, x_star)
def fixed_point_rev(f, res, x_star_bar):
a, x_star = res
_, vjp_a = vjp(lambda a: f(a, x_star), a)
a_bar, = vjp_a(fixed_point(partial(rev_iter, f),
(a, x_star, x_star_bar),
x_star_bar))
return a_bar, jnp.zeros_like(x_star)
def rev_iter(f, packed, u):
a, x_star, x_star_bar = packed
_, vjp_x = vjp(lambda x: f(a, x), x_star)
return x_star_bar + vjp_x(u)[0]
fixed_point.defvjp(fixed_point_fwd, fixed_point_rev)
[30]:
print(newton_sqrt(2.))
1.4142135
[31]:
print(grad(newton_sqrt)(2.))
print(grad(grad(newton_sqrt))(2.))
0.35355338
0.088388346
We can check our answers by differentiating jnp.sqrt
, which uses a totally different implementation:
[32]:
print(grad(jnp.sqrt)(2.))
print(grad(grad(jnp.sqrt))(2.))
0.35355338
0.08838835
A limitation to this approach is that the argument f
can’t close over any values involved in differentiation. That is, you might notice that we kept the parameter a
explicit in the argument list of fixed_point
. While other JAX mechanisms can handle closedover transformationtraced values in the arguments to higherorder functions (as is done for the control flow primitives like lax.cond
, lax.scan
, and lax.while_loop
itself), jax.custom_vjp
used as above cannot. A
fixed_point
routine that used a bit more of JAX’s internals could have a more convenient and robust API.
Basic usage of jax.custom_jvp
and jax.custom_vjp
APIs¶
Use jax.custom_jvp
to define forwardmode (and, indirectly, reversemode) rules¶
Here’s a canonical basic example of using jax.custom_jvp
:
[33]:
from jax import custom_jvp
import jax.numpy as jnp
# f :: a > b
@custom_jvp
def f(x):
return jnp.sin(x)
# f_jvp :: (a, T a) > (b, T b)
def f_jvp(primals, tangents):
x, = primals
t, = tangents
return f(x), jnp.cos(x) * t
f.defjvp(f_jvp)
[34]:
from jax import jvp
print(f(3.))
y, y_dot = jvp(f, (3.,), (1.,))
print(y)
print(y_dot)
0.14112
0.14112
0.9899925
In words, we start with a a primal function f
that takes inputs of type a
and produces outputs of type b
. We associate with it a JVP rule function f_jvp
that takes a pair of inputs representing the primal inputs of type a
and the corresponding tangent inputs of type T a
, and produces a pair of outputs representing the primal outputs of type b
and tangent outputs of type T b
. The tangent outputs should be a linear function of the tangent inputs.
You can also use f.defjvp
as a decorator, as in
@custom_jvp
def f(x):
...
@f.defjvp
def f_jvp(primals, tangents):
...
Even though we defined only a JVP rule and no VJP rule, we can use both forward and reversemode differentiation on f
. JAX will automatically transpose the linear computation on tangent values from our custom JVP rule, computing the VJP as efficiently as if we had written the rule by hand:
[35]:
from jax import grad
print(grad(f)(3.))
print(grad(grad(f))(3.))
0.9899925
0.14112
For automatic transposition to work, the JVP rule’s output tangents must be linear as a function of the input tangents. Otherwise a transposition error is raised.
Multiple arguments work like this:
[36]:
@custom_jvp
def f(x, y):
return x ** 2 * y
@f.defjvp
def f_jvp(primals, tangents):
x, y = primals
x_dot, y_dot = tangents
primal_out = f(x, y)
tangent_out = 2 * x * y * x_dot + x ** 2 * y_dot
return primal_out, tangent_out
[37]:
print(grad(f)(2., 3.))
12.0
The defjvps
convenience wrapper lets us define a JVP for each argument separately, and the results are computed separately then summed:
[38]:
@custom_jvp
def f(x):
return jnp.sin(x)
f.defjvps(lambda t, ans, x: jnp.cos(x) * t)
[39]:
print(grad(f)(3.))
0.9899925
Here’s a defjvps
example with multiple arguments:
[40]:
@custom_jvp
def f(x, y):
return x ** 2 * y
f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot,
lambda y_dot, primal_out, x, y: x ** 2 * y_dot)
[41]:
print(grad(f)(2., 3.))
print(grad(f, 0)(2., 3.)) # same as above
print(grad(f, 1)(2., 3.))
12.0
12.0
4.0
As a shorthand, with defjvps
you can pass a None
value to indicate that the JVP for a particular argument is zero:
[42]:
@custom_jvp
def f(x, y):
return x ** 2 * y
f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot,
None)
[43]:
print(grad(f)(2., 3.))
print(grad(f, 0)(2., 3.)) # same as above
print(grad(f, 1)(2., 3.))
12.0
12.0
0.0
Calling a jax.custom_jvp
function with keyword arguments, or writing a jax.custom_jvp
function definition with default arguments, are both allowed so long as they can be unambiguosly mapped to positional arguments based on the function signature retrieved by the standard library inspect.signature
mechanism.
When you’re not performing differentiation, the function f
is called just as if it weren’t decorated by jax.custom_jvp
:
[44]:
@custom_jvp
def f(x):
print('called f!') # a harmless sideeffect
return jnp.sin(x)
@f.defjvp
def f_jvp(primals, tangents):
print('called f_jvp!') # a harmless sideeffect
x, = primals
t, = tangents
return f(x), jnp.cos(x) * t
[45]:
from jax import vmap, jit
print(f(3.))
called f!
0.14112
[46]:
print(vmap(f)(jnp.arange(3.)))
print(jit(f)(3.))
called f!
[0. 0.84147096 0.9092974 ]
called f!
0.14112
The custom JVP rule is invoked during differentiation, whether forward or reverse:
[47]:
y, y_dot = jvp(f, (3.,), (1.,))
print(y_dot)
called f_jvp!
called f!
0.9899925
[48]:
print(grad(f)(3.))
called f_jvp!
called f!
0.9899925
Notice that f_jvp
calls f
to compute the primal outputs. In the context of higherorder differentiation, each application of a differentiation transform will use the custom JVP rule if and only if the rule calls the original f
to compute the primal outputs. (This represents a kind of fundamental tradeoff, where we can’t make use of intermediate values from the evaluation of f
in our rule and also have the rule apply in all orders of higherorder differentiation.)
[49]:
grad(grad(f))(3.)
called f_jvp!
called f_jvp!
called f!
[49]:
DeviceArray(0.14112, dtype=float32)
You can use Python control flow with jax.custom_jvp
:
[50]:
@custom_jvp
def f(x):
if x > 0:
return jnp.sin(x)
else:
return jnp.cos(x)
@f.defjvp
def f_jvp(primals, tangents):
x, = primals
x_dot, = tangents
ans = f(x)
if x > 0:
return ans, 2 * x_dot
else:
return ans, 3 * x_dot
[51]:
print(grad(f)(1.))
print(grad(f)(1.))
2.0
3.0
Use jax.custom_vjp
to define custom reversemodeonly rules¶
While jax.custom_jvp
suffices for controlling both forward and, via JAX’s automatic transposition, reversemode differentiation behavior, in some cases we may want to directly control a VJP rule, for example in the latter two example problems presented above. We can do that with jax.custom_vjp
:
[52]:
from jax import custom_vjp
import jax.numpy as jnp
# f :: a > b
@custom_vjp
def f(x):
return jnp.sin(x)
# f_fwd :: a > (b, c)
def f_fwd(x):
return f(x), jnp.cos(x)
# f_bwd :: (c, CT b) > CT a
def f_bwd(cos_x, y_bar):
return (cos_x * y_bar,)
f.defvjp(f_fwd, f_bwd)
[53]:
from jax import grad
print(f(3.))
print(grad(f)(3.))
0.14112
0.9899925
In words, we again start with a a primal function f
that takes inputs of type a
and produces outputs of type b
. We associate with it two functions, f_fwd
and f_bwd
, which describe how to perform the forward and backwardpasses of reversemode autodiff, respectively.
The function f_fwd
describes the forward pass, not only the primal computation but also what values to save for use on the backward pass. Its input signature is just like that of the primal function f
, in that it takes a primal input of type a
. But as output it produces a pair, where the first element is the primal output b
and the second element is any “residual” data of type c
to be stored for use by the backward pass. (This second output is analogous to PyTorch’s
save_for_backward mechanism.)
The function f_bwd
describes the backward pass. It takes two inputs, where the first is the residual data of type c
produced by f_fwd
and the second is the output cotangents of type CT b
corresponding to the output of the primal function. It produces an output of type CT a
representing the cotangents corresponding to the input of the primal function. In particular, the output of f_bwd
must be a sequence (e.g. a tuple) of length equal to the number of arguments to the
primal function.
So multiple arguments work like this:
[54]:
from jax import custom_vjp
@custom_vjp
def f(x, y):
return jnp.sin(x) * y
def f_fwd(x, y):
return f(x, y), (jnp.cos(x), jnp.sin(x), y)
def f_bwd(res, g):
cos_x, sin_x, y = res
return (cos_x * g * y, sin_x * g)
f.defvjp(f_fwd, f_bwd)
[55]:
print(grad(f)(2., 3.))
1.2484405
Calling a jax.custom_vjp
function with keyword arguments, or writing a jax.custom_vjp
function definition with default arguments, are both allowed so long as they can be unambiguosly mapped to positional arguments based on the function signature retrieved by the standard library inspect.signature
mechanism.
As with jax.custom_jvp
, the custom VJP rule comprised by f_fwd
and f_bwd
is not invoked if differentiation is not applied. If function is evaluated, or transformed with jit
, vmap
, or other nondifferentiation transformations, then only f
is called.
[56]:
@custom_vjp
def f(x):
print("called f!")
return jnp.sin(x)
def f_fwd(x):
print("called f_fwd!")
return f(x), jnp.cos(x)
def f_bwd(cos_x, y_bar):
print("called f_bwd!")
return (cos_x * y_bar,)
f.defvjp(f_fwd, f_bwd)
[57]:
print(f(3.))
called f!
0.14112
[58]:
print(grad(f)(3.))
called f_fwd!
called f!
called f_bwd!
0.9899925
[59]:
from jax import vjp
y, f_vjp = vjp(f, 3.)
print(y)
called f_fwd!
called f!
0.14112
[60]:
print(f_vjp(1.))
called f_bwd!
(DeviceArray(0.9899925, dtype=float32),)
Forwardmode autodiff cannot be used on the jax.custom_vjp
function and will raise an error:
[61]:
from jax import jvp
try:
jvp(f, (3.,), (1.,))
except TypeError as e:
print('ERROR! {}'.format(e))
called f_fwd!
called f!
ERROR! can't apply forwardmode autodiff (jvp) to a custom_vjp function.
If you want to use both forward and reversemode, use jax.custom_jvp
instead.
We can use jax.custom_vjp
together with pdb
to insert a debugger trace in the backward pass:
[62]:
import pdb
@custom_vjp
def debug(x):
return x # acts like identity
def debug_fwd(x):
return x, x
def debug_bwd(x, g):
import pdb; pdb.set_trace()
return g
debug.defvjp(debug_fwd, debug_bwd)
[63]:
def foo(x):
y = x ** 2
y = debug(y) # insert pdb in corresponding backward pass step
return jnp.sin(y)
jax.grad(foo)(3.)
> <ipythoninput113b19a2dc1abf7>(12)debug_bwd()
> return g
(Pdb) p x
DeviceArray(9., dtype=float32)
(Pdb) p g
DeviceArray(0.91113025, dtype=float32)
(Pdb) q
More features and details¶
Working with list
/ tuple
/ dict
containers (and other pytrees)¶
You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any pytrees are permissible, so long as their structures are consistent according to the type constraints.
Here’s a contrived example with jax.custom_jvp
:
[64]:
from collections import namedtuple
Point = namedtuple("Point", ["x", "y"])
@custom_jvp
def f(pt):
x, y = pt.x, pt.y
return {'a': x ** 2,
'b': (jnp.sin(x), jnp.cos(y))}
@f.defjvp
def f_jvp(primals, tangents):
pt, = primals
pt_dot, = tangents
ans = f(pt)
ans_dot = {'a': 2 * pt.x * pt_dot.x,
'b': (jnp.cos(pt.x) * pt_dot.x, jnp.sin(pt.y) * pt_dot.y)}
return ans, ans_dot
def fun(pt):
dct = f(pt)
return dct['a'] + dct['b'][0]
[65]:
pt = Point(1., 2.)
print(f(pt))
{'a': 1.0, 'b': (DeviceArray(0.84147096, dtype=float32), DeviceArray(0.41614684, dtype=float32))}
[66]:
print(grad(fun)(pt))
Point(x=DeviceArray(2.5403023, dtype=float32), y=array(0., dtype=float32))
And an analogous contrived example with jax.custom_vjp
:
[67]:
@custom_vjp
def f(pt):
x, y = pt.x, pt.y
return {'a': x ** 2,
'b': (jnp.sin(x), jnp.cos(y))}
def f_fwd(pt):
return f(pt), pt
def f_bwd(pt, g):
a_bar, (b0_bar, b1_bar) = g['a'], g['b']
x_bar = 2 * pt.x * a_bar + jnp.cos(pt.x) * b0_bar
y_bar = jnp.sin(pt.y) * b1_bar
return (Point(x_bar, y_bar),)
f.defvjp(f_fwd, f_bwd)
def fun(pt):
dct = f(pt)
return dct['a'] + dct['b'][0]
[68]:
pt = Point(1., 2.)
print(f(pt))
{'a': 1.0, 'b': (DeviceArray(0.84147096, dtype=float32), DeviceArray(0.41614684, dtype=float32))}
[69]:
print(grad(fun)(pt))
Point(x=DeviceArray(2.5403023, dtype=float32), y=DeviceArray(0., dtype=float32))
Handling nondifferentiable arguments¶
Some use cases, like the final example problem, call for nondifferentiable arguments like functionvalued arguments to be passed to functions with custom differentiation rules, and for those arguments to also be passed to the rules themselves. In the case of fixed_point
, the function argument f
was such a nondifferentiable argument. A similar situation arises with jax.experimental.odeint
.
jax.custom_jvp
with nondiff_argnums
¶
Use the optional nondiff_argnums
parameter to jax.custom_jvp
to indicate arguments like these. Here’s an example with jax.custom_jvp
:
[70]:
from functools import partial
@partial(custom_jvp, nondiff_argnums=(0,))
def app(f, x):
return f(x)
@app.defjvp
def app_jvp(f, primals, tangents):
x, = primals
x_dot, = tangents
return f(x), 2. * x_dot
[71]:
print(app(lambda x: x ** 3, 3.))
27.0
[72]:
print(grad(app, 1)(lambda x: x ** 3, 3.))
2.0
Notice the gotcha here: no matter where in the argument list these parameters appear, they’re placed at the start of the signature of the corresponding JVP rule. Here’s another example:
[73]:
@partial(custom_jvp, nondiff_argnums=(0, 2))
def app2(f, x, g):
return f(g((x)))
@app2.defjvp
def app2_jvp(f, g, primals, tangents):
x, = primals
x_dot, = tangents
return f(g(x)), 3. * x_dot
[74]:
print(app2(lambda x: x ** 3, 3., lambda y: 5 * y))
3375.0
[75]:
print(grad(app2, 1)(lambda x: x ** 3, 3., lambda y: 5 * y))
3.0
jax.custom_vjp
with nondiff_argnums
¶
A similar option exists for jax.custom_vjp
, and similarly the convention is that the nondifferentiable arguments are passed as the first arguments to the rules, no matter where they appear in the original function’s signature. Here’s an example:
[76]:
@partial(custom_vjp, nondiff_argnums=(0,))
def app(f, x):
return f(x)
def app_fwd(f, x):
return f(x), x
def app_bwd(f, x, g):
return (5 * g,)
app.defvjp(app_fwd, app_bwd)
[77]:
print(app(lambda x: x ** 2, 4.))
16.0
[78]:
print(grad(app, 1)(lambda x: x ** 2, 4.))
5.0
See fixed_point
above for another usage example.
You don’t need to use nondiff_argnums
with arrayvalued arguments, for example ones with integer dtype. Instead, nondiff_argnums
should only be used for argument values that don’t correspond to JAX types (essentially don’t correspond to array types), like Python callables or strings. If JAX detects that an argument indicated by nondiff_argnums
contains a JAX Tracer, then an error is raised. The clip_gradient
function above is a good example of not using
nondiff_argnums
for integerdtype array arguments.
How JAX primitives work¶
necula@google.com, October 2019.
JAX implements certain transformations of Python functions, e.g., jit
, grad
, vmap
, or pmap
. The Python functions to be transformed must be JAXtraceable, which means that as the Python function executes the only operations it applies to the data are either inspections of data attributes such as shape or type, or special operations called JAX primitives. In particular, a JAXtraceable function is sometimes invoked by JAX with abstract arguments. An example of a JAX abstract value
is ShapedArray(float32[2,2])
, which captures the type and the shape of values, but not the concrete data values. JAX primitives know how to operate on both concrete data values and on the JAX abstract values.
The JAXtransformed functions must themselves be JAXtraceable functions, to ensure that these transformations can be composed, e.g., jit(jacfwd(grad(f)))
.
There are predefined JAX primitives corresponding to most XLA operations, e.g., add, matmul, sin, cos, indexing. JAX comes with an implementation of numpy functions in terms of JAX primitives, which means that Python programs using JAX’s implementation of numpy are JAXtraceable and therefore transformable. Other libraries can be made JAXtraceable by implementing them in terms of JAX primitives.
The set of JAX primitives is extensible. Instead of reimplementing a function in terms of predefined JAX primitives, one can define a new primitive that encapsulates the behavior of the function.
The goal of this document is to explain the interface that a JAX primitive must support in order to allow JAX to perform all its transformations.
Consider that we want to add to JAX support for a multiplyadd function with three arguments, defined mathematically as “multiply_add(x, y, z) = x * y + z”. This function operates on 3 identicallyshaped tensors of floating point values and performs the opertions pointwise.
Using existing primitives¶
The easiest way to define new functions is to write them in terms of JAX primitives, or in terms of other functions that are themselves written using JAX primitives, e.g., those defined in the jax.lax
module:
[1]:
from jax import lax
from jax import api
def multiply_add_lax(x, y, z):
"""Implementation of multiplyadd using the jax.lax primitives."""
return lax.add(lax.mul(x, y), z)
def square_add_lax(a, b):
"""A squareadd function using the newly defined multiplyadd."""
return multiply_add_lax(a, a, b)
print("square_add_lax = ", square_add_lax(2., 10.))
# Differentiate w.r.t. the first argument
print("grad(square_add_lax) = ", api.grad(square_add_lax, argnums=0)(2.0, 10.))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
square_add_lax = 14.0
grad(square_add_lax) = 4.0
In order to understand how JAX is internally using the primitives, we add some helpers for tracing function calls.
[2]:
#@title Helper functions (execute this cell)
import functools
import traceback
_indentation = 0
def _trace(msg=None):
"""Print a message at current indentation."""
if msg is not None:
print(" " * _indentation + msg)
def _trace_indent(msg=None):
"""Print a message and then indent the rest."""
global _indentation
_trace(msg)
_indentation = 1 + _indentation
def _trace_unindent(msg=None):
"""Unindent then print a message."""
global _indentation
_indentation = _indentation  1
_trace(msg)
def trace(name):
"""A decorator for functions to trace arguments and results."""
def trace_func(func): # pylint: disable=missingdocstring
def pp(v):
"""Print certain values more succinctly"""
vtype = str(type(v))
if "jax.lib.xla_bridge._JaxComputationBuilder" in vtype:
return "<JaxComputationBuilder>"
elif "jaxlib.xla_extension.XlaOp" in vtype:
return "<XlaOp at 0x{:x}>".format(id(v))
elif ("partial_eval.JaxprTracer" in vtype or
"batching.BatchTracer" in vtype or
"ad.JVPTracer" in vtype):
return "Traced<{}>".format(v.aval)
elif isinstance(v, tuple):
return "({})".format(pp_values(v))
else:
return str(v)
def pp_values(args):
return ", ".join([pp(arg) for arg in args])
@functools.wraps(func)
def func_wrapper(*args):
_trace_indent("call {}({})".format(name, pp_values(args)))
res = func(*args)
_trace_unindent("< {} = {}".format(name, pp(res)))
return res
return func_wrapper
return trace_func
class expectNotImplementedError(object):
"""Context manager to check for NotImplementedError."""
def __enter__(self): pass
def __exit__(self, type, value, tb):
global _indentation
_indentation = 0
if type is NotImplementedError:
print("\nFound expected exception:")
traceback.print_exc(limit=3)
return True
elif type is None: # No exception
assert False, "Expected NotImplementedError"
else:
return False
Instead of using jax.lax
primitives directly, we can use other functions that are already written in terms of those primitives, such as those in jax.numpy
:
[3]:
import jax.numpy as jnp
import numpy as np
@trace("multiply_add_numpy")
def multiply_add_numpy(x, y, z):
return jnp.add(jnp.multiply(x, y), z)
@trace("square_add_numpy")
def square_add_numpy(a, b):
return multiply_add_numpy(a, a, b)
print("\nNormal evaluation:")
print("square_add_numpy = ", square_add_numpy(2., 10.))
print("\nGradient evaluation:")
print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.))
Normal evaluation:
call square_add_numpy(2.0, 10.0)
call multiply_add_numpy(2.0, 2.0, 10.0)
< multiply_add_numpy = 14.0
< square_add_numpy = 14.0
square_add_numpy = 14.0
Gradient evaluation:
call square_add_numpy(Traced<ConcreteArray(2.0, weak_type=True)>, 10.0)
call multiply_add_numpy(Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ConcreteArray(2.0, weak_type=True)>, 10.0)
< multiply_add_numpy = Traced<ConcreteArray(14.0)>
< square_add_numpy = Traced<ConcreteArray(14.0)>
grad(square_add_numpy) = 4.0
Notice that in the process of computing grad
, JAX invokes square_add_numpy
and multiply_add_numpy
with special arguments ConcreteArray(...)
(described further below in this colab). It is important to remember that a JAXtraceable function must be able to operate not only on concrete arguments but also on special abstract arguments that JAX may use to abstract the function execution.
The JAX traceability property is satisfied as long as the function is written in terms of JAX primitives.
Defining new JAX primitives¶
The right way to add support for multiplyadd is in terms of existing JAX primitives, as shown above. However, in order to demonstrate how JAX primitives work let us pretend that we want to add a new primitive to JAX for the multiplyadd functionality.
[4]:
from jax import core
multiply_add_p = core.Primitive("multiply_add") # Create the primitive
@trace("multiply_add_prim")
def multiply_add_prim(x, y, z):
"""The JAXtraceable way to use the JAX primitive.
Note that the traced arguments must be passed as positional arguments
to `bind`.
"""
return multiply_add_p.bind(x, y, z)
@trace("square_add_prim")
def square_add_prim(a, b):
"""A squareadd function implemented using the new JAXprimitive."""
return multiply_add_prim(a, a, b)
If we try to call the newly defined functions we get an error, because we have not yet told JAX anything about the semantics of the new primitive.
[5]:
with expectNotImplementedError():
square_add_prim(2., 10.)
call square_add_prim(2.0, 10.0)
call multiply_add_prim(2.0, 2.0, 10.0)
Found expected exception:
Traceback (most recent call last):
File "<ipythoninput1acee329b29d0>", line 2, in <module>
square_add_prim(2., 10.)
File "<ipythoninput1756fd2c18f40>", line 48, in func_wrapper
res = func(*args)
File "<ipythoninput1c5402c1795f0>", line 16, in square_add_prim
return multiply_add_prim(a, a, b)
NotImplementedError: Evaluation rule for 'multiply_add' not implemented
Primal evaluation rules¶
[6]:
@trace("multiply_add_impl")
def multiply_add_impl(x, y, z):
"""Concrete implementation of the primitive.
This function does not need to be JAX traceable.
Args:
x, y, z: the concrete arguments of the primitive. Will only be called with
concrete values.
Returns:
the concrete result of the primitive.
"""
# Note that we can use the original numpy, which is not JAX traceable
return np.add(np.multiply(x, y), z)
# Now we register the primal implementation with JAX
multiply_add_p.def_impl(multiply_add_impl)
[6]:
<function __main__.multiply_add_impl(x, y, z)>
[7]:
assert square_add_prim(2., 10.) == 14.
call square_add_prim(2.0, 10.0)
call multiply_add_prim(2.0, 2.0, 10.0)
call multiply_add_impl(2.0, 2.0, 10.0)
< multiply_add_impl = 14.0
< multiply_add_prim = 14.0
< square_add_prim = 14.0
JIT¶
If we now try to use jit
we get a NotImplementedError
:
[8]:
with expectNotImplementedError():
api.jit(square_add_prim)(2., 10.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>)
Found expected exception:
Traceback (most recent call last):
File "<ipythoninput1d4853f4fcae2>", line 2, in <module>
api.jit(square_add_prim)(2., 10.)
File "<ipythoninput1756fd2c18f40>", line 48, in func_wrapper
res = func(*args)
File "<ipythoninput1c5402c1795f0>", line 16, in square_add_prim
return multiply_add_prim(a, a, b)
jax._src.traceback_util.FilteredStackTrace: NotImplementedError: Abstract evaluation for 'multiply_add' not implemented
The stack trace above excludes JAXinternal frames.
The following is the original exception that occurred, unmodified.

The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<ipythoninput1d4853f4fcae2>", line 2, in <module>
api.jit(square_add_prim)(2., 10.)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.7/sitepackages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.7/sitepackages/jax/api.py", line 371, in f_jitted
return cpp_jitted_f(*args, **kwargs)
NotImplementedError: Abstract evaluation for 'multiply_add' not implemented
Abstract evaluation rules¶
In order to JIT the function, and for other transformations as well, JAX first evaluates it abstractly using only the shape and type of the arguments. This abstract evaluation serves multiple purposes:
Gets the sequence of JAX primitives that are used in the computation. This sequence will be compiled.
Computes the shape and type of all vectors and operations used in the computation.
For example, the abstraction of a vector with 3 elements may be ShapedArray(float32[3])
, or ConcreteArray([1., 2., 3.])
. In the latter case, JAX uses the actual concrete value wrapped as an abstract value.
[9]:
from jax import abstract_arrays
@trace("multiply_add_abstract_eval")
def multiply_add_abstract_eval(xs, ys, zs):
"""Abstract evaluation of the primitive.
This function does not need to be JAX traceable. It will be invoked with
abstractions of the actual arguments.
Args:
xs, ys, zs: abstractions of the arguments.
Result:
a ShapedArray for the result of the primitive.
"""
assert xs.shape == ys.shape
assert xs.shape == zs.shape
return abstract_arrays.ShapedArray(xs.shape, xs.dtype)
# Now we register the abstract evaluation with JAX
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)
[9]:
<function __main__.multiply_add_abstract_eval(xs, ys, zs)>
If we reattempt to JIT, we see how the abstract evaluation proceeds, but we get another error, about missing the actual XLA compilation rule:
[10]:
with expectNotImplementedError():
api.jit(square_add_prim)(2., 10.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
< square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
Found expected exception:
Traceback (most recent call last):
File "<ipythoninput1d4853f4fcae2>", line 2, in <module>
api.jit(square_add_prim)(2., 10.)
jax._src.traceback_util.FilteredStackTrace: NotImplementedError: XLA translation rule for primitive 'multiply_add' not found
The stack trace above excludes JAXinternal frames.
The following is the original exception that occurred, unmodified.

The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<ipythoninput1d4853f4fcae2>", line 2, in <module>
api.jit(square_add_prim)(2., 10.)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.7/sitepackages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.7/sitepackages/jax/api.py", line 371, in f_jitted
return cpp_jitted_f(*args, **kwargs)
NotImplementedError: XLA translation rule for primitive 'multiply_add' not found
XLA Compilation rules¶
JAX compilation works by compiling each primitive into a graph of XLA operations.
This is biggest hurdle to adding new functionality to JAX, because the set of XLA operations is limited, and JAX already has predefined primitives for most of them. However, XLA includes a CustomCall
operation that can be used to encapsulate arbitrary functionality defined using C++.
[11]:
from jax.lib import xla_client
@trace("multiply_add_xla_translation")
def multiply_add_xla_translation(c, xc, yc, zc):
"""The compilation to XLA of the primitive.
Given an XlaBuilder and XlaOps for each argument, return the XlaOp for the
result of the function.
Does not need to be a JAXtraceable function.
"""
return xla_client.ops.Add(xla_client.ops.Mul(xc, yc), zc)
# Now we register the XLA compilation rule with JAX
# TODO: for GPU? and TPU?
from jax.interpreters import xla
xla.backend_specific_translations['cpu'][multiply_add_p] = multiply_add_xla_translation
Now we succeed to JIT. Notice below that JAX first evaluates the function abstractly, which triggers the multiply_add_abstract_eval
function, and then compiles the set of primitives it has encountered, including multiply_add
. At this point JAX invokes multiply_add_xla_translation
.
[12]:
assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
< square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
call multiply_add_xla_translation(<jaxlib.xla_extension.XlaBuilder object at 0x7fdfa87236f0>, <XlaOp at 0x7fdfa8723a70>, <XlaOp at 0x7fdfa8723a70>, <XlaOp at 0x7fdfa8723770>)
< multiply_add_xla_translation = <XlaOp at 0x7fdfa8723bf0>
Below is another use of jit
where we compile only with respect to the first argument. Notice how the second argument to square_add_prim
is concrete, which leads in the third argument to multiply_add_abstract_eval
being ConcreteArray
. We see that multiply_add_abstract_eval
may be used with both ShapedArray
and ConcreteArray
.
[13]:
assert api.jit(lambda x, y: square_add_prim(x, y),
static_argnums=1)(2., 10.) == 14.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, 10.0)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, 10.0)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
< square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
call multiply_add_xla_translation(<jaxlib.xla_extension.XlaBuilder object at 0x7fdfb4072fb0>, <XlaOp at 0x7fdfa8751ef0>, <XlaOp at 0x7fdfa8751ef0>, <XlaOp at 0x7fdfb50d9270>)
< multiply_add_xla_translation = <XlaOp at 0x7fdfa8751e30>
Forward differentiation¶
JAX implements forward differentiation in the form of a Jacobianvector product (see the JAX autodiff cookbook).
If we attempt now to compute the jvp
function we get an error because we have not yet told JAX how to differentiate the multiply_add
primitive.
[14]:
# The second argument `(2., 10.)` are the argument values
# where we evaluate the Jacobian, and the third `(1., 1.)`
# are the values of the tangents for the arguments.
with expectNotImplementedError():
api.jvp(square_add_prim, (2., 10.), (1., 1.))
call square_add_prim(Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ConcreteArray(10.0, weak_type=True)>)
call multiply_add_prim(Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ConcreteArray(10.0, weak_type=True)>)
Found expected exception:
Traceback (most recent call last):
File "<ipythoninput1f07eb564206f>", line 5, in <module>
api.jvp(square_add_prim, (2., 10.), (1., 1.))
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.7/sitepackages/jax/api.py", line 1657, in jvp
return _jvp(lu.wrap_init(fun), primals, tangents)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.7/sitepackages/jax/api.py", line 1684, in _jvp
out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat)
NotImplementedError: Differentiation rule for 'multiply_add' not implemented
[15]:
from jax.interpreters import ad
@trace("multiply_add_value_and_jvp")
def multiply_add_value_and_jvp(arg_values, arg_tangents):
"""Evaluates the primal output and the tangents (Jacobianvector product).
Given values of the arguments and perturbation of the arguments (tangents),
compute the output of the primitive and the perturbation of the output.
This method must be JAXtraceable. JAX may invoke it with abstract values
for the arguments and tangents.
Args:
arg_values: a tuple of arguments
arg_tangents: a tuple with the tangents of the arguments. The tuple has
the same length as the arg_values. Some of the tangents may also be the
special value ad.Zero to specify a zero tangent.
Returns:
a pair of the primal output and the tangent.
"""
x, y, z = arg_values
xt, yt, zt = arg_tangents
_trace("Primal evaluation:")
# Now we have a JAXtraceable computation of the output.
# Normally, we can use the ma primtive itself to compute the primal output.
primal_out = multiply_add_prim(x, y, z)
_trace("Tangent evaluation:")
# We must use a JAXtraceable way to compute the tangent. It turns out that
# the output tangent can be computed as (xt * y + x * yt + zt),
# which we can implement in a JAXtraceable way using the same "multiply_add_prim" primitive.
# We do need to deal specially with Zero. Here we just turn it into a
# proper tensor of 0s (of the same shape as 'x').
# An alternative would be to check for Zero and perform algebraic
# simplification of the output tangent computation.
def make_zero(tan):
return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan
output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))
return (primal_out, output_tangent)
# Register the forward differentiation rule with JAX
ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp
[16]:
# Tangent is: xt*y + x*yt + zt = 1.*2. + 2.*1. + 1. = 5.
assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.)
call square_add_prim(Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ConcreteArray(10.0, weak_type=True)>)
call multiply_add_prim(Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ConcreteArray(10.0, weak_type=True)>)
call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (1.0, 1.0, 1.0))
Primal evaluation:
call multiply_add_prim(2.0, 2.0, 10.0)
call multiply_add_impl(2.0, 2.0, 10.0)
< multiply_add_impl = 14.0
< multiply_add_prim = 14.0
Tangent evaluation:
call multiply_add_prim(2.0, 1.0, 1.0)
call multiply_add_impl(2.0, 1.0, 1.0)
< multiply_add_impl = 3.0
< multiply_add_prim = 3.0
call multiply_add_prim(1.0, 2.0, 3.0)
call multiply_add_impl(1.0, 2.0, 3.0)
< multiply_add_impl = 5.0
< multiply_add_prim = 5.0
< multiply_add_value_and_jvp = (14.0, 5.0)
< multiply_add_prim = Traced<ConcreteArray(14.0)>
< square_add_prim = Traced<ConcreteArray(14.0)>
TO EXPLAIN:
Why is JAX using ConcreteArray in square_add_prim? There is no abstract evaluation going on here.
Not sure how to explain that multiply_add_prim is invoked with ConcreteValue, yet we do not call the multiply_add_abstract_eval.
I think it would be useful to show the jaxpr here
JIT of forward differentiation¶
We can apply JIT to the forward differentiation function:
[17]:
assert api.jit(lambda arg_values, arg_tangents:
api.jvp(square_add_prim, arg_values, arg_tangents))(
(2., 10.), (1., 1.)) == (14., 5.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
call multiply_add_value_and_jvp((Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>), (Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>))
Primal evaluation:
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
Tangent evaluation:
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
< multiply_add_value_and_jvp = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>)
< multiply_add_prim = Traced<ShapedArray(float32[])>
< square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_xla_translation(<jaxlib.xla_extension.XlaBuilder object at 0x7fdfa86b52b0>, <XlaOp at 0x7fdfa86b5af0>, <XlaOp at 0x7fdfa86b5af0>, <XlaOp at 0x7fdfa86b5ab0>)
< multiply_add_xla_translation = <XlaOp at 0x7fdfa86b5c30>
call multiply_add_xla_translation(<jaxlib.xla_extension.XlaBuilder object at 0x7fdfa86b52b0>, <XlaOp at 0x7fdfa86b5af0>, <XlaOp at 0x7fdfa86b59f0>, <XlaOp at 0x7fdfa86b52f0>)
< multiply_add_xla_translation = <XlaOp at 0x7fdfa86b5770>
call multiply_add_xla_translation(<jaxlib.xla_extension.XlaBuilder object at 0x7fdfa86b52b0>, <XlaOp at 0x7fdfa86b59f0>, <XlaOp at 0x7fdfa86b5af0>, <XlaOp at 0x7fdfa86b5770>)
< multiply_add_xla_translation = <XlaOp at 0x7fdfa86b5e70>
Notice that first we evaluate multiply_add_value_and_jvp
abstractly, which in turn evaluates abstractly both the primal and the tangent evaluation (a total of 3 invocations of the ma
primitive). Then we compile the 3 occurrences of the primitive.
Reverse differentiation¶
If we attempt now to use reverse differentiation we see that JAX starts by using the multiply_add_value_and_jvp
to compute the forward differentiation for abstract values, but then runs into a NotImplementedError
.
When computing the reverse differentiation JAX first does abstract evaluation of the forward differentiation code multiply_add_value_and_jvp
to obtain a trace of primitives that compute the output tangent. Observe that JAX performs this abstract evaluation with concrete values for the differentiation point, and abstract values for the tangents. Observe also that JAX uses the special abstract tangent value Zero
for the tangent corresponding to the 3rd argument of ma
. This reflects the
fact that we do not differentiate w.r.t. the 2nd argument to square_add_prim
, which flow to 3rd argument to multiply_add_prim
.
Observe also that during the abstract evaluation of the tangent we pass the value 0.0 as the tangent for the 3rd argument. This is due to the use of the make_zero
function in the definition of multiply_add_value_and_jvp
.
[18]:
# This is reverse differentiation w.r.t. the first argument of square_add_prim
with expectNotImplementedError():
api.grad(square_add_prim)(2., 10.)
call square_add_prim(Traced<ConcreteArray(2.0, weak_type=True)>, 10.0)
call multiply_add_prim(Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ConcreteArray(2.0, weak_type=True)>, 10.0)
call multiply_add_value_and_jvp((Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ConcreteArray(2.0, weak_type=True)>, 10.0), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
Primal evaluation:
call multiply_add_prim(Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ConcreteArray(2.0, weak_type=True)>, 10.0)
call multiply_add_impl(2.0, 2.0, 10.0)
< multiply_add_impl = 14.0
< multiply_add_prim = 14.0
Tangent evaluation:
call multiply_add_prim(Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, 0.0)
call multiply_add_abstract_eval(ConcreteArray(2.0, weak_type=True), ShapedArray(float32[], weak_type=True), ConcreteArray(0.0))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ShapedArray(float32[])>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ConcreteArray(2.0, weak_type=True), ShapedArray(float32[]))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>
< multiply_add_value_and_jvp = (14.0, Traced<ShapedArray(float32[])>)
< multiply_add_prim = Traced<ConcreteArray(14.0)>
< square_add_prim = Traced<ConcreteArray(14.0)>
Found expected exception:
Traceback (most recent call last):
File "<ipythoninput1a915b4bc91d2>", line 3, in <module>
api.grad(square_add_prim)(2., 10.)
jax._src.traceback_util.FilteredStackTrace: NotImplementedError: Transpose rule (for reversemode differentiation) for 'multiply_add' not implemented
The stack trace above excludes JAXinternal frames.
The following is the original exception that occurred, unmodified.

The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.7/sitepackages/jax/interpreters/ad.py", line 254, in get_primitive_transpose
return primitive_transposes[p]
KeyError: multiply_add
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<ipythoninput1a915b4bc91d2>", line 3, in <module>
api.grad(square_add_prim)(2., 10.)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.7/sitepackages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.7/sitepackages/jax/api.py", line 706, in grad_f
_, g = value_and_grad_f(*args, **kwargs)
NotImplementedError: Transpose rule (for reversemode differentiation) for 'multiply_add' not implemented
The above error is because there is a missing piece for JAX to be able to use the forward differentiation code to compute reverse differentiation.
Transposition¶
As explained above, when computing reverse differentiation JAX obtains a trace of primitives that compute the tangent using forward differentiation. Then, JAX interprets this trace abstractly backwards and for each primitive it applies a transposition rule.
To understand what is going on, consider for now a simpler example of the function “f(x, y) = x * y + y”. Assume we need to differentiate at the point (2., 4.)
. JAX will produce the following JVP tangent calculation of ft
from the tangents of the input xt
and yt
:
a = xt * 4.
b = 2. * yt
c = a + b
ft = c + yt
By construction, the tangent calculation is always linear in the input tangents. The only nonlinear operator that may arise in the tangent calculation is multiplication, but then one of the operands is constant.
JAX will produce the reverse differentiation computation by processing the JVP computation backwards. For each operation in the tangent computation, it accumulates the cotangents of the variables used by the operation, using the cotangent of the result of the operation:
# Initialize cotangents of inputs and intermediate vars
xct = yct = act = bct = cct = 0.
# Initialize cotangent of the output
fct = 1.
# Process "ft = c + yt"
cct += fct
yct += fct
# Process "c = a + b"
act += cct
bct += cct
# Process "b = 2. * yt"
yct += 2. * bct
# Process "a = xt * 4."
xct += act * 4.
One can verify that this computation produces xct = 4.
and yct = 3.
, which are the partial derivatives of the function f
.
JAX knows for each primitive that may appear in a JVP calculation how to transpose it. Conceptually, if the primitive p(x, y, z)
is linear in the arguments y
and z
for a constant value of x
, e.g., p(x, y, z) = y*cy + z*cz
, then the transposition of the primitive is:
p_transpose(out_ct, x, _, _) = (None, out_ct*cy, out_ct*cz)
Notice that p_transpose
takes the cotangent of the output of the primitive and a value corresponding to each argument of the primitive. For the linear arguments, the transposition gets an undefined _
value, and for the other arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value None
returned for the constant arguments.
In particular,
add_transpose(out_ct, _, _) = (out_ct, out_ct)
mult_transpose(out_ct, x, _) = (None, x * out_ct)
mult_transpose(out_ct, _, y) = (out_ct * y, None)
[19]:
@trace("multiply_add_transpose")
def multiply_add_transpose(ct, x, y, z):
"""Evaluates the transpose of a linear primitive.
This method is only used when computing the backward gradient following
value_and_jvp, and is only needed for primitives that are used in the JVP
calculation for some other primitive. We need transposition for multiply_add_prim,
because we have used multiply_add_prim in the computation of the output_tangent in
multiply_add_value_and_jvp.
In our case, multiply_add is not a linear primitive. However, it is used linearly
w.r.t. tangents in multiply_add_value_and_jvp:
output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))
Always one of the first two multiplicative arguments are constants.
Args:
ct: the cotangent of the output of the primitive.
x, y, z: values of the arguments. The arguments that are used linearly
get an ad.UndefinedPrimal value. The other arguments get a constant
value.
Returns:
a tuple with the cotangent of the inputs, with the value None
corresponding to the constant arguments.
"""
if not ad.is_undefined_primal(x):
# This use of multiply_add is with a constant "x"
assert ad.is_undefined_primal(y)
ct_y = ad.Zero(y.aval) if type(ct) is ad.Zero else multiply_add_prim(x, ct, lax.zeros_like_array(x))
res = None, ct_y, ct
else:
# This use of multiply_add is with a constant "y"
assert ad.is_undefined_primal(x)
ct_x = ad.Zero(x.aval) if type(ct) is ad.Zero else multiply_add_prim(ct, y, lax.zeros_like_array(y))
res = ct_x, None, ct
return res
ad.primitive_transposes[multiply_add_p] = multiply_add_transpose
Now we can complete the run of the grad
:
[20]:
assert api.grad(square_add_prim)(2., 10.) == 4.
call square_add_prim(Traced<ConcreteArray(2.0, weak_type=True)>, 10.0)
call multiply_add_prim(Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ConcreteArray(2.0, weak_type=True)>, 10.0)
call multiply_add_value_and_jvp((Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ConcreteArray(2.0, weak_type=True)>, 10.0), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
Primal evaluation:
call multiply_add_prim(Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ConcreteArray(2.0, weak_type=True)>, 10.0)
call multiply_add_impl(2.0, 2.0, 10.0)
< multiply_add_impl = 14.0
< multiply_add_prim = 14.0
Tangent evaluation:
call multiply_add_prim(Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, 0.0)
call multiply_add_abstract_eval(ConcreteArray(2.0, weak_type=True), ShapedArray(float32[], weak_type=True), ConcreteArray(0.0))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ShapedArray(float32[])>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ConcreteArray(2.0, weak_type=True), ShapedArray(float32[]))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>
< multiply_add_value_and_jvp = (14.0, Traced<ShapedArray(float32[])>)
< multiply_add_prim = Traced<ConcreteArray(14.0)>
< square_add_prim = Traced<ConcreteArray(14.0)>
call multiply_add_transpose(1.0, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), 2.0, UndefinedPrimal(ShapedArray(float32[])))
call multiply_add_prim(1.0, 2.0, 0.0)
call multiply_add_impl(1.0, 2.0, 0.0)
< multiply_add_impl = 2.0
< multiply_add_prim = 2.0
< multiply_add_transpose = (2.0, None, 1.0)
call multiply_add_transpose(1.0, 2.0, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), 0.0)
call multiply_add_prim(2.0, 1.0, 0.0)
call multiply_add_impl(2.0, 1.0, 0.0)
< multiply_add_impl = 2.0
< multiply_add_prim = 2.0
< multiply_add_transpose = (None, 2.0, 1.0)
Notice the two calls to multiply_add_transpose
. They correspond to the two uses of multiply_add_prim
in the computation of the output_tangent
in multiply_add_value_and_jvp
. The first call to transpose corresponds to the last use of multiply_add_prim
: multiply_add_prim(xt, y, ...)
where y
is the constant 2.0.
JIT of reverse differentiation¶
Notice that the abstract evaluation of the multiply_add_value_and_jvp
is using only abstract values, while in the absensce of JIT we used ConcreteArray
.
[21]:
assert api.jit(api.grad(square_add_prim))(2., 10.) == 4.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>)
call multiply_add_value_and_jvp((Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
Primal evaluation:
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
Tangent evaluation:
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[])>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>
< multiply_add_value_and_jvp = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[])>)
< multiply_add_prim = Traced<ShapedArray(float32[])>
< square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_transpose(1.0, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, UndefinedPrimal(ShapedArray(float32[])))
call multiply_add_prim(1.0, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>)
call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
< multiply_add_transpose = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>, None, 1.0)
call multiply_add_transpose(1.0, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, 1.0, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[]), ShapedArray(float32[]))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
< multiply_add_transpose = (None, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>, 1.0)
call multiply_add_xla_translation(<jaxlib.xla_extension.XlaBuilder object at 0x7fdfa86b5570>, <XlaOp at 0x7fdfa86d9170>, <XlaOp at 0x7fdfa86d9170>, <XlaOp at 0x7fdfa86d9730>)
< multiply_add_xla_translation = <XlaOp at 0x7fdfe099feb0>
call multiply_add_xla_translation(<jaxlib.xla_extension.XlaBuilder object at 0x7fdfa86b5570>, <XlaOp at 0x7fdfe099f430>, <XlaOp at 0x7fdfa86d9170>, <XlaOp at 0x7fdfb4072fb0>)
< multiply_add_xla_translation = <XlaOp at 0x7fdfe099f770>
call multiply_add_xla_translation(<jaxlib.xla_extension.XlaBuilder object at 0x7fdfa86b5570>, <XlaOp at 0x7fdfa86d9170>, <XlaOp at 0x7fdfe3c8bc30>, <XlaOp at 0x7fdfbcd0c170>)
< multiply_add_xla_translation = <XlaOp at 0x7fdfa8751e70>
Batching¶
The batching transformation takes a pointwise computation and turns it into a computation on vectors. If we try it right now, we get a NotImplementedError
:
[22]:
# The arguments are two vectors instead of two scalars
with expectNotImplementedError():
api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),
np.array([10., 20.]))
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
Found expected exception:
Traceback (most recent call last):
File "<ipythoninput15e627b0f2eb2>", line 4, in <module>
np.array([10., 20.]))
File "<ipythoninput1756fd2c18f40>", line 48, in func_wrapper
res = func(*args)
File "<ipythoninput1c5402c1795f0>", line 16, in square_add_prim
return multiply_add_prim(a, a, b)
jax._src.traceback_util.FilteredStackTrace: NotImplementedError: Batching rule for 'multiply_add' not implemented
The stack trace above excludes JAXinternal frames.
The following is the original exception that occurred, unmodified.

The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.7/sitepackages/jax/interpreters/batching.py", line 281, in get_primitive_batcher
return primitive_batchers[p]
KeyError: multiply_add
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<ipythoninput15e627b0f2eb2>", line 4, in <module>
np.array([10., 20.]))
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.7/sitepackages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.7/sitepackages/jax/api.py", line 1189, in batched_fun
axis_name=axis_name)
NotImplementedError: Batching rule for 'multiply_add' not implemented
We need to tell JAX how to evaluate the batched version of the primitive. In this particular case, the multiply_add_prim
already operates pointwise for any dimension of input vectors. So the batched version can use the same multiply_add_prim
implementation.
[23]:
from jax.interpreters import batching
@trace("multiply_add_batch")
def multiply_add_batch(vector_arg_values, batch_axes):
"""Computes the batched version of the primitive.
This must be a JAXtraceable function.
Since the multiply_add primitive already operates pointwise on arbitrary
dimension tensors, to batch it we can use the primitive itself. This works as
long as both the inputs have the same dimensions and are batched along the
same axes. The result is batched along the axis that the inputs are batched.
Args:
vector_arg_values: a tuple of two arguments, each being a tensor of matching
shape.
batch_axes: the axes that are being batched. See vmap documentation.
Returns:
a tuple of the result, and the result axis that was batched.
"""
assert batch_axes[0] == batch_axes[1]
assert batch_axes[0] == batch_axes[2]
_trace("Using multiply_add to compute the batch:")
res = multiply_add_prim(*vector_arg_values)
return res, batch_axes[0]
batching.primitive_batchers[multiply_add_p] = multiply_add_batch
[24]:
assert np.allclose(api.vmap(square_add_prim, in_axes=0, out_axes=0)(
np.array([2., 3.]),
np.array([10., 20.])),
[14., 29.])
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_batch(([2. 3.], [2. 3.], [10. 20.]), (0, 0, 0))
Using multiply_add to compute the batch:
call multiply_add_prim([2. 3.], [2. 3.], [10. 20.])
call multiply_add_impl([2. 3.], [2. 3.], [10. 20.])
< multiply_add_impl = [14. 29.]
< multiply_add_prim = [14. 29.]
< multiply_add_batch = ([14. 29.], 0)
< multiply_add_prim = Traced<ShapedArray(float32[])>
< square_add_prim = Traced<ShapedArray(float32[])>
JIT of batching¶
[25]:
assert np.allclose(api.jit(api.vmap(square_add_prim, in_axes=0, out_axes=0))
(np.array([2., 3.]),
np.array([10., 20.])),
[14., 29.])
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_batch((Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=0/1)>), (0, 0, 0))
Using multiply_add to compute the batch:
call multiply_add_prim(Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=0/1)>)
call multiply_add_abstract_eval(ShapedArray(float32[2]), ShapedArray(float32[2]), ShapedArray(float32[2]))
< multiply_add_abstract_eval = ShapedArray(float32[2])
< multiply_add_prim = Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=0/1)>
< multiply_add_batch = (Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=0/1)>, 0)
< multiply_add_prim = Traced<ShapedArray(float32[])>
< square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_xla_translation(<jaxlib.xla_extension.XlaBuilder object at 0x7fdfa87231b0>, <XlaOp at 0x7fdfa8723c70>, <XlaOp at 0x7fdfa8723c70>, <XlaOp at 0x7fdfa87239b0>)
< multiply_add_xla_translation = <XlaOp at 0x7fdfa8723870>
Writing custom Jaxpr interpreters in JAX¶
JAX offers several composable function transformations (jit
, grad
, vmap
, etc.) that enable writing concise, accelerated code.
Here we show how to add your own function transformations to the system, by writing a custom Jaxpr interpreter. And we’ll get composability with all the other transformations for free.
This example uses internal JAX APIs, which may break at any time. Anything not in `the API Documentation <https://jax.readthedocs.io/en/latest/jax.html>`__ should be assumed internal.
[1]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
from jax import random
What is JAX doing?¶
JAX provides a NumPylike API for numerical computing which can be used as is, but JAX’s true power comes from composable function transformations. Take the jit
function transformation, which takes in a function and returns a semantically identical function but is lazily compiled by XLA for accelerators.
[2]:
x = random.normal(random.PRNGKey(0), (5000, 5000))
def f(w, b, x):
return jnp.tanh(jnp.dot(x, w) + b)
fast_f = jit(f)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
When we call fast_f
, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JITcompiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax’s tracing machinery, you can refer to the “How it works” section in the README.
Jaxpr tracer¶
A tracer of special importance in Jax is the Jaxpr tracer, which records ops into a Jaxpr (Jax expression). A Jaxpr is a data structure that can be evaluated like a mini functional programming language and thus Jaxprs are a useful intermediate representation for function transformation.
To get a first look at Jaxprs, consider the make_jaxpr
transformation. make_jaxpr
is essentially a “prettyprinting” transformation: it transforms a function into one that, given example arguments, produces a Jaxpr representation of its computation. Although we can’t generally use the Jaxprs that it returns, it is useful for debugging and introspection. Let’s use it to look at how some example Jaxprs are structured.
[3]:
def examine_jaxpr(typed_jaxpr):
jaxpr = typed_jaxpr.jaxpr
print("invars:", jaxpr.invars)
print("outvars:", jaxpr.outvars)
print("constvars:", jaxpr.constvars)
for eqn in jaxpr.eqns:
print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
print()
print("jaxpr:", jaxpr)
def foo(x):
return x + 1
print("foo")
print("=====")
examine_jaxpr(jax.make_jaxpr(foo)(5))
print()
def bar(w, b, x):
return jnp.dot(w, x) + b + jnp.ones(5), x
print("bar")
print("=====")
examine_jaxpr(jax.make_jaxpr(bar)(jnp.ones((5, 10)), jnp.ones(5), jnp.ones(10)))
foo
=====
invars: [a]
outvars: [b]
constvars: []
equation: [a, 1] add [b] {}
jaxpr: { lambda ; a.
let b = add a 1
in (b,) }
bar
=====
invars: [a, b, c]
outvars: [g, c]
constvars: []
equation: [a, c] dot_general [d] {'dimension_numbers': (((1,), (0,)), ((), ())), 'precision': None}
equation: [d, b] add [e] {}
equation: [1.0] broadcast_in_dim [f] {'shape': (5,), 'broadcast_dimensions': ()}
equation: [e, f] add [g] {}
jaxpr: { lambda ; a b c.
let d = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))
precision=None ] a c
e = add d b
f = broadcast_in_dim[ broadcast_dimensions=( )
shape=(5,) ] 1.0
g = add e f
in (g, c) }
jaxpr.invars
 theinvars
of a Jaxpr are a list of the input variables to Jaxpr, analogous to arguments in Python functionsjaxpr.outvars
 theoutvars
of a Jaxpr are the variables that are returned by the Jaxpr. Every Jaxpr has multiple outputs.jaxpr.constvars
 theconstvars
are a list of variables that are also inputs to the Jaxpr, but correspond to constants from the trace (we’ll go over these in more detail later)jaxpr.eqns
 a list of equations, which are essentially letbindings. Each equation is list of input variables, a list of output variables, and a primitive, which is used to evaluate inputs to produce outputs. Each equation also has aparams
, a dictionary of parameters.
All together, a Jaxpr encapsulates a simple program that can be evaluated with inputs to produce an output. We’ll go over how exactly to do this later. The important thing to note now is that a Jaxpr is a data structure that can be manipulated and evaluated in whatever way we want.
Why are Jaxprs useful?¶
Jaxprs are simple program representations that are easy to transform. And because Jax lets us stage out Jaxprs from Python functions, it gives us a way to transform numerical programs written in Python.
Your first interpreter: invert
¶
Let’s try to implement a simple function “inverter”, which takes in the output of the original function and returns the inputs that produced those outputs. For now, let’s focus on simple, unary functions which are composed of other invertible unary functions.
Goal:
def f(x):
return jnp.exp(jnp.tanh(x))
f_inv = inverse(f)
assert jnp.allclose(f_inv(f(1.0)), 1.0)
The way we’ll implement this is by (1) tracing f
into a Jaxpr, then (2) interpreting the Jaxpr backwards. While interpreting the Jaxpr backwards, for each equation we’ll look up the primitive’s inverse in a table and apply it.
1. Tracing a function¶
We can’t use make_jaxpr
for this, because we need to pull out constants created during the trace to pass into the Jaxpr. However, we can write a function that does something very similar to make_jaxpr
.
[4]:
# Importing Jax functions useful for tracing/interpreting.
import numpy as np
from functools import wraps
from jax import core
from jax import lax
from jax.util import safe_map
This function first flattens its arguments into a list, which are the abstracted and wrapped as partial values. The pe.trace_to_jaxpr
function is used to then trace a function into a Jaxpr from a list of partial value inputs.
[5]:
def f(x):
return jnp.exp(jnp.tanh(x))
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
print(closed_jaxpr)
print(closed_jaxpr.literals)
{ lambda ; a.
let b = tanh a
c = exp b
in (c,) }
[]
2. Evaluating a Jaxpr¶
Before we write a custom Jaxpr interpreter, let’s first implement the “default” interpreter, eval_jaxpr
, which evaluates the Jaxpr asis, computing the same values that the original, untransformed Python function would.
To do this, we first create an environment to store the values for each of the variables, and update the environment with each equation we evaluate in the Jaxpr.
[6]:
def eval_jaxpr(jaxpr, consts, *args):
# Mapping from variable > value
env = {}
def read(var):
# Literals are values baked into the Jaxpr
if type(var) is core.Literal:
return var.val
return env[var]
def write(var, val):
env[var] = val
# Bind args and consts to environment
write(core.unitvar, core.unit)
safe_map(write, jaxpr.invars, args)
safe_map(write, jaxpr.constvars, consts)
# Loop through equations and evaluate primitives using `bind`
for eqn in jaxpr.eqns:
# Read inputs to equation from environment
invals = safe_map(read, eqn.invars)
# `bind` is how a primitive is called
outvals = eqn.primitive.bind(*invals, **eqn.params)
# Primitives may return multiple outputs or not
if not eqn.primitive.multiple_results:
outvals = [outvals]
# Write the results of the primitive into the environment
safe_map(write, eqn.outvars, outvals)
# Read the final result of the Jaxpr from the environment
return safe_map(read, jaxpr.outvars)
[7]:
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, jnp.ones(5))
[7]:
[DeviceArray([2.1416876, 2.1416876, 2.1416876, 2.1416876, 2.1416876], dtype=float32)]
Notice that eval_jaxpr
will always return a flat list even if the original function does not.
Furthermore, this interpreter does not handle subjaxprs
, which we will not cover in this guide. You can refer to core.eval_jaxpr
(link) to see the edge cases that this interpreter does not cover.
Custom inverse
Jaxpr interpreter¶
An inverse
interpreter doesn’t look too different from eval_jaxpr
. We’ll first set up the registry which will map primitives to their inverses. We’ll then write a custom interpreter that looks up primitives in the registry.
It turns out that this interpreter will also look similar to the “transpose” interpreter used in reversemode autodifferentiation found here.
[8]:
inverse_registry = {}
We’ll now register inverses for some of the primitives. By convention, primitives in Jax end in _p
and a lot of the popular ones live in lax
.
[9]:
inverse_registry[lax.exp_p] = jnp.log
inverse_registry[lax.tanh_p] = jnp.arctanh
inverse
will first trace the function, then custominterpret the Jaxpr. Let’s set up a simple skeleton.
[10]:
def inverse(fun):
@wraps(fun)
def wrapped(*args, **kwargs):
# Since we assume unary functions, we won't
# worry about flattening and
# unflattening arguments
closed_jaxpr = jax.make_jaxpr(fun)(*args, **kwargs)
out = inverse_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
return out[0]
return wrapped
Now we just need to define inverse_jaxpr
, which will walk through the Jaxpr backward and invert primitives when it can.
[11]:
def inverse_jaxpr(jaxpr, consts, *args):
env = {}
def read(var):
if type(var) is core.Literal:
return var.val
return env[var]
def write(var, val):
env[var] = val
# Args now correspond to Jaxpr outvars
write(core.unitvar, core.unit)
safe_map(write, jaxpr.outvars, args)
safe_map(write, jaxpr.constvars, consts)
# Looping backward
for eqn in jaxpr.eqns[::1]:
# outvars are now invars
invals = safe_map(read, eqn.outvars)
if eqn.primitive not in inverse_registry:
raise NotImplementedError("{} does not have registered inverse.".format(
eqn.primitive
))
# Assuming a unary function
outval = inverse_registry[eqn.primitive](*invals)
safe_map(write, eqn.invars, [outval])
return safe_map(read, jaxpr.invars)
That’s it!
[12]:
def f(x):
return jnp.exp(jnp.tanh(x))
f_inv = inverse(f)
assert jnp.allclose(f_inv(f(1.0)), 1.0)
Importantly, you can trace through a Jaxpr interpreter.
[13]:
jax.make_jaxpr(inverse(f))(f(1.))
[13]:
{ lambda ; a.
let b = log a
c = atanh b
in (c,) }
That’s all it takes to add a new transformation to a system, and you get composition with all the others for free! For example, we can use jit
, vmap
, and grad
with inverse
!
[14]:
jit(vmap(grad(inverse(f))))((jnp.arange(5) + 1.) / 5.)
[14]:
DeviceArray([3.1440797, 15.584931 , 2.2551253, 1.3155028, 1. ], dtype=float32)
Exercises for the reader¶
Handle primitives with multiple arguments where inputs are partially known, for example
lax.add_p
,lax.mul_p
.Handle
xla_call
andxla_pmap
primitives, which will not work with botheval_jaxpr
andinverse_jaxpr
as written.
Change Log¶
These are the release notes for JAX.
jax 0.2.7 (Unreleased)¶
Breaking changes:
jax.experimental.optix
has been deleted, in favor of the standaloneoptax
Python package.indexing of JAX arrays with nontuple sequences now raises a TypeError. This type of indexing has been deprecated in Numpy since v1.16, and in JAX since v0.2.4. See #4564.
jax 0.2.6 (Nov 18 2020)¶
New Features:
Add support for shapepolymorphic tracing for the jax.experimental.jax2tf converter. See README.md.
Breaking change cleanup
Raise an error on nonhashable static arguments for jax.jit and xla_computation. See cb48f42.
Improve consistency of type promotion behavior (#4744):
Adding a complex Python scalar to a JAX floating point number respects the precision of the JAX float. For example,
jnp.float32(1) + 1j
now returnscomplex64
, where previously it returnedcomplex128
.Results of type promotion with 3 or more terms involving uint64, a signed int, and a third type are now independent of the order of arguments. For example:
jnp.result_type(jnp.uint64, jnp.int64, jnp.float16)
andjnp.result_type(jnp.float16, jnp.uint64, jnp.int64)
both returnfloat16
, where previously the first returnedfloat64
and the second returnedfloat16
.
The contents of the (undocumented)
jax.lax_linalg
linear algebra module are now exposed publicly asjax.lax.linalg
.jax.random.PRNGKey
now produces the same results in and out of JIT compilation (#4877). This required changing the result for a given seed in a few particular cases:With
jax_enable_x64=False
, negative seeds passed as Python integers now return a different result outside JIT mode. For example,jax.random.PRNGKey(1)
previously returned[4294967295, 4294967295]
, and now returns[0, 4294967295]
. This matches the behavior in JIT.Seeds outside the range representable by int64 outside JIT now result in an
OverflowError
rather than aTypeError
. This matches the behavior in JIT.
To recover the keys returned previously for negative integers with
jax_enable_x64=False
outside JIT, you can use:key = random.PRNGKey(1).at[0].set(0xFFFFFFFF)
DeviceArray now raises
RuntimeError
instead ofValueError
when trying to access its value while it has been deleted.
jaxlib 0.1.58 (Unreleased)¶
Fixed a bug that meant JAX sometimes return platformspecific types (e.g., np.cint) instead of standard types (e.g., np.int32). (#4903)
jaxlib 0.1.57 (November 12 2020)¶
Fixed manylinux2010 compliance issues in GPU wheels.
Switched the CPU FFT implementation from Eigen to PocketFFT.
Fixed a bug where the hash of bfloat16 values was not correctly initialized and could change (#4651).
Add support for retaining ownership when passing arrays to DLPack (#4636).
Fixed a bug for batched triangular solves with sizes greater than 128 but not a multiple of 128.
Fixed a bug when performing concurrent FFTs on multiple GPUs (#3518).
Fixed a bug in profiler where tools are missing (#4427).
Dropped support for CUDA 10.0.
jax 0.2.5 (October 27 2020)¶
Improvements:
Ensure that check_jaxpr does not perform FLOPS. See #4650.
Expanded the set of JAX primitives converted by jax2tf. See primitives_with_limited_support.md.
jax 0.2.4 (October 19 2020)¶
jaxlib 0.1.56 (October 14, 2020)¶
jax 0.2.3 (October 14 2020)¶
The reason for another release so soon is we need to temporarily roll back a new jit fastpath while we look into a performance degradation
jax 0.2.2 (October 13 2020)¶
jax 0.2.1 (October 6 2020)¶
Improvements:
As a benefit of omnistaging, the host_callback functions are executed (in program order) even if the result of the
jax.experimental.host_callback.id_print()
/jax.experimental.host_callback.id_tap()
is not used in the computation.
jax (0.1.77) (September 15 2020)¶
Breaking changes:
New simplified interface for
jax.experimental.host_callback.id_tap()
(#4101)
jaxlib 0.1.55 (September 8, 2020)¶
Update XLA:
Fix bug in DLPackManagedTensorToBuffer (#4196)
jax 0.1.76 (September 8, 2020)¶
jax 0.1.75 (July 30, 2020)¶
Bug Fixes:
make jnp.abs() work for unsigned inputs (#3914)
Improvements:
“Omnistaging” behavior added behind a flag, disabled by default (#3370)
jax 0.1.74 (July 29, 2020)¶
New Features:
BFGS (#3101)
TPU suppot for halfprecision arithmetic (#3878)
Bug Fixes:
Prevent some accidental dtype warnings (#3874)
Fix a multithreading bug in custom derivatives (#3845, #3869)
Improvements:
Faster searchsorted implementation (#3873)
Better test coverage for jax.numpy sorting algorithms (#3836)
jaxlib 0.1.52 (July 22, 2020)¶
Update XLA.
jax 0.1.73 (July 22, 2020)¶
The minimum jaxlib version is now 0.1.51.
New Features:
jax.image.resize. (#3703)
hfft and ihfft (#3664)
jax.numpy.intersect1d (#3726)
jax.numpy.lexsort (#3812)
lax.scan
and thescan
primitive support anunroll
parameter for loop unrolling when lowering to XLA (#3738).
Bug Fixes:
Fix reduction repeated axis error (#3618)
Fix shape rule for lax.pad for input dimensions of size 0. (#3608)
make psum transpose handle zero cotangents (#3653)
Fix shape error when taking JVP of reduceprod over size 0 axis. (#3729)
Support differentiation through jax.lax.all_to_all (#3733)
address nan issue in jax.scipy.special.zeta (#3777)
Improvements:
Many improvements to jax2tf
Reimplement argmin/argmax using a single pass variadic reduction. (#3611)
Enable XLA SPMD partitioning by default. (#3151)
Add support for 0d transpose convolution (#3643)
Make LU gradient work for lowrank matrices (#3610)
support multiple_results and custom JVPs in jet (#3657)
Generalize reducewindow padding to support (lo, hi) pairs. (#3728)
Implement complex convolutions on CPU and GPU. (#3735)
Make jnp.take work for empty slices of empty arrays. (#3751)
Relax dimension ordering rules for dot_general. (#3778)
Enable buffer donation for GPU. (#3800)
Add support for base dilation and window dilation to reduce window op… (#3803)
jaxlib 0.1.51 (July 2, 2020)¶
Update XLA.
Add new runtime support for host_callback.
jax 0.1.72 (June 28, 2020)¶
Bug fixes:
Fix an odeint bug introduced in the previous release, see #3587.
jax 0.1.71 (June 25, 2020)¶
The minimum jaxlib version is now 0.1.48.
Bug fixes:
Allow
jax.experimental.ode.odeint
dynamics functions to close over values with respect to which we’re differentiating #3562.
jaxlib 0.1.50 (June 25, 2020)¶
Add support for CUDA 11.0.
Drop support for CUDA 9.2 (we only maintain support for the last four CUDA versions.)
Update XLA.
jaxlib 0.1.49 (June 19, 2020)¶
Bug fixes:
Fix build issue that could result in slow compiles (https://github.com/tensorflow/tensorflow/commit/f805153a25b00d12072bd728e91bb1621bfcf1b1)
jaxlib 0.1.48 (June 12, 2020)¶
New features:
Adds support for fast traceback collection.
Adds preliminary support for ondevice heap profiling.
Implements
np.nextafter
forbfloat16
types.Complex128 support for FFTs on CPU and GPU.
Bugfixes:
Improved float64
tanh
accuracy on GPU.float64 scatters on GPU are much faster.
Complex matrix multiplication on CPU should be much faster.
Stable sorts on CPU should actually be stable now.
Concurrency bug fix in CPU backend.
jax 0.1.70 (June 8, 2020)¶
New features:
lax.switch
introduces indexed conditionals with multiple branches, together with a generalization of thecond
primitive #3318.
jax 0.1.69 (June 3, 2020)¶
jax 0.1.68 (May 21, 2020)¶
jax 0.1.67 (May 12, 2020)¶
New features:
Support for reduction over subsets of a pmapped axis using
axis_index_groups
#2382.Experimental support for printing and calling hostside Python function from compiled code. See id_print and id_tap (#3006).
Notable changes:
The visibility of names exported from :py:module:`jax.numpy` has been tightened. This may break code that was making use of names that were previously exported accidentally.
jaxlib 0.1.47 (May 8, 2020)¶
Fixes crash for outfeed.
jaxlib 0.1.46 (May 5, 2020)¶
Fixes crash for linear algebra functions on Mac OS X (#432).
Fixes an illegal instruction crash caused by using AVX512 instructions when an operating system or hypervisor disabled them (#2906).
jax 0.1.65 (April 30, 2020)¶
New features:
Differentiation of determinants of singular matrices #2809.
Bug fixes:
jaxlib 0.1.45 (April 21, 2020)¶
Fixes segfault: https://github.com/google/jax/issues/2755
Plumb is_stable option on Sort HLO through to Python.
jax 0.1.64 (April 21, 2020)¶
New features:
Add syntactic sugar for functional indexed updates #2684.
Add
jax.numpy.unique()
#2760.Add
jax.numpy.rint()
#2724.Add
jax.numpy.rint()
#2724.Add more primitive rules for
jax.experimental.jet()
.
Bug fixes:
Better errors:
Improves error message for reversemode differentiation of
lax.while_loop()
#2129.
jaxlib 0.1.44 (April 16, 2020)¶
Fixes a bug where if multiple GPUs of different models were present, JAX would only compile programs suitable for the first GPU.
Bugfix for
batch_group_count
convolutions.Added precompiled SASS for more GPU versions to avoid startup PTX compilation hang.
jax 0.1.63 (April 12, 2020)¶
Added
jax.custom_jvp
andjax.custom_vjp
from #2026, see the tutorial notebook. Deprecatedjax.custom_transforms
and removed it from the docs (though it still works).Add
scipy.sparse.linalg.cg
#2566.Changed how Tracers are printed to show more useful information for debugging #2591.
Made
jax.numpy.isclose
handlenan
andinf
correctly #2501.Added several new rules for
jax.experimental.jet
#2537.Fixed
jax.experimental.stax.BatchNorm
whenscale
/center
isn’t provided.Fix some missing cases of broadcasting in
jax.numpy.einsum
#2512.Implement
jax.numpy.cumsum
andjax.numpy.cumprod
in terms of a parallel prefix scan #2596 and makereduce_prod
differentiable to arbitray order #2597.Add
batch_group_count
toconv_general_dilated
#2635.Add docstring for
test_util.check_grads
#2656.Add
callback_transform
#2665.Implement
rollaxis
,convolve
/correlate
1d & 2d,copysign
,trunc
,roots
, andquantile
/percentile
interpolation options.
jaxlib 0.1.43 (March 31, 2020)¶
Fixed a performance regression for Resnet50 on GPU.
jax 0.1.62 (March 21, 2020)¶
JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
Removed the internal function
lax._safe_mul
, which implemented the convention0. * nan == 0.
. This change means some programs when differentiated will produce nans when they previously produced correct values, though it ensures nans rather than silently incorrect results are produced for other programs. See #2447 and #1052 for details.Added an
all_gather
parallel convenience function.More type annotations in core code.
jaxlib 0.1.42 (March 19, 2020)¶
jaxlib 0.1.41 broke cloud TPU support due to an API incompatibility. This release fixes it again.
JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
jax 0.1.61 (March 17, 2020)¶
Fixes Python 3.5 support. This will be the last JAX or jaxlib release that supports Python 3.5.
jax 0.1.60 (March 17, 2020)¶
New features:
jax.pmap()
hasstatic_broadcast_argnums
argument which allows the user to specify arguments that should be treated as compiletime constants and should be broadcasted to all devices. It works analogously tostatic_argnums
injax.jit()
.Improved error messages for when tracers are mistakenly saved in global state.
Added
jax.nn.one_hot()
utility function.Added :py:module:`jax.experimental.jet` for exponentially faster higherorder automatic differentiation.
Added more correctness checking to arguments of
jax.lax.broadcast_in_dim()
.
The minimum jaxlib version is now 0.1.41.
jaxlib 0.1.40 (March 4, 2020)¶
Adds experimental support in Jaxlib for TensorFlow profiler, which allows tracing of CPU and GPU computations from TensorBoard.
Includes prototype support for multihost GPU computations that communicate via NCCL.
Improves performance of NCCL collectives on GPU.
Adds TopK, CustomCallWithoutLayout, CustomCallWithLayout, IGammaGradA and RandomGamma implementations.
Supports device assignments known at XLA compilation time.
jax 0.1.59 (February 11, 2020)¶
Breaking changes
The minimum jaxlib version is now 0.1.38.
Simplified
Jaxpr
by removing theJaxpr.freevars
andJaxpr.bound_subjaxprs
. The call primitives (xla_call
,xla_pmap
,sharded_call
, andremat_call
) get a new parametercall_jaxpr
with a fullyclosed (noconstvars
) jaxpr. Also, added a new fieldcall_primitive
to primitives.
New features:
Reversemode automatic differentiation (e.g.
grad
) oflax.cond
, making it now differentiable in both modes (https://github.com/google/jax/pull/2091)JAX now supports DLPack, which allows sharing CPU and GPU arrays in a zerocopy way with other libraries, such as PyTorch.
JAX GPU DeviceArrays now support
__cuda_array_interface__
, which is another zerocopy protocol for sharing GPU arrays with other libraries such as CuPy and Numba.JAX CPU device buffers now implement the Python buffer protocol, which allows zerocopy buffer sharing between JAX and NumPy.
Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known as slow.
jaxlib 0.1.39 (February 11, 2020)¶
Updates XLA.
jaxlib 0.1.38 (January 29, 2020)¶
CUDA 9.0 is no longer supported.
CUDA 10.2 wheels are now built by default.
jax 0.1.58 (January 28, 2020)¶
Breaking changes
JAX has dropped Python 2 support, because Python 2 reached its end of life on January 1, 2020. Please update to Python 3.5 or newer.
New features
Forwardmode automatic differentiation (jvp) of while loop (https://github.com/google/jax/pull/1980)
New NumPy and SciPy functions:
Batched Cholesky decomposition on GPU now uses a more efficient batched kernel.
Notable bug fixes¶
With the Python 3 upgrade, JAX no longer depends on
fastcache
, which should help with installation.
JAX Frequently Asked Questions (FAQ)¶
We are collecting here answers to frequently asked questions. Contributions welcome!
jit
changes the behavior of my function¶
If you have a Python function that changes behavior after using jax.jit()
, perhaps
your function uses global state, or has sideeffects. In the following code, the
impure_func
uses the global y
and has a sideeffect due to print
:
y = 0
# @jit # Different behavior with jit
def impure_func(x):
print("Inside:", y)
return x + y
for y in range(3):
print("Result:", impure_func(y))
Without jit
the output is:
Inside: 0
Result: 0
Inside: 1
Result: 2
Inside: 2
Result: 4
and with jit
it is:
Inside: 0
Result: 0
Result: 1
Result: 2
For jax.jit()
, the function is executed once using the Python interpreter, at which time the
Inside
printing happens, and the first value of y
is observed. Then, the function
is compiled and cached, and executed multiple times with different values of x
, but
with the same first value of y
.
Additional reading:
jit
decorated function is very slow to compile¶
If your jit
decorated function takes tens of seconds (or more!) to run the
first time you call it, but executes quickly when called again, JAX is taking a
long time to trace or compile your code.
This is usually an symptom of calling your function generating a large amount of
code in JAX’s internal representation, typically because it makes heavy use of
Python control flow such as for
loop. For a handful of loop iterations
Python is OK, but if you need _many_ loop iterations, you should rewrite your
code to make use of JAX’s
structured control flow primitives
(such as lax.scan()
) or avoid wrapping the loop with jit
(you can
still use jit
decorated functions inside the loop).
If you’re not sure if this is the problem, you can try running
jax.make_jaxpr()
on your function. You can expect slow compilation if the
output is many hundreds or thousands of lines long.
Sometimes it isn’t obvious how to rewrite your code to avoid Python loops
because your code makes use of many arrays with different shapes. The
recommended solution in this case is to make use of functions like
jax.numpy.where()
to do your computation on padded arrays with fixed
shape. The JAX team is exploring a “masking” transformation to make such code
easier to write.
If your functions are slow to compile for another reason, please open an issue on GitHub.
Controlling data and computation placement on devices¶
Let’s first look at the principles of data and computation placement in JAX.
In JAX, the computation follows data placement. JAX arrays have two placement properties: 1) the device where the data resides; and 2) whether it is committed to the device or not (the data is sometimes referred to as being sticky to the device).
By default, JAX arrays are placed uncommitted on the default device
(jax.devices()[0]
), which is the first GPU by default. If no GPU is
present, jax.devices()[0]
is the first CPU. The default device can
be set to “cpu” or “gpu” manually by setting the environment variable
JAX_PLATFORM_NAME
or the absl flag jax_platform_name
.
>>> from jax import numpy as jnp
>>> print(jnp.ones(3).device_buffer.device())
gpu:0
Computations involving uncommitted data are performed on the default device and the results are uncommitted on the default device.
Data can also be placed explicitly on a device using jax.device_put()
with a device
parameter, in which case the data becomes committed to the device:
>>> import jax
>>> from jax import device_put
>>> print(device_put(1, jax.devices()[2]).device_buffer.device())
gpu:2
Computations involving some committed inputs will happen on the committed device and the result will be committed on the same device. Invoking an operation on arguments that are committed to more than one device will raise an error.
You can also use jax.device_put()
without a device
parameter. If the data
is already on a device (committed or not), it’s left asis. If the data isn’t on any
device—that is, it’s a regular Python or NumPy value—it’s placed uncommitted on the default
device.
Jitted functions behave like any other primitive operations—they will follow the data and will show errors if invoked on data committed on more than one device.
jnp.device_put(jnp.zeros(...), jax.devices()[1])
or similar will actually create the
array of zeros on jax.devices()[1]
, instead of creating the array on the default
device then moving it. This is thanks to some laziness in array creation, which holds
for all the constant creation operations (ones
, full
, eye
, etc).
(As of April 2020, jax.jit()
has a device parameter that affects the device
placement. That parameter is experimental, is likely to be removed or changed,
and its use is not recommended.)
For a workedout example, we recommend reading through
test_computation_follows_data
in
multi_device_test.py.
Abstract tracer value encountered where concrete value is expected
error¶
If you are getting an error that a library function is called with “Abstract tracer value encountered where concrete value is expected”, you may need to change how you invoke JAX transformations. Below is an example and a couple of possible solutions, followed by the details of what is actually happening, if you are curious or the simple solution does not work for you.
Some library functions take arguments that specify shapes or axes,
such as the second and third arguments for jax.numpy.split()
:
# def np.split(arr, num_sections: Union[int, Sequence[int]], axis: int):
np.split(np.zeros(2), 2, 0) # works
If you try the following code:
jax.jit(np.split)(np.zeros(4), 2, 0)
you will get the following error:
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected (in jax.numpy.split argument 1).
Use transformation parameters such as `static_argnums` for `jit` to avoid tracing input values.
See `https://jax.readthedocs.io/en/latest/faq.html#abstracttracervaluewhereconcretevalueisexpectederror`.
Encountered value: Traced<ShapedArray(int32[], weak_type=True):JaxprTrace(level=1/1)>
You must change the way you use jax.jit()
to ensure that the num_sections
and axis
arguments use their concrete values (2
and 0
respectively).
The best mechanism is to use special transformation parameters
to declare some arguments to be static, e.g., static_argnums
for jax.jit()
:
jax.jit(np.split, static_argnums=(1, 2))(np.zeros(4), 2, 0)
An alternative is to apply the transformation to a closure
that encapsulates the arguments to be protected, either manually as below
or by using functools.partial
:
jax.jit(lambda arr: np.split(arr, 2, 0))(np.zeros(4))
Note a new closure is created at every invocation, which defeats the compilation caching mechanism, which is why static_argnums is preferred.
To understand more subtleties having to do with tracers vs. regular values, and concrete vs. abstract values, you may want to read Different kinds of JAX values.
Different kinds of JAX values¶
In the process of transforming functions, JAX replaces some function arguments with special tracer values.
You could see this if you use a print
statement:
def func(x):
print(x)
return np.cos(x)
res = jax.jit(func)(0.)
The above code does return the correct value 1.
but it also prints
Traced<ShapedArray(float32[])>
for the value of x
. Normally, JAX
handles these tracer values internally in a transparent way, e.g.,
in the numeric JAX primitives that are used to implement the
jax.numpy
functions. This is why np.cos
works in the example above.
More precisely, a tracer value is introduced for the argument of
a JAXtransformed function, except the arguments identified by special
parameters such as static_argnums
for jax.jit()
or
static_broadcasted_argnums
for jax.pmap()
. Typically, computations
that involve at least a tracer value will produce a tracer value. Besides tracer
values, there are regular Python values: values that are computed outside JAX
transformations, or arise from abovementioned static arguments of certain JAX
transformations, or computed solely from other regular Python values.
These are the values that are used everywhere in absence of JAX transformations.
A tracer value carries an abstract value, e.g., ShapedArray
with information
about the shape and dtype of an array. We will refer here to such tracers as
abstract tracers. Some tracers, e.g., those that are
introduced for arguments of autodiff transformations, carry ConcreteArray
abstract values that actually include the regular array data, and are used,
e.g., for resolving conditionals. We will refer here to such tracers
as concrete tracers. Tracer values computed from these concrete tracers,
perhaps in combination with regular values, result in concrete tracers.
A concrete value is either a regular value or a concrete tracer.
Most often values computed from tracer values are themselves tracer values.
There are very few exceptions, when a computation can be entirely done
using the abstract value carried by a tracer, in which case the result
can be a regular value. For example, getting the shape of a tracer
with ShapedArray
abstract value. Another example, is when explicitly
casting a concrete tracer value to a regular type, e.g., int(x)
or
x.astype(float)
.
Another such situation is for bool(x)
, which produces a Python bool when
concreteness makes it possible. That case is especially salient because
of how often it arises in control flow.
Here is how the transformations introduce abstract or concrete tracers:
jax.jit()
: introduces abstract tracers for all positional arguments except those denoted bystatic_argnums
, which remain regular values.
jax.pmap()
: introduces abstract tracers for all positional arguments except those denoted bystatic_broadcasted_argnums
.
jax.vmap()
,jax.make_jaxpr()
,xla_computation()
: introduce abstract tracers for all positional arguments.
jax.jvp()
andjax.grad()
introduce concrete tracers for all positional arguments. An exception is when these transformations are within an outer transformation and the actual arguments are themselves abstract tracers; in that case, the tracers introduced by the autodiff transformations are also abstract tracers.All higherorder controlflow primitives (
lax.cond()
,lax.while_loop()
,lax.fori_loop()
,lax.scan()
) when they process the functionals introduce abstract tracers, whether or not there is a JAX transformation in progress.
All of this is relevant when you have code that can operate only on regular Python values, such as code that has conditional controlflow based on data:
def divide(x, y):
return x / y if y >= 1. else 0.
If we want to apply jax.jit()
, we must ensure to specify static_argnums=1
to ensure y
stays a regular value. This is due to the boolean expression
y >= 1.
, which requires concrete values (regular or tracers). The
same would happen if we write explicitly bool(y >= 1.)
, or int(y)
,
or float(y)
.
Interestingly, jax.grad(divide)(3., 2.)
, works because jax.grad()
uses concrete tracers, and resolves the conditional using the concrete
value of y
.
Gradients contain NaN where using where
¶
If you define a function using where
to avoid an undefined value, if you
are not careful you may obtain a NaN
for reverse differentiation:
def my_log(x):
return np.where(x > 0., np.log(x), 0.)
my_log(0.) ==> 0. # Ok
jax.grad(my_log)(0.) ==> NaN
A short explanation is that during grad
computation the adjoint corresponding
to the undefined np.log(x)
is a NaN
and when it gets accumulated to the
adjoint of the np.where
. The correct way to write such functions is to ensure
that there is a np.where
inside the partiallydefined function, to ensure
that the adjoint is always finite:
def safe_for_grad_log(x):
return np.log(np.where(x > 0., x, 1.))
safe_for_grad_log(0.) ==> 0. # Ok
jax.grad(safe_for_grad_log)(0.) ==> 0. # Ok
The inner np.where
may be needed in addition to the original one, e.g.:
def my_log_or_y(x, y):
"""Return log(x) if x > 0 or y"""
return np.where(x > 0., np.log(np.where(x > 0., x, 1.), y)
Additional reading:
Understanding Jaxprs¶
Updated: May 3, 2020 (for commit f1a46fe).
Conceptually, one can think of JAX transformations as first tracespecializing the Python function to be transformed into a small and wellbehaved intermediate form that is then interpreted with transformationspecific interpretation rules. One of the reasons JAX can pack so much power into such a small software package is that it starts with a familiar and flexible programming interface (Python with NumPy) and it uses the actual Python interpreter to do most of the heavy lifting to distill the essence of the computation into a simple staticallytyped expression language with limited higherorder features. That language is the jaxpr language.
Not all Python programs can be processed this way, but it turns out that many scientific computing and machine learning programs can.
Before we proceed, it is important to point out that not all JAX transformations literally materialize a jaxpr as described above; some, e.g., differentiation or batching, will apply transformations incrementally during tracing. Nevertheless, if one wants to understand how JAX works internally, or to make use of the result of JAX tracing, it is useful to understand jaxprs.
A jaxpr instance represents a function with one or more typed parameters (input
variables) and one or more typed results. The results depend only on the input
variables; there are no free variables captured from enclosing scopes. The
inputs and outputs have types, which in JAX are represented as abstract values.
There are two related representations in the code for jaxprs,
jax.core.Jaxpr
and jax.core.ClosedJaxpr
. A
jax.core.ClosedJaxpr
represents a partiallyapplied
jax.core.Jaxpr
, and is what you obtain when you use
jax.make_jaxpr()
to inspect jaxprs. It has the following fields:
jaxpr
: is ajax.core.Jaxpr
representing the actual computation content of the function (described below).
consts
is a list of constants.
The most interesting part of the ClosedJaxpr is the actual execution content,
represented as a jax.core.Jaxpr
as printed using the following
grammar:
jaxpr ::= { lambda Var* ; Var+.
let Eqn*
in [Expr+] }
 where:
The parameters of the jaxpr are shown as two lists of variables separated by
;
. The first set of variables are the ones that have been introduced to stand for constants that have been hoisted out. These are called theconstvars
, and in ajax.core.ClosedJaxpr
theconsts
field holds corresponding values. The second list of variables, calledinvars
, correspond to the inputs of the traced Python function.Eqn*
is a list of equations, defining intermediate variables referring to intermediate expressions. Each equation defines one or more variables as the result of applying a primitive on some atomic expressions. Each equation uses only input variables and intermediate variables defined by previous equations.Expr+
: is a list of output atomic expressions (literals or variables) for the jaxpr.
Equations are printed as follows:
Eqn ::= let Var+ = Primitive [ Param* ] Expr+
 where:
Var+
are one or more intermediate variables to be defined as the output of a primitive invocation (some primitives can return multiple values)Expr+
are one or more atomic expressions, each either a variable or a literal constant. A special variableunitvar
or literalunit
, printed as*
, represents a value that is not needed in the rest of the computation and has been elided. That is, units are just placeholders.Param*
are zero or more named parameters to the primitive, printed in square brackets. Each parameter is shown asName = Value
.
Most jaxpr primitives are firstorder (they take just one or more Expr as arguments):
Primitive := add  sub  sin  mul  ...
The jaxpr primitives are documented in the jax.lax
module.
For example, here is the jaxpr produced for the function func1
below
>>> from jax import make_jaxpr
>>> import jax.numpy as jnp
>>> def func1(first, second):
... temp = first + jnp.sin(second) * 3.
... return jnp.sum(temp)
...
>>> print(make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8)))
{ lambda ; a b.
let c = sin b
d = mul c 3.0
e = add a d
f = reduce_sum[ axes=(0,) ] e
in (f,) }
Here there are no constvars, a
and b
are the input variables
and they correspond respectively to
first
and second
function parameters. The scalar literal 3.0
is kept
inline.
The reduce_sum
primitive has named parameters axes
and input_shape
, in
addition to the operand e
.
Note that JAX traces through Pythonlevel controlflow and higherorder functions
when it extracts the jaxpr. This means that just because a Python program contains
functions and controlflow, the resulting jaxpr does not have
to contain controlflow or higherorder features.
For example, when tracing the function func3
JAX will inline the call to
inner
and the conditional if second.shape[0] > 4
, and will produce the same
jaxpr as before
>>> def func2(inner, first, second):
... temp = first + inner(second) * 3.
... return jnp.sum(temp)
...
>>> def inner(second):
... if second.shape[0] > 4:
... return jnp.sin(second)
... else:
... assert False
...
>>> def func3(first, second):
... return func2(inner, first, second)
...
>>> print(make_jaxpr(func3)(jnp.zeros(8), jnp.ones(8)))
{ lambda ; a b.
let c = sin b
d = mul c 3.0
e = add a d
f = reduce_sum[ axes=(0,) ] e
in (f,) }
Handling PyTrees¶
In jaxpr there are no tuple types; instead primitives take multiple inputs and produce multiple outputs. When processing a function that has structured inputs or outputs, JAX will flatten those and in jaxpr they will appear as lists of inputs and outputs. For more details, please see the documentation for PyTrees (notebooks/JAX_pytrees).
For example, the following code produces an identical jaxpr to what we saw before (with two input vars, one for each element of the input tuple)
>>> def func4(arg): # Arg is a pair
... temp = arg[0] + jnp.sin(arg[1]) * 3.
... return jnp.sum(temp)
...
>>> print(make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8))))
{ lambda ; a b.
let c = sin b
d = mul c 3.0
e = add a d
f = reduce_sum[ axes=(0,) ] e
in (f,) }
Constant Vars¶
Some values in jaxprs are constants, in that their value does not depend on the jaxpr’s arguments. When these values are scalars they are represented directly in the jaxpr equations; nonscalar array constants are instead hoisted out to the toplevel jaxpr, where they correspond to constant variables (“constvars”). These constvars differ from the other jaxpr parameters (“invars”) only as a bookkeeping convention.
Higherorder primitives¶
jaxpr includes several higherorder primitives. They are more complicated because they include subjaxprs.
Conditionals¶
JAX traces through normal Python conditionals. To capture a
conditional expression for dynamic execution, one must use the
jax.lax.switch()
and jax.lax.cond()
constructors,
which have the signatures:
lax.switch(index: int, branches: Sequence[A > B], operand: A) > B
lax.cond(pred: bool, true_body: A > B, false_body: A > B, operand: A) > B
Both of these will bind a primitive called cond
internally. The
cond
primitive in jaxprs reflects the more general signature of
lax.switch()
: it takes an integer denoting the index of the branch
to execute (clamped into valid indexing range).
For example:
>>> from jax import lax
>>>
>>> def one_of_three(index, arg):
... return lax.switch(index, [lambda x: x + 1.,
... lambda x: x  2.,
... lambda x: x + 3.],
... arg)
...
>>> print(make_jaxpr(one_of_three)(1, 5.))
{ lambda ; a b.
let c = clamp 0 a 2
d = cond[ branches=( { lambda ; a.
let b = add a 1.0
in (b,) }
{ lambda ; a.
let b = sub a 2.0
in (b,) }
{ lambda ; a.
let b = add a 3.0
in (b,) } )
linear=(False,) ] c b
in (d,) }
The cond primitive has a number of parameters:
branches are jaxprs that correspond to the branch functionals. In this example, those functionals each take one input variable, corresponding to
x
.linear is a tuple of booleans that is used internally by the autodifferentiation machinery to encode which of the input parameters are used linearly in the conditional.
The above instance of the cond primitive takes two operands. The first
one (c
) is the branch index, then b
is the operand (arg
) to
be passed to whichever jaxpr in branches
is selected by the branch
index.
Another example, using lax.cond()
:
>>> from jax import lax
>>>
>>> def func7(arg):
... return lax.cond(arg >= 0.,
... lambda xtrue: xtrue + 3.,
... lambda xfalse: xfalse  3.,
... arg)
...
>>> print(make_jaxpr(func7)(5.))
{ lambda ; a.
let b = ge a 0.0
c = convert_element_type[ new_dtype=int32
old_dtype=bool ] b
d = cond[ branches=( { lambda ; a.
let b = sub a 3.0
in (b,) }
{ lambda ; a.
let b = add a 3.0
in (b,) } )
linear=(False,) ] c a
in (d,) }
In this case, the boolean predicate is converted to an integer index
(0 or 1), and branches
are jaxprs that correspond to the false and
true branch functionals, in that order. Again, each functional takes
one input variable, corresponding to xtrue
and xfalse
respectively.
The following example shows a more complicated situation when the input
to the branch functionals is a tuple, and the false branch functional
contains a constant jnp.ones(1)
that is hoisted as a constvar
>>> def func8(arg1, arg2): # arg2 is a pair
... return lax.cond(arg1 >= 0.,
... lambda xtrue: xtrue[0],
... lambda xfalse: jnp.array([1]) + xfalse[1],
... arg2)
...
>>> print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.)))
{ lambda a ; b c d.
let e = ge b 0.0
f = convert_element_type[ new_dtype=int32
old_dtype=bool ] e
g = cond[ branches=( { lambda ; a b c.
let d = convert_element_type[ new_dtype=float32
old_dtype=int32 ] a
e = add d c
in (e,) }
{ lambda ; f_ a b.
let
in (a,) } )
linear=(False, False, False) ] f a c d
in (g,) }
While¶
Just like for conditionals, Python loops are inlined during tracing.
If you want to capture a loop for dynamic execution, you must use one of several
special operations, jax.lax.while_loop()
(a primitive)
and jax.lax.fori_loop()
(a helper that generates a while_loop primitive):
lax.while_loop(cond_fun: (C > bool), body_fun: (C > C), init: C) > C
lax.fori_loop(start: int, end: int, body: (int > C > C), init: C) > C
In the above signature, “C” stands for the type of a the loop “carry” value. For example, here is an example fori loop
>>> import numpy as np
>>>
>>> def func10(arg, n):
... ones = jnp.ones(arg.shape) # A constant
... return lax.fori_loop(0, n,
... lambda i, carry: carry + ones * 3. + arg,
... arg + ones)
...
>>> print(make_jaxpr(func10)(np.ones(16), 5))
{ lambda ; a b.
let c = broadcast_in_dim[ broadcast_dimensions=( )
shape=(16,) ] 1.0
d = add a c
_ _ e = while[ body_jaxpr={ lambda ; a b c d e.
let f = add c 1
g = mul a 3.0
h = add e g
i = add h b
in (f, d, i) }
body_nconsts=2
cond_jaxpr={ lambda ; a b c.
let d = lt a b
in (d,) }
cond_nconsts=0 ] c a 0 b d
in (e,) }
The while primitive takes 5 arguments: c a 0 b e
, as follows:
0 constants for
cond_jaxpr
(sincecond_nconsts
is 0)2 constants for
body_jaxpr
(c
, anda
)3 parameters for the initial value of carry
Scan¶
JAX supports a special form of loop over the elements of an array (with
statically known shape). The fact that there are a fixed number of iterations
makes this form of looping easily reversedifferentiable. Such loops are
constructed with the jax.lax.scan()
function:
lax.scan(body_fun: (C > A > (C, B)), init_carry: C, in_arr: Array[A]) > (C, Array[B])
Here C
is the type of the scan carry, A
is the element type of the
input array(s), and B
is the element type of the output array(s).
For the example consider the function func11
below
>>> def func11(arr, extra):
... ones = jnp.ones(arr.shape) # A constant
... def body(carry, aelems):
... # carry: running dotproduct of the two arrays
... # aelems: a pair with corresponding elements from the two arrays
... ae1, ae2 = aelems
... return (carry + ae1 * ae2 + extra, carry)
... return lax.scan(body, 0., (arr, ones))
...
>>> print(make_jaxpr(func11)(np.ones(16), 5.))
{ lambda ; a b.
let c = broadcast_in_dim[ broadcast_dimensions=( )
shape=(16,) ] 1.0
d e = scan[ jaxpr={ lambda ; a b c d.
let e = mul c d
f = add b e
g = add f a
in (g, b) }
length=16
linear=(False, False, False, False)
num_carry=1
num_consts=1
reverse=False
unroll=1 ] b 0.0 a c
in (d, e) }
The linear
parameter describes for each of the input variables whether they
are guaranteed to be used linearly in the body. Once the scan goes through
linearization, more arguments will be linear.
The scan primitive takes 4 arguments: b 0.0 a c
, of which:
one is the free variable for the body
one is the initial value of the carry
The next 2 are the arrays over which the scan operates.
XLA_call¶
The call primitive arises from JIT compilation, and it encapsulates a subjaxpr along with parameters the specify the backend and the device the computation should run. For example
>>> from jax import jit
>>>
>>> def func12(arg):
... @jit
... def inner(x):
... return x + arg * jnp.ones(1) # Include a constant in the inner function
... return arg + inner(arg  2.)
...
>>> print(make_jaxpr(func12)(1.))
{ lambda ; a.
let b = sub a 2.0
c = xla_call[ backend=None
call_jaxpr={ lambda ; a b.
let c = broadcast_in_dim[ broadcast_dimensions=( )
shape=(1,) ] 1.0
d = mul a c
e = add b d
in (e,) }
device=None
donated_invars=(False, False)
name=inner ] a b
d = add a c
in (d,) }
XLA_pmap¶
If you use the jax.pmap()
transformation, the function to be mapped is
captured using the xla_pmap
primitive. Consider this example
>>> from jax import pmap
>>>
>>> def func13(arr, extra):
... def inner(x):
... # use a free variable "extra" and a constant jnp.ones(1)
... return (x + extra + jnp.ones(1)) / lax.psum(x, axis_name='rows')
... return pmap(inner, axis_name='rows')(arr)
...
>>> print(make_jaxpr(func13)(jnp.ones((1, 3)), 5.))
{ lambda ; a b.
let c = xla_pmap[ axis_name=rows
axis_size=1
backend=None
call_jaxpr={ lambda ; a b.
let c = add b a
d = broadcast_in_dim[ broadcast_dimensions=( )
shape=(1,) ] 1.0
e = add c d
f = psum[ axis_index_groups=None
axis_name=rows ] b
g = div e f
in (g,) }
devices=None
donated_invars=(False, False)
global_arg_shapes=(None,)
global_axis_size=None
in_axes=(None, 0)
name=inner
out_axes=(0,) ] b a
in (c,) }
The xla_pmap
primitive specifies the name of the axis (parameter rows
)
and the body of the function to be mapped as the call_jaxpr
parameter.
value of this parameter is a Jaxpr with 3 input variables:
The parameter in_axes
specifies which of the input variables should be
mapped and which should be broadcast. In our example, the value of extra
is broadcast, the other input values are mapped.
Asynchronous dispatch¶
JAX uses asynchronous dispatch to hide Python overheads. Consider the following program:
>>> import numpy as np
>>> import jax.numpy as jnp
>>> from jax import random
>>> x = random.uniform(random.PRNGKey(0), (1000, 1000))
>>> jnp.dot(x, x) + 3.
DeviceArray([[258.01971436, 249.64862061, 257.13372803, ...,
236.67948914, 250.68939209, 241.36853027],
[265.65979004, 256.28912354, 262.18252563, ...,
242.03181458, 256.16757202, 252.44122314],
[262.38916016, 255.72747803, 261.23059082, ...,
240.83563232, 255.41094971, 249.62471008],
...,
[259.15814209, 253.09197998, 257.72174072, ...,
242.23876953, 250.72680664, 247.16642761],
[271.22662354, 261.91204834, 265.33398438, ...,
248.26651001, 262.05389404, 261.33700562],
[257.16134644, 254.7543335, 259.08300781, ..., 241.59848022,
248.62597656, 243.22348022]], dtype=float32)
When an operation such as jnp.dot(x, x)
is executed, JAX does not wait
for the operation to complete before returning control to the Python program.
Instead, JAX returns a DeviceArray
value, which is a future,
i.e., a value that will be produced in the future on an accelerator device but
isn’t necessarily available immediately. We can inspect the shape or type of a
DeviceArray
without waiting for the computation that produced it to
complete, and we can even pass it to another JAX computation, as we do with the
addition operation here. Only if we actually inspect the value of the array from
the host, for example by printing it or by converting it into a plain old
numpy.ndarray
will JAX force the Python code to wait for the
computation to complete.
Asynchronous dispatch is useful since it allows Python code to “run ahead” of an accelerator device, keeping Python code out of the critical path. Provided the Python code enqueues work on the device faster than it can be executed, and provided that the Python code does not actually need to inspect the output of a computation on the host, then a Python program can enqueue arbitrary amounts of work and avoid having the accelerator wait.
Asynchronous dispatch has a slightly surprising consequence for microbenchmarks.
>>> %time jnp.dot(x, x)
CPU times: user 267 µs, sys: 93 µs, total: 360 µs
Wall time: 269 µs
DeviceArray([[255.01972961, 246.64862061, 254.13371277, ...,
233.67948914, 247.68939209, 238.36853027],
[262.65979004, 253.28910828, 259.18252563, ...,
239.03181458, 253.16757202, 249.44122314],
[259.38916016, 252.72747803, 258.23059082, ...,
237.83563232, 252.41094971, 246.62471008],
...,
[256.15814209, 250.09197998, 254.72172546, ...,
239.23876953, 247.72680664, 244.16642761],
[268.22662354, 258.91204834, 262.33398438, ...,
245.26651001, 259.05389404, 258.33700562],
[254.16134644, 251.7543335, 256.08300781, ..., 238.59848022,
245.62597656, 240.22348022]], dtype=float32)
269µs is a surprisingly small time for a 1000x1000 matrix multiplication on CPU!
However it turns out that asynchronous dispatch is misleading us and we are not
timing the execution of the matrix multiplication, only the time to dispatch
the work. To measure the true cost of the operation we must either read the
value on the host (e.g., convert it to a plain old hostside numpy array), or
use the block_until_ready()
method on a
DeviceArray
value to wait for the computation that produced it to
complete.
>>> %time np.asarray(jnp.dot(x, x))
CPU times: user 61.1 ms, sys: 0 ns, total: 61.1 ms
Wall time: 8.09 ms
Out[16]:
array([[255.01973, 246.64862, 254.13371, ..., 233.67949, 247.68939,
238.36853],
[262.6598 , 253.28911, 259.18253, ..., 239.03181, 253.16757,
249.44122],
[259.38916, 252.72748, 258.2306 , ..., 237.83563, 252.41095,
246.62471],
...,
[256.15814, 250.09198, 254.72173, ..., 239.23877, 247.7268 ,
244.16643],
[268.22662, 258.91205, 262.33398, ..., 245.26651, 259.0539 ,
258.337 ],
[254.16135, 251.75433, 256.083 , ..., 238.59848, 245.62598,
240.22348]], dtype=float32)
>>> %time jnp.dot(x, x).block_until_ready()
CPU times: user 50.3 ms, sys: 928 µs, total: 51.2 ms
Wall time: 4.92 ms
DeviceArray([[255.01972961, 246.64862061, 254.13371277, ...,
233.67948914, 247.68939209, 238.36853027],
[262.65979004, 253.28910828, 259.18252563, ...,
239.03181458, 253.16757202, 249.44122314],
[259.38916016, 252.72747803, 258.23059082, ...,
237.83563232, 252.41094971, 246.62471008],
...,
[256.15814209, 250.09197998, 254.72172546, ...,
239.23876953, 247.72680664, 244.16642761],
[268.22662354, 258.91204834, 262.33398438, ...,
245.26651001, 259.05389404, 258.33700562],
[254.16134644, 251.7543335, 256.08300781, ..., 238.59848022,
245.62597656, 240.22348022]], dtype=float32)
Blocking without transferring the result back to Python is usually faster, and is often the best choice when writing microbenchmarks of computation times.
Concurrency¶
JAX has some limited support for Python concurrency.
Concurrency support is experimental and only lightly tested; please report any bugs.
Clients may call JAX APIs (e.g., jit()
or grad()
)
concurrently from separate Python threads.
It is not permitted to manipulate JAX trace values concurrently from multiple
threads. In other words, while it is permissible to call functions that use JAX
tracing (e.g., jit()
) from multiple threads, you must not use
threading to manipulate JAX values inside the implementation of the function
f that is passed to jit()
. The most likely outcome if you do this
is a mysterious error from JAX.
GPU memory allocation¶
JAX will preallocate 90% of currentlyavailable GPU memory when the first JAX operation is run. Preallocating minimizes allocation overhead and memory fragmentation, but can sometimes cause outofmemory (OOM) errors. If your JAX process fails with OOM, the following environment variables can be used to override the default behavior:
XLA_PYTHON_CLIENT_PREALLOCATE=false
This disables the preallocation behavior. JAX will instead allocate GPU memory as needed, potentially decreasing the overall memory usage. However, this behavior is more prone to GPU memory fragmentation, meaning a JAX program that uses most of the available GPU memory may OOM with preallocation disabled.
XLA_PYTHON_CLIENT_MEM_FRACTION=.XX
If preallocation is enabled, this makes JAX preallocate XX% of currentlyavailable GPU memory, instead of the default 90%. Lowering the amount preallocated can fix OOMs that occur when the JAX program starts.
XLA_PYTHON_CLIENT_ALLOCATOR=platform
This makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed (note that this is the only configuration that will deallocate GPU memory, instead of reusing it). This is very slow, so is not recommended for general use, but may be useful for running with the minimal possible GPU memory footprint or debugging OOM failures.
Common causes of OOM failures¶
 Running multiple JAX processes concurrently.
Either use
XLA_PYTHON_CLIENT_MEM_FRACTION
to give each process an appropriate amount of memory, or setXLA_PYTHON_CLIENT_PREALLOCATE=false
. Running JAX and GPU TensorFlow concurrently.
TensorFlow also preallocates by default, so this is similar to running multiple JAX processes concurrently.
One solution is to use CPUonly TensorFlow (e.g. if you’re only doing data loading with TF). You can prevent TensorFlow from using the GPU with the command
tf.config.experimental.set_visible_devices([], "GPU")
Alternatively, use
XLA_PYTHON_CLIENT_MEM_FRACTION
orXLA_PYTHON_CLIENT_PREALLOCATE
. There are also similar options to configure TensorFlow’s GPU memory allocation (gpu_memory_fraction and allow_growth in TF1, which should be set in atf.ConfigProto
passed totf.Session
. See Using GPUs: Limiting GPU memory growth for TF2). Running JAX on the display GPU.
Use
XLA_PYTHON_CLIENT_MEM_FRACTION
orXLA_PYTHON_CLIENT_PREALLOCATE
.
Profiling JAX programs¶
TensorBoard profiling¶
TensorBoard’s profiler can be used to profile JAX programs. Tensorboard is a great way to acquire and visualize performance traces and profiles of your program, including activity on GPU and TPU. The end result looks something like this:
Installation¶
# Requires TensorFlow and TensorBoard version >= 2.2
pip install upgrade tensorflow tensorboard_plugin_profile
Usage¶
The following are instructions for capturing a manuallytriggered Nsecond trace from a running program.
Start a TensorBoard server:
tensorboard logdir /tmp/tensorboard/
You should be able to load TensorBoard at http://localhost:6006/. You can specify a different port with the
port
flag. See Profiling on a remote machine below if running JAX on a remote server.In the Python program or process you’d like to profile, add the following somewhere near the beginning:
import jax.profiler server = jax.profiler.start_server(9999)
This starts the profiler server that TensorBoard connects to. The profiler server must be running before you move on to the next step. It will remain alive and listening until the object returned by
start_server()
is destroyed.If you’d like to profile a snippet of a longrunning program (e.g. a long training loop), you can put this at the beginning of the program and start your program as usual. If you’d like to profile a short program (e.g. a microbenchmark), one option is to start the profiler server in an IPython shell, and run the short program with
%run
after starting the capture in the next step. Another option is to start the profiler server at the beginning of the program and usetime.sleep()
to give you enough time to start the capture.Open http://localhost:6006/#profile, and click the “CAPTURE PROFILE” button in the upper left. Enter “localhost:9999” as the profile service URL (this is the address of the profiler server you started in the previous step). Enter the number of milliseconds you’d like to profile for, and click “CAPTURE”.
If the code you’d like to profile isn’t already running (e.g. if you started the profiler server in a Python shell), run it while the capture is running.
After the capture finishes, TensorBoard should automatically refresh. (Not all of the TensorBoard profiling features are hooked up with JAX, so it may initially look like nothing was captured.) On the left under “Tools”, select “trace_viewer”.
You should now see a timeline of the execution. You can use the WASD keys to navigate the trace, and click or drag to select events to see more details at the bottom. See these TensorFlow docs for more details on using the trace viewer.
By default, the events in the trace viewer are mostly lowlevel internal JAX functions. You can add your own events and functions by using
jax.profiler.TraceContext()
andjax.profiler.trace_function()
in your code and capturing a new profile.
Troubleshooting¶
GPU profiling¶
Programs running on GPU should produce traces for the GPU streams near the top of the trace viewer. If you’re only seeing the host traces, check your program logs and/or output for the following error messages.
If you get an error like: Could not load dynamic library 'libcupti.so.10.1'
Full error:
W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libcupti.so.10.1'; dlerror: libcupti.so.10.1: cannot open shared object file: No such file or directory
20200612 13:19:59.822799: E external/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1422] function cupti_interface_>Subscribe( &subscriber_, (CUpti_CallbackFunc)ApiCallback, this)failed with error CUPTI could not be loaded or symbol could not be found.
Add the path to libcupti.so
to the environment variable LD_LIBRARY_PATH
.
(Try locate libcupti.so
to find the path.) For example:
export LD_LIBRARY_PATH=/usr/local/cuda10.1/extras/CUPTI/lib64/:$LD_LIBRARY_PATH
If you get an error like: failed with error CUPTI_ERROR_INSUFFICIENT_PRIVILEGES
Full error:
E external/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1445] function cupti_interface_>EnableCallback( 0 , subscriber_, CUPTI_CB_DOMAIN_DRIVER_API, cbid)failed with error CUPTI_ERROR_INSUFFICIENT_PRIVILEGES
20200612 14:31:54.097791: E external/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1487] function cupti_interface_>ActivityDisable(activity)failed with error CUPTI_ERROR_NOT_INITIALIZED
Run the following commands (note this requires a reboot):
echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"'  sudo tee a /etc/modprobe.d/nvidiakernelcommon.conf
sudo updateinitramfs u
sudo reboot now
See NVIDIA’s documentation on this error for more information.
Profiling on a remote machine¶
If the JAX program you’d like to profile is running on a remote machine, one option is to run all the instructions above on the remote machine (in particular, start the TensorBoard server on the remote machine), then use SSH local port forwarding to access the TensorBoard web UI from your local machine. Use the following SSH command to forward the default TensorBoard port 6006 from the local to the remote machine:
ssh L 6006:localhost:6006 <remote server address>
Nsight¶
NVIDIA’s Nsight
tools can be used to trace and profile JAX code on GPU. For
details, see the Nsight
documentation.
XLA profiling¶
XLA has some builtin support for profiling on both CPU and GPU. To use XLA’s
profiling features from JAX, set the environment variables
TF_CPP_MIN_LOG_LEVEL=0
and XLA_FLAGS=xla_hlo_profile
. XLA will log
profiling information about each computation JAX runs. For example:
$ TF_CPP_MIN_LOG_LEVEL=0 XLA_FLAGS=xla_hlo_profile ipython
...
In [1]: from jax import lax
In [2]: lax.add(1, 2)
20190808 20:47:52.659030: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7fe2c719e200 executing computations on platform Host. Devices:
20190808 20:47:52.659054: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:175] StreamExecutor device (0): Host, Default Version
/Users/phawkins/p/jax/jax/lib/xla_bridge.py:114: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
20190808 20:47:52.674813: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] Execution profile for primitive_computation.4: (0.0324 us @ f_nom)
20190808 20:47:52.674832: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] 94 cycles (100.% 100Σ) :: 0.0 usec ( 0.0 optimal) :: 30.85MFLOP/s :: :: 353.06MiB/s :: 0.128B/cycle :: [total] [entry]
20190808 20:47:52.674838: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] 94 cycles (100.00% 100Σ) :: 0.0 usec ( 0.0 optimal) :: 30.85MFLOP/s :: :: 353.06MiB/s :: 0.128B/cycle :: %add.3 = s32[] add(s32[] %parameter.1, s32[] %parameter.2)
20190808 20:47:52.674842: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174]
20190808 20:47:52.674846: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] ********** microseconds report **********
20190808 20:47:52.674909: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] There are 0 microseconds in total.
20190808 20:47:52.674921: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] There are 0 microseconds ( 0.00%) not accounted for by the data.
20190808 20:47:52.674925: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] There are 1 ops.
20190808 20:47:52.674928: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174]
20190808 20:47:52.674932: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] ********** categories table for microseconds **********
20190808 20:47:52.674935: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174]
20190808 20:47:52.674939: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] 0 (100.00% Σ100.00%) nonfusion elementwise (1 ops)
20190808 20:47:52.674942: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] * 100.00% %add.3 = s32[] add(s32[], s32[])
20190808 20:47:52.675673: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174]
20190808 20:47:52.675682: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174]
20190808 20:47:52.675688: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] ********** MiB read+written report **********
20190808 20:47:52.675692: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] There are 0 MiB read+written in total.
20190808 20:47:52.675697: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] There are 0 MiB read+written ( 0.00%) not accounted for by the data.
20190808 20:47:52.675700: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] There are 3 ops.
20190808 20:47:52.675703: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174]
20190808 20:47:52.675812: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] ********** categories table for MiB read+written **********
20190808 20:47:52.675823: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174]
20190808 20:47:52.675827: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] 0 (100.00% Σ100.00%) nonfusion elementwise (1 ops)
20190808 20:47:52.675832: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] * 100.00% %add.3 = s32[] add(s32[], s32[])
20190808 20:47:52.675835: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] 0 ( 0.00% Σ100.00%) ... (1 more categories)
20190808 20:47:52.675839: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174]
Out[2]: DeviceArray(3, dtype=int32)
Device Memory Profiling¶
The JAX Device Memory Profiler allows us to explore how and why JAX programs are using GPU or TPU memory. For example, it can be used to:
Figure out which arrays and executables are in GPU memory at a given time, or
Track down memory leaks.
Installation¶
The JAX device memory profiler emits output that can be interpreted using
pprof (https://github.com/google/pprof). Start by installing pprof
,
by following its
installation instructions.
At the time of writing, installing pprof
requires first installing
Go and Graphviz, and then
running
go get u github.com/google/pprof
which installs pprof
as $GOPATH/bin/pprof
, where GOPATH
defaults to
~/go
.
Note
The version of pprof
from https://github.com/google/pprof is not the same as
the older tool of the same name distributed as part of the gperftools
package.
The gperftools
version of pprof
will not work with JAX.
Understanding how a JAX program is using GPU or TPU memory¶
A common use of the device memory profiler is to figure out why a JAX program is using a large amount of GPU or TPU memory, for example if trying to debug an outofmemory problem.
To capture a device memory profile to disk, use
jax.profiler.save_device_memory_profile()
. For example, consider the
following Python program:
import jax
import jax.numpy as jnp
import jax.profiler
def func1(x):
return jnp.tile(x, 10) * 0.5
def func2(x):
y = func1(x)
return y, jnp.tile(x, 10) + 1
x = jax.random.normal(jax.random.PRNGKey(42), (1000, 1000))
y, z = func2(x)
z.block_until_ready()
jax.profiler.save_device_memory_profile("memory.prof")
If we first run the program above and then execute
pprof web memory.prof
pprof
opens a web browser containing the following visualization of the device
memory profile in callgraph format:
The callgraph is a visualization of
the Python stack at the point the allocation of each live buffer was made.
For example, in this specific case, the visualization shows that
func2
and its callees were responsible for allocating 76.30MB, of which
38.15MB was allocated inside the call from func1
to func2
.
For more information about how to interpret callgraph visualizations, see the
pprof documentation.
Functions compiled with jax.jit()
are opaque to the device memory profiler.
That is, any memory allocated inside a jit
compiled function will be
attributed to the function as whole.
In the example, the call to block_until_ready()
is to ensure that func2
completes before the device memory profile is collected. See
Asynchronous dispatch for more details.
Debugging memory leaks¶
We can also use the JAX device memory profiler to track down memory leaks by using
pprof
to visualize the change in memory usage between two device memory profiles
taken at different times. For example consider the following program which
accumulates JAX arrays into a constantlygrowing Python list.
import jax
import jax.numpy as jnp
import jax.profiler
def afunction():
return jax.random.normal(jax.random.PRNGKey(77), (1000000,))
z = afunction()
def anotherfunc():
arrays = []
for i in range(1, 10):
x = jax.random.normal(jax.random.PRNGKey(42), (i, 10000))
arrays.append(x)
x.block_until_ready()
jax.profiler.save_device_memory_profile(f"memory{i}.prof")
anotherfunc()
If we simply visualize the device memory profile at the end of execution
(memory9.prof
), it may not be obvious that each iteration of the loop in
anotherfunc
accumulates more device memory allocations:
pprof web memory9.prof
The large but fixed allocation inside afunction
dominates the profile but does
not grow over time.
By using pprof
’s
diff_base
feature to visualize the change in memory usage
across loop iterations, we can identify why the memory usage of the
program increases over time:
pprof web diff_base memory1.prof memory9.prof
The visualization shows that the memory growth can be attributed to the call to
normal
inside anotherfunc
.
Pytrees¶
What is a pytree?¶
In JAX, we use the term pytree to refer to a treelike structure built out of containerlike Python objects. Classes are considered containerlike if they are in the pytree registry, which by default includes lists, tuples, and dicts. That is:
any object whose type is not in the pytree container registry is considered a leaf pytree;
any object whose type is in the pytree container registry, and which contains pytrees, is considered a pytree.
For each entry in the pytree container registry, a containerlike type is
registered with a pair of functions which specify how to convert an instance of
the container type to a (children, metadata)
pair and how to convert such a
pair back to an instance of the container type. Using these functions, JAX can
canonicalize any tree of reigstered container objects into tuples.
Example pytrees:
[1, "a", object()] # 3 leaves
(1, (2, 3), ()) # 3 leaves
[1, {"k1": 2, "k2": (3, 4)}, 5] # 5 leaves
JAX can be extended to consider other container types as pytrees; see Extending pytrees below.
Pytrees and JAX functions¶
Many JAX functions, like jax.lax.scan
, operate over pytrees of arrays.
JAX function transformations can be applied to functions that accept as input
and produce as output pytrees of arrays.
Applying optional parameters to pytrees¶
Some JAX function transformations take optional parameters that specify how
certain input or output values should be treated (e.g. the in_axes
and
out_axes
arguments to vmap
). These parameters can also be pytrees, and
their structure must correspond to the pytree structure of the corresponding
arguments. In particular, to be able to “match up” leaves in these parameter
pytrees with values in the argument pytrees, the parameter pytrees are often
constrained to be tree prefixes of the argument pytrees.
For example, if we pass the following input to vmap
(note that the input
arguments to a function considered a tuple):
(a1, {"k1": a2, "k2": a3})
We can use the following in_axes
pytree to specify that only the k2
argument is mapped (axis=0
) and the rest aren’t mapped over
(axis=None
):
(None, {"k1": None, "k2": 0})
The optional parameter pytree structure must match that of the main input
pytree. However, the optional parameters can optionally be specified as a
“prefix” pytree, meaning that a single leaf value can be applied to an entire
subpytree. For example, if we have the same vmap
input as above, but wish
to only map over the dictionary argument, we can use:
(None, 0) # equivalent to (None, {"k1": 0, "k2": 0})
Or, if we want every argument to be mapped, we can simply write a single leaf value that is applied over the entire argument tuple pytree:
0
This happens to be the default in_axes
value for vmap
!
The same logic applies to other optional parameters that refer to specific input
or output values of a transformed function, e.g. vmap
’s out_axes
.
Developer information¶
This is primarily JAX internal documentation, endusers are not supposed to need to understand this to use JAX, except when registering new userdefined container types with JAX. Some of these details may change.
Internal pytree handling¶
JAX flattens pytrees into lists of leaves at the api.py
boundary (and also
in control flow primitives). This keeps downstream JAX internals simpler:
transformations like grad
, jit
, and vmap
can handle user functions
that accept and return the myriad different Python containers, while all the
other parts of the system can operate on functions that only take (multiple)
array arguments and always return a flat list of arrays.
When JAX flattens a pytree it will produce a list of leaves and a treedef
object that encodes the structure of the original value. The treedef
can
then be used to construct a matching structured value after transforming the
leaves. Pytrees are treelike, rather than DAGlike or graphlike, in that we
handle them assuming referential transparency and that they can’t contain
reference cycles.
Here is a simple example:
from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node
import jax.numpy as jnp
# The structured value to be transformed
value_structured = [1., (2., 3.)]
# The leaves in value_flat correspond to the `*` markers in value_tree
value_flat, value_tree = tree_flatten(value_structured)
print("value_flat={}\nvalue_tree={}".format(value_flat, value_tree))
# Transform the flat value list using an elementwise numeric transformer
transformed_flat = list(map(lambda v: v * 2., value_flat))
print("transformed_flat={}".format(transformed_flat))
# Reconstruct the structured output, using the original
transformed_structured = tree_unflatten(value_tree, transformed_flat)
print("transformed_structured={}".format(transformed_structured))
# Output:
# value_flat=[1.0, 2.0, 3.0]
# value_tree=PyTreeDef(list, [*,PyTreeDef(tuple, [*,*])])
# transformed_flat=[2.0, 4.0, 6.0]
# transformed_structured=[2.0, (4.0, 6.0)]
By default, Pytrees containers can be lists, tuples, dicts, namedtuple, None, OrderedDict. Other types of values, including numeric and ndarray values, are treated as leaves:
from collections import namedtuple
Point = namedtuple('Point', ['x', 'y'])
example_containers = [
(1., [2., 3.]),
(1., {'b': 2., 'a': 3.}),
1.,
None,
jnp.zeros(2),
Point(1., 2.)
]
def show_example(structured):
flat, tree = tree_flatten(structured)
unflattened = tree_unflatten(tree, flat)
print("structured={}\n flat={}\n tree={}\n unflattened={}".format(
structured, flat, tree, unflattened))
for structured in example_containers:
show_example(structured)
# Output:
# structured=(1.0, [2.0, 3.0])
# flat=[1.0, 2.0, 3.0]
# tree=PyTreeDef(tuple, [*,PyTreeDef(list, [*,*])])
# unflattened=(1.0, [2.0, 3.0])
# structured=(1.0, {'b': 2.0, 'a': 3.0})
# flat=[1.0, 3.0, 2.0]
# tree=PyTreeDef(tuple, [*,PyTreeDef(dict[['a', 'b']], [*,*])])
# unflattened=(1.0, {'a': 3.0, 'b': 2.0})
# structured=1.0
# flat=[1.0]
# tree=*
# unflattened=1.0
# structured=None
# flat=[]
# tree=PyTreeDef(None, [])
# unflattened=None
# structured=[0. 0.]
# flat=[DeviceArray([0., 0.], dtype=float32)]
# tree=*
# unflattened=[0. 0.]
# structured=Point(x=1.0, y=2.0)
# flat=[1.0, 2.0]
# tree=PyTreeDef(namedtuple[<class '__main__.Point'>], [*,*])
# unflattened=Point(x=1.0, y=2.0)
Extending pytrees¶
By default, any part of a structured value that is not recognized as an internal pytree node (i.e. containerlike) is treated as a leaf:
class Special(object):
def __init__(self, x, y):
self.x = x
self.y = y
def __repr__(self):
return "Special(x={}, y={})".format(self.x, self.y)
show_example(Special(1., 2.))
# Output:
# structured=Special(x=1.0, y=2.0)
# flat=[Special(x=1.0, y=2.0)]
# tree=*
# unflattened=Special(x=1.0, y=2.0)
The set of Python types that are considered internal pytree nodes is extensible, through a global registry of types. Values of registered types are traversed recursively:
class RegisteredSpecial(Special):
def __repr__(self):
return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)
def special_flatten(v):
"""Specifies a flattening recipe.
Params:
v: the value of registered type to flatten.
Returns:
a pair of an iterable with the children to be flattened recursively,
and some opaque auxiliary data to pass back to the unflattening recipe.
The auxiliary data is stored in the treedef for use during unflattening.
The auxiliary data could be used, e.g., for dictionary keys.
"""
children = (v.x, v.y)
aux_data = None
return (children, aux_data)
def special_unflatten(aux_data, children):
"""Specifies an unflattening recipe.
Params:
aux_data: the opaque data that was specified during flattening of the
current treedef.
children: the unflattened children
Returns:
a reconstructed object of the registered type, using the specified
children and auxiliary data.
"""
return RegisteredSpecial(*children)
# Global registration
register_pytree_node(
RegisteredSpecial,
special_flatten, # tell JAX what are the children nodes
special_unflatten # tell JAX how to pack back into a RegisteredSpecial
)
show_example(RegisteredSpecial(1., 2.))
# Output:
# structured=RegisteredSpecial(x=1.0, y=2.0)
# flat=[1.0, 2.0]
# tree=PyTreeDef(<class '__main__.RegisteredSpecial'>[None], [*,*])
# unflattened=RegisteredSpecial(x=1.0, y=2.0)
JAX needs sometimes to compare treedef for equality. Therefore care must be taken to ensure that the auxiliary data specified in the flattening recipe supports a meaningful equality comparison.
The whole set of functions for operating on pytrees are in tree_util module.
Rank promotion warning¶
NumPy broadcasting rules allow automatic promotion of arguments from one rank (number of array axes) to another. This behavior can be convenient when intended but can also lead to surprising bugs where a silent rank promotion masks an underlying shape error.
Here’s an example of rank promotion:
>>> import numpy as np
>>> x = np.arange(12).reshape(4, 3)
>>> y = np.array([0, 1, 0])
>>> x + y
array([[ 0, 2, 2],
[ 3, 5, 5],
[ 6, 8, 8],
[ 9, 11, 11]])
To avoid potential surprises, jax.numpy
is configurable so that
expressions requiring rank promotion can lead to a warning, error, or can be
allowed just like regular NumPy. The configuration option is named
jax_numpy_rank_promotion
and it can take on string values
allow
, warn
, and raise
. The default setting is
warn
, which raises a warning on the first occurrence of rank promotion.
The raise
setting raises an error on rank promotion, and allow
allows rank promotion without warning or error.
As with most other JAX configuration options, you can set this option in
several ways. One is by using jax.config
in your code:
from jax.config import config
config.update("jax_numpy_rank_promotion", "allow")
You can also set the option using the environment variable
JAX_NUMPY_RANK_PROMOTION
, for example as
JAX_NUMPY_RANK_PROMOTION='raise'
. Finally, when using abslpy
the option can be set with a commandline flag.
Type promotion semantics¶
JAX’s type promotion rules (i.e., the result of
jax.numpy.promote_types()
for each pair of types) is determined via
the following type promotion lattice:
where, for example:
b1
meansnp.bool_
,i2
meansnp.int16
,u4
meansnp.uint32
,bf
meansnp.bfloat16
,f2
meansnp.float16
,c8
meansnp.complex128
,i*
means Pythonint
,f*
means Pythonfloat
, andc*
means Pythoncomplex
.
Promotion between any two types is given by their join on this lattice, which generates the following binary promotion table:
b1  u1  u2  u4  u8  i1  i2  i4  i8  bf  f2  f4  f8  c4  c8  i*  f*  c*  

b1  b1  u1  u2  u4  u8  i1  i2  i4  i8  bf  f2  f4  f8  c4  c8  i8  f8  c8 
u1  u1  u1  u2  u4  u8  i2  i2  i4  i8  bf  f2  f4  f8  c4  c8  u1  f8  c8 
u2  u2  u2  u2  u4  u8  i4  i4  i4  i8  bf  f2  f4  f8  c4  c8  u2  f8  c8 
u4  u4  u4  u4  u4  u8  i8  i8  i8  i8  bf  f2  f4  f8  c4  c8  u4  f8  c8 
u8  u8  u8  u8  u8  u8  f8  f8  f8  f8  bf  f2  f4  f8  c4  c8  u8  f8  c8 
i1  i1  i2  i4  i8  f8  i1  i2  i4  i8  bf  f2  f4  f8  c4  c8  i1  f8  c8 
i2  i2  i2  i4  i8  f8  i2  i2  i4  i8  bf  f2  f4  f8  c4  c8  i2  f8  c8 
i4  i4  i4  i4  i8  f8  i4  i4  i4  i8  bf  f2  f4  f8  c4  c8  i4  f8  c8 
i8  i8  i8  i8  i8  f8  i8  i8  i8  i8  bf  f2  f4  f8  c4  c8  i8  f8  c8 
bf  bf  bf  bf  bf  bf  bf  bf  bf  bf  bf  f4  f4  f8  c4  c8  bf  bf  c4 
f2  f2  f2  f2  f2  f2  f2  f2  f2  f2  f4  f2  f4  f8  c4  c8  f2  f2  c4 
f4  f4  f4  f4  f4  f4  f4  f4  f4  f4  f4  f4  f4  f8  c4  c8  f4  f4  c4 
f8  f8  f8  f8  f8  f8  f8  f8  f8  f8  f8  f8  f8  f8  c8  c8  f8  f8  c8 
c4  c4  c4  c4  c4  c4  c4  c4  c4  c4  c4  c4  c4  c8  c4  c8  c4  c4  c4 
c8  c8  c8  c8  c8  c8  c8  c8  c8  c8  c8  c8  c8  c8  c8  c8  c8  c8  c8 
i*  i8  u1  u2  u4  u8  i1  i2  i4  i8  bf  f2  f4  f8  c4  c8  i8  f8  c8 
f*  f8  f8  f8  f8  f8  f8  f8  f8  f8  bf  f2  f4  f8  c4  c8  f8  f8  c8 
c*  c8  c8  c8  c8  c8  c8  c8  c8  c8  c4  c4  c4  c8  c4  c8  c8  c8  c8 
Jax’s type promotion rules differ from those of NumPy, as given by
numpy.promote_types()
, in those cells highlighted with a green background
in the table above. There are three key differences:
when promoting a Python scalar value against a typed JAX value of the same category, JAX always prefers the precision of the JAX value. For example,
jnp.int16(1) + 1
will returnint16
rather than promoting toint64
as in Numpy.when promoting an integer or boolean type against a floatingpoint or complex type, JAX always prefers the type of the floatingpoint or complex type.
JAX supports the bfloat16 nonstandard 16bit floating point type (
jax.numpy.bfloat16
), which is useful for neural network training. The only notable promotion behavior is with respect to IEEE754float16
, with whichbfloat16
promotes to afloat32
.
These differences are motivated by the fact that accelerator devices, such as GPUs and TPUs, either pay a significant performance penalty to use 64bit floating point types (GPUs) or do not support 64bit floating point types at all (TPUs). Classic NumPy’s promotion rules are too willing to overpromote to 64bit types, which is problematic for a system designed to run on accelerators.
JAX uses floating point promotion rules that are more suited to modern accelerator devices and are less aggressive about promoting floating point types. The promotion rules used by JAX for floatingpoint types are similar to those used by PyTorch.
Building from source¶
First, obtain the JAX source code:
git clone https://github.com/google/jax
cd jax
Building JAX involves two steps:
Building or installing
jaxlib
, the C++ support library forjax
.Installing the
jax
Python package.
Building or installing jaxlib
¶
Installing jaxlib
with pip¶
If you’re only modifying Python portions of JAX, we recommend installing
jaxlib
from a prebuilt wheel using pip:
pip install jaxlib
See the JAX readme for full guidance on pip installation (e.g., for GPU support).
Building jaxlib
from source¶
To build jaxlib
from source, you must also install some prerequisites:
a C++ compiler (g++, clang, or MSVC)
On Ubuntu or Debian you can install the necessary prerequisites with:
sudo apt install g++ python python3dev
If you are building on a Mac, make sure XCode and the XCode command line tools are installed.
See below for Windows build instructions.
Python packages:
numpy
,scipy
,six
,wheel
.The
six
package is required for during the jaxlib build only, and is not required at install time.
You can install the necessary Python dependencies using pip
:
pip install numpy scipy six wheel
To build jaxlib
with CUDA support, you can run:
python build/build.py enable_cuda
pip install e dist/*.whl # installs jaxlib (includes XLA)
See python build/build.py help
for configuration options, including ways to
specify the paths to CUDA and CUDNN, which you must have installed. Here
python
should be the name of your Python 3 interpreter; on some systems, you
may need to use python3
instead. By default, the wheel is written to the
dist/
subdirectory of the current directory.
To build jaxlib
without CUDA GPU support (CPU only), drop the enable_cuda
:
python build/build.py
pip install dist/*.whl # installs jaxlib (includes XLA)
Additional Notes for Building jaxlib
from source on Windows¶
On Windows, follow Install Visual Studio to setup a C++ toolchain. Visual Studio 2019 version 16.5 or newer is required. If you need to build with CUDA enabled, follow CUDA Installation Guide to setup CUDA environment.
It is recommended to use Anaconda or Miniconda to setup python environment.
Some targets of Bazel use bash utilities to do scripting, so MSYS2 is needed. See Installing Bazel on Windows for more details. Install the following packages:
pacman S patch realpath
Once everything is installed. Open PowerShell, and make sure MSYS2 is in the
path of the current session. Ensure bazel
, patch
and realpath
are
accessible. Activate the conda environment. The following command builds with
CUDA enabled, adjust it to whatever suitable for you:
python .\build\build.py `
enable_cuda `
cuda_path='C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1' `
cudnn_path='C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1' `
cuda_compute_capabilities='6.1' `
cuda_version='10.1' `
cudnn_version='7.6.5'
To build with debug information, add the flag bazel_options='copt=/Z7'
.
Installing jax
¶
Once jaxlib
has been installed, you can install jax
by running:
pip install e . # installs jax
To upgrade to the latest version from GitHub, just run git pull
from the JAX
repository root, and rebuild by running build.py
or upgrading jaxlib
if
necessary. You shouldn’t have to reinstall jax
because pip install e
sets up symbolic links from sitepackages into the repository.
Running the tests¶
To run all the JAX tests, we recommend using pytestxdist
, which can run tests in
parallel. First, install pytestxdist
and pytestbenchmark
by running
pip install pytestxdist pytestbenchmark
.
Then, from the repository root directory run:
pytest n auto tests
JAX generates test cases combinatorially, and you can control the number of cases that are generated and checked for each test (default is 10). The automated tests currently use 25:
JAX_NUM_GENERATED_CASES=25 pytest n auto tests
The automated tests also run the tests with default 64bit floats and ints:
JAX_ENABLE_X64=1 JAX_NUM_GENERATED_CASES=25 pytest n auto tests
You can run a more specific set of tests using pytest’s builtin selection mechanisms, or alternatively you can run a specific test file directly to see more detailed information about the cases being run:
python tests/lax_numpy_test.py num_generated_cases=5
You can skip a few tests known as slow, by passing environment variable JAX_SKIP_SLOW_TESTS=1.
To specify a particular set of tests to run from a test file, you can pass a string
or regular expression via the test_targets
flag. For example, you can run all
the tests of jax.numpy.pad
using:
python tests/lax_numpy_test.py test_targets="testPad"
The Colab notebooks are tested for errors as part of the documentation build.
Note that to run the full pmap tests on a (multicore) CPU only machine, you can run:
pytest tests/pmap_tests.py
I.e. don’t use the n auto option, since that effectively runs each test on a singlecore worker.
Type checking¶
We use mypy
to check the type hints. To check types locally the same way
as Travis checks them:
pip install mypy
mypy config=mypy.ini showerrorcodes jax
Update documentation¶
To rebuild the documentation, install several packages:
pip install r docs/requirements.txt
You must also install pandoc
in order to regenerate the notebooks.
See Install Pandoc,
or using Miniconda which
I have used successfully on the Mac: conda install c condaforge pandoc
.
If you do not want to install pandoc
then you should regenerate the documentation
without the notebooks.
You run at toplevel one of the following commands:
sphinxbuild b html docs docs/build/html # with the notebooks
sphinxbuild b html D nbsphinx_execute=never docs docs/build/html # without the notebooks
You can then see the generated documentation in
docs/build/html/index.html
.
Update notebooks¶
Open the notebook with http://colab.research.google.com (then Upload from your
local repo), update it as needed, Run all cells
then
Download ipynb
. You may want to test that it executes properly, using sphinxbuild
as
explained above.
Some of the notebooks are built automatically as part of the Travis presubmit checks and as part of the Read the docs build. The build will fail if cells raise errors. If the errors are intentional, you can either catch them, or tag the cell with raisesexceptions metadata (example PR). You have to add this metadata by hand in the .ipynb file. It will be preserved when somebody else resaves the notebook.
We exclude some notebooks from the build, e.g., because they contain long computations. See exclude_patterns in conf.py.
Documentation building on readthedocs.io¶
JAX’s autogenerated documentations is at jax.readthedocs.io.
The documentation building is controlled for the entire project by the
readthedocs JAX settings. The current settings
trigger a documentation build as soon as code is pushed to the GitHub master
branch.
For each code version, the building process is driven by the
.readthedocs.yml
and the docs/conf.py
configuration files.
For each automated documentation build you can see the documentation build logs.
If you want to test the documentation generation on Readthedocs, you can push code to the testdocs
branch. That branch is also built automatically, and you can
see the generated documentation here.
For a local test, I was able to do it in a fresh directory by replaying the commands I saw in the Readthedocs logs:
mkvirtualenv jaxdocs # A new virtualenv
mkdir jaxdocs # A new directory
cd jaxdocs
git clone nosinglebranch depth 50 https://github.com/google/jax
cd jax
git checkout force origin/testdocs
git clean d f f
workon jaxdocs
python m pip install upgrade nocachedir pip
python m pip install upgrade nocachedir I Pygments==2.3.1 setuptools==41.0.1 docutils==0.14 mock==1.0.1 pillow==5.4.1 alabaster>=0.7,<0.8,!=0.7.5 commonmark==0.8.1 recommonmark==0.5.0 'sphinx<2' 'sphinxrtdtheme<0.5' 'readthedocssphinxext<1.1'
python m pip install existsaction=w nocachedir r docs/requirements.txt
cd docs
python `which sphinxbuild` T E b html d _build/doctreesreadthedocs D language=en . _build/html
Internal APIs¶
core¶




Public API: jax package¶
Subpackages¶
jax.numpy package¶
Implements the NumPy API, using the primitives in jax.lax
.
While JAX tries to follow the NumPy API as closely as possible, sometimes JAX cannot follow NumPy exactly.
Notably, since JAX arrays are immutable, NumPy APIs that mutate arrays inplace cannot be implemented in JAX. However, often JAX is able to provide a alternative API that is purely functional. For example, instead of inplace array updates (
x[i] = y
), JAX provides an alternative pure indexed update functionjax.ops.index_update()
.NumPy is very aggressive at promoting values to
float64
type. JAX sometimes is less aggressive about type promotion.
A small number of NumPy operations that have datadependent output shapes are
incompatible with jax.jit()
compilation. The XLA compiler requires that
shapes of arrays be known at compile time. While it would be possible to provide
a JAX implementation of an API such as numpy.nonzero()
, we would be unable
to JITcompile it because the shape of its output depends on the contents of the
input data.
Not every function in NumPy is implemented; contributions are welcome!

Calculate the absolute value elementwise. 

Calculate the absolute value elementwise. 

Add arguments elementwise. 

Test whether all array elements along a given axis evaluate to True. 

Returns True if two arrays are elementwise equal within a tolerance. 

Test whether all array elements along a given axis evaluate to True. 

Return the maximum of an array or maximum along an axis. 

Return the minimum of an array or minimum along an axis. 

Return the angle of the complex argument. 

Test whether any array element along a given axis evaluates to True. 

Append values to the end of an array. 

Apply a function to 1D slices along the given axis. 

Apply a function repeatedly over multiple axes. 

Return evenly spaced values within a given interval. 

Trigonometric inverse cosine, elementwise. 

Inverse hyperbolic cosine, elementwise. 

Inverse sine, elementwise. 

Inverse hyperbolic sine elementwise. 

Trigonometric inverse tangent, elementwise. 

Elementwise arc tangent of 

Inverse hyperbolic tangent elementwise. 

Returns the indices of the maximum values along an axis. 

Returns the indices of the minimum values along an axis. 

Returns the indices that would sort an array. 

Find the indices of array elements that are nonzero, grouped by element. 

Round an array to the given number of decimals. 

Create an array. 

True if two arrays have the same shape and elements, False otherwise. 

Returns True if input arrays are shape consistent and all elements equal. 

Return the string representation of an array. 

Split an array into multiple subarrays. 

Return a string representation of the data in an array. 

Convert the input to an array. 

Convert inputs to arrays with at least one dimension. 

View inputs as arrays with at least two dimensions. 

View inputs as arrays with at least three dimensions. 

Compute the weighted average along the specified axis. 

Return the Bartlett window. 

Count number of occurrences of each value in array of nonnegative ints. 

Compute the bitwise AND of two arrays elementwise. 

Compute bitwise inversion, or bitwise NOT, elementwise. 

Compute the bitwise OR of two arrays elementwise. 

Compute the bitwise XOR of two arrays elementwise. 

Return the Blackman window. 

Assemble an ndarray from nested lists of blocks. 

Like Numpy’s broadcast_arrays but doesn’t return views. 

Broadcast an array to a new shape. 

Returns True if cast between data types can occur according to the casting rule. 

Return the cuberoot of an array, elementwise. 
alias of 


Return the ceiling of the input, elementwise. 
Abstract base class of all character string scalar types. 


Construct an array from an index array and a set of arrays to choose from. 

Clip (limit) the values in an array. 

Stack 1D arrays as columns into a 2D array. 
alias of 

Abstract base class of all complex number scalar types that are made up of floatingpoint numbers. 

The warning raised when casting a complex dtype to a real dtype. 


Return selected slices of an array along given axis. 

Join a sequence of arrays along an existing axis. 

Return the complex conjugate, elementwise. 

Return the complex conjugate, elementwise. 

Returns the discrete, linear convolution of two onedimensional sequences. 

Change the sign of x1 to that of x2, elementwise. 

Return Pearson productmoment correlation coefficients. 

Crosscorrelation of two 1dimensional sequences. 

Cosine elementwise. 

Hyperbolic cosine, elementwise. 

Counts the number of nonzero values in the array 

Estimate a covariance matrix, given data and weights. 

Return the cross product of two (arrays of) vectors. 
alias of 


Return the cumulative product of elements along a given axis. 

Return the cumulative product of elements along a given axis. 

Return the cumulative sum of the elements along a given axis. 

Convert angles from degrees to radians. 

Convert angles from radians to degrees. 

Extract a diagonal or construct a diagonal array. 

Create a twodimensional array with the flattened input as a diagonal. 

Return the indices to access the main diagonal of an array. 

Return the indices to access the main diagonal of an ndimensional array. 

Return specified diagonals. 

Calculate the nth discrete difference along the given axis. 

Return the indices of the bins to which each value in input array belongs. 

Returns a true division of the inputs, elementwise. 

Return elementwise quotient and remainder simultaneously. 

Dot product of two arrays. 
alias of 


Split array into multiple subarrays along the 3rd axis (depth). 

Stack arrays in sequence depth wise (along third axis). 

Create a data type object. 

The differences between consecutive elements of an array. 

Evaluates the Einstein summation convention on the operands. 

Evaluates the lowest cost contraction order for an einsum expression by 

Return a new array of given shape and type, filled with zeros. 

Return an array of zeros with the same shape and type as a given array. 

Return (x1 == x2) elementwise. 

Calculate the exponential of all elements in the input array. 

Calculate 2**p for all p in the input array. 

Expand the shape of an array. 

Calculate 

Return the elements of an array that satisfy some condition. 

Return a 2D array with ones on the diagonal and zeros elsewhere. 

Compute the absolute values elementwise. 

Machine limits for floating point types. 

Round to nearest integer towards zero. 

Return indices that are nonzero in the flattened version of a. 
Abstract base class of all scalar types without predefined length. 


Reverse the order of elements in an array along the given axis. 

Flip array in the left/right direction. 

Flip array in the up/down direction. 
alias of 

Abstract base class of all floatingpoint scalar types. 


First array elements raised to powers from second array, elementwise. 

Return the floor of the input, elementwise. 

Return the largest integer smaller or equal to the division of the inputs. 

Elementwise maximum of array elements. 

Elementwise minimum of array elements. 

Return the elementwise remainder of division. 

Decompose the elements of x into mantissa and twos exponent. 

Return a new array of given shape and type, filled with fill_value. 

Return a full array with the same shape and type as a given array. 

Returns the greatest common divisor of 

Return numbers spaced evenly on a log scale (a geometric progression). 

Return the gradient of an Ndimensional array. 

Return the truth value of (x1 > x2) elementwise. 

Return the truth value of (x1 >= x2) elementwise. 

Return the Hamming window. 

Return the Hanning window. 

Compute the Heaviside step function. 

Compute the histogram of a set of data. 

Function to calculate only the edges of the bins used by the histogram 

Compute the bidimensional histogram of two data samples. 

Compute the multidimensional histogram of some data. 

Split an array into multiple subarrays horizontally (columnwise). 

Stack arrays in sequence horizontally (column wise). 

Given the “legs” of a right triangle, return its hypotenuse. 

Modified Bessel function of the first kind, order 0. 

Return the identity array. 

Machine limits for integer types. 

Return the imaginary part of the complex argument. 

Test whether each element of a 1D array is also present in a second array. 

Return an array representing the indices of a grid. 
Abstract base class of all numeric scalar types with a (potentially) inexact representation of the values in its range, such as floatingpoint numbers. 


Inner product of two arrays. 
alias of 

Abstract base class of all integer scalar types. 


Onedimensional linear interpolation. 

Find the intersection of two arrays. 

Compute bitwise inversion, or bitwise NOT, elementwise. 

Returns a boolean array where two arrays are elementwise equal within a 

Returns a bool array, where True if input element is complex. 

Check for a complex type or an array of complex numbers. 

Test elementwise for finiteness (not infinity or not Not a Number). 

Calculates element in test_elements, broadcasting over element only. 

Test elementwise for positive or negative infinity. 

Test elementwise for NaN and return result as a boolean array. 

Test elementwise for negative infinity, return result as bool array. 

Test elementwise for positive infinity, return result as bool array. 

Returns a bool array, where True if input element is real. 

Return True if x is a not complex type or an array of complex numbers. 

Returns True if the type of element is a scalar type. 

Returns True if first argument is a typecode lower/equal in type hierarchy. 

Determine if the first argument is a subclass of the second argument. 

Check whether or not an object can be iterated over. 

Construct an open mesh from multiple sequences. 

Return the Kaiser window. 

Kronecker product of two arrays. 

Returns the lowest common multiple of 

Returns x1 * 2**x2, elementwise. 

Shift the bits of an integer to the left. 

Return the truth value of (x1 < x2) elementwise. 

Return the truth value of (x1 =< x2) elementwise. 

Perform an indirect stable sort using a sequence of keys. 

Return evenly spaced numbers over a specified interval. 

Load arrays or pickled objects from 

Natural logarithm, elementwise. 

Return the base 10 logarithm of the input array, elementwise. 

Return the natural logarithm of one plus the input array, elementwise. 

Base2 logarithm of x. 
Logarithm of the sum of exponentiations of the inputs. 

Logarithm of the sum of exponentiations of the inputs in base2. 


Compute the truth value of x1 AND x2 elementwise. 

Compute the truth value of NOT x elementwise. 

Compute the truth value of x1 OR x2 elementwise. 

Compute the truth value of x1 XOR x2, elementwise. 

Return numbers spaced evenly on a log scale. 

Return the indices to access (n, n) arrays, given a masking function. 

Matrix product of two arrays. 

Return the maximum of an array or maximum along an axis. 

Elementwise maximum of array elements. 

Compute the arithmetic mean along the specified axis. 

Compute the median along the specified axis. 

Return coordinate matrices from coordinate vectors. 

Return the minimum of an array or minimum along an axis. 

Elementwise minimum of array elements. 

Return elementwise remainder of division. 

Return the fractional and integral parts of an array, elementwise. 

Move axes of an array to new positions. 

Return a copy of an array sorted along the first axis. 

Multiply arguments elementwise. 

Return the indices of the maximum values in the specified axis ignoring 

Return the indices of the minimum values in the specified axis ignoring 

Return the cumulative product of array elements over a given axis treating Not a 

Return the cumulative sum of array elements over a given axis treating Not a 

Return the maximum of an array or maximum along an axis, ignoring any 

Compute the arithmetic mean along the specified axis, ignoring NaNs. 

Compute the median along the specified axis, while ignoring NaNs. 

Return minimum of an array or minimum along an axis, ignoring any NaNs. 

Compute the qth percentile of the data along the specified axis, 

Return the product of array elements over a given axis treating Not a 

Compute the qth quantile of the data along the specified axis, 

Compute the standard deviation along the specified axis, while 

Return the sum of array elements over a given axis treating Not a 

Replace NaN with zero and infinity with large finite numbers (default 

Compute the variance along the specified axis, while ignoring NaNs. 



Return the number of dimensions of an array. 

Numerical negative, elementwise. 

Return the next floatingpoint value after x1 towards x2, elementwise. 

Return the indices of the elements that are nonzero. 

Return (x1 != x2) elementwise. 
Abstract base class of all numeric scalar types. 

Any Python object. 


Return a new array of given shape and type, filled with ones. 

Return an array of ones with the same shape and type as a given array. 

Compute the outer product of two vectors. 

Packs the elements of a binaryvalued array into bits in a uint8 array. 

Pad an array. 

Compute the qth percentile of the data along the specified axis. 

Evaluate a piecewisedefined function. 

Find the sum of two polynomials. 

Return the derivative of the specified order of a polynomial. 

Find the product of two polynomials. 

Difference (subtraction) of two polynomials. 

Evaluate a polynomial at specific values. 

Numerical positive, elementwise. 

First array elements raised to powers from second array, elementwise. 

Return the product of array elements over a given axis. 

Return the product of array elements over a given axis. 

Returns the type to which a binary operation should cast its arguments. 

Range of values (maximum  minimum) along an axis. 

Compute the qth quantile of the data along the specified axis. 

Convert angles from radians to degrees. 

Convert angles from degrees to radians. 

Return a contiguous flattened array. 

Converts a tuple of index arrays into an array of flat 

Return the real part of the complex argument. 

Return the reciprocal of the argument, elementwise. 

Return elementwise remainder of division. 

Repeat elements of an array. 

Gives a new shape to an array without changing its data. 

Returns the type that results from applying the NumPy 

Shift the bits of an integer to the right. 

Round elements of the array to the nearest integer. 

Roll array elements along a given axis. 

Roll the specified axis backwards, until it lies in a given position. 

Return the roots of a polynomial with coefficients given in p. 

Rotate an array by 90 degrees in the plane specified by axes. 

Round an array to the given number of decimals. 

Stack arrays in sequence vertically (row wise). 

Save an array to a binary file in NumPy 

Save several arrays into a single file in uncompressed 

Find indices where elements should be inserted to maintain order. 

Return an array drawn from elements in choicelist, depending on conditions. 

Set printing options. 

Find the set difference of two arrays. 

Return the shape of an array. 

Returns an elementwise indication of the sign of a number. 

Returns elementwise True where signbit is set (less than zero). 
Abstract base class of all signed integer scalar types. 


Trigonometric sine, elementwise. 

Return the sinc function. 
alias of 


Hyperbolic sine, elementwise. 

Return the number of elements along a given axis. 

Test whether any array element along a given axis evaluates to True. 

Return a sorted copy of an array. 

Sort a complex array using the real part first, then the imaginary part. 

Split an array into multiple subarrays as views into ary. 

Return the nonnegative squareroot of an array, elementwise. 

Return the elementwise square of the input. 

Remove singledimensional entries from the shape of an array. 

Join a sequence of arrays along a new axis. 

Compute the standard deviation along the specified axis. 

Subtract arguments, elementwise. 

Sum of array elements over a given axis. 

Interchange two axes of an array. 

Take elements from an array along an axis. 

Take values from the input array by matching 1d index and data slices. 

Compute tangent elementwise. 

Compute hyperbolic tangent elementwise. 

Compute tensor dot product along specified axes. 

Construct an array by repeating A the number of times given by reps. 

Return the sum along diagonals of the array. 

Reverse or permute the axes of an array; returns the modified array. 

Integrate along the given axis using the composite trapezoidal rule. 

An array with ones at and below the given diagonal and zeros elsewhere. 

Lower triangle of an array. 

Return the indices for the lowertriangle of an (n, m) array. 

Return the indices for the lowertriangle of arr. 

Trim the leading and/or trailing zeros from a 1D array or sequence. 

Upper triangle of an array. 

Return the indices for the uppertriangle of an (n, m) array. 

Return the indices for the uppertriangle of arr. 

Returns a true division of the inputs, elementwise. 

Return the truncated value of the input, elementwise. 

Find the unique elements of an array. 

Unpacks elements of a uint8 array into a binaryvalued output array. 

Converts a flat index or array of flat indices into a tuple 
Abstract base class of all unsigned integer scalar types. 


Unwrap by changing deltas between values to 2*pi complement. 

Generate a Vandermonde matrix. 

Compute the variance along the specified axis. 

Return the dot product of two vectors. 

Define a vectorized function with broadcasting. 

Split an array into multiple subarrays vertically (rowwise). 

Stack arrays in sequence vertically (row wise). 

Return elements chosen from x or y depending on condition. 

Return a new array of given shape and type, filled with zeros. 

Return an array of zeros with the same shape and type as a given array. 
jax.numpy.fft¶

Compute the onedimensional discrete Fourier Transform. 

Compute the 2dimensional discrete Fourier Transform 

Return the Discrete Fourier Transform sample frequencies. 

Compute the Ndimensional discrete Fourier Transform. 

Shift the zerofrequency component to the center of the spectrum. 

Compute the FFT of a signal that has Hermitian symmetry, i.e., a real 

Compute the onedimensional inverse discrete Fourier Transform. 

Compute the 2dimensional inverse discrete Fourier Transform. 

Compute the Ndimensional inverse discrete Fourier Transform. 

The inverse of fftshift. Although identical for evenlength x, the 

Compute the inverse FFT of a signal that has Hermitian symmetry. 

Compute the inverse of the npoint DFT for real input. 

Compute the 2dimensional inverse FFT of a real array. 

Compute the inverse of the Ndimensional FFT of real input. 

Compute the onedimensional discrete Fourier Transform for real input. 

Compute the 2dimensional FFT of a real array. 

Return the Discrete Fourier Transform sample frequencies 

Compute the Ndimensional discrete Fourier Transform for real input. 
jax.numpy.linalg¶

Cholesky decomposition. 

Compute the condition number of a matrix. 
Compute the determinant of an array. 


Compute the eigenvalues and right eigenvectors of a square array. 

Return the eigenvalues and eigenvectors of a complex Hermitian 

Compute the eigenvalues of a general matrix. 