JAX reference documentationÂ¶
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more.
For an introduction to JAX, start at the JAX GitHub page.
JAX QuickstartÂ¶
JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for highperformance machine learning research.
With its updated version of Autograd, JAX can automatically differentiate native Python and NumPy code. It can differentiate through a large subset of Pythonâ€™s features, including loops, ifs, recursion, and closures, and it can even take derivatives of derivatives of derivatives. It supports reversemode as well as forwardmode differentiation, and the two can be composed arbitrarily to any order.
Whatâ€™s new is that JAX uses XLA to compile and run your NumPy code on accelerators, like GPUs and TPUs. Compilation happens under the hood by default, with library calls getting justintime compiled and executed. But JAX even lets you justintime compile your own Python functions into XLAoptimized kernels using a onefunction API. Compilation and automatic differentiation can be composed arbitrarily, so you can express sophisticated algorithms and get maximal performance without having to leave Python.
[1]:
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random
Multiplying MatricesÂ¶
Weâ€™ll be generating random data in the following examples. One big difference between NumPy and JAX is how you generate random numbers. For more details, see the readme.
[2]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/testdocs/lib/python3.7/sitepackages/jax/lib/xla_bridge.py:123: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
[0.372111 0.26423106 0.18252774 0.7368198 0.44030386 0.15214427
0.6713536 0.5908642 0.73168874 0.5673025 ]
Letâ€™s dive right in and multiply two big matrices.
[3]:
size = 3000
x = random.normal(key, (size, size), dtype=np.float32)
%timeit np.dot(x, x.T).block_until_ready() # runs on the GPU
566 ms Â± 135 ms per loop (mean Â± std. dev. of 7 runs, 1 loop each)
We added that block_until_ready
because JAX uses asynchronous execution by default.
JAX NumPy functions work on regular NumPy arrays.
[4]:
import numpy as onp # original CPUbacked NumPy
x = onp.random.normal(size=(size, size)).astype(onp.float32)
%timeit np.dot(x, x.T).block_until_ready()
1.08 s Â± 464 ms per loop (mean Â± std. dev. of 7 runs, 1 loop each)
Thatâ€™s slower because it has to transfer data to the GPU every time. You can ensure that an NDArray is backed by device memory using device_put
.
[5]:
from jax import device_put
x = onp.random.normal(size=(size, size)).astype(onp.float32)
x = device_put(x)
%timeit np.dot(x, x.T).block_until_ready()
645 ms Â± 27.2 ms per loop (mean Â± std. dev. of 7 runs, 1 loop each)
The output of device_put
still acts like an NDArray, but it only copies values back to the CPU when theyâ€™re needed for printing, plotting, saving to disk, branching, etc. The behavior of device_put
is equivalent to the function jit(lambda x: x)
, but itâ€™s faster.
If you have a GPU (or TPU!) these calls run on the accelerator and have the potential to be much faster than on CPU.
[6]:
x = onp.random.normal(size=(size, size)).astype(onp.float32)
%timeit onp.dot(x, x.T)
547 ms Â± 134 ms per loop (mean Â± std. dev. of 7 runs, 1 loop each)
JAX is much more than just a GPUbacked NumPy. It also comes with a few program transformations that are useful when writing numerical code. For now, thereâ€™s three main ones:
jit
, for speeding up your codegrad
, for taking derivativesvmap
, for automatic vectorization or batching.
Letâ€™s go over these, onebyone. Weâ€™ll also end up composing these in interesting ways.
Using jit
to speed up functionsÂ¶
JAX runs transparently on the GPU (or CPU, if you donâ€™t have one, and TPU coming soon!). However, in the above example, JAX is dispatching kernels to the GPU one operation at a time. If we have a sequence of operations, we can use the @jit
decorator to compile multiple operations together using XLA. Letâ€™s try that.
[7]:
def selu(x, alpha=1.67, lmbda=1.05):
return lmbda * np.where(x > 0, x, alpha * np.exp(x)  alpha)
x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()
3.77 ms Â± 627 Âµs per loop (mean Â± std. dev. of 7 runs, 100 loops each)
We can speed it up with @jit
, which will jitcompile the first time selu
is called and will be cached thereafter.
[8]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()
1.28 ms Â± 130 Âµs per loop (mean Â± std. dev. of 7 runs, 1000 loops each)
Taking derivatives with grad
Â¶
In addition to evaluating numerical functions, we also want to transform them. One transformation is automatic differentiation. In JAX, just like in Autograd, you can compute gradients with the grad
function.
[9]:
def sum_logistic(x):
return np.sum(1.0 / (1.0 + np.exp(x)))
x_small = np.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))
[0.25 0.19661197 0.10499357]
Letâ€™s verify with finite differences that our result is correct.
[10]:
def first_finite_differences(f, x):
eps = 1e3
return np.array([(f(x + eps * v)  f(x  eps * v)) / (2 * eps)
for v in np.eye(len(x))])
print(first_finite_differences(sum_logistic, x_small))
[0.24998187 0.1964569 0.10502338]
Taking derivatives is as easy as calling grad
. grad
and jit
compose and can be mixed arbitrarily. In the above example we jitted sum_logistic
and then took its derivative. We can go further:
[11]:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))
0.03532558
For more advanced autodiff, you can use jax.vjp
for reversemode vectorJacobian products and jax.jvp
for forwardmode Jacobianvector products. The two can be composed arbitrarily with one another, and with other JAX transformations. Hereâ€™s one way to compose them to make a function that efficiently computes full Hessian matrices:
[12]:
from jax import jacfwd, jacrev
def hessian(fun):
return jit(jacfwd(jacrev(fun)))
Autovectorization with vmap
Â¶
JAX has one more transformation in its API that you might find useful: vmap
, the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a functionâ€™s primitive operations for better performance. When composed with jit
, it can be just as fast as adding the batch dimensions by hand.
Weâ€™re going to work with a simple example, and promote matrixvector products into matrixmatrix products using vmap
. Although this is easy to do by hand in this specific case, the same technique can apply to more complicated functions.
[13]:
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))
def apply_matrix(v):
return np.dot(mat, v)
Given a function such as apply_matrix
, we can loop over a batch dimension in Python, but usually the performance of doing so is poor.
[14]:
def naively_batched_apply_matrix(v_batched):
return np.stack([apply_matrix(v) for v in v_batched])
print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
Naively batched
3.89 ms Â± 275 Âµs per loop (mean Â± std. dev. of 7 runs, 100 loops each)
We know how to batch this operation manually. In this case, np.dot
handles extra batch dimensions transparently.
[15]:
@jit
def batched_apply_matrix(v_batched):
return np.dot(v_batched, mat.T)
print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()
Manually batched
221 Âµs Â± 13.1 Âµs per loop (mean Â± std. dev. of 7 runs, 1000 loops each)
However, suppose we had a more complicated function without batching support. We can use vmap
to add batching support automatically.
[16]:
@jit
def vmap_batched_apply_matrix(v_batched):
return vmap(apply_matrix)(v_batched)
print('Autovectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
Autovectorized with vmap
250 Âµs Â± 23.7 Âµs per loop (mean Â± std. dev. of 7 runs, 1000 loops each)
Of course, vmap
can be arbitrarily composed with jit
, grad
, and any other JAX transformation.
This is just a taste of what JAX can do. Weâ€™re really excited to see what you do with it!
The Autodiff CookbookÂ¶
alexbw@, mattjj@
JAX has a pretty general automatic differentiation system. In this notebook, weâ€™ll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics.
[1]:
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random
key = random.PRNGKey(0)
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/testdocs/lib/python3.7/sitepackages/jax/lib/xla_bridge.py:123: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
GradientsÂ¶
Starting with grad
Â¶
You can differentiate a function with grad
:
[2]:
grad_tanh = grad(np.tanh)
print(grad_tanh(2.0))
0.070650816
grad
takes a function and returns a function. If you have a Python function f
that evaluates the mathematical function \(f\), then grad(f)
is a Python function that evaluates the mathematical function \(\nabla f\). That means grad(f)(x)
represents the value \(\nabla f(x)\).
Since grad
operates on functions, you can apply it to its own output to differentiate as many times as you like:
[3]:
print(grad(grad(np.tanh))(2.0))
print(grad(grad(grad(np.tanh)))(2.0))
0.13621867
0.25265405
Letâ€™s look at computing gradients with grad
in a linear logistic regression model. First, the setup:
[4]:
def sigmoid(x):
return 0.5 * (np.tanh(x / 2) + 1)
# Outputs probability of a label being true.
def predict(W, b, inputs):
return sigmoid(np.dot(inputs, W) + b)
# Build a toy dataset.
inputs = np.array([[0.52, 1.12, 0.77],
[0.88, 1.08, 0.15],
[0.52, 0.06, 1.30],
[0.74, 2.49, 1.39]])
targets = np.array([True, True, False, True])
# Training loss is the negative loglikelihood of the training examples.
def loss(W, b):
preds = predict(W, b, inputs)
label_probs = preds * targets + (1  preds) * (1  targets)
return np.sum(np.log(label_probs))
# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())
Use the grad
function with its argnums
argument to differentiate a function with respect to positional arguments.
[5]:
# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print('W_grad', W_grad)
# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print('W_grad', W_grad)
# But we can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print('b_grad', b_grad)
# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print('W_grad', W_grad)
print('b_grad', b_grad)
W_grad [0.16965581 0.8774648 1.4901345 ]
W_grad [0.16965581 0.8774648 1.4901345 ]
b_grad 0.29227245
W_grad [0.16965581 0.8774648 1.4901345 ]
b_grad 0.29227245
This grad
API has a direct correspondence to the excellent notation in Spivakâ€™s classic Calculus on Manifolds (1965), also used in Sussman and Wisdomâ€™s *Structure and Interpretation of Classical Mechanics* (2015) and their *Functional Differential Geometry* (2013). Both books are openaccess. See in particular the â€śPrologueâ€ť section
of Functional Differential Geometry for a defense of this notation.
Essentially, when using the argnums
argument, if f
is a Python function for evaluating the mathematical function \(f\), then the Python expression grad(f, i)
evaluates to a Python function for evaluating \(\partial_i f\).
Differentiating with respect to nested lists, tuples, and dictsÂ¶
Differentiating with respect to standard Python containers just works, so use tuples, lists, and dicts (and arbitrary nesting) however you like.
[6]:
def loss2(params_dict):
preds = predict(params_dict['W'], params_dict['b'], inputs)
label_probs = preds * targets + (1  preds) * (1  targets)
return np.sum(np.log(label_probs))
print(grad(loss2)({'W': W, 'b': b}))
{'W': DeviceArray([0.16965581, 0.8774648 , 1.4901345 ], dtype=float32), 'b': DeviceArray(0.29227245, dtype=float32)}
You can register your own container types to work with not just grad
but all the JAX transformations (jit
, vmap
, etc.).
Evaluate a function and its gradient using value_and_grad
Â¶
Another convenient function is value_and_grad
for efficiently computing both a functionâ€™s value as well as its gradientâ€™s value:
[7]:
from jax import value_and_grad
loss_value, Wb_grad = value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b))
loss value 3.0519395
loss value 3.0519395
Checking against numerical differencesÂ¶
A great thing about derivatives is that theyâ€™re straightforward to check with finite differences:
[8]:
# Set a step size for finite differences calculations
eps = 1e4
# Check b_grad with scalar finite differences
b_grad_numerical = (loss(W, b + eps / 2.)  loss(W, b  eps / 2.)) / eps
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', grad(loss, 1)(W, b))
# Check W_grad with finite differences in a random direction
key, subkey = random.split(key)
vec = random.normal(subkey, W.shape)
unitvec = vec / np.sqrt(np.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b)  loss(W  eps / 2. * unitvec, b)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', np.vdot(grad(loss)(W, b), unitvec))
b_grad_numerical 0.29325485
b_grad_autodiff 0.29227245
W_dirderiv_numerical 0.19788742
W_dirderiv_autodiff 0.19909072
JAX provides a simple convenience function that does essentially the same thing, but checks up to any order of differentiation that you like:
[9]:
from jax.test_util import check_grads
check_grads(loss, (W, b), order=2) # check up to 2nd order derivatives
Hessianvector products with grad
ofgrad
Â¶
One thing we can do with higherorder grad
is build a Hessianvector product function. (Later on weâ€™ll write an even more efficient implementation that mixes both forward and reversemode, but this one will use pure reversemode.)
A Hessianvector product function can be useful in a truncated Newton ConjugateGradient algorithm for minimizing smooth convex functions, or for studying the curvature of neural network training objectives (e.g. 1, 2, 3, 4).
For a scalarvalued function \(f : \mathbb{R}^n \to \mathbb{R}\), the Hessian at a point \(x \in \mathbb{R}^n\) is written as \(\partial^2 f(x)\). A Hessianvector product function is then able to evaluate
\(\qquad v \mapsto \partial^2 f(x) \cdot v\)
for any \(v \in \mathbb{R}^n\).
The trick is not to instantiate the full Hessian matrix: if \(n\) is large, perhaps in the millions or billions in the context of neural networks, then that might be impossible to store.
Luckily, grad
already gives us a way to write an efficient Hessianvector product function. We just have to use the identity
\(\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)\),
where \(g(x) = \partial f(x) \cdot v\) is a new scalarvalued function that dots the gradient of \(f\) at \(x\) with the vector \(v\). Notice that weâ€™re only ever differentiating scalarvalued functions of vectorvalued arguments, which is exactly where we know grad
is efficient.
In JAX code, we can just write this:
[10]:
def hvp(f, x, v):
return grad(lambda x: np.vdot(grad(f)(x), v))
This example shows that you can freely use lexical closure, and JAX will never get perturbed or confused.
Weâ€™ll check this implementation a few cells down, once we see how to compute dense Hessian matrices. Weâ€™ll also write an even better version that uses both forwardmode and reversemode.
Jacobians and Hessians using jacfwd
and jacrev
Â¶
You can compute full Jacobian matrices using the jacfwd
and jacrev
functions:
[11]:
from jax import jacfwd, jacrev
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
J = jacfwd(f)(W)
print("jacfwd result, with shape", J.shape)
print(J)
J = jacrev(f)(W)
print("jacrev result, with shape", J.shape)
print(J)
jacfwd result, with shape (4, 3)
[[ 0.05981752 0.12883773 0.08857594]
[ 0.04015911 0.04928619 0.0068453 ]
[ 0.12188289 0.01406341 0.3047072 ]
[ 0.00140426 0.00472514 0.00263773]]
jacrev result, with shape (4, 3)
[[ 0.05981752 0.12883773 0.08857594]
[ 0.04015911 0.04928619 0.0068453 ]
[ 0.12188289 0.01406341 0.3047072 ]
[ 0.00140426 0.00472514 0.00263773]]
These two functions compute the same values (up to machine numerics), but differ in their implementation: jacfwd
uses forwardmode automatic differentiation, which is more efficient for â€śtallâ€ť Jacobian matrices, while jacrev
uses reversemode, which is more efficient for â€śwideâ€ť Jacobian matrices. For matrices that are nearsquare, jacfwd
probably has an edge over jacrev
.
You can also use jacfwd
and jacrev
with container types:
[12]:
def predict_dict(params, inputs):
return predict(params['W'], params['b'], inputs)
J_dict = jacrev(predict_dict)({'W': W, 'b': b}, inputs)
for k, v in J_dict.items():
print("Jacobian from {} to logits is".format(k))
print(v)
Jacobian from W to logits is
[[ 0.05981752 0.12883773 0.08857594]
[ 0.04015911 0.04928619 0.0068453 ]
[ 0.12188289 0.01406341 0.3047072 ]
[ 0.00140426 0.00472514 0.00263773]]
Jacobian from b to logits is
[0.11503369 0.04563536 0.23439017 0.00189765]
For more details on forward and reversemode, as well as how to implement jacfwd
and jacrev
as efficiently as possible, read on!
Using a composition of two of these functions gives us a way to compute dense Hessian matrices:
[13]:
def hessian(f):
return jacfwd(jacrev(f))
H = hessian(f)(W)
print("hessian, with shape", H.shape)
print(H)
hessian, with shape (4, 3, 3)
[[[ 0.02285464 0.04922538 0.03384245]
[ 0.04922538 0.10602391 0.07289143]
[ 0.03384245 0.07289144 0.05011286]]
[[0.03195212 0.03921397 0.00544638]
[ 0.03921397 0.04812624 0.0066842 ]
[0.00544638 0.0066842 0.00092836]]
[[0.01583708 0.00182736 0.0395927 ]
[0.00182736 0.00021085 0.00456839]
[ 0.0395927 0.00456839 0.09898175]]
[[0.0010352 0.00348331 0.0019445 ]
[ 0.00348331 0.01172087 0.00654297]
[0.0019445 0.00654297 0.0036525 ]]]
This shape makes sense: if we start with a function \(f : \mathbb{R}^n \to \mathbb{R}^m\), then at a point \(x \in \mathbb{R}^n\) we expect to get the shapes
 \(f(x) \in \mathbb{R}^m\), the value of \(f\) at \(x\),
 \(\partial f(x) \in \mathbb{R}^{m \times n}\), the Jacobian matrix at \(x\),
 \(\partial^2 f(x) \in \mathbb{R}^{m \times n \times n}\), the Hessian at \(x\),
and so on.
To implement hessian
, we could have used jacrev(jacrev(f))
or jacrev(jacfwd(f))
or any other composition of the two. But forwardoverreverse is typically the most efficient. Thatâ€™s because in the inner Jacobian computation weâ€™re often differentiating a function wide Jacobian (maybe like a loss function \(f : \mathbb{R}^n \to \mathbb{R}\)), while in the outer Jacobian computation weâ€™re differentiating a function with a square Jacobian (since
\(\nabla f : \mathbb{R}^n \to \mathbb{R}^n\)), which is where forwardmode wins out.
How itâ€™s made: two foundational autodiff functionsÂ¶
JacobianVector products (JVPs, aka forwardmode autodiff)Â¶
JAX includes efficient and general implementations of both forward and reversemode automatic differentiation. The familiar grad
function is built on reversemode, but to explain the difference in the two modes, and when each can be useful, we need a bit of math background.
JVPs in mathÂ¶
Mathematically, given a function \(f : \mathbb{R}^n \to \mathbb{R}^m\), the Jacobian of \(f\) evaluated at an input point \(x \in \mathbb{R}^n\), denoted \(\partial f(x)\), is often thought of as a matrix in \(\mathbb{R}^m \times \mathbb{R}^n\):
\(\qquad \partial f(x) \in \mathbb{R}^{m \times n}\).
But we can also think of \(\partial f(x)\) as a linear map, which maps the tangent space of the domain of \(f\) at the point \(x\) (which is just another copy of \(\mathbb{R}^n\)) to the tangent space of the codomain of \(f\) at the point \(f(x)\) (a copy of \(\mathbb{R}^m\)):
\(\qquad \partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\).
This map is called the pushforward map of \(f\) at \(x\). The Jacobian matrix is just the matrix for this linear map in a standard basis.
If we donâ€™t commit to one specific input point \(x\), then we can think of the function \(\partial f\) as first taking an input point and returning the Jacobian linear map at that input point:
\(\qquad \partial f : \mathbb{R}^n \to \mathbb{R}^n \to \mathbb{R}^m\).
In particular, we can uncurry things so that given input point \(x \in \mathbb{R}^n\) and a tangent vector \(v \in \mathbb{R}^n\), we get back an output tangent vector in \(\mathbb{R}^m\). We call that mapping, from \((x, v)\) pairs to output tangent vectors, the Jacobianvector product, and write it as
\(\qquad (x, v) \mapsto \partial f(x) v\)
JVPs in JAX codeÂ¶
Back in Python code, JAXâ€™s jvp
function models this transformation. Given a Python function that evaluates \(f\), JAXâ€™s jvp
is a way to get a Python function for evaluating \((x, v) \mapsto (f(x), \partial f(x) v)\).
[14]:
from jax import jvp
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
key, subkey = random.split(key)
v = random.normal(subkey, W.shape)
# Push forward the vector `v` along `f` evaluated at `W`
y, u = jvp(f, (W,), (v,))
In terms of Haskelllike type signatures, we could write
jvp :: (a > b) > a > T a > (b, T b)
where we use T a
to denote the type of the tangent space for a
. In words, jvp
takes as arguments a function of type a > b
, a value of type a
, and a tangent vector value of type T a
. It gives back a pair consisting of a value of type b
and an output tangent vector of type T b
.
The jvp
transformed function is evaluated much like the original function, but paired up with each primal value of type a
it pushes along tangent values of type T a
. For each primitive numerical operation that the original function would have applied, the jvp
transformed function executes a â€śJVP ruleâ€ť for that primitive that both evaluates the primitive on the primals and applies the primitiveâ€™s JVP at those primal values.
That evaluation strategy has some immediate implications about computational complexity: since we evaluate JVPs as we go, we donâ€™t need to store anything for later, and so the memory cost is independent of the depth of the computation. In addition, the FLOP cost of the jvp
transformed function is about 3x the cost of just evaluating the function (one unit of work for evaluating the original function, for example sin(x)
; one unit for linearizing, like cos(x)
; and one unit for applying
the linearized function to a vector, like cos_x * v
). Put another way, for a fixed primal point \(x\), we can evaluate \(v \mapsto \partial f(x) \cdot v\) for about the same marginal cost as evaluating \(f\).
That memory complexity sounds pretty compelling! So why donâ€™t we see forwardmode very often in machine learning?
To answer that, first think about how you could use a JVP to build a full Jacobian matrix. If we apply a JVP to a onehot tangent vector, it reveals one column of the Jacobian matrix, corresponding to the nonzero entry we fed in. So we can build a full Jacobian one column at a time, and to get each column costs about the same as one function evaluation. That will be efficient for functions with â€śtallâ€ť Jacobians, but inefficient for â€śwideâ€ť Jacobians.
If youâ€™re doing gradientbased optimization in machine learning, you probably want to minimize a loss function from parameters in \(\mathbb{R}^n\) to a scalar loss value in \(\mathbb{R}\). That means the Jacobian of this function is a very wide matrix: \(\partial f(x) \in \mathbb{R}^{1 \times n}\), which we often identify with the Gradient vector \(\nabla f(x) \in \mathbb{R}^n\). Building that matrix one column at a time, with each call taking a similar number of FLOPs to evaluating the original function, sure seems inefficient! In particular, for training neural networks, where \(f\) is a training loss function and \(n\) can be in the millions or billions, this approach just wonâ€™t scale.
To do better for functions like this, we just need to use reversemode.
VectorJacobian products (VJPs, aka reversemode autodiff)Â¶
Where forwardmode gives us back a function for evaluating Jacobianvector products, which we can then use to build Jacobian matrices one column at a time, reversemode is a way to get back a function for evaluating vectorJacobian products (equivalently Jacobiantransposevector products), which we can use to build Jacobian matrices one row at a time.
VJPs in mathÂ¶
Letâ€™s again consider a function \(f : \mathbb{R}^n \to \mathbb{R}^m\). Starting from our notation for JVPs, the notation for VJPs is pretty simple:
\(\qquad (x, v) \mapsto v \partial f(x)\),
where \(v\) is an element of the cotangent space of \(f\) at \(x\) (isomorphic to another copy of \(\mathbb{R}^m\)). When being rigorous, we should think of \(v\) as a linear map \(v : \mathbb{R}^m \to \mathbb{R}\), and when we write \(v \partial f(x)\) we mean function composition \(v \circ \partial f(x)\), where the types work out because \(\partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\). But in the common case we can identify \(v\) with a vector in \(\mathbb{R}^m\) and use the two almost interchageably, just like we might sometimes flip between â€ścolumn vectorsâ€ť and â€śrow vectorsâ€ť without much comment.
With that identification, we can alternatively think of the linear part of a VJP as the transpose (or adjoint conjugate) of the linear part of a JVP:
\(\qquad (x, v) \mapsto \partial f(x)^\mathsf{T} v\).
For a given point \(x\), we can write the signature as
\(\qquad \partial f(x)^\mathsf{T} : \mathbb{R}^m \to \mathbb{R}^n\).
The corresponding map on cotangent spaces is often called the pullback of \(f\) at \(x\). The key for our purposes is that it goes from something that looks like the output of \(f\) to something that looks like the input of \(f\), just like we might expect from a transposed linear function.
VJPs in JAX codeÂ¶
Switching from math back to Python, the JAX function vjp
can take a Python function for evaluating \(f\) and give us back a Python function for evaluating the VJP \((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\).
[15]:
from jax import vjp
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
y, vjp_fun = vjp(f, W)
key, subkey = random.split(key)
u = random.normal(subkey, y.shape)
# Pull back the covector `u` along `f` evaluated at `W`
v = vjp_fun(u)
In terms of Haskelllike type signatures, we could write
vjp :: (a > b) > a > (b, CT b > CT a)
where we use CT a
to denote the type for the cotangent space for a
. In words, vjp
takes as arguments a function of type a > b
and a point of type a
, and gives back a pair consisting of a value of type b
and a linear map of type CT b > CT a
.
This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating \((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\) is only about three times the cost of evaluating \(f\). In particular, if we want the gradient of a function \(f : \mathbb{R}^n \to \mathbb{R}\), we can do it in just one call. Thatâ€™s how grad
is efficient for gradientbased optimization, even for objectives like neural network training loss functions on millions or
billions of parameters.
Thereâ€™s a cost, though: though the FLOPs are friendly, memory scales with the depth of the computation. Also, the implementation is traditionally more complex than that of forwardmode, though JAX has some tricks up its sleeve (thatâ€™s a story for a future notebook!).
For more on how reversemode works, see this tutorial video from the Deep Learning Summer School in 2017.
Hessianvector products using both forward and reversemodeÂ¶
In a previous section, we implemented a Hessianvector product function just using reversemode:
[16]:
def hvp(f, x, v):
return grad(lambda x: np.vdot(grad(f)(x), v))
Thatâ€™s efficient, but we can do even better and save some memory by using forwardmode together with reversemode.
Mathematically, given a function \(f : \mathbb{R}^n \to \mathbb{R}\) to differentiate, a point \(x \in \mathbb{R}^n\) at which to linearize the function, and a vector \(v \in \mathbb{R}^n\), the Hessianvector product function we want is
\((x, v) \mapsto \partial^2 f(x) v\)
Consider the helper function \(g : \mathbb{R}^n \to \mathbb{R}^n\) defined to be the derivative (or gradient) of \(f\), namely \(g(x) = \partial f(x)\). All we need is its JVP, since that will give us
\((x, v) \mapsto \partial g(x) v = \partial^2 f(x) v\).
We can translate that almost directly into code:
[17]:
from jax import jvp, grad
# forwardoverreverse
def hvp(f, primals, tangents):
return jvp(grad(f), primals, tangents)[1]
Even better, since we didnâ€™t have to call np.dot
directly, this hvp
function works with arrays of any shape and with arbitrary container types (like vectors stored as nested lists/dicts/tuples), and doesnâ€™t even have a dependence on jax.numpy
.
Hereâ€™s an example of how to use it:
[18]:
def f(X):
return np.sum(np.tanh(X)**2)
key, subkey1, subkey2 = random.split(key, 3)
X = random.normal(subkey1, (30, 40))
V = random.normal(subkey2, (30, 40))
ans1 = hvp(f, (X,), (V,))
ans2 = np.tensordot(hessian(f)(X), V, 2)
print(np.allclose(ans1, ans2, 1e4, 1e4))
True
Another way you might consider writing this is using reverseoverforward:
[19]:
# reverseoverforward
def hvp_revfwd(f, primals, tangents):
g = lambda primals: jvp(f, primals, tangents)[1]
return grad(g)(primals)
Thatâ€™s not quite as good, though, because forwardmode has less overhead than reversemode, and since the outer differentiation operator here has to differentiate a larger computation than the inner one, keeping forwardmode on the outside works best:
[20]:
# reverseoverreverse, only works for single arguments
def hvp_revrev(f, primals, tangents):
x, = primals
v, = tangents
return grad(lambda x: np.vdot(grad(f)(x), v))(x)
print("Forward over reverse")
%timeit n10 r3 hvp(f, (X,), (V,))
print("Reverse over forward")
%timeit n10 r3 hvp_revfwd(f, (X,), (V,))
print("Reverse over reverse")
%timeit n10 r3 hvp_revrev(f, (X,), (V,))
print("Naive full Hessian materialization")
%timeit n10 r3 np.tensordot(hessian(f)(X), V, 2)
Forward over reverse
9.74 ms Â± 421 Âµs per loop (mean Â± std. dev. of 3 runs, 10 loops each)
Reverse over forward
14.3 ms Â± 2.96 ms per loop (mean Â± std. dev. of 3 runs, 10 loops each)
Reverse over reverse
15.9 ms Â± 1.86 ms per loop (mean Â± std. dev. of 3 runs, 10 loops each)
Naive full Hessian materialization
29.8 ms Â± 3.82 ms per loop (mean Â± std. dev. of 3 runs, 10 loops each)
Composing VJPs, JVPs, and vmap
Â¶
JacobianMatrix and MatrixJacobian productsÂ¶
Now that we have jvp
and vjp
transformations that give us functions to pushforward or pullback single vectors at a time, we can use JAXâ€™s `vmap
transformation <https://github.com/google/jax#autovectorizationwithvmap>`__ to push and pull entire bases at once. In particular, we can use that to write fast matrixJacobian and Jacobianmatrix products.
[21]:
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
# Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`.
# First, use a list comprehension to loop over rows in the matrix M.
def loop_mjp(f, x, M):
y, vjp_fun = vjp(f, x)
return np.vstack([vjp_fun(mi) for mi in M])
# Now, use vmap to build a computation that does a single fast matrixmatrix
# multiply, rather than an outer loop over vectormatrix multiplies.
def vmap_mjp(f, x, M):
y, vjp_fun = vjp(f, x)
outs, = vmap(vjp_fun)(M)
return outs
key = random.PRNGKey(0)
num_covecs = 128
U = random.normal(key, (num_covecs,) + y.shape)
loop_vs = loop_mjp(f, W, M=U)
print('Nonvmapped MatrixJacobian product')
%timeit n10 r3 loop_mjp(f, W, M=U)
print('\nVmapped MatrixJacobian product')
vmap_vs = vmap_mjp(f, W, M=U)
%timeit n10 r3 vmap_mjp(f, W, M=U)
assert np.allclose(loop_vs, vmap_vs), 'Vmap and nonvmapped MatrixJacobian Products should be identical'
Nonvmapped MatrixJacobian product
129 ms Â± 3 ms per loop (mean Â± std. dev. of 3 runs, 10 loops each)
Vmapped MatrixJacobian product
5.77 ms Â± 215 Âµs per loop (mean Â± std. dev. of 3 runs, 10 loops each)
[22]:
def loop_jmp(f, x, M):
# jvp immediately returns the primal and tangent values as a tuple,
# so we'll compute and select the tangents in a list comprehension
return np.vstack([jvp(f, (W,), (mi,))[1] for mi in M])
def vmap_jmp(f, x, M):
_jvp = lambda s: jvp(f, (W,), (s,))[1]
return vmap(_jvp)(M)
num_vecs = 128
S = random.normal(key, (num_vecs,) + W.shape)
loop_vs = loop_jmp(f, W, M=S)
print('Nonvmapped JacobianMatrix product')
%timeit n10 r3 loop_jmp(f, W, M=S)
vmap_vs = vmap_jmp(f, W, M=S)
print('\nVmapped JacobianMatrix product')
%timeit n10 r3 vmap_jmp(f, W, M=S)
assert np.allclose(loop_vs, vmap_vs), 'Vmap and nonvmapped JacobianMatrix products should be identical'
Nonvmapped JacobianMatrix product
423 ms Â± 7.21 ms per loop (mean Â± std. dev. of 3 runs, 10 loops each)
Vmapped JacobianMatrix product
4.43 ms Â± 137 Âµs per loop (mean Â± std. dev. of 3 runs, 10 loops each)
The implementation of jacfwd
and jacrev
Â¶
Now that weâ€™ve seen fast Jacobianmatrix and matrixJacobian products, itâ€™s not hard to guess how to write jacfwd
and jacrev
. We just use the same technique to pushforward or pullback an entire standard basis (isomorphic to an identity matrix) at once.
[23]:
from jax import jacrev as builtin_jacrev
def our_jacrev(f):
def jacfun(x):
y, vjp_fun = vjp(f, x)
# Use vmap to do a matrixJacobian product.
# Here, the matrix is the Euclidean basis, so we get all
# entries in the Jacobian at once.
J, = vmap(vjp_fun, in_axes=0)(np.eye(len(y)))
return J
return jacfun
assert np.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reversemode Jacobian results!'
[24]:
from jax import jacfwd as builtin_jacfwd
def our_jacfwd(f):
def jacfun(x):
_jvp = lambda s: jvp(f, (x,), (s,))[1]
Jt =vmap(_jvp, in_axes=1)(np.eye(len(x)))
return np.transpose(Jt)
return jacfun
assert np.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forwardmode Jacobian results!'
Interestingly, Autograd couldnâ€™t do this. Our implementation of reversemode ``jacobian` in Autograd <https://github.com/HIPS/autograd/blob/96a03f44da43cd7044c61ac945c483955deba957/autograd/differential_operators.py#L60>`__ had to pull back one vector at a time with an outerloop map
. Pushing one vector at a time through the computation is much less efficient than batching it all together with vmap
.
Another thing that Autograd couldnâ€™t do is jit
. Interestingly, no matter how much Python dynamism you use in your function to be differentiated, we could always use jit
on the linear part of the computation. For example:
[25]:
def f(x):
try:
if x < 3:
return 2 * x ** 3
else:
raise ValueError
except ValueError:
return np.pi * x
y, f_vjp = vjp(f, 4.)
print(jit(f_vjp)(1.))
(DeviceArray(3.1415927, dtype=float32),)
Complex numbers and differentiationÂ¶
JAX is great at complex numbers and differentiation. To support both holomorphic and nonholomorphic differentiation, JAX follows Autogradâ€™s convention for encoding complex derivatives.
Consider a complextocomplex function \(f: \mathbb{C} \to \mathbb{C}\) that we break down into its component realtoreal functions:
[26]:
def f(z):
x, y = real(z), imag(z)
return u(x, y), v(x, y) * 1j
That is, weâ€™ve decomposed \(f(z) = u(x, y) + v(x, y) i\) where \(z = x + y i\). We define grad(f)
to correspond to
[27]:
def grad_f(z):
x, y = real(z), imag(z)
return grad(u, 0)(x, y)  grad(u, 1)(x, y) * 1j
In math symbols, that means we define \(\partial f(z) \triangleq \partial_0 u(x, y)  \partial_1 u(x, y) i\). So we throw out \(v\), ignoring the complex component function of \(f\) entirely!
This convention covers three important cases:
 If
f
evaluates a holomorphic function, then we get the usual complex derivative, since \(\partial_0 u = \partial_1 v\) and \(\partial_1 u =  \partial_0 v\).  If
f
is evaluates the realvalued loss function of a complex parameterx
, then we get a result that we can use in gradientbased optimization by taking steps in the direction of the conjugate ofgrad(f)(x)
.  If
f
evaluates a realtoreal function, but its implementation uses complex primitives internally (some of which must be nonholomorphic, e.g. FFTs used in convolutions) then we get the same result that an implementation that only used real primitives would have given.
By throwing away v
entirely, this convention does not handle the case where f
evaluates a nonholomorphic function and you want to evaluate all of \(\partial_0 u\), \(\partial_1 u\), \(\partial_0 v\), and \(\partial_1 v\) at once. But in that case the answer would have to contain four real values, and so thereâ€™s no way to express it as a single complex number.
You should expect complex numbers to work everywhere in JAX. Hereâ€™s differentiating through a Cholesky decomposition of a complex matrix:
[28]:
A = np.array([[5., 2.+3j, 5j],
[2.3j, 7., 1.+7j],
[5j, 1.7j, 12.]])
def f(X):
L = np.linalg.cholesky(X)
return np.sum((L  np.sin(L))**2)
grad(f, holomorphic=True)(A)
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/testdocs/lib/python3.7/sitepackages/jax/lax/lax.py:1976: ComplexWarning: Casting complex values to real discards the imaginary part
lambda t, new_dtype, old_dtype: [convert_element_type(t, old_dtype)])
[28]:
DeviceArray([[0.7534186 +0.j , 3.0509028 10.940544j ,
5.9896836 +3.5423026j],
[3.0509028 +10.940544j , 8.9044895 +0.j ,
5.1351523 6.559372j ],
[ 5.9896836 3.5423026j, 5.1351523 +6.559372j ,
0.01320427 +0.j ]], dtype=complex64)
For primitivesâ€™ JVP rules, writing the primals as \(z = a + bi\) and the tangents as \(t = c + di\), we define the Jacobianvector product \(t \mapsto \partial f(z) \cdot t\) as
\(t \mapsto \begin{matrix} \begin{bmatrix} 1 & 1 \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(a, b) & \partial_0 v(a, b) \\  \partial_1 u(a, b) i & \partial_1 v(a, b) i \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}\).
See Chapter 4 of Dougalâ€™s PhD thesis for more details.
More advanced autodiffÂ¶
In this notebook, we worked through some easy, and then progressively more complicated, applications of automatic differentiation in JAX. We hope you now feel that taking derivatives in JAX is easy and powerful.
Thereâ€™s a whole world of other autodiff tricks and functionality out there. Topics we didnâ€™t cover, but hope to in a â€śAdvanced Autodiff Cookbookâ€ť include:
 GaussNewton Vector Products, linearizing once
 Custom VJPs and JVPs
 Efficient derivatives at fixedpoints
 Estimating the trace of a Hessian using random Hessianvector products.
 Forwardmode autodiff using only reversemode autodiff.
 Taking derivatives with respect to custom data types.
 Checkpointing (binomial checkpointing for efficient reversemode, not model snapshotting).
 Optimizing VJPs with Jacobian preaccumulation.
Autobatching logdensities exampleÂ¶
This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.
Inspired by a notebook by @davmre.
[1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import itertools
import re
import sys
import time
from matplotlib.pyplot import *
import jax
from jax import lax
from jax import numpy as np
from jax import scipy
from jax import random
import numpy as onp
import scipy as oscipy
Generate a fake binary classification datasetÂ¶
[2]:
onp.random.seed(10009)
num_features = 10
num_points = 100
true_beta = onp.random.randn(num_features).astype(np.float32)
all_x = onp.random.randn(num_points, num_features).astype(np.float32)
y = (onp.random.rand(num_points) < oscipy.special.expit(all_x.dot(true_beta))).astype(np.int32)
[3]:
y
[3]:
array([0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0,
1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0,
1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1,
1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0], dtype=int32)
Write the logjoint function for the modelÂ¶
Weâ€™ll write a nonbatched version, a manually batched version, and an autobatched version.
NonbatchedÂ¶
[4]:
def log_joint(beta):
result = 0.
# Note that no `axis` parameter is provided to `np.sum`.
result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=1.))
result = result + np.sum(np.log(1 + np.exp((2*y1) * np.dot(all_x, beta))))
return result
[5]:
log_joint(onp.random.randn(num_features))
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/testdocs/lib/python3.7/sitepackages/jax/lib/xla_bridge.py:123: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
[5]:
DeviceArray(213.23558, dtype=float32)
[6]:
# This doesn't work, because we didn't write `log_prob()` to handle batching.
try:
batch_size = 10
batched_test_beta = onp.random.randn(batch_size, num_features)
log_joint(onp.random.randn(batch_size, num_features))
except ValueError as e:
print("Caught expected exception " + str(e))
Caught expected exception Incompatible shapes for broadcasting: ((100, 10), (1, 100))
Manually batchedÂ¶
[7]:
def batched_log_joint(beta):
result = 0.
# Here (and below) `sum` needs an `axis` parameter. At best, forgetting to set axis
# or setting it incorrectly yields an error; at worst, it silently changes the
# semantics of the model.
result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=1.),
axis=1)
# Note the multiple transposes. Getting this right is not rocket science,
# but it's also not totally mindless. (I didn't get it right on the first
# try.)
result = result + np.sum(np.log(1 + np.exp((2*y1) * np.dot(all_x, beta.T).T)),
axis=1)
return result
[8]:
batch_size = 10
batched_test_beta = onp.random.randn(batch_size, num_features)
batched_log_joint(batched_test_beta)
[8]:
DeviceArray([147.84032, 207.02205, 109.26076, 243.80833, 163.02908,
143.84848, 160.28772, 113.7717 , 126.60544, 190.81989], dtype=float32)
Autobatched with vmapÂ¶
It just works.
[9]:
vmap_batched_log_joint = jax.vmap(log_joint)
vmap_batched_log_joint(batched_test_beta)
[9]:
DeviceArray([147.84032, 207.02205, 109.26076, 243.80833, 163.02908,
143.84848, 160.28772, 113.7717 , 126.60544, 190.81989], dtype=float32)
Selfcontained variational inference exampleÂ¶
A little code is copied from above.
Set up the (batched) logjoint functionÂ¶
[10]:
@jax.jit
def log_joint(beta):
result = 0.
# Note that no `axis` parameter is provided to `np.sum`.
result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=10.))
result = result + np.sum(np.log(1 + np.exp((2*y1) * np.dot(all_x, beta))))
return result
batched_log_joint = jax.jit(jax.vmap(log_joint))
Define the ELBO and its gradientÂ¶
[11]:
def elbo(beta_loc, beta_log_scale, epsilon):
beta_sample = beta_loc + np.exp(beta_log_scale) * epsilon
return np.mean(batched_log_joint(beta_sample), 0) + np.sum(beta_log_scale  0.5 * onp.log(2*onp.pi))
elbo = jax.jit(elbo, static_argnums=(1, 2))
elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))
Optimize the ELBO using SGDÂ¶
[12]:
def normal_sample(key, shape):
"""Convenience function for quasistateful RNG."""
new_key, sub_key = random.split(key)
return new_key, random.normal(sub_key, shape)
normal_sample = jax.jit(normal_sample, static_argnums=(1,))
key = random.PRNGKey(10003)
beta_loc = np.zeros(num_features, np.float32)
beta_log_scale = np.zeros(num_features, np.float32)
step_size = 0.01
batch_size = 128
epsilon_shape = (batch_size, num_features)
for i in range(1000):
key, epsilon = normal_sample(key, epsilon_shape)
elbo_val, (beta_loc_grad, beta_log_scale_grad) = elbo_val_and_grad(
beta_loc, beta_log_scale, epsilon)
beta_loc += step_size * beta_loc_grad
beta_log_scale += step_size * beta_log_scale_grad
if i % 10 == 0:
print('{}\t{}'.format(i, elbo_val))
0 180.8538818359375
10 113.06046295166016
20 102.73725891113281
30 99.787353515625
40 98.90898895263672
50 98.29745483398438
60 98.18630981445312
70 97.5797348022461
80 97.28599548339844
90 97.469970703125
100 97.4771728515625
110 97.58067321777344
120 97.49435424804688
130 97.50271606445312
140 96.86395263671875
150 97.44197845458984
160 97.06939697265625
170 96.84028625488281
180 97.21336364746094
190 97.56502532958984
200 97.26398468017578
210 97.11979675292969
220 97.39593505859375
230 97.16830444335938
240 97.118408203125
250 97.24345397949219
260 97.2978744506836
270 96.69285583496094
280 96.96438598632812
290 97.30055236816406
300 96.63592529296875
310 97.03518676757812
320 97.52909851074219
330 97.28812408447266
340 97.07322692871094
350 97.15620422363281
360 97.25882720947266
370 97.19514465332031
380 97.13092803955078
390 97.11727905273438
400 96.93872833251953
410 97.26676940917969
420 97.35324096679688
430 97.21007537841797
440 97.28434753417969
450 97.16310119628906
460 97.2612533569336
470 97.21343994140625
480 97.23997497558594
490 97.14913940429688
500 97.23528289794922
510 96.9342041015625
520 97.21209716796875
530 96.82577514648438
540 97.01286315917969
550 96.94176483154297
560 97.16520690917969
570 97.29165649414062
580 97.42939758300781
590 97.24371337890625
600 97.15220642089844
610 97.49844360351562
620 96.99070739746094
630 96.88957977294922
640 96.89970397949219
650 97.137939453125
660 97.43707275390625
670 96.99235534667969
680 97.15623474121094
690 97.18690490722656
700 97.11161041259766
710 97.78105163574219
720 97.23226165771484
730 97.16206359863281
740 96.99581909179688
750 96.66722106933594
760 97.16796875
770 97.51435089111328
780 97.28901672363281
790 96.91226196289062
800 97.1709976196289
810 97.29047393798828
820 97.16242980957031
830 97.1910629272461
840 97.56382751464844
850 97.00193786621094
860 96.86555480957031
870 96.76336669921875
880 96.83660888671875
890 97.12179565429688
900 97.09554290771484
910 97.0682373046875
920 97.11947631835938
930 96.8792953491211
940 97.45625305175781
950 96.69280242919922
960 97.29376220703125
970 97.33531188964844
980 97.34962463378906
990 97.09674835205078
Display the resultsÂ¶
Coverage isnâ€™t quite as good as we might like, but itâ€™s not bad, and nobody said variational inference was exact.
[13]:
figure(figsize=(7, 7))
plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')
plot(true_beta, beta_loc + 2*np.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\sigma$ Error Bars')
plot(true_beta, beta_loc  2*np.exp(beta_log_scale), 'r.')
plot_scale = 3
plot([plot_scale, plot_scale], [plot_scale, plot_scale], 'k')
xlabel('True beta')
ylabel('Estimated beta')
legend(loc='best')
[13]:
<matplotlib.legend.Legend at 0x7fab540cb7f0>
[ ]:
đź”Ş JAX  The Sharp Bits đź”ŞÂ¶
levskaya@ mattjj@
When walking about the countryside of Italy, the people will not hesitate to tell you that JAX has â€śuna anima di pura programmazione funzionaleâ€ť.
JAX is a language for expressing and composing transformations of numerical programs. JAX is also able to compile numerical programs for CPU or accelerators (GPU/TPU). JAX works great for many numerical and scientific programs, but only if they are written with certain constraints that we describe below.
[1]:
import numpy as onp
from jax import grad, jit
from jax import lax
from jax import random
import jax
import jax.numpy as np
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib import rcParams
rcParams['image.interpolation'] = 'nearest'
rcParams['image.cmap'] = 'viridis'
rcParams['axes.grid'] = False
đź”Ş Pure functionsÂ¶
JAX transformation and compilation are designed to work only on Python functions that are functionally pure: all the input data is passed through the function parameters, all the results are output through the function results. A pure function will always return the same result if involved with the same inputs.
Here are some examples of functions that are not functially pure for which JAX behaves differently than the Python interpreter. Note that these behaviors are not guaranteed by the JAX system; the proper way to use JAX is to use it only on functionally pure Python functions.
[2]:
def impure_print_side_effect(x):
print("Executing function") # This is a sideeffect
return x
# The sideeffects appear during the first run
print ("First call: ", jit(impure_print_side_effect)(4.))
# Subsequent runs with parameters of same type and shape may not show the sideeffect
# This is because JAX now invokes a cached compilation of the function
print ("Second call: ", jit(impure_print_side_effect)(5.))
# JAX reruns the Python function when the type or shape of the argument changes
print ("Third call, different type: ", jit(impure_print_side_effect)(np.array([5.])))
Executing function
First call: 4.0
Second call: 5.0
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/testdocs/lib/python3.7/sitepackages/jax/lib/xla_bridge.py:123: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
Executing function
Third call, different type: [5.]
[3]:
g = 0.
def impure_uses_globals(x):
return x + g
# JAX captures the value of the global during the first run
print ("First call: ", jit(impure_uses_globals)(4.))
g = 10. # Update the global
# Subsequent runs may silently use the cached value of the globals
print ("Second call: ", jit(impure_uses_globals)(5.))
# JAX reruns the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
print ("Third call, different type: ", jit(impure_uses_globals)(np.array([4.])))
First call: 4.0
Second call: 5.0
Third call, different type: [14.]
[4]:
g = 0.
def impure_saves_global(x):
global g
g = x
return x
# JAX runs once the transformed function with special Traced values for arguments
print ("First call: ", jit(impure_saves_global)(4.))
print ("Saved global: ", g) # Saved global has an internal JAX value
First call: 4.0
Saved global: Traced<ShapedArray(float32[], weak_type=True):JaxprTrace(level=1/1)>
A Python function can be functionally pure even if it actually uses stateful objects internally, as long as it does not read or write external state:
[5]:
def pure_uses_internal_state(x):
state = dict(even=0, odd=0)
for i in range(10):
state['even' if i % 2 == 0 else 'odd'] += x
return state['even'] + state['odd']
print(jit(pure_uses_internal_state)(5.))
50.0
đź”Ş InPlace UpdatesÂ¶
In Numpy youâ€™re used to doing this:
[6]:
numpy_array = onp.zeros((3,3), dtype=np.float32)
print("original array:")
print(numpy_array)
# In place, mutating update
numpy_array[1, :] = 1.0
print("updated array:")
print(numpy_array)
original array:
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
updated array:
[[0. 0. 0.]
[1. 1. 1.]
[0. 0. 0.]]
If we try to update a JAX device array inplace, however, we get an error! (â‰_â‰)
[7]:
jax_array = np.zeros((3,3), dtype=np.float32)
# In place update of JAX's array will yield an error!
try:
jax_array[1, :] = 1.0
except Exception as e:
print("Exception {}".format(e))
Exception '<class 'jax.interpreters.xla.DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?
What gives?!
Allowing mutation of variables inplace makes program analysis and transformation very difficult. JAX requires a pure functional expression of a numerical program.
Instead, JAX offers the functional update functions: **index_update**, **index_add**, **index_min**, **index_max**, and the **index** helper.
ď¸Źâš ď¸Ź inside jit
â€™d code and lax.while_loop
or lax.fori_loop
the size of slices canâ€™t be functions of argument values but only functions of argument shapes â€“ the slice start indices have no such restriction. See the below Control Flow Section for more information on this limitation.
[8]:
from jax.ops import index, index_add, index_update
index_updateÂ¶
If the input values of index_update arenâ€™t reused, jitcompiled code will perform these operations inplace.
[9]:
jax_array = np.zeros((3, 3))
print("original array:")
print(jax_array)
new_jax_array = index_update(jax_array, index[1, :], 1.)
print("old array unchanged:")
print(jax_array)
print("new array:")
print(new_jax_array)
original array:
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
old array unchanged:
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
new array:
[[0. 0. 0.]
[1. 1. 1.]
[0. 0. 0.]]
index_addÂ¶
If the input values of index_update arenâ€™t reused, jitcompiled code will perform these operations inplace.
[10]:
print("original array:")
jax_array = np.ones((5, 6))
print(jax_array)
new_jax_array = index_add(jax_array, index[::2, 3:], 7.)
print("new array postaddition:")
print(new_jax_array)
original array:
[[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]]
new array postaddition:
[[1. 1. 1. 8. 8. 8.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 8. 8. 8.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 8. 8. 8.]]
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/testdocs/lib/python3.7/sitepackages/jax/lax/lax.py:4671: UserWarning: Explicitly requested dtype <class 'jax.numpy.lax_numpy.int64'> requested in arange is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#currentgotchas for more.
warnings.warn(msg.format(dtype, fun_name , truncated_dtype))
đź”Ş OutofBounds IndexingÂ¶
In Numpy, you are used to errors being thrown when you index an array outside of its bounds, like this:
[11]:
try:
onp.arange(10)[11]
except Exception as e:
print("Exception {}".format(e))
Exception index 11 is out of bounds for axis 0 with size 10
However, raising an error on other accelerators can be more difficult. Therefore, JAX does not raise an error and instead returns the last value in the array.
[12]:
np.arange(10)[11]
[12]:
DeviceArray(9, dtype=int32)
đź”Ş Random NumbersÂ¶
If all scientific papers whose results are in doubt because of bad ``rand()``s were to disappear from library shelves, there would be a gap on each shelf about as big as your fist.  Numerical Recipes
RNGs and StateÂ¶
Youâ€™re used to stateful pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness:
[13]:
print(onp.random.random())
print(onp.random.random())
print(onp.random.random())
0.013746606380725335
0.7133536372684778
0.9504874492057808
Underneath the hood, numpy uses the Mersenne Twister PRNG to power its pseudorandom functions. The PRNG has a period of \(2^{199371}\) and at any point can be described by 624 32bit unsigned ints and a position indicating how much of this â€śentropyâ€ť has been used up.
[14]:
onp.random.seed(0)
rng_state = onp.random.get_state()
#print(rng_state)
# > ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,
# 2481403966, 4042607538, 337614300, ... 614 more numbers...,
# 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)
This pseudorandom state vector is automagically updated behind the scenes every time a random number is needed, â€śconsumingâ€ť 2 of the uint32s in the Mersenne twister state vector:
[15]:
_ = onp.random.uniform()
rng_state = onp.random.get_state()
#print(rng_state)
# > ('MT19937', array([2443250962, 1093594115, 1878467924,
# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)
# Let's exhaust the entropy in this PRNG statevector
for i in range(311):
_ = onp.random.uniform()
rng_state = onp.random.get_state()
#print(rng_state)
# > ('MT19937', array([2443250962, 1093594115, 1878467924,
# ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)
# Next call iterates the RNG state for a new batch of fake "entropy".
_ = onp.random.uniform()
rng_state = onp.random.get_state()
# print(rng_state)
# > ('MT19937', array([1499117434, 2949980591, 2242547484,
# 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)
The problem with magic PRNG state is that itâ€™s hard to reason about how itâ€™s being used and updated across different threads, processes, and devices, and itâ€™s very easy to screw up when the details of entropy production and consumption are hidden from the end user.
The Mersenne Twister PRNG is also known to have a number of problems, it has a large 2.5Kb state size, which leads to problematic initialization issues. It fails modern BigCrush tests, and is generally slow.
JAX PRNGÂ¶
JAX instead implements an explicit PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern Threefry counterbased PRNG thatâ€™s splittable. That is, its design allows us to fork the PRNG state into new PRNGs for use with parallel stochastic generation.
The random state is described by two unsignedint32s that we call a key:
[16]:
from jax import random
key = random.PRNGKey(0)
key
[16]:
DeviceArray([0, 0], dtype=uint32)
JAXâ€™s random functions produce pseudorandom numbers from the PRNG state, but do not change the state!
Reusing the same state will cause sadness and monotony, depriving the enduser of lifegiving chaos:
[17]:
print(random.normal(key, shape=(1,)))
print(key)
# No no no!
print(random.normal(key, shape=(1,)))
print(key)
[0.20584235]
[0 0]
[0.20584235]
[0 0]
Instead, we split the PRNG to get usable subkeys every time we need a new pseudorandom number:
[18]:
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print(" \SPLIT > new key ", key)
print(" \> new subkey", subkey, "> normal", normal_pseudorandom)
old key [0 0]
\SPLIT > new key [4146024105 967050713]
\> new subkey [2718843009 1272950319] > normal [1.2515389]
We propagate the key and make new subkeys whenever we need a new random number:
[19]:
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print(" \SPLIT > new key ", key)
print(" \> new subkey", subkey, "> normal", normal_pseudorandom)
old key [4146024105 967050713]
\SPLIT > new key [2384771982 3928867769]
\> new subkey [1278412471 2182328957] > normal [0.58665067]
We can generate more than one subkey at a time:
[20]:
key, *subkeys = random.split(key, 4)
for subkey in subkeys:
print(random.normal(subkey, shape=(1,)))
[0.3753345]
[0.9864503]
[0.1455319]
đź”Ş Control FlowÂ¶
âś” python control_flow + autodiff âś”Â¶
If you just want to apply grad
to your python functions, you can use regular python controlflow constructs with no problems, as if you were using Autograd (or Pytorch or TF Eager).
[21]:
def f(x):
if x < 3:
return 3. * x ** 2
else:
return 4 * x
print(grad(f)(2.)) # ok!
print(grad(f)(4.)) # ok!
12.0
4.0
python control flow + JITÂ¶
Using control flow with jit
is more complicated, and by default it has more constraints.
This works:
[22]:
@jit
def f(x):
for i in range(3):
x = 2 * x
return x
print(f(3))
24
So does this:
[23]:
@jit
def g(x):
y = 0.
for i in range(x.shape[0]):
y = y + x[i]
return y
print(g(np.array([1., 2., 3.])))
6.0
But this doesnâ€™t, at least by default:
[24]:
@jit
def f(x):
if x < 3:
return 3. * x ** 2
else:
return 4 * x
# This will fail!
try:
f(2)
except Exception as e:
print("Exception {}".format(e))
Exception Abstract value passed to `bool`, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.
What gives!?
When we jit
compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we donâ€™t have to recompile on each function evaluation.
For example, if we evaluate an @jit
function on the array np.array([1., 2., 3.], np.float32)
, we might want to compile code that we can reuse to evaluate the function on np.array([4., 5., 6.], np.float32)
to save on compile time.
To get a view of your Python code that is valid for many different argument values, JAX traces it on abstract values that represent sets of possible inputs. There are multiple different levels of abstraction, and different transformations use different abstraction levels.
By default, jit
traces your code on the ShapedArray
abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value ShapedArray((3,), np.float32)
, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.
But thereâ€™s a tradeoff here: if we trace a Python function on a ShapedArray((), np.float32)
that isnâ€™t committed to a specific concrete value, when we hit a line like if x < 3
, the expression x < 3
evaluates to an abstract ShapedArray((), np.bool_)
that represents the set {True, False}
. When Python attempts to coerce that to a concrete True
or False
, we get an error: we donâ€™t know which branch to take, and canâ€™t continue tracing! The tradeoff is that with higher
levels of abstraction we gain a more general view of the Python code (and thus save on recompilations), but we require more constraints on the Python code to complete the trace.
The good news is that you can control this tradeoff yourself. By having jit
trace on more refined abstract values, you can relax the traceability constraints. For example, using the static_argnums
argument to jit
, we can specify to trace on concrete values of some arguments. Hereâ€™s that example function again:
[25]:
def f(x):
if x < 3:
return 3. * x ** 2
else:
return 4 * x
f = jit(f, static_argnums=(0,))
print(f(2.))
12.0
Hereâ€™s another example, this time involving a loop:
[26]:
def f(x, n):
y = 0.
for i in range(n):
y = y + x[i]
return y
f = jit(f, static_argnums=(1,))
f(np.array([2., 3., 4.]), 2)
[26]:
DeviceArray(5., dtype=float32)
In effect, the loop gets statically unrolled. JAX can also trace at higher levels of abstraction, like Unshaped
, but thatâ€™s not currently the default for any transformation
ď¸Źâš ď¸Ź functions with argument**value dependent shapes**
These controlflow issues also come up in a more subtle way: numerical functions we want to jit canâ€™t specialize the shapes of internal arrays on argument values (specializing on argument shapes is ok). As a trivial example, letâ€™s make a function whose output happens to depend on the input variable length
.
[27]:
def example_fun(length, val):
return np.ones((length,)) * val
# unjit'd works fine
print(example_fun(5, 4))
bad_example_jit = jit(example_fun)
# this will fail:
try:
print(bad_example_jit(10, 4))
except Exception as e:
print("Exception {}".format(e))
# static_argnums tells JAX to recompile on changes at these argument positions:
good_example_jit = jit(example_fun, static_argnums=(0,))
# first compile
print(good_example_jit(10, 4))
# recompiles
print(good_example_jit(5, 4))
[4. 4. 4. 4. 4.]
Exception Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True):JaxprTrace(level=1/1)>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]
[4. 4. 4. 4. 4.]
static_argnums
can be handy if length
in our example rarely changes, but it would be disastrous if it changed a lot!
Lastly, if your function has global sideeffects, JAXâ€™s tracer can cause weird things to happen. A common gotcha is trying to print arrays inside jitâ€™d functions:
[28]:
@jit
def f(x):
print(x)
y = 2 * x
print(y)
return y
f(2)
Traced<ShapedArray(int32[], weak_type=True):JaxprTrace(level=1/1)>
Traced<ShapedArray(int32[]):JaxprTrace(level=1/1)>
[28]:
DeviceArray(4, dtype=int32)
Structured control flow primitivesÂ¶
There are more options for control flow in JAX. Say you want to avoid recompilations but still want to use control flow thatâ€™s traceable, and that avoids unrolling large loops. Then you can use these 4 structured control flow primitives:
lax.cond
differentiablelax.while_loop
fwdmodedifferentiablelax.fori_loop
fwdmodedifferentiablelax.scan
differentiable
condÂ¶
python equivalent:
def cond(pred, true_operand, true_fun, false_operand, false_fun):
if pred:
return true_fun(true_operand)
else:
return false_fun(false_operand)
[29]:
from jax import lax
operand = np.array([0.])
lax.cond(True, operand, lambda x: x+1, operand, lambda x: x1)
# > array([1.], dtype=float32)
lax.cond(False, operand, lambda x: x+1, operand, lambda x: x1)
# > array([1.], dtype=float32)
[29]:
DeviceArray([1.], dtype=float32)
while_loopÂ¶
python equivalent:
def while_loop(cond_fun, body_fun, init_val):
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
[30]:
init_val = 0
cond_fun = lambda x: x<10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)
# > array(10, dtype=int32)
[30]:
DeviceArray(10, dtype=int32)
fori_loopÂ¶
python equivalent:
def fori_loop(start, stop, body_fun, init_val):
val = init_val
for i in range(start, stop):
val = body_fun(i, val)
return val
[31]:
init_val = 0
start = 0
stop = 10
body_fun = lambda i,x: x+i
lax.fori_loop(start, stop, body_fun, init_val)
# > array(45, dtype=int32)
[31]:
DeviceArray(45, dtype=int32)
SummaryÂ¶
\(\ast\) = argumentvalueindependent loop condition  unrolls the loop
đź”Ş ConvolutionsÂ¶
JAX and XLA offer the very general Ndimensional conv_general_dilated function, but itâ€™s not very obvious how to use it. Weâ€™ll give some examples of the common usecases. There are also the convenience functions lax.conv
and lax.conv_general_padding
for the most common kinds of convolutions.
A survey of the family of convolutional operators, a guide to convolutional arithmetic is highly recommended reading!
Letâ€™s define a simple diagonal edge kernel:
[32]:
# 2D kernel  HWIO layout
kernel = onp.zeros((3, 3, 3, 3), dtype=np.float32)
kernel += onp.array([[1, 1, 0],
[1, 0,1],
[0,1,1]])[:, :, onp.newaxis, onp.newaxis]
print("Edge Conv kernel:")
plt.imshow(kernel[:, :, 0, 0]);
Edge Conv kernel:
And weâ€™ll make a simple synthetic image:
[33]:
# NHWC layout
img = onp.zeros((1, 200, 198, 3), dtype=np.float32)
for k in range(3):
x = 30 + 60*k
y = 20 + 60*k
img[0, x:x+10, y:y+10, k] = 1.0
print("Original Image:")
plt.imshow(img[0]);
Original Image:
lax.conv and lax.conv_with_general_paddingÂ¶
These are the simple convenience functions for convolutions
ď¸Źâš ď¸Ź The convenience lax.conv
, lax.conv_with_general_padding
helper function assume NCHW images and OIHW kernels.
[34]:
out = lax.conv(np.transpose(img,[0,3,1,2]), # lhs = NCHW image tensor
np.transpose(kernel,[3,2,0,1]), # rhs = OIHW conv kernel tensor
(1, 1), # window strides
'SAME') # padding mode
print("out shape: ", out.shape)
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(onp.array(out)[0,0,:,:]);
out shape: (1, 3, 200, 198)
First output channel:
[35]:
out = lax.conv_with_general_padding(
np.transpose(img,[0,3,1,2]), # lhs = NCHW image tensor
np.transpose(kernel,[2,3,0,1]), # rhs = IOHW conv kernel tensor
(1, 1), # window strides
((2,2),(2,2)), # general padding 2x2
(1,1), # lhs/image dilation
(1,1)) # rhs/kernel dilation
print("out shape: ", out.shape)
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(onp.array(out)[0,0,:,:]);
out shape: (1, 3, 202, 200)
First output channel:
Dimension Numbers define dimensional layout for conv_general_dilatedÂ¶
The important argument is the 3tuple of axis layout arguments: (Input Layout, Kernel Layout, Output Layout)  N  batch dimension  H  spatial height  W  spatial height  C  channel dimension  I  kernel input channel dimension  O  kernel output channel dimension
âš ď¸Ź To demonstrate the flexibility of dimension numbers we choose a NHWC image and HWIO kernel convention for lax.conv_general_dilated
below.
[36]:
dn = lax.conv_dimension_numbers(img.shape, # only ndim matters, not shape
kernel.shape, # only ndim matters, not shape
('NHWC', 'HWIO', 'NHWC')) # the important bit
print(dn)
ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))
SAME padding, no stride, no dilationÂ¶
[37]:
out = lax.conv_general_dilated(img, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(1,1), # window strides
'SAME', # padding mode
(1,1), # lhs/image dilation
(1,1), # rhs/kernel dilation
dn) # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape)
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(onp.array(out)[0,:,:,0]);
out shape: (1, 200, 198, 3)
First output channel:
VALID padding, no stride, no dilationÂ¶
[38]:
out = lax.conv_general_dilated(img, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(1,1), # window strides
'VALID', # padding mode
(1,1), # lhs/image dilation
(1,1), # rhs/kernel dilation
dn) # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, "DIFFERENT from above!")
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(onp.array(out)[0,:,:,0]);
out shape: (1, 198, 196, 3) DIFFERENT from above!
First output channel:
SAME padding, 2,2 stride, no dilationÂ¶
[39]:
out = lax.conv_general_dilated(img, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(2,2), # window strides
'SAME', # padding mode
(1,1), # lhs/image dilation
(1,1), # rhs/kernel dilation
dn) # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, " < half the size of above")
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(onp.array(out)[0,:,:,0]);
out shape: (1, 100, 99, 3) < half the size of above
First output channel:
VALID padding, no stride, rhs kernel dilation ~ Atrous convolution (excessive to illustrate)Â¶
[40]:
out = lax.conv_general_dilated(img, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(1,1), # window strides
'VALID', # padding mode
(1,1), # lhs/image dilation
(12,12), # rhs/kernel dilation
dn) # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape)
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(onp.array(out)[0,:,:,0]);
out shape: (1, 176, 174, 3)
First output channel:
VALID padding, no stride, lhs=input dilation ~ Transposed ConvolutionÂ¶
[41]:
out = lax.conv_general_dilated(img, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(1,1), # window strides
((0, 0), (0, 0)), # padding mode
(2,2), # lhs/image dilation
(1,1), # rhs/kernel dilation
dn) # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, "< larger than original!")
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(onp.array(out)[0,:,:,0]);
out shape: (1, 397, 393, 3) < larger than original!
First output channel:
We can use the last to, for instance, implement transposed convolutions:
[42]:
# The following is equivalent to tensorflow:
# N,H,W,C = img.shape
# out = tf.nn.conv2d_transpose(img, kernel, (N,2*H,2*W,C), (1,2,2,1))
# transposed conv = 180deg kernel roation plus LHS dilation
# rotate kernel 180deg:
kernel_rot = np.rot90(np.rot90(kernel, axes=(0,1)), axes=(0,1))
# need a custom output padding:
padding = ((2, 1), (2, 1))
out = lax.conv_general_dilated(img, # lhs = image tensor
kernel_rot, # rhs = conv kernel tensor
(1,1), # window strides
padding, # padding mode
(2,2), # lhs/image dilation
(1,1), # rhs/kernel dilation
dn) # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, "< transposed_conv")
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(onp.array(out)[0,:,:,0]);
out shape: (1, 400, 396, 3) < transposed_conv
First output channel:
1D ConvolutionsÂ¶
You arenâ€™t limited to 2D convolutions, a simple 1D demo is below:
[43]:
# 1D kernel  WIO layout
kernel = onp.array([[[1, 0, 1], [1, 0, 1]],
[[1, 1, 1], [1, 1, 1]]],
dtype=np.float32).transpose([2,1,0])
# 1D data  NWC layout
data = onp.zeros((1, 200, 2), dtype=np.float32)
for i in range(2):
for k in range(2):
x = 35*i + 30 + 60*k
data[0, x:x+30, k] = 1.0
print("in shapes:", data.shape, kernel.shape)
plt.figure(figsize=(10,5))
plt.plot(data[0]);
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
('NWC', 'WIO', 'NWC'))
print(dn)
out = lax.conv_general_dilated(data, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(1,), # window strides
'SAME', # padding mode
(1,), # lhs/image dilation
(1,), # rhs/kernel dilation
dn) # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape)
plt.figure(figsize=(10,5))
plt.plot(out[0]);
in shapes: (1, 200, 2) (3, 2, 2)
ConvDimensionNumbers(lhs_spec=(0, 2, 1), rhs_spec=(2, 1, 0), out_spec=(0, 2, 1))
out shape: (1, 200, 2)
3D ConvolutionsÂ¶
[44]:
# Random 3D kernel  HWDIO layout
kernel = onp.array([
[[0, 0, 0], [0, 1, 0], [0, 0, 0]],
[[0, 1, 0], [1, 0, 1], [0, 1, 0]],
[[0, 0, 0], [0, 1, 0], [0, 0, 0]]],
dtype=np.float32)[:, :, :, onp.newaxis, onp.newaxis]
# 3D data  NHWDC layout
data = onp.zeros((1, 30, 30, 30, 1), dtype=np.float32)
x, y, z = onp.mgrid[0:1:30j, 0:1:30j, 0:1:30j]
data += (onp.sin(2*x*np.pi)*onp.cos(2*y*np.pi)*onp.cos(2*z*np.pi))[None,:,:,:,None]
print("in shapes:", data.shape, kernel.shape)
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
('NHWDC', 'HWDIO', 'NHWDC'))
print(dn)
out = lax.conv_general_dilated(data, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(1,1,1), # window strides
'SAME', # padding mode
(1,1,1), # lhs/image dilation
(1,1,1), # rhs/kernel dilation
dn) # dimension_numbers
print("out shape: ", out.shape)
# Make some simple 3d density plots:
from mpl_toolkits.mplot3d import Axes3D
def make_alpha(cmap):
my_cmap = cmap(np.arange(cmap.N))
my_cmap[:,1] = np.linspace(0, 1, cmap.N)**3
return mpl.colors.ListedColormap(my_cmap)
my_cmap = make_alpha(plt.cm.viridis)
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=data.ravel(), cmap=my_cmap)
ax.axis('off')
ax.set_title('input')
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=out.ravel(), cmap=my_cmap)
ax.axis('off')
ax.set_title('3D conv output');
in shapes: (1, 30, 30, 30, 1) (3, 3, 3, 1, 1)
ConvDimensionNumbers(lhs_spec=(0, 4, 1, 2, 3), rhs_spec=(4, 3, 0, 1, 2), out_spec=(0, 4, 1, 2, 3))
out shape: (1, 30, 30, 30, 1)
đź”Ş NaNsÂ¶
Debugging NaNsÂ¶
If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaNchecker by:  setting the JAX_DEBUG_NANS=True
environment variable.  adding from jax.config import config
and config.update("jax_debug_nans", True)
near the top of your main file  adding from jax.config import config
and config.parse_flags_with_absl()
to your main file, then set the option using a commandline flag like jax_debug_nans=True
.
This will cause computations to errorout immediately on production of a NaN.
âš ď¸Ź You shouldnâ€™t have the NaNchecker on if youâ€™re not debugging, as it can introduce lots of devicehost roundtrips and performance regressions!
Double (64bit) precisionÂ¶
At the moment, JAX by default enforces singleprecision numbers to mitigate the Numpy APIâ€™s tendency to aggressively promote operands to double
. This is the desired behavior for many machinelearning applications, but it may catch you by surprise!
[45]:
x = random.uniform(random.PRNGKey(0), (1000,), dtype=np.float64)
x.dtype
[45]:
dtype('float32')
To use doubleprecision numbers, you need to set the jax_enable_x64
configuration variable at startup.
There are a few ways to do this:
 You can enable 64bit mode by setting the environment variable
JAX_ENABLE_X64=True
.  You can manually set the
jax_enable_x64
configuration flag at startup:
# again, this only works on startup!
from jax.config import config
config.update("jax_enable_x64", True)
 You can parse commandline flags with
absl.app.run(main)
from jax.config import config
config.config_with_absl()
 If you want JAX to run absl parsing for you, i.e. you donâ€™t want to do
absl.app.run(main)
, you can instead use
from jax.config import config
if __name__ == '__main__':
# calls config.config_with_absl() *and* runs absl parsing
config.parse_flags_with_absl()
Note that #2#4 work for any of JAXâ€™s configuration options.
We can then confirm that x64
mode is enabled:
[46]:
from jax import numpy as np, random
x = random.uniform(random.PRNGKey(0), (1000,), dtype=np.float64)
x.dtype # > dtype('float64')
[46]:
dtype('float32')
CaveatsÂ¶
âš ď¸Ź XLA doesnâ€™t support 64bit convolutions on all backends!
Fin.Â¶
If somethingâ€™s not covered here that has caused you weeping and gnashing of teeth, please let us know and weâ€™ll extend these introductory advisos!
Custom derivative rules for JAXtransformable Python functionsÂ¶
mattjj@ Mar 19 2020, last updated Mar 30 2020
There are two ways to define differentiation rules in JAX:
 using
jax.custom_jvp
andjax.custom_vjp
to define custom differentiation rules for Python functions that are already JAXtransformable; and  defining new
core.Primitive
instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems.
This notebook is about #1. To read instead about #2, see the notebook on adding primitives.
For an introduction to JAXâ€™s automatic differentiation API, see The Autodiff Cookbook. This notebook assumes some familiarity with jax.jvp and jax.grad, and the mathematical meaning of JVPs and VJPs.
TL;DRÂ¶
Custom JVPs with jax.custom_jvp
Â¶
[1]:
import jax.numpy as np
from jax import custom_jvp
@custom_jvp
def f(x, y):
return np.sin(x) * y
@f.defjvp
def f_jvp(primals, tangents):
x, y = primals
x_dot, y_dot = tangents
primal_out = f(x, y)
tangent_out = np.cos(x) * x_dot * y  np.sin(x) * y_dot
return primal_out, tangent_out
[2]:
from jax import jvp, grad
print(f(2., 3.))
y, y_dot = jvp(f, (2., 3.), (1., 0.))
print(y)
print(y_dot)
print(grad(f)(2., 3.))
2.7278922
2.7278922
1.2484405
1.2484405
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/testdocs/lib/python3.7/sitepackages/jax/lib/xla_bridge.py:123: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
[3]:
# Equivalent alternative using the defjvps convenience wrapper
@custom_jvp
def f(x, y):
return np.sin(x) * y
f.defjvps(lambda x_dot, primal_out, x, y: np.cos(x) * x_dot * y,
lambda y_dot, primal_out, x, y: np.sin(x) * y_dot)
[4]:
print(f(2., 3.))
y, y_dot = jvp(f, (2., 3.), (1., 0.))
print(y)
print(y_dot)
print(grad(f)(2., 3.))
2.7278922
2.7278922
1.2484405
1.2484405
Custom VJPs with jax.custom_vjp
Â¶
[5]:
from jax import custom_vjp
@custom_vjp
def f(x, y):
return np.sin(x) * y
def f_fwd(x, y):
return f(x, y), (np.cos(x), np.sin(x), y)
def f_bwd(res, g):
cos_x, sin_x, y = res
return (cos_x * g * y, sin_x * g)
f.defvjp(f_fwd, f_bwd)
[6]:
print(grad(f)(2., 3.))
1.2484405
Example problemsÂ¶
To get an idea of what problems jax.custom_jvp
and jax.custom_vjp
are meant to solve, letâ€™s go over a few examples. A more thorough introduction to the jax.custom_jvp
and jax.custom_vjp
APIs is in the next section.
One application of jax.custom_jvp
is to improve the numerical stability of differentiation.
Say we want to write a function called log1pexp
, which computes \(x \mapsto \log ( 1 + e^x )\). We can write that using jax.numpy
:
[7]:
import jax.numpy as np
def log1pexp(x):
return np.log(1. + np.exp(x))
log1pexp(3.)
[7]:
DeviceArray(3.0485873, dtype=float32)
Since itâ€™s written in terms of jax.numpy
, itâ€™s JAXtransformable:
[8]:
from jax import jit, grad, vmap
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(np.arange(3.)))
3.0485873
0.95257413
[0.5 0.7310586 0.8807971]
But thereâ€™s a numerical stability problem lurking here:
[9]:
print(grad(log1pexp)(100.))
nan
That doesnâ€™t seem right! After all, the derivative of \(x \mapsto \log (1 + e^x)\) is \(x \mapsto \frac{e^x}{1 + e^x}\), and so for large values of \(x\) weâ€™d expect the value to be about 1.
We can get a bit more insight into whatâ€™s going on by looking at the jaxpr for the gradient computation:
[10]:
from jax import make_jaxpr
make_jaxpr(grad(log1pexp))(100.)
[10]:
{ lambda ; a.
let b = exp a
c = add b 1.0
d = div 1.0 c
e = mul d b
in (e,) }
Stepping through how the jaxpr would be evaluated, we can see that the last line would involve multiplying values that floating point math will round to 0 and \(\infty\), respectively, which is never a good idea. That is, weâ€™re effectively evaluating lambda x: (1 / (1 + np.exp(x))) * np.exp(x)
for large x
, which effectively turns into 0. * np.inf
.
Instead of generating such large and small values, hoping for a cancellation that floats canâ€™t always provide, weâ€™d rather just express the derivative function as a more numerically stable program. In particular, we can write a program that more closely evaluates the equal mathematical expression \(1  \frac{1}{1 + e^x}\), with no cancellation in sight.
This problem is interesting because even though our definition of log1pexp
could already be JAXdifferentiated (and transformed with jit
, vmap
, â€¦), weâ€™re not happy with the result of applying standard autodiff rules to the primitives comprising log1pexp
and composing the result. Instead, weâ€™d like to specify how the whole function log1pexp
should be differentiated, as a unit, and thus arrange those exponentials better.
This is one application of custom derivative rules for Python functions that are already JAX transformable: specifying how a composite function should be differentiated, while still using its original Python definition for other transformations (like jit
, vmap
, â€¦).
Hereâ€™s a solution using jax.custom_jvp
:
[11]:
from jax import custom_jvp
@custom_jvp
def log1pexp(x):
return np.log(1. + np.exp(x))
@log1pexp.defjvp
def log1pexp_jvp(primals, tangents):
x, = primals
x_dot, = tangents
ans = log1pexp(x)
ans_dot = (1  1/(1 + np.exp(x))) * x_dot
return ans, ans_dot
[12]:
print(grad(log1pexp)(100.))
1.0
[13]:
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(np.arange(3.)))
3.0485873
0.95257413
[0.5 0.7310586 0.8807971]
Hereâ€™s a defjvps
convenience wrapper to express the same thing:
[14]:
@custom_jvp
def log1pexp(x):
return np.log(1. + np.exp(x))
log1pexp.defjvps(lambda t, ans, x: (1  1/(1 + np.exp(x))) * t)
[15]:
print(grad(log1pexp)(100.))
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(np.arange(3.)))
1.0
3.0485873
0.95257413
[0.5 0.7310586 0.8807971]
A related application is to enforce a differentiation convention, perhaps at a boundary.
Consider the function \(f : \mathbb{R}_+ \mapsto \mathbb{R}_+\) with \(f(x) = \frac{x}{1 + \sqrt{x}}\), where we take \(\mathbb{R}_+ = [0, \infty)\). We might implement \(f\) as a program like this:
[16]:
def f(x):
return x / (1 + np.sqrt(x))
As a mathematical function on \(\mathbb{R}\) (the full real line), \(f\) is not differentiable at zero (because the limit defining the derivative doesnâ€™t exist from the left). Correspondingly, autodiff produces a nan
value:
[17]:
print(grad(f)(0.))
nan
But mathematically if we think of \(f\) as a function on \(\mathbb{R}_+\) then it is differentiable at 0 [Rudinâ€™s Principles of Mathematical Analysis Definition 5.1, or Taoâ€™s Analysis I 3rd ed. Definition 10.1.1 and Example 10.1.6]. Alternatively, we might say as a convention we want to consider the directional derivative from the right. So there is a sensible value for the Python function grad(f)
to return at 0.0
, namely 1.0
. By default, JAXâ€™s machinery for differentiation
assumes all functions are defined over \(\mathbb{R}\) and thus doesnâ€™t produce 1.0
here.
We can use a custom JVP rule! In particular, we can define the JVP rule in terms of the derivative function \(x \mapsto \frac{\sqrt{x} + 2}{2(\sqrt{x} + 1)^2}\) on \(\mathbb{R}_+\),
[18]:
@custom_jvp
def f(x):
return x / (1 + np.sqrt(x))
@f.defjvp
def f_jvp(primals, tangents):
x, = primals
x_dot, = tangents
ans = f(x)
ans_dot = ((np.sqrt(x) + 2) / (2 * (np.sqrt(x) + 1)**2)) * x_dot
return ans, ans_dot
[19]:
print(grad(f)(0.))
1.0
Hereâ€™s the convenience wrapper version:
[20]:
@custom_jvp
def f(x):
return x / (1 + np.sqrt(x))
f.defjvps(lambda t, ans, x: ((np.sqrt(x) + 2) / (2 * (np.sqrt(x) + 1)**2)) * t)
[21]:
print(grad(f)(0.))
1.0
While in some cases we want to express a mathematical differentiation computation, in other cases we may even want to take a step away from mathematics to adjust the computation autodiff performs. One canonical example is reversemode gradient clipping.
For gradient clipping, we can use np.clip
together with a jax.custom_vjp
reversemodeonly rule:
[22]:
from functools import partial
from jax import custom_vjp
@partial(custom_vjp, nondiff_argnums=(0, 1))
def clip_gradient(lo, hi, x):
return x # identity function
def clip_gradient_fwd(lo, hi, x):
return x, None # no residual values to save
def clip_gradient_bwd(lo, hi, _, g):
return (np.clip(g, lo, hi),)
clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
[23]:
import matplotlib.pyplot as plt
from jax import vmap
t = np.linspace(0, 10, 1000)
plt.plot(np.sin(t))
plt.plot(vmap(grad(np.sin))(t))
[23]:
[<matplotlib.lines.Line2D at 0x7f95881caeb8>]
[24]:
def clip_sin(x):
x = clip_gradient(0.75, 0.75, x)
return np.sin(x)
plt.plot(clip_sin(t))
plt.plot(vmap(grad(clip_sin))(t))
[24]:
[<matplotlib.lines.Line2D at 0x7f957cc4b710>]
Another application that is motivated by development workflow rather than numerics is to set a pdb
debugger trace in the backward pass of reversemode autodiff.
When trying to track down the source of a nan
runtime error, or just examine carefully the cotangent (gradient) values being propagated, it can be useful to insert a debugger at a point in the backward pass that corresponds to a specific point in the primal computation. You can do that with jax.custom_vjp
.
Weâ€™ll defer an example until the next section.
This example gets pretty deep in the mathematical weeds!
Another application for jax.custom_vjp
is reversemode differentiation of functions that are JAXtransformable (by jit
, vmap
, â€¦) but not efficiently JAXdifferentiable for some reason, perhaps because they involve lax.while_loop
. (Itâ€™s not possible to produce an XLA HLO program that efficiently computes the reversemode derivative of an XLA HLO While loop because that would require a program with unbounded memory use, which isnâ€™t possible to express in XLA HLO, at least without
sideeffecting interactions through infeed/outfeed.)
For example, consider this fixed_point
routine which computes a fixed point by iteratively applying a function in a while_loop
:
[25]:
from jax.lax import while_loop
def fixed_point(f, a, x_guess):
def cond_fun(carry):
x_prev, x = carry
return np.abs(x_prev  x) > 1e6
def body_fun(carry):
_, x = carry
return x, f(a, x)
_, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess)))
return x_star
This is an iterative procedure for numerically solving the equation \(x = f(a, x)\) for \(x\), by iterating \(x_{t+1} = f(a, x_t)\) until \(x_{t+1}\) is sufficiently close to \(x_t\). The result \(x^*\) depends on the parameters \(a\), and so we can think of there being a function \(a \mapsto x^*(a)\) that is implicity defined by equation \(x = f(a, x)\).
We can use fixed_point
to run iterative procedures to convergence, for example running Newtonâ€™s method to calculate square roots while only executing adds, multiplies, and divides:
[26]:
def newton_sqrt(a):
update = lambda a, x: 0.5 * (x + a / x)
return fixed_point(update, a, a)
[27]:
print(newton_sqrt(2.))
1.4142135
We can vmap
or jit
the function as well:
[28]:
print(jit(vmap(newton_sqrt))(np.array([1., 2., 3., 4.])))
[1. 1.4142135 1.7320509 2. ]
We canâ€™t apply reversemode automatic differentiation because of the while_loop
, but it turns out we wouldnâ€™t want to anyway: instead of differentiating through the implementation of fixed_point
and all its iterations, we can exploit the mathematical structure to do something that is much more memoryefficient (and FLOPefficient in this case, too!). We can instead use the implicit function theorem [Prop A.25 of Bertsekasâ€™s Nonlinear Programming, 2nd ed.], which guarantees (under some
conditions) the existence of the mathematical objects weâ€™re about to use. In essence, we linearize at the solution and solve those linear equations iteratively to compute the derivatives we want.
Consider again the equation \(x = f(a, x)\) and the function \(x^*\). We want to evaluate vectorJacobian products like \(v^\mathsf{T} \mapsto v^\mathsf{T} \partial x^*(a_0)\).
At least in an open neighborhood around the point \(a_0\) at which we want to differentiate, letâ€™s assume that the equation \(x^*(a) = f(a, x^*(a))\) holds for all \(a\). Since the two sides are equal as functions of \(a\), their derivatives must be equal as well, so letâ€™s differentiate both sides:
\(\qquad \partial x^*(a) = \partial_0 f(a, x^*(a)) + \partial_1 f(a, x^*(a)) \partial x^*(a)\).
Setting \(A = \partial_1 f(a_0, x^*(a_0))\) and \(B = \partial_0 f(a_0, x^*(a_0))\), we can write the quantity weâ€™re after more simply as
\(\qquad \partial x^*(a_0) = B + A \partial x^*(a_0)\),
or, by rearranging,
\(\qquad \partial x^*(a_0) = (I  A)^{1} B\).
That means we can evaluate vectorJacobian products like
\(\qquad v^\mathsf{T} \partial x^*(a_0) = v^\mathsf{T} (I  A)^{1} B = w^\mathsf{T} B\),
where \(w^\mathsf{T} = v^\mathsf{T} (I  A)^{1}\), or equivalently \(w^\mathsf{T} = v^\mathsf{T} + w^\mathsf{T} A\), or equivalently \(w^\mathsf{T}\) is the fixed point of the map \(u^\mathsf{T} \mapsto v^\mathsf{T} + u^\mathsf{T} A\). That last characterization gives us a way to write the VJP for fixed_point
in terms of a call to fixed_point
! Moreover, after expanding \(A\) and \(B\) back out, we can see we need only to evaluate VJPs of \(f\) at
\((a_0, x^*(a_0))\).
Hereâ€™s the upshot:
[29]:
from jax import vjp
@partial(custom_vjp, nondiff_argnums=(0,))
def fixed_point(f, a, x_guess):
def cond_fun(carry):
x_prev, x = carry
return np.abs(x_prev  x) > 1e6
def body_fun(carry):
_, x = carry
return x, f(a, x)
_, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess)))
return x_star
def fixed_point_fwd(f, a, x_init):
x_star = fixed_point(f, a, x_init)
return x_star, (a, x_star)
def fixed_point_rev(f, res, x_star_bar):
a, x_star = res
_, vjp_a = vjp(lambda a: f(a, x_star), a)
a_bar, = vjp_a(fixed_point(partial(rev_iter, f),
(a, x_star, x_star_bar),
x_star_bar))
return a_bar, np.zeros_like(x_star)
def rev_iter(f, packed, u):
a, x_star, x_star_bar = packed
_, vjp_x = vjp(lambda x: f(a, x), x_star)
return x_star_bar + vjp_x(u)[0]
fixed_point.defvjp(fixed_point_fwd, fixed_point_rev)
[30]:
print(newton_sqrt(2.))
1.4142135
[31]:
print(grad(newton_sqrt)(2.))
print(grad(grad(newton_sqrt))(2.))
0.35355338
0.088388346
We can check our answers by differentiating np.sqrt
, which uses a totally different implementation:
[32]:
print(grad(np.sqrt)(2.))
print(grad(grad(np.sqrt))(2.))
0.35355338
0.08838835
A limitation to this approach is that the argument f
canâ€™t close over any values involved in differentiation. That is, you might notice that we kept the parameter a
explicit in the argument list of fixed_point
. While other JAX mechanisms can handle closedover transformationtraced values in the arguments to higherorder functions (as is done for the control flow primitives like lax.cond
, lax.scan
, and lax.while_loop
itself), jax.custom_vjp
used as above cannot. A
fixed_point
routine that used a bit more of JAXâ€™s internals could have a more convenient and robust API.
Basic usage of jax.custom_jvp
and jax.custom_vjp
APIsÂ¶
Hereâ€™s a canonical basic example of using jax.custom_jvp
:
[33]:
from jax import custom_jvp
import jax.numpy as np
# f :: a > b
@custom_jvp
def f(x):
return np.sin(x)
# f_jvp :: (a, T a) > (b, T b)
def f_jvp(primals, tangents):
x, = primals
t, = tangents
return f(x), np.cos(x) * t
f.defjvp(f_jvp)
[34]:
from jax import jvp
print(f(3.))
y, y_dot = jvp(f, (3.,), (1.,))
print(y)
print(y_dot)
0.14112
0.14112
0.9899925
In words, we start with a a primal function f
that takes inputs of type a
and produces outputs of type b
. We associate with it a JVP rule function f_jvp
that takes a pair of inputs representing the primal inputs of type a
and the corresponding tangent inputs of type T a
, and produces a pair of outputs representing the primal outputs of type b
and tangent outputs of type T b
. The tangent outputs should be a linear function of the tangent inputs.
You can also use f.defjvp
as a decorator, as in
@custom_jvp
def f(x):
...
@f.defjvp
def f_jvp(primals, tangents):
...
Even though we defined only a JVP rule and no VJP rule, we can use both forward and reversemode differentiation on f
. JAX will automatically transpose the linear computation on tangent values from our custom JVP rule, computing the VJP as efficiently as if we had written the rule by hand:
[35]:
from jax import grad
print(grad(f)(3.))
print(grad(grad(f))(3.))
0.9899925
0.14112
For automatic transposition to work, the JVP ruleâ€™s output tangents must be linear as a function of the input tangents. Otherwise a transposition error is raised.
Multiple arguments work like this:
[36]:
@custom_jvp
def f(x, y):
return x ** 2 * y
@f.defjvp
def f_jvp(primals, tangents):
x, y = primals
x_dot, y_dot = tangents
primal_out = f(x, y)
tangent_out = 2 * x * y * x_dot + x ** 2 * y_dot
return primal_out, tangent_out
[37]:
print(grad(f)(2., 3.))
12.0
The defjvps
convenience wrapper lets us define a JVP for each argument separately, and the results are computed separately then summed:
[38]:
@custom_jvp
def f(x):
return np.sin(x)
f.defjvps(lambda t, ans, x: np.cos(x) * t)
[39]:
print(grad(f)(3.))
0.9899925
Hereâ€™s the defjvps
convenience wrapper version:
[40]:
@custom_jvp
def f(x, y):
return x ** 2 * y
f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot,
lambda y_dot, primal_out, x, y: x ** 2 * y_dot)
[41]:
print(grad(f)(2., 3.))
print(grad(f, 0)(2., 3.)) # same as above
print(grad(f, 1)(2., 3.))
12.0
12.0
4.0
As a shorthand, with defjvps
you can pass a None
value to indicate that the JVP for a particular argument is zero:
[42]:
@custom_jvp
def f(x, y):
return x ** 2 * y
f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot,
None)
[43]:
print(grad(f)(2., 3.))
print(grad(f, 0)(2., 3.)) # same as above
print(grad(f, 1)(2., 3.))
12.0
12.0
0.0
Calling a jax.custom_jvp
function with keyword arguments, or writing a jax.custom_jvp
function definition with default arguments, are both allowed so long as they can be unambiguosly mapped to positional arguments based on the function signature retrieved by the standard library inspect.signature
mechanism.
When youâ€™re not performing differentiation, the function f
is called just as if it werenâ€™t decorated by jax.custom_jvp
:
[44]:
@custom_jvp
def f(x):
print('called f!') # a harmless sideeffect
return np.sin(x)
@f.defjvp
def f_jvp(primals, tangents):
print('called f_jvp!') # a harmless sideeffect
x, = primals
t, = tangents
return f(x), np.cos(x) * t
[45]:
from jax import vmap, jit
print(f(3.))
called f!
0.14112
[46]:
print(vmap(f)(np.arange(3.)))
print(jit(f)(3.))
called f!
[0. 0.84147096 0.9092974 ]
called f!
0.14112
The custom JVP rule is invoked during differentiation, whether forward or reverse:
[47]:
y, y_dot = jvp(f, (3.,), (1.,))
print(y_dot)
called f_jvp!
called f!
0.9899925
[48]:
print(grad(f)(3.))
called f_jvp!
called f!
0.9899925
Notice that f_jvp
calls f
to compute the primal outputs. In the context of higherorder differentiation, each application of a differentiation transform will use the custom JVP rule if and only if the rule calls the original f
to compute the primal outputs. (This represents a kind of fundamental tradeoff, where we canâ€™t make use of intermediate values from the evaluation of f
in our rule and also have the rule apply in all orders of higherorder differentiation.)
[49]:
grad(grad(f))(3.)
called f_jvp!
called f_jvp!
called f!
[49]:
DeviceArray(0.14112, dtype=float32)
You can use Python control flow with jax.custom_jvp
:
[50]:
@custom_jvp
def f(x):
if x > 0:
return np.sin(x)
else:
return np.cos(x)
@f.defjvp
def f_jvp(primals, tangents):
x, = primals
x_dot, = tangents
ans = f(x)
if x > 0:
return ans, 2 * x_dot
else:
return ans, 3 * x_dot
[51]:
print(grad(f)(1.))
print(grad(f)(1.))
2.0
3.0
While jax.custom_jvp
suffices for controlling both forward and, via JAXâ€™s automatic transposition, reversemode differentiation behavior, in some cases we may want to directly control a VJP rule, for example in the latter two example problems presented above. We can do that with jax.custom_vjp
:
[52]:
from jax import custom_vjp
import jax.numpy as np
# f :: a > b
@custom_vjp
def f(x):
return np.sin(x)
# f_fwd :: a > (b, c)
def f_fwd(x):
return f(x), np.cos(x)
# f_bwd :: (c, CT b) > CT a
def f_bwd(cos_x, y_bar):
return (cos_x * y_bar,)
f.defvjp(f_fwd, f_bwd)
[53]:
from jax import grad
print(f(3.))
print(grad(f)(3.))
0.14112
0.9899925
In words, we again start with a a primal function f
that takes inputs of type a
and produces outputs of type b
. We associate with it two functions, f_fwd
and f_bwd
, which describe how to perform the forward and backwardpasses of reversemode autodiff, respectively.
The function f_fwd
describes the forward pass, not only the primal computation but also what values to save for use on the backward pass. Its input signature is just like that of the primal function f
, in that it takes a primal input of type a
. But as output it produces a pair, where the first element is the primal output b
and the second element is any â€śresidualâ€ť data of type c
to be stored for use by the backward pass. (This second output is analogous to PyTorchâ€™s
``save_for_backward` mechanism <https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html>`__.)
The function f_bwd
describes the backward pass. It takes two inputs, where the first is the residual data of type c
produced by f_fwd
and the second is the output cotangents of type CT b
corresponding to the output of the primal function. It produces an output of type CT a
representing the cotangents corresponding to the input of the primal function. In particular, the output of f_bwd
must be a sequence (e.g. a tuple) of length equal to the number of arguments to the
primal function.
So multiple arguments work like this:
[54]:
from jax import custom_vjp
@custom_vjp
def f(x, y):
return np.sin(x) * y
def f_fwd(x, y):
return f(x, y), (np.cos(x), np.sin(x), y)
def f_bwd(res, g):
cos_x, sin_x, y = res
return (cos_x * g * y, sin_x * g)
f.defvjp(f_fwd, f_bwd)
[55]:
print(grad(f)(2., 3.))
1.2484405
Calling a jax.custom_vjp
function with keyword arguments, or writing a jax.custom_vjp
function definition with default arguments, are both allowed so long as they can be unambiguosly mapped to positional arguments based on the function signature retrieved by the standard library inspect.signature
mechanism.
As with jax.custom_jvp
, the custom VJP rule comprised by f_fwd
and f_bwd
is not invoked if differentiation is not applied. If function is evaluated, or transformed with jit
, vmap
, or other nondifferentiation transformations, then only f
is called.
[56]:
@custom_vjp
def f(x):
print("called f!")
return np.sin(x)
def f_fwd(x):
print("called f_fwd!")
return f(x), np.cos(x)
def f_bwd(cos_x, y_bar):
print("called f_bwd!")
return (cos_x * y_bar,)
f.defvjp(f_fwd, f_bwd)
[57]:
print(f(3.))
called f!
0.14112
[58]:
print(grad(f)(3.))
called f_fwd!
called f!
called f_bwd!
0.9899925
[59]:
from jax import vjp
y, f_vjp = vjp(f, 3.)
print(y)
called f_fwd!
called f!
0.14112
[60]:
print(f_vjp(1.))
called f_bwd!
(DeviceArray(0.9899925, dtype=float32),)
Forwardmode autodiff cannot be used on the ``jax.custom_vjp`` function and will raise an error:
[61]:
from jax import jvp
try:
jvp(f, (3.,), (1.,))
except TypeError as e:
print('ERROR! {}'.format(e))
called f_fwd!
called f!
ERROR! can't apply forwardmode autodiff (jvp) to a custom_vjp function.
If you want to use both forward and reversemode, use jax.custom_jvp
instead.
We can use jax.custom_vjp
together with pdb
to insert a debugger trace in the backward pass:
[62]:
import pdb
@custom_vjp
def debug(x):
return x # acts like identity
def debug_fwd(x):
return x, x
def debug_bwd(x, g):
import pdb; pdb.set_trace()
return g
debug.defvjp(debug_fwd, debug_bwd)
[63]:
def foo(x):
y = x ** 2
y = debug(y) # insert pdb in corresponding backward pass step
return np.sin(y)
jax.grad(foo)(3.)
> <ipythoninput113b19a2dc1abf7>(12)debug_bwd()
> return g
(Pdb) p x
DeviceArray(9., dtype=float32)
(Pdb) p g
DeviceArray(0.91113025, dtype=float32)
(Pdb) q
More features and detailsÂ¶
You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any pytrees are permissible, so long as their structures are consistent according to the type constraints.
Hereâ€™s a contrived example with jax.custom_jvp
:
[64]:
from collections import namedtuple
Point = namedtuple("Point", ["x", "y"])
@custom_jvp
def f(pt):
x, y = pt.x, pt.y
return {'a': x ** 2,
'b': (np.sin(x), np.cos(y))}
@f.defjvp
def f_jvp(primals, tangents):
pt, = primals
pt_dot, = tangents
ans = f(pt)
ans_dot = {'a': 2 * pt.x * pt_dot.x,
'b': (np.cos(pt.x) * pt_dot.x, np.sin(pt.y) * pt_dot.y)}
return ans, ans_dot
def fun(pt):
dct = f(pt)
return dct['a'] + dct['b'][0]
[65]:
pt = Point(1., 2.)
print(f(pt))
{'a': 1.0, 'b': (DeviceArray(0.84147096, dtype=float32), DeviceArray(0.41614684, dtype=float32))}
[66]:
print(grad(fun)(pt))
Point(x=DeviceArray(2.5403023, dtype=float32), y=array(0.))
And an analogous contrived example with jax.custom_vjp
:
[67]:
@custom_vjp
def f(pt):
x, y = pt.x, pt.y
return {'a': x ** 2,
'b': (np.sin(x), np.cos(y))}
def f_fwd(pt):
return f(pt), pt
def f_bwd(pt, g):
a_bar, (b0_bar, b1_bar) = g['a'], g['b']
x_bar = 2 * pt.x * a_bar + np.cos(pt.x) * b0_bar
y_bar = np.sin(pt.y) * b1_bar
return (Point(x_bar, y_bar),)
f.defvjp(f_fwd, f_bwd)
def fun(pt):
dct = f(pt)
return dct['a'] + dct['b'][0]
[68]:
pt = Point(1., 2.)
print(f(pt))
{'a': 1.0, 'b': (DeviceArray(0.84147096, dtype=float32), DeviceArray(0.41614684, dtype=float32))}
[69]:
print(grad(fun)(pt))
Point(x=DeviceArray(2.5403023, dtype=float32), y=DeviceArray(0., dtype=float32))
Some use cases, like the final example problem, call for nondifferentiable arguments to be passed to functions with custom differentiation rules, and for those arguments to also be passed to the rules themselves. In the case of fixed_point
, the function argument f
was such a nondifferentiable argument. A similar situation arises with jax.experimental.odeint
.
jax.custom_jvp
with nondiff_argnums
Â¶
Use the optional nondiff_argnums
parameter to jax.custom_jvp
to indicate arguments like these. Hereâ€™s an example with jax.custom_jvp
:
[70]:
from functools import partial
@partial(custom_jvp, nondiff_argnums=(0,))
def app(f, x):
return f(x)
@app.defjvp
def app_jvp(f, primals, tangents):
x, = primals
x_dot, = tangents
return f(x), 2. * x_dot
[71]:
print(app(lambda x: x ** 3, 3.))
27.0
[72]:
print(grad(app, 1)(lambda x: x ** 3, 3.))
2.0
Notice the gotcha here: no matter where in the argument list these parameters appear, theyâ€™re placed at the start of the signature of the corresponding JVP rule. Hereâ€™s another example:
[73]:
@partial(custom_jvp, nondiff_argnums=(0, 2))
def app2(f, x, g):
return f(g((x)))
@app2.defjvp
def app2_jvp(f, g, primals, tangents):
x, = primals
x_dot, = tangents
return f(g(x)), 3. * x_dot
[74]:
print(app2(lambda x: x ** 3, 3., lambda y: 5 * y))
3375.0
[75]:
print(grad(app2, 1)(lambda x: x ** 3, 3., lambda y: 5 * y))
3.0
jax.custom_vjp
with nondiff_argnums
Â¶
A similar option exists for jax.custom_vjp
, and similarly the convention is that the nondifferentiable arguments are passed as the first arguments to the rules, no matter where they appear in the original functionâ€™s signature. Hereâ€™s an example:
[76]:
@partial(custom_vjp, nondiff_argnums=(0,))
def app(f, x):
return f(x)
def app_fwd(f, x):
return f(x), x
def app_bwd(f, x, g):
return (5 * g,)
app.defvjp(app_fwd, app_bwd)
[77]:
print(app(lambda x: x ** 2, 4.))
16.0
[78]:
print(grad(app, 1)(lambda x: x ** 2, 4.))
5.0
See clip_gradient
and fixed_point
above for other usage examples.
JAX pytreesÂ¶
Date: October 2019
This is primarily JAX internal documentation, endusers are not supposed to need to understand this to use JAX, except when registering new userdefined container types with JAX. Some of these details may change.
Python has a lot of container data types (list, tuple, dict, namedtuple, etc.), and users sometimes define their own. To keep the JAX internals simpler while supporting lots of container types, we canonicalize nested containers into flat lists of numeric or array types at the api.py
boundary (and also in control flow primitives). That way grad
, jit
, vmap
etc., can handle user functions that accept and return these containers, while all the other parts of the system can operate on
functions that only take (multiple) array arguments and always return a flat list of arrays.
We refer to a recursive structured value whose leaves are basic types as a pytree
. When JAX flattens a pytree it will produce a list of leaves and a treedef
object that encodes the structure of the original value. The treedef
can then be used to construct a matching structured value after transforming the leaves. Pytrees are treelike, rather than DAGlike or graphlike, in that we handle them assuming referential transparency and that they canâ€™t contain reference cycles.
Here is a simple example:
[1]:
from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node
from jax import numpy as np
# The structured value to be transformed
value_structured = [1., (2., 3.)]
# The leaves in value_flat correspond to the `*` markers in value_tree
value_flat, value_tree = tree_flatten(value_structured)
print("value_flat={}\nvalue_tree={}".format(value_flat, value_tree))
# Transform the flt value list using an elementwise numeric transformer
transformed_flat = list(map(lambda v: v * 2., value_flat))
print("transformed_flat={}".format(transformed_flat))
# Reconstruct the structured output, using the original
transformed_structured = tree_unflatten(value_tree, transformed_flat)
print("transformed_structured={}".format(transformed_structured))
value_flat=[1.0, 2.0, 3.0]
value_tree=PyTreeDef(list, [*,PyTreeDef(tuple, [*,*])])
transformed_flat=[2.0, 4.0, 6.0]
transformed_structured=[2.0, (4.0, 6.0)]
Pytrees containers can be lists, tuples, dicts, namedtuple, None, OrderedDict. Other types of values, including numeric and ndarray values, are treated as leaves:
[2]:
from collections import namedtuple
Point = namedtuple('Point', ['x', 'y'])
example_containers = [
(1., [2., 3.]),
(1., {'b': 2., 'a': 3.}),
1.,
None,
np.zeros(2),
Point(1., 2.)
]
def show_example(structured):
flat, tree = tree_flatten(structured)
unflattened = tree_unflatten(tree, flat)
print("structured={}\n flat={}\n tree={}\n unflattened={}".format(
structured, flat, tree, unflattened))
for structured in example_containers:
show_example(structured)
structured=(1.0, [2.0, 3.0])
flat=[1.0, 2.0, 3.0]
tree=PyTreeDef(tuple, [*,PyTreeDef(list, [*,*])])
unflattened=(1.0, [2.0, 3.0])
structured=(1.0, {'b': 2.0, 'a': 3.0})
flat=[1.0, 3.0, 2.0]
tree=PyTreeDef(tuple, [*,PyTreeDef(dict[['a', 'b']], [*,*])])
unflattened=(1.0, {'a': 3.0, 'b': 2.0})
structured=1.0
flat=[1.0]
tree=*
unflattened=1.0
structured=None
flat=[]
tree=PyTreeDef(None, [])
unflattened=None
structured=[0. 0.]
flat=[DeviceArray([0., 0.], dtype=float32)]
tree=*
unflattened=[0. 0.]
structured=Point(x=1.0, y=2.0)
flat=[1.0, 2.0]
tree=PyTreeDef(namedtuple[<class '__main__.Point'>], [*,*])
unflattened=Point(x=1.0, y=2.0)
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/testdocs/lib/python3.7/sitepackages/jax/lib/xla_bridge.py:123: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
Pytrees are extensibleÂ¶
By default, any part of a structured value that is not recognized as an internal pytree node is treated as a leaf (and such containers could not be passed to JAXtraceable functions):
[3]:
class Special(object):
def __init__(self, x, y):
self.x = x
self.y = y
def __repr__(self):
return "Special(x={}, y={})".format(self.x, self.y)
show_example(Special(1., 2.))
structured=Special(x=1.0, y=2.0)
flat=[Special(x=1.0, y=2.0)]
tree=*
unflattened=Special(x=1.0, y=2.0)
The set of Python types that are considered internal pytree nodes is extensible, through a global registry of types. Values of registered types are traversed recursively:
[4]:
class RegisteredSpecial(Special):
def __repr__(self):
return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)
def special_flatten(v):
"""Specifies a flattening recipe.
Params:
v: the value of registered type to flatten.
Returns:
a pair of an iterable with the children to be flattened recursively,
and some opaque auxiliary data to pass back to the unflattening recipe.
The auxiliary data is stored in the treedef for use during unflattening.
The auxiliary data could be used, e.g., for dictionary keys.
"""
children = (v.x, v.y)
aux_data = None
return (children, aux_data)
def special_unflatten(aux_data, children):
"""Specifies an unflattening recipe.
Params:
aux_data: the opaque data that was specified during flattening of the
current treedef.
children: the unflattened children
Returns:
a reconstructed object of the registered type, using the specified
children and auxiliary data.
"""
return RegisteredSpecial(*children)
# Global registration
register_pytree_node(
RegisteredSpecial,
special_flatten, # tell JAX what are the children nodes
special_unflatten # tell JAX how to pack back into a RegisteredSpecial
)
show_example(RegisteredSpecial(1., 2.))
structured=RegisteredSpecial(x=1.0, y=2.0)
flat=[1.0, 2.0]
tree=PyTreeDef(<class '__main__.RegisteredSpecial'>[None], [*,*])
unflattened=RegisteredSpecial(x=1.0, y=2.0)
JAX needs sometimes to compare treedef for equality. Therefore care must be taken to ensure that the auxiliary data specified in the flattening recipe supports a meaningful equality comparison.
The whole set of functions for operating on pytrees are in the tree_util module.
How JAX primitives workÂ¶
necula@google.com, October 2019.
JAX implements certain transformations of Python functions, e.g., jit
, grad
, vmap
, or pmap
. The Python functions to be transformed must be JAXtraceable, which means that as the Python function executes the only operations it applies to the data are either inspections of data attributes such as shape or type, or special operations called JAX primitives. In particular, a JAXtraceable function is sometimes invoked by JAX with abstract arguments. An example of a JAX abstract value
is ShapedArray(float32[2,2])
, which captures the type and the shape of values, but not the concrete data values. JAX primitives know how to operate on both concrete data values and on the JAX abstract values.
The JAXtransformed functions must themselves be JAXtraceable functions, to ensure that these transformations can be composed, e.g., jit(jacfwd(grad(f)))
.
There are predefined JAX primitives corresponding to most XLA operations, e.g., add, matmul, sin, cos, indexing. JAX comes with an implementation of numpy functions in terms of JAX primitives, which means that Python programs using JAXâ€™s implementation of numpy are JAXtraceable and therefore transformable. Other libraries can be made JAXtraceable by implementing them in terms of JAX primitives.
The set of JAX primitives is extensible. Instead of reimplementing a function in terms of predefined JAX primitives, one can define a new primitive that encapsulates the behavior of the function.
The goal of this document is to explain the interface that a JAX primitive must support in order to allow JAX to perform all its transformations.
Consider that we want to add to JAX support for a multiplyadd function with three arguments, defined mathematically as â€śmultiply_add(x, y, z) = x * y + zâ€ť. This function operates on 3 identicallyshaped tensors of floating point values and performs the opertions pointwise.
Using existing primitivesÂ¶
The easiest way to define new functions is to write them in terms of JAX primitives, or in terms of other functions that are themselves written using JAX primitives, e.g., those defined in the jax.lax
module:
[1]:
from jax import lax
from jax import api
def multiply_add_lax(x, y, z):
"""Implementation of multiplyadd using the jax.lax primitives."""
return lax.add(lax.mul(x, y), z)
def square_add_lax(a, b):
"""A squareadd function using the newly defined multiplyadd."""
return multiply_add_lax(a, a, b)
print("square_add_lax = ", square_add_lax(2., 10.))
# Differentiate w.r.t. the first argument
print("grad(square_add_lax) = ", api.grad(square_add_lax, argnums=0)(2.0, 10.))
square_add_lax = 14.0
grad(square_add_lax) = 4.0
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/testdocs/lib/python3.7/sitepackages/jax/lib/xla_bridge.py:123: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
In order to understand how JAX is internally using the primitives, we add some helpers for tracing function calls.
[2]:
#@title Helper functions (execute this cell)
import functools
import traceback
_indentation = 0
def _trace(msg=None):
"""Print a message at current indentation."""
if msg is not None:
print(" " * _indentation + msg)
def _trace_indent(msg=None):
"""Print a message and then indent the rest."""
global _indentation
_trace(msg)
_indentation = 1 + _indentation
def _trace_unindent(msg=None):
"""Unindent then print a message."""
global _indentation
_indentation = _indentation  1
_trace(msg)
def trace(name):
"""A decorator for functions to trace arguments and results."""
def trace_func(func): # pylint: disable=missingdocstring
def pp(v):
"""Print certain values more succinctly"""
vtype = str(type(v))
if "jax.lib.xla_bridge._JaxComputationBuilder" in vtype:
return "<JaxComputationBuilder>"
elif "jaxlib.xla_extension.XlaOp" in vtype:
return "<XlaOp at 0x{:x}>".format(id(v))
elif ("partial_eval.JaxprTracer" in vtype or
"batching.BatchTracer" in vtype or
"ad.JVPTracer" in vtype):
return "Traced<{}>".format(v.aval)
elif isinstance(v, tuple):
return "({})".format(pp_values(v))
else:
return str(v)
def pp_values(args):
return ", ".join([pp(arg) for arg in args])
@functools.wraps(func)
def func_wrapper(*args):
_trace_indent("call {}({})".format(name, pp_values(args)))
res = func(*args)
_trace_unindent("< {} = {}".format(name, pp(res)))
return res
return func_wrapper
return trace_func
class expectNotImplementedError(object):
"""Context manager to check for NotImplementedError."""
def __enter__(self): pass
def __exit__(self, type, value, tb):
global _indentation
_indentation = 0
if type is NotImplementedError:
print("\nFound expected exception:")
traceback.print_exc(limit=3)
return True
elif type is None: # No exception
assert False, "Expected NotImplementedError"
else:
return False
Instead of using jax.lax
primitives directly, we can use other functions that are already written in terms of those primitives, such as those in jax.numpy
:
[3]:
import jax.numpy as jnp
import numpy as onp
@trace("multiply_add_numpy")
def multiply_add_numpy(x, y, z):
return jnp.add(jnp.multiply(x, y), z)
@trace("square_add_numpy")
def square_add_numpy(a, b):
return multiply_add_numpy(a, a, b)
print("\nNormal evaluation:")
print("square_add_numpy = ", square_add_numpy(2., 10.))
print("\nGradient evaluation:")
print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.))
Normal evaluation:
call square_add_numpy(2.0, 10.0)
call multiply_add_numpy(2.0, 2.0, 10.0)
< multiply_add_numpy = 14.0
< square_add_numpy = 14.0
square_add_numpy = 14.0
Gradient evaluation:
call square_add_numpy(Traced<ConcreteArray(2.0, weak_type=True)>, 10.0)
call multiply_add_numpy(Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ConcreteArray(2.0, weak_type=True)>, 10.0)
< multiply_add_numpy = Traced<ConcreteArray(14.0)>
< square_add_numpy = Traced<ConcreteArray(14.0)>
grad(square_add_numpy) = 4.0
Notice that in the process of computing grad
, JAX invokes square_add_numpy
and multiply_add_numpy
with special arguments ConcreteArray(...)
(described further below in this colab). It is important to remember that a JAXtraceable function must be able to operate not only on concrete arguments but also on special abstract arguments that JAX may use to abstract the function execution.
The JAX traceability property is satisfied as long as the function is written in terms of JAX primitives.
Defining new JAX primitivesÂ¶
The right way to add support for multiplyadd is in terms of existing JAX primitives, as shown above. However, in order to demonstrate how JAX primitives work let us pretend that we want to add a new primitive to JAX for the multiplyadd functionality.
[4]:
from jax import core
multiply_add_p = core.Primitive("multiply_add") # Create the primitive
@trace("multiply_add_prim")
def multiply_add_prim(x, y, z):
"""The JAXtraceable way to use the JAX primitive.
Note that the traced arguments must be passed as positional arguments
to `bind`.
"""
return multiply_add_p.bind(x, y, z)
@trace("square_add_prim")
def square_add_prim(a, b):
"""A squareadd function implemented using the new JAXprimitive."""
return multiply_add_prim(a, a, b)
If we try to call the newly defined functions we get an error, because we have not yet told JAX anything about the semantics of the new primitive.
[5]:
with expectNotImplementedError():
square_add_prim(2., 10.)
call square_add_prim(2.0, 10.0)
call multiply_add_prim(2.0, 2.0, 10.0)
Found expected exception:
Traceback (most recent call last):
File "<ipythoninput5acee329b29d0>", line 2, in <module>
square_add_prim(2., 10.)
File "<ipythoninput2756fd2c18f40>", line 48, in func_wrapper
res = func(*args)
File "<ipythoninput4c5402c1795f0>", line 16, in square_add_prim
return multiply_add_prim(a, a, b)
NotImplementedError: Evaluation rule for 'multiply_add' not implemented
Primal evaluation rulesÂ¶
[6]:
@trace("multiply_add_impl")
def multiply_add_impl(x, y, z):
"""Concrete implementation of the primitive.
This function does not need to be JAX traceable.
Args:
x, y, z: the concrete arguments of the primitive. Will only be caled with
concrete values.
Returns:
the concrete result of the primitive.
"""
# Note that we can use the original numpy, which is not JAX traceable
return onp.add(onp.multiply(x, y), z)
# Now we register the primal implementation with JAX
multiply_add_p.def_impl(multiply_add_impl)
[6]:
<function __main__.multiply_add_impl(x, y, z)>
[7]:
assert square_add_prim(2., 10.) == 14.
call square_add_prim(2.0, 10.0)
call multiply_add_prim(2.0, 2.0, 10.0)
call multiply_add_impl(2.0, 2.0, 10.0)
< multiply_add_impl = 14.0
< multiply_add_prim = 14.0
< square_add_prim = 14.0
JITÂ¶
If we now try to use jit
we get a NotImplementedError
:
[8]:
with expectNotImplementedError():
api.jit(square_add_prim)(2., 10.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
Found expected exception:
Traceback (most recent call last):
File "<ipythoninput8d4853f4fcae2>", line 2, in <module>
api.jit(square_add_prim)(2., 10.)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/testdocs/lib/python3.7/sitepackages/jax/api.py", line 151, in f_jitted
name=flat_fun.__name__)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/testdocs/lib/python3.7/sitepackages/jax/core.py", line 951, in call_bind
outs = primitive.impl(f, *args, **params)
NotImplementedError: Abstract evaluation for 'multiply_add' not implemented
Abstract evaluation rulesÂ¶
In order to JIT the function, and for other transformations as well, JAX first evaluates it abstractly using only the shape and type of the arguments. This abstract evaluation serves multiple purposes:
 Gets the sequence of JAX primitives that are used in the computation. This sequence will be compiled.
 Computes the shape and type of all vectors and operations used in the computation.
For example, the abstraction of a vector with 3 elements may be ShapedArray(float32[3])
, or ConcreteArray([1., 2., 3.])
. In the latter case, JAX uses the actual concrete value wrapped as an abstract value.
[9]:
from jax import abstract_arrays
@trace("multiply_add_abstract_eval")
def multiply_add_abstract_eval(xs, ys, zs):
"""Abstract evaluation of the primitive.
This function does not need to be JAX traceable. It will be invoked with
abstractions of the actual arguments.
Args:
xs, ys, zs: abstractions of the arguments.
Result:
a ShapedArray for the result of the primitive.
"""
assert xs.shape == ys.shape
assert xs.shape == zs.shape
return abstract_arrays.ShapedArray(xs.shape, xs.dtype)
# Now we register the abstract evaluation with JAX
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)
[9]:
<function __main__.multiply_add_abstract_eval(xs, ys, zs)>
If we reattempt to JIT, we see how the abstract evaluation proceeds, but we get another error, about missing the actual XLA compilation rule:
[10]:
with expectNotImplementedError():
api.jit(square_add_prim)(2., 10.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>
< square_add_prim = Traced<ShapedArray(float32[])>
Found expected exception:
Traceback (most recent call last):
File "<ipythoninput10d4853f4fcae2>", line 2, in <module>
api.jit(square_add_prim)(2., 10.)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/testdocs/lib/python3.7/sitepackages/jax/api.py", line 151, in f_jitted
name=flat_fun.__name__)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/testdocs/lib/python3.7/sitepackages/jax/core.py", line 951, in call_bind
outs = primitive.impl(f, *args, **params)
NotImplementedError: XLA translation rule for primitive 'multiply_add' not found
XLA Compilation rulesÂ¶
JAX compilation works by compiling each primitive into a graph of XLA operations.
This is biggest hurdle to adding new functionality to JAX, because the set of XLA operations is limited, and JAX already has predefined primitives for most of them. However, XLA includes a CustomCall
operation that can be used to encapsulate arbitrary functionality defined using C++.
[11]:
@trace("multiply_add_xla_translation")
def multiply_add_xla_translation(c, xc, yc, zc):
"""The compilation to XLA of the primitive.
Given an XlaBuilder and XlaOps for each argument, return the XlaOp for the
result of the function.
Does not need to be a JAXtraceable function.
"""
return c.Add(c.Mul(xc, yc), zc)
# Now we register the XLA compilation rule with JAX
# TODO: for GPU? and TPU?
from jax import xla
xla.backend_specific_translations['cpu'][multiply_add_p] = multiply_add_xla_translation
Now we succeed to JIT. Notice below that JAX first evaluates the function abstractly, which triggers the multiply_add_abstract_eval
function, and then compiles the set of primitives it has encountered, including multiply_add
. At this point JAX invokes multiply_add_xla_translation
.
[12]:
assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>
< square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_xla_translation(<JaxComputationBuilder>, <XlaOp at 0x7fd5087f0f80>, <XlaOp at 0x7fd5087f0f80>, <XlaOp at 0x7fd5087f0fb8>)
< multiply_add_xla_translation = <XlaOp at 0x7fd511061ce0>
Below is another use of jit
where we compile only with respect to the first argument. Notice how the second argument to square_add_prim
is concrete, which leads in the third argument to multiply_add_abstract_eval
being ConcreteArray
. We see that multiply_add_abstract_eval
may be used with both ShapedArray
and ConcreteArray
.
[13]:
assert api.jit(lambda x, y: square_add_prim(x, y),
static_argnums=1)(2., 10.) == 14.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, 10.0)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, 10.0)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ConcreteArray(10.0, weak_type=True))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>
< square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_xla_translation(<JaxComputationBuilder>, <XlaOp at 0x7fd50874d618>, <XlaOp at 0x7fd50874d618>, <XlaOp at 0x7fd50874d6f8>)
< multiply_add_xla_translation = <XlaOp at 0x7fd50874d5e0>
Forward differentiationÂ¶
JAX implements forward differentiation in the form of a Jacobianvector product (see the JAX autodiff cookbook).
If we attempt now to compute the jvp
function we get an error because we have not yet told JAX how to differentiate the multiply_add
primitive.
[14]:
# The second argument `(2., 10.)` are the argument values
# where we evaluate the Jacobian, and the third `(1., 1.)`
# are the values of the tangents for the arguments.
with expectNotImplementedError():
api.jvp(square_add_prim, (2., 10.), (1., 1.))
call square_add_prim(Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ConcreteArray(10.0, weak_type=True)>)
call multiply_add_prim(Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ConcreteArray(10.0, weak_type=True)>)
Found expected exception:
Traceback (most recent call last):
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/testdocs/lib/python3.7/sitepackages/jax/interpreters/ad.py", line 297, in process_primitive
jvp = primitive_jvps[primitive]
KeyError: multiply_add
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<ipythoninput14f07eb564206f>", line 5, in <module>
api.jvp(square_add_prim, (2., 10.), (1., 1.))
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/testdocs/lib/python3.7/sitepackages/jax/api.py", line 1160, in jvp
return _jvp(lu.wrap_init(fun), primals, tangents)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/testdocs/lib/python3.7/sitepackages/jax/api.py", line 1183, in _jvp
out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat)
NotImplementedError: Forwardmode differentiation rule for 'multiply_add' not implemented
[15]:
from jax import ad
@trace("multiply_add_value_and_jvp")
def multiply_add_value_and_jvp(arg_values, arg_tangents):
"""Evaluates the primal output and the tangents (Jacobianvector product).
Given values of the arguments and perturbation of the arguments (tangents),
compute the output of the primitive and the perturbation of the output.
This method must be JAXtraceable. JAX may invoke it with abstract values
for the arguments and tangents.
Args:
arg_values: a tuple of arguments
arg_tangents: a tuple with the tangents of the arguments. The tuple has
the same length as the arg_values. Some of the tangents may also be the
special value ad.Zero to specify a zero tangent.
Returns:
a pair of the primal output and the tangent.
"""
x, y, z = arg_values
xt, yt, zt = arg_tangents
_trace("Primal evaluation:")
# Now we have a JAXtraceable computation of the output.
# Normally, we can use the ma primtive itself to compute the primal output.
primal_out = multiply_add_prim(x, y, z)
_trace("Tangent evaluation:")
# We must use a JAXtraceable way to compute the tangent. It turns out that
# the output tangent can be computed as (xt * y + x * yt + zt),
# which we can implement in a JAXtraceable way using the same "multiply_add_prim" primitive.
# We do need to deal specially with Zero. Here we just turn it into a
# proper tensor of 0s (of the same shape as 'x').
# An alternative would be to check for Zero and perform algebraic
# simplification of the output tangent computation.
def make_zero(tan):
return lax.zeros_like_array(x) if tan is ad.zero else tan
output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))
return (primal_out, output_tangent)
# Register the forward differentiation rule with JAX
ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp
[16]:
# Tangent is: xt*y + x*yt + zt = 1.*2. + 2.*1. + 1. = 5.
assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.)
call square_add_prim(Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ConcreteArray(10.0, weak_type=True)>)
call multiply_add_prim(Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ConcreteArray(10.0, weak_type=True)>)
call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (1.0, 1.0, 1.0))
Primal evaluation:
call multiply_add_prim(2.0, 2.0, 10.0)
call multiply_add_impl(2.0, 2.0, 10.0)
< multiply_add_impl = 14.0
< multiply_add_prim = 14.0
Tangent evaluation:
call multiply_add_prim(2.0, 1.0, 1.0)
call multiply_add_impl(2.0, 1.0, 1.0)
< multiply_add_impl = 3.0
< multiply_add_prim = 3.0
call multiply_add_prim(1.0, 2.0, 3.0)
call multiply_add_impl(1.0, 2.0, 3.0)
< multiply_add_impl = 5.0
< multiply_add_prim = 5.0
< multiply_add_value_and_jvp = (14.0, 5.0)
< multiply_add_prim = Traced<ConcreteArray(14.0)>
< square_add_prim = Traced<ConcreteArray(14.0)>
TO EXPLAIN:
 Why is JAX using ConcreteArray in square_add_prim? There is no abstract evaluation going on here.
 Not sure how to explain that multiply_add_prim is invoked with ConcreteValue, yet we do not call the multiply_add_abstract_eval.
 I think it would be useful to show the jaxpr here
JIT of forward differentiationÂ¶
We can apply JIT to the forward differentiation function:
[17]:
assert api.jit(lambda arg_values, arg_tangents:
api.jvp(square_add_prim, arg_values, arg_tangents))(
(2., 10.), (1., 1.)) == (14., 5.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
call multiply_add_value_and_jvp((Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>))
Primal evaluation:
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>
Tangent evaluation:
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[])>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>
< multiply_add_value_and_jvp = (Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
< multiply_add_prim = Traced<ShapedArray(float32[])>
< square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_xla_translation(<JaxComputationBuilder>, <XlaOp at 0x7fd50874d618>, <XlaOp at 0x7fd50874d618>, <XlaOp at 0x7fd50874d3b0>)
< multiply_add_xla_translation = <XlaOp at 0x7fd5087f02d0>
call multiply_add_xla_translation(<JaxComputationBuilder>, <XlaOp at 0x7fd50874d618>, <XlaOp at 0x7fd5087f0308>, <XlaOp at 0x7fd5087f0f48>)
< multiply_add_xla_translation = <XlaOp at 0x7fd5087f0fb8>
call multiply_add_xla_translation(<JaxComputationBuilder>, <XlaOp at 0x7fd5087f0308>, <XlaOp at 0x7fd50874d618>, <XlaOp at 0x7fd5087f0fb8>)
< multiply_add_xla_translation = <XlaOp at 0x7fd5087f0570>
Notice that first we evaluate multiply_add_value_and_jvp
abstractly, which in turn evaluates abstractly both the primal and the tangent evaluation (a total of 3 invocations of the ma
primitive). Then we compile the 3 occurrences of the primitive.
Reverse differentiationÂ¶
If we attempt now to use reverse differentiation we see that JAX starts by using the multiply_add_value_and_jvp
to compute the forward differentiation for abstract values, but then runs into a NotImplementedError
.
When computing the reverse differentiation JAX first does abstract evaluation of the forward differentiation code multiply_add_value_and_jvp
to obtain a trace of primitives that compute the output tangent. Observe that JAX performs this abstract evaluation with concrete values for the differentiation point, and abstract values for the tangents. Observe also that JAX uses the special abstract tangent value Zero
for the tangent corresponding to the 3rd argument of ma
. This reflects the
fact that we do not differentiate w.r.t. the 2nd argument to square_add_prim
, which flow to 3rd argument to multiply_add_prim
.
Observe also that during the abstract evaluation of the tangent we pass the value 0.0 as the tangent for the 3rd argument. This is due to the use of the make_zero
function in the definition of multiply_add_value_and_jvp
.
[18]:
# This is reverse differentiation w.r.t. the first argument of square_add_prim
with expectNotImplementedError():
api.grad(square_add_prim)(2., 10.)
call square_add_prim(Traced<ConcreteArray(2.0, weak_type=True)>, 10.0)
call multiply_add_prim(Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ConcreteArray(2.0, weak_type=True)>, 10.0)
call multiply_add_value_and_jvp((Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ConcreteArray(2.0, weak_type=True)>, 10.0), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero))
Primal evaluation:
call multiply_add_prim(Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ConcreteArray(2.0, weak_type=True)>, 10.0)
call multiply_add_impl(2.0, 2.0, 10.0)
< multiply_add_impl = 14.0
< multiply_add_prim = 14.0
Tangent evaluation:
call multiply_add_prim(Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, 0.0)
call multiply_add_abstract_eval(ConcreteArray(2.0, weak_type=True), ShapedArray(float32[], weak_type=True), ConcreteArray(0.0))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ShapedArray(float32[])>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ConcreteArray(2.0, weak_type=True), ShapedArray(float32[]))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>
< multiply_add_value_and_jvp = (14.0, Traced<ShapedArray(float32[])>)
< multiply_add_prim = Traced<ConcreteArray(14.0)>
< square_add_prim = Traced<ConcreteArray(14.0)>
Found expected exception:
Traceback (most recent call last):
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/testdocs/lib/python3.7/sitepackages/jax/interpreters/ad.py", line 277, in get_primitive_transpose
return primitive_transposes[p]
KeyError: multiply_add
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<ipythoninput18a915b4bc91d2>", line 3, in <module>
api.grad(square_add_prim)(2., 10.)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/testdocs/lib/python3.7/sitepackages/jax/api.py", line 370, in grad_f
_, g = value_and_grad_f(*args, **kwargs)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/testdocs/lib/python3.7/sitepackages/jax/api.py", line 436, in value_and_grad_f
g = vjp_py(onp.ones((), dtype=dtype))
NotImplementedError: Transpose rule (for reversemode differentiation) for 'multiply_add' not implemented
The above error is because there is a missing piece for JAX to be able to use the forward differentiation code to compute reverse differentiation.
TranspositionÂ¶
As explained above, when computing reverse differentiation JAX obtains a trace of primitives that compute the tangent using forward differentiation. Then, JAX interprets this trace abstractly backwards and for each primitive it applies a transposition rule.
To understand what is going on, consider for now a simpler example of the function â€śf(x, y) = x * y + yâ€ť. Assume we need to differentiate at the point (2., 4.)
. JAX will produce the following JVP tangent calculation of ft
from the tangents of the input xt
and yt
:
a = xt * 4.
b = 2. * yt
c = a + b
ft = c + yt
By construction, the tangent calculation is always linear in the input tangents. The only nonlinear operator that may arise in the tangent calculation is multiplication, but then one of the operands is constant.
JAX will produce the reverse differentiation computation by processing the JVP computation backwards. For each operation in the tangent computation, it accumulates the cotangents of the variables used by the operation, using the cotangent of the result of the operation:
# Initialize cotangents of inputs and intermediate vars
xct = yct = act = bct = cct = 0.
# Initialize cotangent of the output
fct = 1.
# Process "ft = c + yt"
cct += fct
yct += fct
# Process "c = a + b"
act += cct
bct += cct
# Process "b = 2. * yt"
yct += 2. * bct
# Process "a = xt * 4."
xct += act * 4.
One can verify that this computation produces xct = 4.
and yct = 3.
, which are the partial derivatives of the function f
.
JAX knows for each primitive that may appear in a JVP calculation how to transpose it. Conceptually, if the primitive p(x, y, z)
is linear in the arguments y
and z
for a constant value of x
, e.g., p(x, y, z) = y*cy + z*cz
, then the transposition of the primitive is:
p_transpose(out_ct, x, _, _) = (None, out_ct*cy, out_ct*cz)
Notice that p_transpose
takes the cotangent of the output of the primitive and a value corresponding to each argument of the primitive. For the linear arguments, the transposition gets an undefined _
value, and for the other arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value None
returned for the constant arguments.
In particular,
add_transpose(out_ct, _, _) = (out_ct, out_ct)
mult_transpose(out_ct, x, _) = (None, x * out_ct)
mult_transpose(out_ct, _, y) = (out_ct * y, None)
[19]:
@trace("multiply_add_transpose")
def multiply_add_transpose(ct, x, y, z):
"""Evaluates the transpose of a linear primitive.
This method is only used when computing the backward gradient following
value_and_jvp, and is only needed for primitives that are used in the JVP
calculation for some other primitive. We need transposition for multiply_add_prim,
because we have used multiply_add_prim in the computation of the output_tangent in
multiply_add_value_and_jvp.
In our case, multiply_add is not a linear primitive. However, it is used linearly
w.r.t. tangents in multiply_add_value_and_jvp:
output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt)
Always one of the first two multiplicative arguments are constants.
Args:
ct: the cotangent of the output of the primitive.
x, y, z: values of the arguments. The arguments that are used linearly
get an ad.UndefinedPrimal value. The other arguments get a constant
value.
Returns:
a tuple with the cotangent of the inputs, with the value None
corresponding to the constant arguments.
"""
if not ad.is_undefined_primal(x):
# This use of multiply_add is with a constant "x"
assert ad.is_undefined_primal(y)
ct_y = ad.zero if ct is ad.zero else multiply_add_prim(x, ct, lax.zeros_like_array(x))
res = None, ct_y, ct
else:
# This use of multiply_add is with a constant "y"
assert ad.is_undefined_primal(x)
ct_x = ad.zero if ct is ad.zero else multiply_add_prim(ct, y, lax.zeros_like_array(y))
res = ct_x, None, ct
return res
ad.primitive_transposes[multiply_add_p] = multiply_add_transpose
Now we can complete the run of the grad
:
[20]:
assert api.grad(square_add_prim)(2., 10.) == 4.
call square_add_prim(Traced<ConcreteArray(2.0, weak_type=True)>, 10.0)
call multiply_add_prim(Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ConcreteArray(2.0, weak_type=True)>, 10.0)
call multiply_add_value_and_jvp((Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ConcreteArray(2.0, weak_type=True)>, 10.0), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero))
Primal evaluation:
call multiply_add_prim(Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ConcreteArray(2.0, weak_type=True)>, 10.0)
call multiply_add_impl(2.0, 2.0, 10.0)
< multiply_add_impl = 14.0
< multiply_add_prim = 14.0
Tangent evaluation:
call multiply_add_prim(Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, 0.0)
call multiply_add_abstract_eval(ConcreteArray(2.0, weak_type=True), ShapedArray(float32[], weak_type=True), ConcreteArray(0.0))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ConcreteArray(2.0, weak_type=True)>, Traced<ShapedArray(float32[])>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ConcreteArray(2.0, weak_type=True), ShapedArray(float32[]))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>
< multiply_add_value_and_jvp = (14.0, Traced<ShapedArray(float32[])>)
< multiply_add_prim = Traced<ConcreteArray(14.0)>
< square_add_prim = Traced<ConcreteArray(14.0)>
call multiply_add_transpose(1.0, UndefinedPrimal(ShapedArray(float32[])), 2.0, UndefinedPrimal(ShapedArray(float32[])))
call multiply_add_prim(1.0, 2.0, 0.0)
call multiply_add_impl(1.0, 2.0, 0.0)
< multiply_add_impl = 2.0
< multiply_add_prim = 2.0
< multiply_add_transpose = (2.0, None, 1.0)
call multiply_add_transpose(1.0, 2.0, UndefinedPrimal(ShapedArray(float32[])), 0.0)
call multiply_add_prim(2.0, 1.0, 0.0)
call multiply_add_impl(2.0, 1.0, 0.0)
< multiply_add_impl = 2.0
< multiply_add_prim = 2.0
< multiply_add_transpose = (None, 2.0, 1.0)
Notice the two calls to multiply_add_transpose
. They correspond to the two uses of multiply_add_prim
in the computation of the output_tangent
in multiply_add_value_and_jvp
. The first call to transpose corresponds to the last use of multiply_add_prim
: multiply_add_prim(xt, y, ...)
where y
is the constant 2.0.
JIT of reverse differentiationÂ¶
Notice that the abstract evaluation of the multiply_add_value_and_jvp
is using only abstract values, while in the absensce of JIT we used ConcreteArray
.
[21]:
assert api.jit(api.grad(square_add_prim))(2., 10.) == 4.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
call multiply_add_value_and_jvp((Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero))
Primal evaluation:
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>
Tangent evaluation:
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[])>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[])>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>
< multiply_add_value_and_jvp = (Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
< multiply_add_prim = Traced<ShapedArray(float32[])>
< square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_transpose(1.0, UndefinedPrimal(ShapedArray(float32[])), Traced<ShapedArray(float32[], weak_type=True)>, UndefinedPrimal(ShapedArray(float32[])))
call multiply_add_prim(1.0, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[])>)
call multiply_add_abstract_eval(ConcreteArray(1.0), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>
< multiply_add_transpose = (Traced<ShapedArray(float32[])>, None, 1.0)
call multiply_add_transpose(1.0, Traced<ShapedArray(float32[], weak_type=True)>, UndefinedPrimal(ShapedArray(float32[])), Traced<ShapedArray(float32[])>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, 1.0, Traced<ShapedArray(float32[])>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ConcreteArray(1.0), ShapedArray(float32[]))
< multiply_add_abstract_eval = ShapedArray(float32[])
< multiply_add_prim = Traced<ShapedArray(float32[])>
< multiply_add_transpose = (None, Traced<ShapedArray(float32[])>, 1.0)
call multiply_add_xla_translation(<JaxComputationBuilder>, <XlaOp at 0x7fd50874dc38>, <XlaOp at 0x7fd5087f0730>, <XlaOp at 0x7fd50874dd88>)
< multiply_add_xla_translation = <XlaOp at 0x7fd50874d228>
call multiply_add_xla_translation(<JaxComputationBuilder>, <XlaOp at 0x7fd5087f0730>, <XlaOp at 0x7fd508765570>, <XlaOp at 0x7fd50874dc00>)
< multiply_add_xla_translation = <XlaOp at 0x7fd50874dc38>
BatchingÂ¶
The batching transformation takes a pointwise computation and turns it into a computation on vectors. If we try it right now, we get a NotImplementedError
:
[22]:
# The arguments are two vectors instead of two scalars
with expectNotImplementedError():
api.vmap(square_add_prim, in_axes=0, out_axes=0)(onp.array([2., 3.]),
onp.array([10., 20.]))
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
Found expected exception:
Traceback (most recent call last):
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/testdocs/lib/python3.7/sitepackages/jax/interpreters/batching.py", line 203, in get_primitive_batcher
return primitive_batchers[p]
KeyError: multiply_add
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<ipythoninput22c82d11ecb99c>", line 4, in <module>
onp.array([10., 20.]))
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/testdocs/lib/python3.7/sitepackages/jax/api.py", line 722, in batched_fun
lambda: _flatten_axes(out_tree(), out_axes))
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/testdocs/lib/python3.7/sitepackages/jax/interpreters/batching.py", line 34, in batch
return batched_fun.call_wrapped(*in_vals)
NotImplementedError: Batching rule for 'multiply_add' not implemented
We need to tell JAX how to evaluate the batched version of the primitive. In this particular case, the multiply_add_prim
already operates pointwise for any dimension of input vectors. So the batched version can use the same multiply_add_prim
implementation.
[23]:
from jax import batching
@trace("multiply_add_batch")
def multiply_add_batch(vector_arg_values, batch_axes):
"""Computes the batched version of the primitive.
This must be a JAXtraceable function.
Since the multiply_add primitive already operates pointwise on arbitrary
dimension tensors, to batch it we can use the primitive itself. This works as
long as both the inputs have the same dimensions and are batched along the
same axes. The result is batched along the axis that the inputs are batched.
Args:
vector_arg_values: a tuple of two arguments, each being a tensor of matching
shape.
batch_axes: the axes that are being batched. See vmap documentation.
Returns:
a tuple of the result, and the result axis that was batched.
"""
assert batch_axes[0] == batch_axes[1]
assert batch_axes[0] == batch_axes[2]
_trace("Using multiply_add to compute the batch:")
res = multiply_add_prim(*vector_arg_values)
return res, batch_axes[0]
batching.primitive_batchers[multiply_add_p] = multiply_add_batch
[24]:
assert onp.allclose(api.vmap(square_add_prim, in_axes=0, out_axes=0)(
onp.array([2., 3.]),
onp.array([10., 20.])),
[14., 29.])
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_batch(([2. 3.], [2. 3.], [10. 20.]), (0, 0, 0))
Using multiply_add to compute the batch:
call multiply_add_prim([2. 3.], [2. 3.], [10. 20.])
call multiply_add_impl([2. 3.], [2. 3.], [10. 20.])
< multiply_add_impl = [14. 29.]
< multiply_add_prim = [14. 29.]
< multiply_add_batch = ([14. 29.], 0)
< multiply_add_prim = Traced<ShapedArray(float32[])>
< square_add_prim = Traced<ShapedArray(float32[])>
JIT of batchingÂ¶
[25]:
assert onp.allclose(api.jit(api.vmap(square_add_prim, in_axes=0, out_axes=0))
(onp.array([2., 3.]),
onp.array([10., 20.])),
[14., 29.])
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_batch((Traced<ShapedArray(float32[2])>, Traced<ShapedArray(float32[2])>, Traced<ShapedArray(float32[2])>), (0, 0, 0))
Using multiply_add to compute the batch:
call multiply_add_prim(Traced<ShapedArray(float32[2])>, Traced<ShapedArray(float32[2])>, Traced<ShapedArray(float32[2])>)
call multiply_add_abstract_eval(ShapedArray(float32[2]), ShapedArray(float32[2]), ShapedArray(float32[2]))
< multiply_add_abstract_eval = ShapedArray(float32[2])
< multiply_add_prim = Traced<ShapedArray(float32[2])>
< multiply_add_batch = (Traced<ShapedArray(float32[2])>, 0)
< multiply_add_prim = Traced<ShapedArray(float32[])>
< square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_xla_translation(<JaxComputationBuilder>, <XlaOp at 0x7fd508765ca8>, <XlaOp at 0x7fd508765ca8>, <XlaOp at 0x7fd508765ae8>)
< multiply_add_xla_translation = <XlaOp at 0x7fd508765c70>
Writing custom Jaxpr interpreters in JAXÂ¶
JAX offers several composable function transformations (jit
, grad
, vmap
, etc.) that enable writing concise, accelerated code.
Here we show how to add your own function transformations to the system, by writing a custom Jaxpr interpreter. And weâ€™ll get composability with all the other transformations for free.
[1]:
import numpy as onp
import jax
import jax.numpy as np
from jax import jit, grad, vmap
from jax import random
What is JAX doing?Â¶
JAX provides a NumPylike API for numerical computing which can be used as is, but JAXâ€™s true power comes from composable function transformations. Take the jit
function transformation, which takes in a function and returns a semantically identical function but is lazily compiled by XLA for accelerators.
[2]:
x = random.normal(random.PRNGKey(0), (5000, 5000))
def f(w, b, x):
return np.tanh(np.dot(x, w) + b)
fast_f = jit(f)
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/testdocs/lib/python3.7/sitepackages/jax/lib/xla_bridge.py:123: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
When we call fast_f
, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JITcompiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jaxâ€™s tracing machinery, you can refer to the â€śHow it worksâ€ť section in the README.
Jaxpr tracerÂ¶
A tracer of special importance in Jax is the Jaxpr tracer, which records ops into a Jaxpr (Jax expression). A Jaxpr is a data structure that can be evaluated like a mini functional programming language and thus Jaxprs are a useful intermediate representation for function transformation.
To get a first look at Jaxprs, consider the make_jaxpr
transformation. make_jaxpr
is essentially a â€śprettyprintingâ€ť transformation: it transforms a function into one that, given example arguments, produces a Jaxpr representation of its computation. Although we canâ€™t generally use the Jaxprs that it returns, it is useful for debugging and introspection. Letâ€™s use it to look at how some example Jaxprs are structured.
[3]:
def examine_jaxpr(typed_jaxpr):
jaxpr = typed_jaxpr.jaxpr
print("invars:", jaxpr.invars)
print("outvars:", jaxpr.outvars)
print("constvars:", jaxpr.constvars)
for eqn in jaxpr.eqns:
print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
print()
print("jaxpr:", jaxpr)
def foo(x):
return x + 1
print("foo")
print("=====")
examine_jaxpr(jax.make_jaxpr(foo)(5))
print()
def bar(w, b, x):
return np.dot(w, x) + b + np.ones(5), x
print("bar")
print("=====")
examine_jaxpr(jax.make_jaxpr(bar)(np.ones((5, 10)), np.ones(5), np.ones(10)))
foo
=====
invars: [a]
outvars: [b]
constvars: []
equation: [a, 1] add [b] {}
jaxpr: { lambda ; a.
let b = add a 1
in (b,) }
bar
=====
invars: [a, b, c]
outvars: [g, c]
constvars: [f]
equation: [a, c] dot_general [d] {'dimension_numbers': (((1,), (0,)), ((), ())), 'precision': None}
equation: [d, b] add [e] {}
equation: [e, f] add [g] {}
jaxpr: { lambda f ; a b c.
let d = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))
precision=None ] a c
e = add d b
g = add e f
in (g, c) }
jaxpr.invars
 theinvars
of a Jaxpr are a list of the input variables to Jaxpr, analogous to arguments in Python functionsjaxpr.outvars
 theoutvars
of a Jaxpr are the variables that are returned by the Jaxpr. Every Jaxpr has multiple outputs.jaxpr.constvars
 theconstvars
are a list of variables that are also inputs to the Jaxpr, but correspond to constants from the trace (weâ€™ll go over these in more detail later)jaxpr.eqns
 a list of equations, which are essentially letbindings. Each equation is list of input variables, a list of output variables, and a primitive, which is used to evaluate inputs to produce outputs. Each equation also has aparams
, a dictionary of parameters.
All together, a Jaxpr encapsulates a simple program that can be evaluated with inputs to produce an output. Weâ€™ll go over how exactly to do this later. The important thing to note now is that a Jaxpr is a data structure that can be manipulated and evaluated in whatever way we want.
Why are Jaxprs useful?Â¶
Jaxprs are simple program representations that are easy to transform. And because Jax lets us stage out Jaxprs from Python functions, it gives us a way to transform numerical programs written in Python.
Your first interpreter: invert
Â¶
Letâ€™s try to implement a simple function â€śinverterâ€ť, which takes in the output of the original function and returns the inputs that produced those outputs. For now, letâ€™s focus on simple, unary functions which are composed of other invertible unary functions.
Goal:
def f(x):
return np.exp(np.tanh(x))
f_inv = inverse(f)
assert np.allclose(f_inv(f(1.0)), 1.0)
The way weâ€™ll implement this is by (1) tracing f
into a Jaxpr, then (2) interpreting the Jaxpr backwards. While interpreting the Jaxpr backwards, for each equation weâ€™ll look up the primitiveâ€™s inverse in a table and apply it.
1. Tracing a functionÂ¶
We canâ€™t use make_jaxpr
for this, because we need to pull out constants created during the trace to pass into the Jaxpr. However, we can write a function that does something very similar to make_jaxpr
.
[4]:
# Importing Jax functions useful for tracing/interpreting.
import numpy as onp
from functools import wraps
from jax import api_util
from jax import core
from jax import lax
from jax import linear_util as lu
from jax import tree_util
from jax.abstract_arrays import ShapedArray
from jax.interpreters import partial_eval as pe
from jax.util import safe_map
[5]:
def make_jaxpr2(fun):
def pv_like(x):
# ShapedArrays are abstract values that carry around
# shape and dtype information
aval = ShapedArray(onp.shape(x), onp.result_type(x))
return pe.PartialVal((aval, core.unit))
@wraps(fun)
def jaxpr_const_maker(*args, **kwargs):
# Set up fun for transformation
wrapped = lu.wrap_init(fun)
# Flatten input args
jax_args, in_tree = tree_util.tree_flatten((args, kwargs))
# Transform fun to accept flat args
# and return a flat list result
jaxtree_fun, out_tree = api_util.flatten_fun(wrapped, in_tree)
# Abstract and partialval's flat args
pvals = safe_map(pv_like, jax_args)
# Trace function into Jaxpr
jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals)
return jaxpr, consts, (in_tree, out_tree())
return jaxpr_const_maker
This function first flattens its arguments into a list, which are the abstracted and wrapped as partial values. The pe.trace_to_jaxpr
function is used to then trace a function into a Jaxpr from a list of partial value inputs.
[6]:
def f(x):
return np.exp(np.tanh(x))
jaxpr, consts, _ = make_jaxpr2(f)(np.ones(5))
print(jaxpr)
print(consts)
{ lambda ; a.
let b = tanh a
c = exp b
in (c,) }
()
This particular function doesnâ€™t have any example constants, but in general, this is how you both trace into a Jaxpr and extract the constants.
2. Evaluating a JaxprÂ¶
Before we write a custom Jaxpr interpreter, letâ€™s first implement the â€śdefaultâ€ť interpreter, eval_jaxpr
, which evaluates the Jaxpr asis, computing the same values that the original, untransformed Python function would.
To do this, we first create an environment to store the values for each of the variables, and update the environment with each equation we evaluate in the Jaxpr.
[7]:
def eval_jaxpr(jaxpr, consts, *args):
# Mapping from variable > value
env = {}
def read(var):
# Literals are values baked into the Jaxpr
if type(var) is core.Literal:
return var.val
return env[var]
def write(var, val):
env[var] = val
# Bind args and consts to environment
write(core.unitvar, core.unit)
safe_map(write, jaxpr.invars, args)
safe_map(write, jaxpr.constvars, consts)
# Loop through equations and evaluate primitives using `bind`
for eqn in jaxpr.eqns:
# Read inputs to equation from environment
invals = safe_map(read, eqn.invars)
# `bind` is how a primitive is called
outvals = eqn.primitive.bind(*invals, **eqn.params)
# Primitives may return multiple outputs or not
if not eqn.primitive.multiple_results:
outvals = [outvals]
# Write the results of the primitive into the environment
safe_map(write, eqn.outvars, outvals)
# Read the final result of the Jaxpr from the environment
return safe_map(read, jaxpr.outvars)
[8]:
jaxpr, consts, _ = make_jaxpr2(f)(np.ones(5))
eval_jaxpr(jaxpr, consts, np.ones(5))
[8]:
[DeviceArray([2.1416876, 2.1416876, 2.1416876, 2.1416876, 2.1416876], dtype=float32)]
Notice that eval_jaxpr
will always return a list even if the original function does not. To â€śunflattenâ€ť the list into what the function was originally supposed to return, we can use the out_tree
object returned by trace
.
Furthermore, this interpreter does not handle subjaxprs
, which we will not cover in this guide. You can refer to core.eval_jaxpr
(link) to see the edge cases that this interpreter does not cover.
Custom inverse
Jaxpr interpreterÂ¶
An inverse
interpreter doesnâ€™t look too different from eval_jaxpr
. Weâ€™ll first set up the registry which will map primitives to their inverses. Weâ€™ll then write a custom interpreter that looks up primitives in the registry.
It turns out that this interpreter will also look similar to the â€śtransposeâ€ť interpreter used in reversemode autodifferentiation found here.
[9]:
inverse_registry = {}
Weâ€™ll now register inverses for some of the primitives. By convention, primitives in Jax end in _p
and a lot of the popular ones live in lax
.
[10]:
inverse_registry[lax.exp_p] = np.log
inverse_registry[lax.tanh_p] = np.arctanh
inverse
will first trace the function, then custominterpret the Jaxpr. Letâ€™s set up a simple skeleton.
[11]:
def inverse(fun):
@wraps(fun)
def wrapped(*args, **kwargs):
# Since we assume unary functions, we won't
# worry about flattening and
# unflattening arguments
jaxpr, consts, _ = make_jaxpr2(fun)(*args, **kwargs)
out = inverse_jaxpr(jaxpr, consts, *args)
return out[0]
return wrapped
Now we just need to define inverse_jaxpr
, which will walk through the Jaxpr backward and invert primitives when it can.
[12]:
def inverse_jaxpr(jaxpr, consts, *args):
env = {}
def read(var):
if type(var) is core.Literal:
return var.val
return env[var]
def write(var, val):
env[var] = val
# Args now correspond to Jaxpr outvars
write(core.unitvar, core.unit)
safe_map(write, jaxpr.outvars, args)
safe_map(write, jaxpr.constvars, consts)
# Looping backward
for eqn in jaxpr.eqns[::1]:
# outvars are now invars
invals = safe_map(read, eqn.outvars)
if eqn.primitive not in inverse_registry:
raise NotImplementedError("{} does not have registered inverse.".format(
eqn.primitive
))
# Assuming a unary function
outval = inverse_registry[eqn.primitive](*invals)
safe_map(write, eqn.invars, [outval])
return safe_map(read, jaxpr.invars)
Thatâ€™s it!
[13]:
def f(x):
return np.exp(np.tanh(x))
f_inv = inverse(f)
assert np.allclose(f_inv(f(1.0)), 1.0)
Importantly, you can trace through a Jaxpr interpreter.
[14]:
jax.make_jaxpr(inverse(f))(f(1.))
[14]:
{ lambda ; a.
let b = log a
c = atanh b
in (c,) }
Thatâ€™s all it takes to add a new transformation to a system, and you get composition with all the others for free! For example, we can use jit
, vmap
, and grad
with inverse
!
[15]:
jit(vmap(grad(inverse(f))))((np.arange(5) + 1.) / 5.)
[15]:
DeviceArray([3.1440797, 15.584931 , 2.2551253, 1.3155028, 1. ], dtype=float32)
Exercises for the readerÂ¶
 Handle primitives with multiple arguments where inputs are partially known, for example
lax.add_p
,lax.mul_p
.  Handle
xla_call
andxla_pmap
primitives, which will not work with botheval_jaxpr
andinverse_jaxpr
as written.
Change LogÂ¶
These are the release notes for JAX.
jax 0.1.63 (unreleased)Â¶
jax 0.1.62 (March 21, 2020)Â¶
 JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
 Removed the internal function
lax._safe_mul
, which implemented the convention0. * nan == 0.
. This change means some programs when differentiated will produce nans when they previously produced correct values, though it ensures nans rather than silently incorrect results are produced for other programs. See #2447 and #1052 for details.  Added an
all_gather
parallel convenience function.  More type annotations in core code.
jaxlib 0.1.42 (March 19, 2020)Â¶
 jaxlib 0.1.41 broke cloud TPU support due to an API incompatibility. This release fixes it again.
 JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
jax 0.1.61 (March 17, 2020)Â¶
 Fixes Python 3.5 support. This will be the last JAX or jaxlib release that supports Python 3.5.
jax 0.1.60 (March 17, 2020)Â¶
 GitHub commits.
 New features:
jax.pmap()
hasstatic_broadcast_argnums
argument which allows the user to specify arguments that should be treated as compiletime constants and should be broadcasted to all devices. It works analogously tostatic_argnums
injax.jit()
. Improved error messages for when tracers are mistakenly saved in global state.
 Added
jax.nn.one_hot()
utility function.  Added :py:module:`jax.experimental.jet` for exponentially faster higherorder automatic differentiation.
 Added more sanity checking to arguments of
jax.lax.broadcast_in_dim()
.
 The minimum jaxlib version is now 0.1.41.
jaxlib 0.1.40 (March 4, 2020)Â¶
 Adds experimental support in Jaxlib for TensorFlow profiler, which allows tracing of CPU and GPU computations from TensorBoard.
 Includes prototype support for multihost GPU computations that communicate via NCCL.
 Improves performance of NCCL collectives on GPU.
 Adds TopK, CustomCallWithoutLayout, CustomCallWithLayout, IGammaGradA and RandomGamma implementations.
 Supports device assignments known at XLA compilation time.
jax 0.1.59 (February 11, 2020)Â¶
 GitHub commits.
 Breaking changes
 The minimum jaxlib version is now 0.1.38.
 Simplified
Jaxpr
by removing theJaxpr.freevars
andJaxpr.bound_subjaxprs
. The call primitives (xla_call
,xla_pmap
,sharded_call
, andremat_call
) get a new parametercall_jaxpr
with a fullyclosed (noconstvars
) jaxpr. Also, added a new fieldcall_primitive
to primitives.
 New features:
 Reversemode automatic differentiation (e.g.
grad
) oflax.cond
, making it now differentiable in both modes (https://github.com/google/jax/pull/2091)  JAX now supports DLPack, which allows sharing CPU and GPU arrays in a zerocopy way with other libraries, such as PyTorch.
 JAX GPU DeviceArrays now support
__cuda_array_interface__
, which is another zerocopy protocol for sharing GPU arrays with other libraries such as CuPy and Numba.  JAX CPU device buffers now implement the Python buffer protocol, which allows zerocopy buffer sharing between JAX and NumPy.
 Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known as slow.
 Reversemode automatic differentiation (e.g.
jaxlib 0.1.39 (February 11, 2020)Â¶
 Updates XLA.
jaxlib 0.1.38 (January 29, 2020)Â¶
 CUDA 9.0 is no longer supported.
 CUDA 10.2 wheels are now built by default.
jax 0.1.58 (January 28, 2020)Â¶
Breaking changes
 JAX has dropped Python 2 support, because Python 2 reached its end of life on January 1, 2020. Please update to Python 3.5 or newer.
New features
 Forwardmode automatic differentiation (jvp) of while loop (https://github.com/google/jax/pull/1980)
 New NumPy and SciPy functions:
jax.numpy.fft.fft2()
jax.numpy.fft.ifft2()
jax.numpy.fft.rfft()
jax.numpy.fft.irfft()
jax.numpy.fft.rfft2()
jax.numpy.fft.irfft2()
jax.numpy.fft.rfftn()
jax.numpy.fft.irfftn()
jax.numpy.fft.fftfreq()
jax.numpy.fft.rfftfreq()
jax.numpy.linalg.matrix_rank()
jax.numpy.linalg.matrix_power()
jax.scipy.special.betainc()
 Batched Cholesky decomposition on GPU now uses a more efficient batched kernel.
Notable bug fixesÂ¶
 With the Python 3 upgrade, JAX no longer depends on
fastcache
, which should help with installation.
JAX Frequently Asked QuestionsÂ¶
We are collecting here answers to frequently asked questions. Contributions welcome!
Creating arrays with jax.numpy.array is slower than with numpy.arrayÂ¶
The following code is relatively fast when using NumPy, and slow when using JAXâ€™s NumPy:
import numpy as np
np.array([0] * int(1e6))
The reason is that in NumPy the numpy.array function is implemented in C, while the jax.numpy.array is implemented in Python, and it needs to iterate over a long list to convert each list element to an array element.
An alternative would be to create the array with original NumPy and then convert it to a JAX array:
from jax import numpy as jnp
jnp.array(np.array([0] * int(1e6)))
jit changes the behavior of my functionÂ¶
If you have a Python function that changes behavior after using jit, perhaps your function uses global state, or has sideeffects. In the following code, the impure_func uses the global y and has a sideeffect due to print:
y = 0
# @jit # Different behavior with jit
def impure_func(x):
print("Inside:", y)
return x + y
for y in range(3):
print("Result:", impure_func(y))
Without jit the output is:
Inside: 0
Result: 0
Inside: 1
Result: 2
Inside: 2
Result: 4
and with jit it is:
Inside: 0 Result: 0 Result: 1 Result: 2
For jit the function is executed once using the Python interpreter, at which time the Inside printing happens, and the first value of y is observed. Then the function is compiled and cached, and executed multiple times with different values of x, but with the same first value of y.
Additional reading:
 [JAX  The Sharp Bits: Pure Functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AAPurefunctions)
Gradients contain NaN where using where
Â¶
If you define a function using where
to avoid an undefined value, if you
are not careful you may obtain a NaN for reverse differentiation:
def my_log(x):
return np.where(x > 0., np.log(x), 0.)
my_log(0.) ==> 0. # Ok
jax.grad(my_log)(0.) ==> NaN
A short explanation is that during grad
computation the adjoint corresponding
to the undefined np.log(x)
is a NaN
and when it gets accumulated to the
adjoint of the np.where
. The correct way to write such functions is to ensure
that there is a np.where
inside the partiallydefined function, to ensure
that the adjoint is always finite:
def safe_for_grad_log(x):
return np.log(np.where(x > 0., x, 1.)
safe_for_grad_log(0.) ==> 0. # Ok
jax.grad(safe_for_grad_log)(0.) ==> 0. # Ok
The inner np.where
may be needed in addition to the original one, e.g.:
 def my_log_or_y(x, y):
 â€śâ€ťâ€śReturn log(x) if x > 0 or yâ€ťâ€śâ€ť return np.where(x > 0., np.log(np.where(x > 0., x, 1.), y)
Additional reading:
 [Issue: gradients through np.where when one of branches is nan](https://github.com/google/jax/issues/1052#issuecomment514083352)
 [How to avoid NaN gradients when using
where
](https://github.com/tensorflow/probability/blob/master/discussion/wherenan.pdf)
Understanding jaxprsÂ¶
Updated: February 14, 2020 (for commit 9e6fe64).
(Note: the code examples in this file can be seed also in
jax/tests/api_test::JaxprTest.testExamplesJaxprDoc
.)
Conceptually, one can think of JAX transformations as first tracing the Python function to be transformed into a small and wellbehaved intermediate form, the jaxpr, that is then transformed accordingly, and ultimately compiled and executed. One of the reasons JAX can pack so much power into such a small software package is that it starts with a familiar and flexible programming interface (Python with NumPy) and it uses the actual Python interpreter to do most of the heavy lifting to distill the essence of the computation into a simple staticallytyped expression language with limited higherorder features: the jaxpr language.
Not all Python programs can be processed this way, but it turns out that many scientific computing and machine learning programs do have this property.
Before we proceed, it is important to point out that not all JAX transformations materialize a jaxpr as described above; some, e.g., differentiation, will apply transformations incrementally during tracing. Nevertheless, if one wants to understand how JAX works internally, or to make use of the result of JAX tracing, it is useful to understand jaxpr.
A jaxpr instance represents a function with one of more typed parameters (input variables)
and one or more typed results. The results depend only on the input
variables; there are no free variables captured from enclosing scopes.
The inputs and outputs have types, which in JAX are represented as abstract
values. There are two related representations in the code for jaxprs. The main
one is jax.core.TypedJaxpr
and is what you obtain when you
use jax.make_jaxpr()
to inspect jaxprs. It has the following
fields:
jaxpr
: is the actual computation content of the actual function (described below).literals
is a list of constants. For various reasons, during tracing JAX will collect the nonscalar constants that arise and will replace them with variables, e.g., constants that appear in the Python program, or the result of constant folding such constants. The variables that stand for these constants are mentioned separately in the enclosedjaxpr
. When applying aTypedJaxpr
to some actual arguments, one must pass first theliterals
followed by the actual arguments.in_avals
andout_avals
are the types of the input variables (excluding the ones that correspond to theliterals
), and of the output values. These types are called in JAX abstract values, e.g.,ShapedArray(float32[10,10])
.
The most interesting part of the TypedJaxpr is the actual execution content,
represented as a jax.core.Jaxpr
as printed using the following
grammar:
jaxpr ::= { lambda Var* ; Var+.
let Eqn*
in [Expr+] }
 where:
 The parameter of the jaxpr are shown as two lists of variables separated by
;
. The first set of variables are the ones that have been introduced to stand for constants that have been hoisted out. These are called the constvars. The second list of variables are the real input variables. Eqn*
is a list of equations, defining intermediate variables referring to intermediate expressions. Each equation defines one or more variables as the result of applying a primitive on some atomic expressions. Each equation uses only input variables and intermediate variables defined by previous equations.Expr+
: is a list of output atomic expressions for the jaxpr.
 The parameter of the jaxpr are shown as two lists of variables separated by
Equations are printed as follows:
Eqn ::= let Var+ = Primitive [ Param* ] Expr+
 where:
Var+â€ť
are one or more intermediate variables to be defined as the output of a primitive invocation (some primitives can return multiple values)Expr+
are one or more atomic expressions, each either a variable or a literal constant. A special form of an atomic expression is the unit expression, printed as*
and standing for a value that is not needed in the rest of the computation and has been elided.Param*
are zero or more named parameters to the primitive, printed in square brackets. Each parameter is shown asName = Value
.
Most jaxpr primitives are firstorder (they take just one or more Expr as arguments):
Primitive := add  sub  sin  mul  ...
The jaxpr primitives are documented in the jax.lax
module.
For example, here is the jaxpr produced for the function func1
below:
from jax import numpy as jnp
def func1(first, second):
temp = first + jnp.sin(second) * 3.
return jnp.sum(temp)
print(jax.make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8)))
{ lambda ; a b.
let c = sin b
d = mul c 3.0
e = add a d
f = reduce_sum[ axes=(0,)
input_shape=(8,) ] e
in f }
Here there are no constvars, a
and b
are the input variables
and they correspond respectively to
first
and second
function parameters. The scalar literal 3.0
is kept
inline.
The reduce_sum
primitive has named parameters axes
and input_shape
, in
addition to the operand e
.
Note that JAX traces through Pythonlevel controlflow and higherorder functions
when it extracts the jaxpr. This means that just because a Python program contains
functions and controlflow, the resulting jaxpr does not have
to contain controlflow or higherorder features.
For example, when tracing the function func3
JAX will inline the call to
inner
and the conditional if second.shape[0] > 4
, and will produce the same
jaxpr as before:
def func2(inner, first, second):
temp = first + inner(second) * 3.
return jnp.sum(temp)
def inner(second):
if second.shape[0] > 4:
return jnp.sin(second)
else:
assert False
def func3(first, second):
return func2(inner, first, second)
print(api.make_jaxpr(func2)(jnp.zeros(8), jnp.ones(8)))
{ lambda ; a b.
let c = sin b
d = mul c 3.0
e = add a d
f = reduce_sum[ axes=(0,)
input_shape=(8,) ] e
in f }
Handling PyTreesÂ¶
In jaxpr there are no tuple types; instead primitives take multiple inputs and produce multiple outputs. When processing a function that has structured inputs or outputs, JAX will flatten those and in jaxpr they will appear as lists of inputs and outputs. For more details, please see the documentation for PyTrees (JAX pytrees).
For example, the following code produces an identical jaxpr to what we saw before (with two input vars, one for each element of the input tuple):
def func4(arg): # Arg is a pair
temp = arg[0] + jnp.sin(arg[1]) * 3.
return jnp.sum(temp)
print(api.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8))))
{ lambda a b.
let c = sin b
d = mul c 3.0
e = add a d
f = reduce_sum[ axes=(0,)
input_shape=(8,) ] e
in f }
Constant VarsÂ¶
ConstVars arise when the computation ontains array constants, either
from the Python program, or from constantfolding. For example, the function
func6
below:
def func5(first, second):
temp = first + jnp.sin(second) * 3.  jnp.ones(8)
return temp
def func6(first):
return func5(first, jnp.ones(8))
print(api.make_jaxpr(func6)(jnp.ones(8)))
JAX produces the following jaxpr:
{ lambda b d a.
let c = add a b
e = sub c d
in e }
When tracing func6
, the function func5
is invoked with a constant value
(onp.ones(8)
) for the second argument. As a result, the subexpression
jnp.sin(second) * 3.
is constantfolded.
There are two ConstVars, b
(standing for jnp.sin(second) * 3.
) and d
(standing for jnp.ones(8)
). Unfortunately, it is not easy to tell from the
jaxpr notation what constants the constant variables stand for.
Higherorder primitivesÂ¶
jaxpr includes several higherorder primitives. They are more complicated because they include subjaxprs.
CondÂ¶
JAX traces through normal Python conditionals. To capture a conditional expression
for dynamic execution, one must use the jax.lax.cond()
constructor
with the following signature:
lax.cond(pred : bool, true_op: A, true_body: A > B, false_op: C, false_body: C > B) > B
For example:
def func7(arg):
return lax.cond(arg >= 0.,
arg,
lambda xtrue: xtrue + 3.,
arg,
lambda xfalse: xfalse  3.)
print(api.make_jaxpr(func7)(5.))
{ lambda ; a.
let b = ge a 0.0
c = cond[ false_jaxpr={ lambda ; a.
let b = sub a 3.0
in b }
linear=(False, False)
true_jaxpr={ lambda ; a.
let b = add a 3.0
in b } ] b a a
in c }
The cond primitive has a number of parameters:
 true_jaxpr and false_jaxpr are jaxprs that correspond to the true and false branch functionals. In this example, those functionals take each one input variable, corresponding to
xtrue
andxfalse
respectively. linear is a tuple of booleans that is used internally by the autodifferentiation machinery to encode which of the input parameters are used linearly in the conditional.
The above instance of the cond primitive takes 3 operands.
The first one (b
) is the predicate, then a` is the ``true_op
(arg
, to be
passed to true_jaxpr
) and also a
is the false_op
(arg
, to be passed to false_jaxpr
).
The following example shows a more complicated situation when the input
to the branch functionals is a tuple, and the false branch functional
contains a constant jnp.ones(1)
that is hoisted as a constvar:
def func8(arg1, arg2): # arg2 is a pair
return lax.cond(arg1 >= 0.,
arg2,
lambda xtrue: xtrue[0],
arg2,
lambda xfalse: jnp.ones(1) + xfalse[1])
print(api.make_jaxpr(func8)(5., (jnp.zeros(1), 2.)))
{ lambda e ; a b c.
let d = ge a 0.0
f = cond[ false_jaxpr={ lambda ; c a b.
let d = add c b
in d }
linear=(False, False, False, False, False)
true_jaxpr={ lambda ; a b.
let
in a } ] d b c e b c
in f }
The toplevel jaxpr has one constvar e
(corresponding to jnp.ones(1)
from the
body of the false_jaxpr
) and three input variables a b c
(corresponding to arg1
and the two elements of arg2
; note that arg2
has been flattened).
The true_jaxpr
has two input variables (corresponding to the two elements of arg2
that is passed to true_jaxpr
).
The false_jaxpr
has three input variables (c
corresponding to the constant for
jnp.ones(1)
, and a b
for the two elements of arg2
that are passed
to false_jaxpr
).
The actual operands to the cond primitive are: d b c e b c
, which correspond in order to:
 1 operand for the predicate,
 2 operands for
true_jaxpr
, i.e.,b
andc
, which are input vars, corresponding toarg2
for the toplevel jaxpr, 1 constant for
false_jaxpr
, i.e.,e
, which is a consvar for the toplevel jaxpr, 2 operands for
true_jaxpr
, i.e.,b
andc
, which are the input vars corresponding toarg2
for the toplevel jaxpr.
WhileÂ¶
Just like for conditionals, Python loops are inlined during tracing.
If you want to capture a loop for dynamic execution, you must use one of several
special operations, jax.lax.while_loop()
(a primitive)
and jax.lax.fori_loop()
(a helper that generates a while_loop primitive):
lax.while_loop(cond_fun: (C > bool), body_fun: (C > C), init: C) > C
lax.fori_loop(start: int, end: int, body: (int > C > C), init: C) > C
In the above signature, â€śCâ€ť stands for the type of a the loop â€ścarryâ€ť value. For example, here is an example fori loop:
def func10(arg, n):
ones = jnp.ones(arg.shape) # A constant
return lax.fori_loop(0, n,
lambda i, carry: carry + ones * 3. + arg,
arg + ones)
print(api.make_jaxpr(func10)(onp.ones(16), 5))
{ lambda c d ; a b.
let e = add a d
f g h = while[ body_jaxpr={ lambda ; e g a b c.
let d = add a 1
f = add c e
h = add f g
in (d, b, h) }
body_nconsts=2
cond_jaxpr={ lambda ; a b c.
let d = lt a b
in d }
cond_nconsts=0 ] c a 0 b e
in h }
The toplevel jaxpr has two constvars: c
(corresponding to ones * 3.
from the body
of the loop) and d
(corresponding to the use of ones
in the initial carry).
There are also two input variables (a
corresponding to arg
and b
corresponding
to n
).
The loop carry consists of three values, as seen in the body of cond_jaxpr
(corresponding to the iteration index, iteration end, and the accumulated value carry).
Note that body_jaxpr
takes 5 input variables. The first two are actually
constvars: e
corresponding to ones * 3
and g
corresponding to the
captures use of arg
in the loop body.
The parameter body_nconsts = 2
specifies that there are 2 constants for the
body_jaxpr
.
The other 3 input variables for body_jaxpr
correspond to the flattened carry values.
The while primitive takes 5 arguments: c a 0 b e
, as follows:
 0 constants for
cond_jaxpr
(sincecond_nconsts
is 0) 2 constants for
body_jaxpr
(c
, anda
) 3 parameters for the initial value of carry
ScanÂ¶
JAX supports a special form of loop over the elements of an array (with
statically known shape). The fact that there are a fixed number of iterations
makes this form of looping easily reversedifferentiable. Such loops are constructed
with the jax.lax.scan()
operator:
lax.scan(body_fun: (C > A > (C, B)), init_carry: C, in_arr: Array[A]) > (C, Array[B])
Here C
is the type of the scan carry, A
is the element type of the input array(s),
and B
is the element type of the output array(s).
For the example consider the function func11
below:
def func11(arr, extra):
ones = jnp.ones(arr.shape) # A constant
def body(carry, aelems):
# carry: running dotproduct of the two arrays
# aelems: a pair with corresponding elements from the two arrays
ae1, ae2 = aelems
return (carry + ae1 * ae2 + extra, carry)
return lax.scan(body, 0., (arr, ones))
print(api.make_jaxpr(func11)(onp.ones(16), 5.))
{ lambda c ; a b.
let d e = scan[ forward=True
jaxpr={ lambda ; a b c d e.
let f = mul c e
g = add b f
h = add g a
in (h, b) }
length=16
linear=(False, False, False, True, False)
num_carry=1
num_consts=1 ] b 0.0 a * c
in (d, e) }
The toplevel jaxpr has one constvar c
corresponding to the ones
constant,
and two input variables corresponding to the arguments arr
and extra
.
The body of the scan has 5 input variables, of which:
 one (
a
) is a constant (sincenum_consts = 1
), and stands for the captured variableextra
used in the loop body, one (
b
) is the value of the carry (sincenum_carry = 1
) The remaining 3 are the input values. Notice that only
c
ande
are used, and stand respectively for the array element from the first array passed to lax.scan (arr
) and to the second array (ones
). The input variables (d
) seems to be an artifact of the translation.
The linear
parameter describes for each of the input variables whether they
are guaranteed to be used linearly in the body. Here, only the unused input
variable is marked linear. Once the scan goes through linearization, more arguments
will be linear.
The scan primitive takes 5 arguments: b 0.0 a * c
, of which:
 one is the free variable for the body
 one is the initial value of the carry
 The next 3 are the arrays over which the scan operates. The middle one is not used (*).
XLA_callÂ¶
The call primitive arises from JIT compilation, and it encapsulates a subjaxpr along with parameters the specify the backend and the device the computation should run. For example:
def func12(arg):
@api.jit
def inner(x):
return x + arg * jnp.ones(1) # Include a constant in the inner function
return arg + inner(arg  2.)
print(api.make_jaxpr(func12)(1.))
{ lambda b ; a.
let c = sub a 2.0
d = xla_call[ backend=None
call_jaxpr={ lambda ; c b a.
let d = mul b c
e = add a d
in e }
device=None
name=inner ] b a c
e = add a d
in e }
The toplevel constvar b
refers to the jnp.ones(1)
constant, and
the toplevel input variable a refers to the arg
parameter of func12
.
The xla_call
primitive stands for a call to the jitted inner
function.
The primitive has the function body in the call_jaxpr
parameter, a jaxpr
with 3 input parameters:
c
is a constvar and stands for theones
constant,b
corresponds to the free variablearg
captured in theinner
function,a
corresponds to theinner
parameterx
.
The primitive takes three arguments b a c
.
XLA_pmapÂ¶
If you use the jax.pmap()
transformation, the function to be
mapped is captured using the xla_pmap
primitive. Consider this
example:
def func13(arr, extra):
def inner(x):
# use a free variable "extra" and a constant jnp.ones(1)
return (x + extra + jnp.ones(1)) / lax.psum(x, axis_name='rows')
return api.pmap(inner, axis_name='rows')(arr)
print(api.make_jaxpr(func13)(jnp.ones((1, 3)), 5.))
{ lambda c ; a b.
let d = xla_pmap[ axis_name=rows
axis_size=1
backend=None
call_jaxpr={ lambda ; d b a.
let c = add a b
e = add c d
f = psum[ axis_name=rows ] a
g = div e f
in g }
devices=None
global_axis_size=None
mapped_invars=(True, False, True)
name=inner ] c b a
in d }
The toplevel constvar c
refers to the jnp.ones(1)
constant.
The xla_pmap
primitive specifies the name of the axis (parameter rows
)
and the body of the function to be mapped as the call_jaxpr
parameter. The
value of this parameter is a Jaxpr with 3 input variables:
d
stands for the constantjnp.ones(1)
,b
stands for the free variableextra
,a
stands for the parameterx
ofinner
.
The parameter mapped_invars
specify which of the input variables should be
mapped and which should be broadcast. In our example, the value of extra
is broadcast, the other input values are mapped.
Asynchronous dispatchÂ¶
JAX uses asynchronous dispatch to hide Python overheads. Consider the following program:
>>> import numpy as onp
>>> from jax import numpy as np
>>> from jax import random
>>> x = random.uniform(random.PRNGKey(0), (1000, 1000))
>>> np.dot(x, x) + 3.
DeviceArray([[258.01971436, 249.64862061, 257.13372803, ...,
236.67948914, 250.68939209, 241.36853027],
[265.65979004, 256.28912354, 262.18252563, ...,
242.03181458, 256.16757202, 252.44122314],
[262.38916016, 255.72747803, 261.23059082, ...,
240.83563232, 255.41094971, 249.62471008],
...,
[259.15814209, 253.09197998, 257.72174072, ...,
242.23876953, 250.72680664, 247.16642761],
[271.22662354, 261.91204834, 265.33398438, ...,
248.26651001, 262.05389404, 261.33700562],
[257.16134644, 254.7543335, 259.08300781, ..., 241.59848022,
248.62597656, 243.22348022]], dtype=float32)
When an operation such as np.dot(x, x)
is executed, JAX does not wait
for the operation to complete before returning control to the Python program.
Instead, JAX returns a DeviceArray
value, which is a future,
i.e., a value that will be produced in the future on an accelerator device but
isnâ€™t necessarily available immediately. We can inspect the shape or type of a
DeviceArray
without waiting for the computation that produced it to
complete, and we can even pass it to another JAX computation, as we do with the
addition operation here. Only if we actually inspect the value of the array from
the host, for example by printing it or by converting it into a plain old
numpy.ndarray
will JAX force the Python code to wait for the
computation to complete.
Asynchronous dispatch is useful since it allows Python code to â€śrun aheadâ€ť of an accelerator device, keeping Python code out of the critical path. Provided the Python code enqueues work on the device faster than it can be executed, and provided that the Python code does not actually need to inspect the output of a computation on the host, then a Python program can enqueue arbitrary amounts of work and avoid having the accelerator wait.
Asynchronous dispatch has a slightly surprising consequence for microbenchmarks.
>>> %time np.dot(x, x)
CPU times: user 267 Âµs, sys: 93 Âµs, total: 360 Âµs
Wall time: 269 Âµs
DeviceArray([[255.01972961, 246.64862061, 254.13371277, ...,
233.67948914, 247.68939209, 238.36853027],
[262.65979004, 253.28910828, 259.18252563, ...,
239.03181458, 253.16757202, 249.44122314],
[259.38916016, 252.72747803, 258.23059082, ...,
237.83563232, 252.41094971, 246.62471008],
...,
[256.15814209, 250.09197998, 254.72172546, ...,
239.23876953, 247.72680664, 244.16642761],
[268.22662354, 258.91204834, 262.33398438, ...,
245.26651001, 259.05389404, 258.33700562],
[254.16134644, 251.7543335, 256.08300781, ..., 238.59848022,
245.62597656, 240.22348022]], dtype=float32)
269Âµs is a surprisingly small time for a 1000x1000 matrix multiplication on CPU!
However it turns out that asynchronous dispatch is misleading us and we are not
timing the execution of the matrix multiplication, only the time to dispatch
the work. To measure the true cost of the operation we must either read the
value on the host (e.g., convert it to a plain old hostside numpy array), or
use the block_until_ready()
method on a
DeviceArray
value to wait for the computation that produced it to
complete.
>>> %time onp.asarray(np.dot(x, x))
CPU times: user 61.1 ms, sys: 0 ns, total: 61.1 ms
Wall time: 8.09 ms
Out[16]:
array([[255.01973, 246.64862, 254.13371, ..., 233.67949, 247.68939,
238.36853],
[262.6598 , 253.28911, 259.18253, ..., 239.03181, 253.16757,
249.44122],
[259.38916, 252.72748, 258.2306 , ..., 237.83563, 252.41095,
246.62471],
...,
[256.15814, 250.09198, 254.72173, ..., 239.23877, 247.7268 ,
244.16643],
[268.22662, 258.91205, 262.33398, ..., 245.26651, 259.0539 ,
258.337 ],
[254.16135, 251.75433, 256.083 , ..., 238.59848, 245.62598,
240.22348]], dtype=float32)
>>> %time np.dot(x, x).block_until_ready()
CPU times: user 50.3 ms, sys: 928 Âµs, total: 51.2 ms
Wall time: 4.92 ms
DeviceArray([[255.01972961, 246.64862061, 254.13371277, ...,
233.67948914, 247.68939209, 238.36853027],
[262.65979004, 253.28910828, 259.18252563, ...,
239.03181458, 253.16757202, 249.44122314],
[259.38916016, 252.72747803, 258.23059082, ...,
237.83563232, 252.41094971, 246.62471008],
...,
[256.15814209, 250.09197998, 254.72172546, ...,
239.23876953, 247.72680664, 244.16642761],
[268.22662354, 258.91204834, 262.33398438, ...,
245.26651001, 259.05389404, 258.33700562],
[254.16134644, 251.7543335, 256.08300781, ..., 238.59848022,
245.62597656, 240.22348022]], dtype=float32)
Blocking without transferring the result back to Python is usually faster, and is often the best choice when writing microbenchmarks of computation times.
ConcurrencyÂ¶
JAX has some limited support for Python concurrency.
Concurrency support is experimental and only lightly tested; please report any bugs.
Clients may call JAX APIs (e.g., jit()
or grad()
)
concurrently from separate Python threads.
It is not permitted to manipulate JAX trace values concurrently from multiple
threads. In other words, while it is permissible to call functions that use JAX
tracing (e.g., jit()
) from multiple threads, you must not use
threading to manipulate JAX values inside the implementation of the function
f that is passed to jit()
. The most likely outcome if you do this
is a mysterious error from JAX.
GPU memory allocationÂ¶
JAX will preallocate 90% of currentlyavailable GPU memory when the first JAX operation is run. Preallocating minimizes allocation overhead and memory fragmentation, but can sometimes cause outofmemory (OOM) errors. If your JAX process fails with OOM, the following environment variables can be used to override the default behavior:
XLA_PYTHON_CLIENT_PREALLOCATE=false
 This disables the preallocation behavior. JAX will instead allocate GPU memory as needed, potentially decreasing the overall memory usage. However, this behavior is more prone to GPU memory fragmentation, meaning a JAX program that uses most of the available GPU memory may OOM with preallocation disabled.
XLA_PYTHON_CLIENT_MEM_FRACTION=.XX
 If preallocation is enabled, this makes JAX preallocate XX% of currentlyavailable GPU memory, instead of the default 90%. Lowering the amount preallocated can fix OOMs that occur when the JAX program starts.
XLA_PYTHON_CLIENT_ALLOCATOR=platform
This makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed (note that this is the only configuration that will deallocate GPU memory, instead of reusing it). This is very slow, so is not recommended for general use, but may be useful for running with the minimal possible GPU memory footprint or debugging OOM failures.
Common causes of OOM failuresÂ¶
 Running multiple JAX processes concurrently.
 Either use
XLA_PYTHON_CLIENT_MEM_FRACTION
to give each process an appropriate amount of memory, or setXLA_PYTHON_CLIENT_PREALLOCATE=false
.  Running JAX and GPU TensorFlow concurrently.
TensorFlow also preallocates by default, so this is similar to running multiple JAX processes concurrently.
One solution is to use CPUonly TensorFlow (e.g. if youâ€™re only doing data loading with TF). You can prevent TensorFlow from using the GPU with the command
tf.config.experimental.set_visible_devices([], "GPU")
Alternatively, use
XLA_PYTHON_CLIENT_MEM_FRACTION
orXLA_PYTHON_CLIENT_PREALLOCATE
. There are also similar options to configure TensorFlowâ€™s GPU memory allocation (gpu_memory_fraction and allow_growth in TF1, which should be set in atf.ConfigProto
passed totf.Session
. See Using GPUs: Limiting GPU memory growth for TF2). Running JAX on the display GPU.
 Use
XLA_PYTHON_CLIENT_MEM_FRACTION
orXLA_PYTHON_CLIENT_PREALLOCATE
.
Profiling JAX programsÂ¶
To profile JAX programs, there are currently two options: nvprof and XLAâ€™s profiling features.
nvprofÂ¶
Nvidiaâ€™s nvprof tool can be used to trace and profile JAX code on GPU. For details, see the nvprof documentation.
XLA profilingÂ¶
XLA has some builtin support for profiling on both CPU and GPU. To use XLAâ€™s
profiling features from JAX, set the environment variables
TF_CPP_MIN_LOG_LEVEL=0
and XLA_FLAGS=xla_hlo_profile
. XLA will
log profiling information about each computation JAX runs. For example:
$ TF_CPP_MIN_LOG_LEVEL=0 XLA_FLAGS=xla_hlo_profile ipython
...
In [1]: from jax import lax
lax.add
In [2]: lax.add(1, 2)
20190808 20:47:52.659030: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7fe2c719e200 executing computations on platform Host. Devices:
20190808 20:47:52.659054: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:175] StreamExecutor device (0): Host, Default Version
/Users/phawkins/p/jax/jax/lib/xla_bridge.py:114: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
20190808 20:47:52.674813: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] Execution profile for primitive_computation.4: (0.0324 us @ f_nom)
20190808 20:47:52.674832: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] 94 cycles (100.% 100ÎŁ) :: 0.0 usec ( 0.0 optimal) :: 30.85MFLOP/s :: :: 353.06MiB/s :: 0.128B/cycle :: [total] [entry]
20190808 20:47:52.674838: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] 94 cycles (100.00% 100ÎŁ) :: 0.0 usec ( 0.0 optimal) :: 30.85MFLOP/s :: :: 353.06MiB/s :: 0.128B/cycle :: %add.3 = s32[] add(s32[] %parameter.1, s32[] %parameter.2)
20190808 20:47:52.674842: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174]
20190808 20:47:52.674846: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] ********** microseconds report **********
20190808 20:47:52.674909: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] There are 0 microseconds in total.
20190808 20:47:52.674921: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] There are 0 microseconds ( 0.00%) not accounted for by the data.
20190808 20:47:52.674925: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] There are 1 ops.
20190808 20:47:52.674928: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174]
20190808 20:47:52.674932: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] ********** categories table for microseconds **********
20190808 20:47:52.674935: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174]
20190808 20:47:52.674939: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] 0 (100.00% ÎŁ100.00%) nonfusion elementwise (1 ops)
20190808 20:47:52.674942: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] * 100.00% %add.3 = s32[] add(s32[], s32[])
20190808 20:47:52.675673: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174]
20190808 20:47:52.675682: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174]
20190808 20:47:52.675688: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] ********** MiB read+written report **********
20190808 20:47:52.675692: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] There are 0 MiB read+written in total.
20190808 20:47:52.675697: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] There are 0 MiB read+written ( 0.00%) not accounted for by the data.
20190808 20:47:52.675700: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] There are 3 ops.
20190808 20:47:52.675703: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174]
20190808 20:47:52.675812: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] ********** categories table for MiB read+written **********
20190808 20:47:52.675823: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174]
20190808 20:47:52.675827: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] 0 (100.00% ÎŁ100.00%) nonfusion elementwise (1 ops)
20190808 20:47:52.675832: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] * 100.00% %add.3 = s32[] add(s32[], s32[])
20190808 20:47:52.675835: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174] 0 ( 0.00% ÎŁ100.00%) ... (1 more categories)
20190808 20:47:52.675839: I external/org_tensorflow/tensorflow/compiler/xla/service/executable.cc:174]
Out[2]: DeviceArray(3, dtype=int32)
Rank promotion warningÂ¶
NumPy broadcasting rules allow automatic promotion of arguments from one rank (number of array axes) to another. This behavior can be convenient when intended but can also lead to surprising bugs where a silent rank promotion masks an underlying shape error.
Hereâ€™s an example of rank promotion:
>>> import numpy as onp
>>> x = onp.arange(12).reshape(4, 3)
>>> y = onp.array([0, 1, 0])
>>> x + y
array([[ 0, 2, 2],
[ 3, 5, 5],
[ 6, 8, 8],
[ 9, 11, 11]])
To avoid potential surprises, jax.numpy
is configurable so that
expressions requiring rank promotion can lead to a warning, error, or can be
allowed just like regular NumPy. The configuration option is named
jax_numpy_rank_promotion
and it can take on string values
allow
, warn
, and raise
. The default setting is
warn
, which raises a warning on the first occurrence of rank promotion.
The raise
setting raises an error on rank promotion, and allow
allows rank promotion without warning or error.
As with most other JAX configuration options, you can set this option in
several ways. One is by using jax.config
in your code:
from jax.config import config
config.update("jax_numpy_rank_promotion", "allow")
You can also set the option using the environment variable
JAX_NUMPY_RANK_PROMOTION
, for example as
JAX_NUMPY_RANK_PROMOTION='raise'
. Finally, when using abslpy
the option can be set with a commandline flag.
Type promotion semanticsÂ¶
JAXâ€™s type promotion rules (i.e., the result of
jax.numpy.promote_types()
for each pair of types) are given by the
following table, where, for example
 â€śb1â€ť means
np.bool_
,  â€śs2â€ť means
np.int16
,  â€śu4â€ť means
np.uint32
,  â€śbfâ€ť means
np.bfloat16
,  â€śf2â€ť means
np.float16
, and  â€śc8â€ť means
np.complex128
.
b1  u1  u2  u4  u8  i1  i2  i4  i8  bf  f2  f4  f8  c4  c8  

b1  b1  u1  u2  u4  u8  i1  i2  i4  i8  bf  f2  f4  f8  c4  c8 
u1  u1  u1  u2  u4  u8  i2  i2  i4  i8  bf  f2  f4  f8  c4  c8 
u2  u2  u2  u2  u4  u8  i4  i4  i4  i8  bf  f2  f4  f8  c4  c8 
u4  u4  u4  u4  u4  u8  i8  i8  i8  i8  bf  f2  f4  f8  c4  c8 
u8  u8  u8  u8  u8  u8  f8  f8  f8  f8  bf  f2  f4  f8  c4  c8 
i1  i1  i2  i4  i8  f8  i1  i2  i4  i8  bf  f2  f4  f8  c4  c8 
i2  i2  i2  i4  i8  f8  i2  i2  i4  i8  bf  f2  f4  f8  c4  c8 
i4  i4  i4  i4  i8  f8  i4  i4  i4  i8  bf  f2  f4  f8  c4  c8 
i8  i8  i8  i8  i8  f8  i8  i8  i8  i8  bf  f2  f4  f8  c4  c8 
bf  bf  bf  bf  bf  bf  bf  bf  bf  bf  bf  f4  f4  f8  c4  c8 
f2  f2  f2  f2  f2  f2  f2  f2  f2  f2  f4  f2  f4  f8  c4  c8 
f4  f4  f4  f4  f4  f4  f4  f4  f4  f4  f4  f4  f4  f8  c4  c8 
f8  f8  f8  f8  f8  f8  f8  f8  f8  f8  f8  f8  f8  f8  c8  c8 
c4  c4  c4  c4  c4  c4  c4  c4  c4  c4  c4  c4  c4  c8  c4  c8 
c8  c8  c8  c8  c8  c8  c8  c8  c8  c8  c8  c8  c8  c8  c8  c8 
Jaxâ€™s type promotion rules differ from those of NumPy, as given by
numpy.promote_types()
, in those cells highlighted with a green background
in the table above. There are two key differences:
when promoting an integer or boolean type against a floatingpoint or complex type, JAX always prefers the type of the floatingpoint or complex type.
Accelerator devices, such as GPUs and TPUs, either pay a significant performance penalty to use 64bit floating point types (GPUs) or do not support 64bit floating point types at all (TPUs). Classic NumPyâ€™s promotion rules are too willing to overpromote to 64bit types, which is problematic for a system designed to run on accelerators.
JAX uses floating point promotion rules that are more suited to modern accelerator devices and are less aggressive about promoting floating point types. The promotion rules used by JAX for floatingpoint types are similar to those used by PyTorch.
JAX supports the bfloat16 nonstandard 16bit floating point type (
jax.numpy.bfloat16
), which is useful for neural network training. The only notable promotion behavior is with respect to IEEE754float16
, with whichbfloat16
promotes to afloat32
.
Building from sourceÂ¶
First, obtain the JAX source code:
git clone https://github.com/google/jax
cd jax
Building JAX involves two steps:
 Building or installing
jaxlib
, the C++ support library forjax
.  Installing the
jax
Python package.
Building or installing jaxlib
Â¶
Installing jaxlib
with pipÂ¶
If youâ€™re only modifying Python portions of JAX, we recommend installing
jaxlib
from a prebuilt wheel using pip:
pip install jaxlib
See the JAX readme for full guidance on pip installation (e.g., for GPU support).
Building jaxlib
from sourceÂ¶
To build jaxlib
from source, you must also install some prerequisites:
 a C++ compiler (g++ or clang)
 Numpy
 Scipy
 Cython
 six (required for during the jaxlib build only, not required at install time)
On Ubuntu 18.04 or Debian you can install the necessary prerequisites with:
sudo aptget install g++ python python3dev python3numpy python3scipy cython3 python3six
If you are building on a Mac, make sure XCode and the XCode command line tools are installed.
You can also install the necessary Python dependencies using pip
:
pip install numpy scipy cython six
To build jaxlib
with CUDA support, you can run:
python build/build.py enable_cuda
pip install e build # installs jaxlib (includes XLA)
See python build/build.py help
for configuration options, including ways to
specify the paths to CUDA and CUDNN, which you must have installed. Here
python
should be the name of your Python 3 interpreter; on some systems, you
may need to use python3
instead.
To build jaxlib
without CUDA GPU support (CPU only), drop the enable_cuda
:
python build/build.py
pip install e build # installs jaxlib (includes XLA)
Installing jax
Â¶
Once jaxlib
has been installed, you can install jax
by running:
pip install e . # installs jax
To upgrade to the latest version from GitHub, just run git pull
from the JAX
repository root, and rebuild by running build.py
or upgrading jaxlib
if
necessary. You shouldnâ€™t have to reinstall jax
because pip install e
sets up symbolic links from sitepackages into the repository.
Running the testsÂ¶
To run all the JAX tests, we recommend using pytestxdist
, which can run tests in
parallel. First, install pytestxdist
and pytestbenchmark
by running
pip install pytestxdist pytestbenchmark
.
Then, from the repository root directory run:
pytest n auto tests
JAX generates test cases combinatorially, and you can control the number of cases that are generated and checked for each test (default is 10). The automated tests currently use 25:
JAX_NUM_GENERATED_CASES=25 pytest n auto tests
The automated tests also run the tests with default 64bit floats and ints:
JAX_ENABLE_X64=1 JAX_NUM_GENERATED_CASES=25 pytest n auto tests
You can run a more specific set of tests using pytestâ€™s builtin selection mechanisms, or alternatively you can run a specific test file directly to see more detailed information about the cases being run:
python tests/lax_numpy_test.py num_generated_cases=5
You can skip a few tests known as slow, by passing environment variable JAX_SKIP_SLOW_TESTS=1.
The Colab notebooks are tested for errors as part of the documentation build.
Update documentationÂ¶
To rebuild the documentation, install several packages:
pip install r docs/requirements.txt
You must also install pandoc
in order to regenerate the notebooks.
See `Install Pandoc<https://pandoc.org/installing.html>`_,
or using `Miniconda<https://docs.conda.io/en/latest/miniconda.html>`_ which
I have used successfully on the Mac: conda install c condaforge pandoc
.
If you do not want to install pandoc
then you should regenerate the documentation
without the notebooks.
You run at toplevel one of the following commands:
sphinxbuild b html docs docs/build/html # with the notebooks
sphinxbuild b html D nbsphinx_execute=never docs docs/build/html # without the notebooks
You can then see the generated documentation in
docs/build/html/index.html
.
Update notebooksÂ¶
Open the notebook with http://colab.research.google.com (then Upload from your
local repo), update it as needed, Run all cells
then
Download ipynb
. You may want to test that it executes properly, using sphinxbuild
as
explained above.
Some of the notebooks are built automatically as part of the Travis presubmit checks and as part of the [Read the docs](https://jax.readthedocs.io/en/latest) build. The build will fail if cells raise errors. If the errors are intentional, you can either catch them, or tag the cell with raisesexceptions metadata ([example PR](https://github.com/google/jax/pull/2402/files)). You have to add this metadata by hand in the .ipynb file. It will be preserved when somebody else resaves the notebook.
We exclude some notebooks from the build, e.g., because they contain long computations. See exclude_patterns in [conf.py](https://github.com/google/jax/blob/master/docs/conf.py).
Documentation building on readthedocs.ioÂ¶
JAXâ€™s autogenerated documentations is at jax.readthedocs.io.
The documentation building is controlled for the entire project by the
readthedocs JAX settings. The current settings
trigger a documentation build as soon as code is pushed to the GitHub master
branch.
For each code version, the building process is driven by the
.readthedocs.yml
and the docs/conf.py
configuration files.
For each automated documentation build you can see the documentation build logs.
If you want to test the documentation generation on Readthedocs, you can push code to the testdocs
branch. That branch is also built automatically, and you can
see the generated documentation here.
For a local test, I was able to do it in a fresh directory by replaying the commands I saw in the Readthedocs logs:
mkvirtualenv jaxdocs # A new virtualenv
mkdir jaxdocs # A new directory
cd jaxdocs
git clone nosinglebranch depth 50 https://github.com/google/jax
cd jax
git checkout force origin/testdocs
git clean d f f
python m pip install upgrade nocachedir pip
python m pip install upgrade nocachedir I Pygments==2.3.1 setuptools==41.0.1 docutils==0.14 mock==1.0.1 pillow==5.4.1 alabaster>=0.7,<0.8,!=0.7.5 commonmark==0.8.1 recommonmark==0.5.0 'sphinx<2' 'sphinxrtdtheme<0.5' 'readthedocssphinxext<1.1'
python m pip install existsaction=w nocachedir r docs/requirements.txt
python `which sphinxbuild` T E b html d _build/doctreesreadthedocs D language=en . _build/html
Internal APIsÂ¶
coreÂ¶
Jaxpr (constvars, invars, outvars, eqns) 

TypedJaxpr (jaxpr, literals, in_avals, out_avals) 
Public API: jax packageÂ¶
SubpackagesÂ¶
jax.numpy packageÂ¶
Implements the NumPy API, using the primitives in jax.lax
.
While JAX tries to follow the NumPy API as closely as possible, sometimes JAX cannot follow NumPy exactly.
 Notably, since JAX arrays are immutable, NumPy APIs that mutate arrays
inplace cannot be implemented in JAX. However, often JAX is able to provide a
alternative API that is purely functional. For example, instead of inplace
array updates (
x[i] = y
), JAX provides an alternative pure indexed update functionjax.ops.index_update()
.  NumPy is very aggressive at promoting values to
float64
type. JAX sometimes is less aggressive about type promotion.
A small number of NumPy operations that have datadependent output shapes are
incompatible with jax.jit()
compilation. The XLA compiler requires that
shapes of arrays be known at compile time. While it would be possible to provide
a JAX implementation of an API such as numpy.nonzero()
, we would be unable
to JITcompile it because the shape of its output depends on the contents of the
input data.
Not every function in NumPy is implemented; contributions are welcome!
abs (x) 
Calculate the absolute value elementwise. 
absolute (x) 
Calculate the absolute value elementwise. 
add (x1, x2) 
Add arguments elementwise. 
all (a[, axis, dtype, out, keepdims]) 
Test whether all array elements along a given axis evaluate to True. 
allclose (a, b[, rtol, atol]) 
Returns True if two arrays are elementwise equal within a tolerance. 
alltrue (a[, axis, dtype, out, keepdims]) 
Test whether all array elements along a given axis evaluate to True. 
amax (a[, axis, dtype, out, keepdims]) 
Return the maximum of an array or maximum along an axis. 
amin (a[, axis, dtype, out, keepdims]) 
Return the minimum of an array or minimum along an axis. 
angle (z) 
Return the angle of the complex argument. 
any (a[, axis, dtype, out, keepdims]) 
Test whether any array element along a given axis evaluates to True. 
append (arr, values[, axis]) 
Append values to the end of an array. 
arange (start[, stop, step, dtype]) 
Return evenly spaced values within a given interval. 
arccos (x) 
Trigonometric inverse cosine, elementwise. 
arccosh (x) 
Inverse hyperbolic cosine, elementwise. 
arcsin (x) 
Inverse sine, elementwise. 
arcsinh (x) 
Inverse hyperbolic sine elementwise. 
arctan (x) 
Trigonometric inverse tangent, elementwise. 
arctan2 (x1, x2) 
Elementwise arc tangent of x1/x2 choosing the quadrant correctly. 
arctanh (x) 
Inverse hyperbolic tangent elementwise. 
argmax (a[, axis]) 
Returns the indices of the maximum values along an axis. 
argmin (a[, axis]) 
Returns the indices of the minimum values along an axis. 
argsort (a[, axis, kind, order]) 
Returns the indices that would sort an array. 
around (a[, decimals]) 
Round an array to the given number of decimals. 
array (object[, dtype, copy, order, ndmin]) 
Create an array. 
array_repr (arr[, max_line_width, precision, â€¦]) 
Return the string representation of an array. 
array_str (a[, max_line_width, precision, â€¦]) 
Return a string representation of the data in an array. 
asarray (a[, dtype, order]) 
Convert the input to an array. 
atleast_1d (*arys) 
Convert inputs to arrays with at least one dimension. 
atleast_2d (*arys) 
View inputs as arrays with at least two dimensions. 
atleast_3d (*arys) 
View inputs as arrays with at least three dimensions. 
bartlett (*args, **kwargs) 
Return the Bartlett window. 
bitwise_and (x1, x2) 
Compute the bitwise AND of two arrays elementwise. 
bitwise_not (x) 
Compute bitwise inversion, or bitwise NOT, elementwise. 
bitwise_or (x1, x2) 
Compute the bitwise OR of two arrays elementwise. 
bitwise_xor (x1, x2) 
Compute the bitwise XOR of two arrays elementwise. 
blackman (*args, **kwargs) 
Return the Blackman window. 
block (arrays) 
Assemble an ndarray from nested lists of blocks. 
broadcast_arrays (*args) 
Like Numpyâ€™s broadcast_arrays but doesnâ€™t return views. 
broadcast_to (arr, shape) 
Like Numpyâ€™s broadcast_to but doesnâ€™t necessarily return views. 
can_cast (from_, to[, casting]) 
Returns True if cast between data types can occur according to the casting rule. 
ceil (x) 
Return the ceiling of the input, elementwise. 
clip (a[, a_min, a_max]) 
Clip (limit) the values in an array. 
column_stack (tup) 
Stack 1D arrays as columns into a 2D array. 
concatenate (arrays[, axis]) 
Join a sequence of arrays along an existing axis. 
conj (x) 
Return the complex conjugate, elementwise. 
conjugate (x) 
Return the complex conjugate, elementwise. 
corrcoef (x[, y, rowvar, bias, ddof]) 
Return Pearson productmoment correlation coefficients. 
cos (x) 
Cosine elementwise. 
cosh (x) 
Hyperbolic cosine, elementwise. 
count_nonzero (a[, axis]) 
Counts the number of nonzero values in the array a . 
cov (m[, y, rowvar, bias, ddof, fweights, â€¦]) 
Estimate a covariance matrix, given data and weights. 
cross (a, b[, axisa, axisb, axisc, axis]) 
Return the cross product of two (arrays of) vectors. 
cumsum (a[, axis, dtype]) 
Return the cumulative sum of the elements along a given axis. 
cumprod (a[, axis, dtype]) 
Return the cumulative product of elements along a given axis. 
cumproduct (a[, axis, dtype]) 
Return the cumulative product of elements along a given axis. 
deg2rad (x) 
Convert angles from degrees to radians. 
degrees (x) 
Convert angles from radians to degrees. 
diag (v[, k]) 
Extract a diagonal or construct a diagonal array. 
diag_indices (n[, ndim]) 
Return the indices to access the main diagonal of an array. 
diagonal (a[, offset, axis1, axis2]) 
Return specified diagonals. 
divide (x1, x2) 
Returns a true division of the inputs, elementwise. 
divmod (x1, x2) 
Return elementwise quotient and remainder simultaneously. 
dot (a, b[, precision]) 
Dot product of two arrays. 
dsplit (ary, indices_or_sections) 
Split array into multiple subarrays along the 3rd axis (depth). 
dstack (tup) 
Stack arrays in sequence depth wise (along third axis). 
einsum (*operands, **kwargs) 
Evaluates the Einstein summation convention on the operands. 
equal (x1, x2) 
Return (x1 == x2) elementwise. 
empty (shape[, dtype]) 
Return a new array of given shape and type, filled with zeros. 
empty_like (x[, dtype]) 
Return an array of zeros with the same shape and type as a given array. 
exp (x) 
Calculate the exponential of all elements in the input array. 
exp2 (x) 
Calculate 2**p for all p in the input array. 
expand_dims (a, axis) 
Expand the shape of an array. 
expm1 (x) 
Calculate exp(x)  1 for all elements in the array. 
eye (N[, M, k, dtype]) 
Return a 2D array with ones on the diagonal and zeros elsewhere. 
fabs (x) 
Compute the absolute values elementwise. 
fix (x[, out]) 
Round to nearest integer towards zero. 
flip (m[, axis]) 
Reverse the order of elements in an array along the given axis. 
fliplr (m) 
Flip array in the left/right direction. 
flipud (m) 
Flip array in the up/down direction. 
float_power (x1, x2) 
First array elements raised to powers from second array, elementwise. 
floor (x) 
Return the floor of the input, elementwise. 
floor_divide (x1, x2) 
Return the largest integer smaller or equal to the division of the inputs. 
fmod (x1, x2) 
Return the elementwise remainder of division. 
full (shape, fill_value[, dtype]) 
Return a new array of given shape and type, filled with fill_value. 
full_like (a, fill_value[, dtype]) 
Return a full array with the same shape and type as a given array. 
gcd (x1, x2) 
Returns the greatest common divisor of x1 and x2 
geomspace (start, stop[, num, endpoint, â€¦]) 
Return numbers spaced evenly on a log scale (a geometric progression). 
greater (x1, x2) 
Return the truth value of (x1 > x2) elementwise. 
greater_equal (x1, x2) 
Return the truth value of (x1 >= x2) elementwise. 
hamming (*args, **kwargs) 
Return the Hamming window. 
hanning (*args, **kwargs) 
Return the Hanning window. 
heaviside (x1, x2) 
Compute the Heaviside step function. 
hsplit (ary, indices_or_sections) 
Split an array into multiple subarrays horizontally (columnwise). 
hstack (tup) 
Stack arrays in sequence horizontally (column wise). 
hypot (x1, x2) 
Given the â€ślegsâ€ť of a right triangle, return its hypotenuse. 
identity (n[, dtype]) 
Return the identity array. 
imag (val) 
Return the imaginary part of the complex argument. 
inner (a, b[, precision]) 
Inner product of two arrays. 
isclose (a, b[, rtol, atol, equal_nan]) 
Returns a boolean array where two arrays are elementwise equal within a 
iscomplex (x) 
Returns a bool array, where True if input element is complex. 
isfinite (x) 
Test elementwise for finiteness (not infinity or not Not a Number). 
isinf (x) 
Test elementwise for positive or negative infinity. 
isnan (x) 
Test elementwise for NaN and return result as a boolean array. 
isneginf (infinity, x) 
Test elementwise for negative infinity, return result as bool array. 
isposinf (infinity, x) 
Test elementwise for positive infinity, return result as bool array. 
isreal (x) 
Returns a bool array, where True if input element is real. 
isscalar (num) 
Returns True if the type of element is a scalar type. 
issubdtype (arg1, arg2) 
Returns True if first argument is a typecode lower/equal in type hierarchy. 
issubsctype (arg1, arg2) 
Determine if the first argument is a subclass of the second argument. 
ix_ (*args) 
Construct an open mesh from multiple sequences. 
kaiser (*args, **kwargs) 
Return the Kaiser window. 
kron (a, b) 
Kronecker product of two arrays. 
lcm (x1, x2) 
Returns the lowest common multiple of x1 and x2 
left_shift (x1, x2) 
Shift the bits of an integer to the left. 
less (x1, x2) 
Return the truth value of (x1 < x2) elementwise. 
less_equal (x1, x2) 
Return the truth value of (x1 =< x2) elementwise. 
linspace (start, stop[, num, endpoint, â€¦]) 
Return evenly spaced numbers over a specified interval. 
log (x) 
Natural logarithm, elementwise. 
log10 (x) 
Return the base 10 logarithm of the input array, elementwise. 
log1p (x) 
Return the natural logarithm of one plus the input array, elementwise. 
log2 (x) 
Base2 logarithm of x. 
logaddexp (x1, x2) 
Logarithm of the sum of exponentiations of the inputs. 
logaddexp2 (x1, x2) 
Logarithm of the sum of exponentiations of the inputs in base2. 
logical_and (*args) 
Compute the truth value of x1 AND x2 elementwise. 
logical_not (*args) 
Compute the truth value of NOT x elementwise. 
logical_or (*args) 
Compute the truth value of x1 OR x2 elementwise. 
logical_xor (*args) 
Compute the truth value of x1 XOR x2, elementwise. 
logspace (start, stop[, num, endpoint, base, â€¦]) 
Return numbers spaced evenly on a log scale. 
matmul (a, b[, precision]) 
Matrix product of two arrays. 
max (a[, axis, dtype, out, keepdims]) 
Return the maximum of an array or maximum along an axis. 
maximum (x1, x2) 
Elementwise maximum of array elements. 
mean (a[, axis, dtype, out, keepdims]) 
Compute the arithmetic mean along the specified axis. 
median (a[, axis, out, overwrite_input, keepdims]) 
Compute the median along the specified axis. 
meshgrid (*args, **kwargs) 
Return coordinate matrices from coordinate vectors. 
min (a[, axis, dtype, out, keepdims]) 
Return the minimum of an array or minimum along an axis. 
minimum (x1, x2) 
Elementwise minimum of array elements. 
mod (x1, x2) 
Return elementwise remainder of division. 
moveaxis (a, source, destination) 
Move axes of an array to new positions. 
msort (a) 
Return a copy of an array sorted along the first axis. 
multiply (x1, x2) 
Multiply arguments elementwise. 
nan_to_num (x[, copy]) 
Replace NaN with zero and infinity with large finite numbers (default 
nancumprod (a[, axis, dtype]) 
Return the cumulative product of array elements over a given axis treating Not a 
nancumsum (a[, axis, dtype]) 
Return the cumulative sum of array elements over a given axis treating Not a 
nanmax (a[, axis, out, keepdims]) 
Return the maximum of an array or maximum along an axis, ignoring any 
nanmin (a[, axis, out, keepdims]) 
Return minimum of an array or minimum along an axis, ignoring any NaNs. 
nanprod (a[, axis, out, keepdims]) 
Return the product of array elements over a given axis treating Not a 
nansum (a[, axis, out, keepdims]) 
Return the sum of array elements over a given axis treating Not a 
negative (x) 
Numerical negative, elementwise. 
nextafter (x1, x2) 
Return the next floatingpoint value after x1 towards x2, elementwise. 
nonzero (a) 
Return the indices of the elements that are nonzero. 
not_equal (x1, x2) 
Return (x1 != x2) elementwise. 
ones (shape[, dtype]) 
Return a new array of given shape and type, filled with ones. 
ones_like (x[, dtype]) 
Return an array of ones with the same shape and type as a given array. 
outer (a, b[, out]) 
Compute the outer product of two vectors. 
pad (array, pad_width[, mode, constant_values]) 
Pad an array. 
percentile (a, q[, axis, out, â€¦]) 
Compute the qth percentile of the data along the specified axis. 
polyval (p, x) 
Evaluate a polynomial at specific values. 
power (x1, x2) 
First array elements raised to powers from second array, elementwise. 
positive (x) 
Numerical positive, elementwise. 
prod (a[, axis, dtype, out, keepdims]) 
Return the product of array elements over a given axis. 
product (a[, axis, dtype, out, keepdims]) 
Return the product of array elements over a given axis. 
promote_types (a, b) 
Returns the type to which a binary operation should cast its arguments. 
ptp (a[, axis, out, keepdims]) 
Range of values (maximum  minimum) along an axis. 
quantile (a, q[, axis, out, overwrite_input, â€¦]) 
Compute the qth quantile of the data along the specified axis. 
rad2deg (x) 
Convert angles from radians to degrees. 
radians (x) 
Convert angles from degrees to radians. 
ravel (a[, order]) 
Return a contiguous flattened array. 
real (val) 
Return the real part of the complex argument. 
reciprocal (x) 
Return the reciprocal of the argument, elementwise. 
remainder (x1, x2) 
Return elementwise remainder of division. 
repeat (a, repeats[, axis]) 
Repeat elements of an array. 
reshape (a, newshape[, order]) 
Gives a new shape to an array without changing its data. 
result_type (*args) 
Returns the type that results from applying the NumPy 
right_shift (x1, x2) 
Shift the bits of an integer to the right. 
roll (a, shift[, axis]) 
Roll array elements along a given axis. 
rot90 (m[, k, axes]) 
Rotate an array by 90 degrees in the plane specified by axes. 
round (a[, decimals]) 
Round an array to the given number of decimals. 
row_stack (tup) 
Stack arrays in sequence vertically (row wise). 
select (condlist, choicelist[, default]) 
Return an array drawn from elements in choicelist, depending on conditions. 
sign (x) 
Returns an elementwise indication of the sign of a number. 
signbit (x) 
Returns elementwise True where signbit is set (less than zero). 
sin (x) 
Trigonometric sine, elementwise. 
sinc (x) 
Return the sinc function. 
sinh (x) 
Hyperbolic sine, elementwise. 
sometrue (a[, axis, dtype, out, keepdims]) 
Test whether any array element along a given axis evaluates to True. 
sort (a[, axis, kind, order]) 
Return a sorted copy of an array. 
split (ary, indices_or_sections[, axis]) 
Split an array into multiple subarrays as views into ary. 
sqrt (x) 
Return the nonnegative squareroot of an array, elementwise. 
square (x) 
Return the elementwise square of the input. 
squeeze (a[, axis]) 
Remove singledimensional entries from the shape of an array. 
stack (arrays[, axis]) 
Join a sequence of arrays along a new axis. 
std (a[, axis, dtype, out, ddof, keepdims]) 
Compute the standard deviation along the specified axis. 
subtract (x1, x2) 
Subtract arguments, elementwise. 
sum (a[, axis, dtype, out, keepdims]) 
Sum of array elements over a given axis. 
swapaxes (a, axis1, axis2) 
Interchange two axes of an array. 
take (a, indices[, axis, out, mode]) 
Take elements from an array along an axis. 
take_along_axis (arr, indices, axis) 
Take values from the input array by matching 1d index and data slices. 
tan (x) 
Compute tangent elementwise. 
tanh (x) 
Compute hyperbolic tangent elementwise. 
tensordot (a, b[, axes, precision]) 
Compute tensor dot product along specified axes. 
tile (a, reps) 
Construct an array by repeating A the number of times given by reps. 
trace (a[, offset, axis1, axis2, dtype, out]) 
Return the sum along diagonals of the array. 
transpose (a[, axes]) 
Permute the dimensions of an array. 
tri (N[, M, k, dtype]) 
An array with ones at and below the given diagonal and zeros elsewhere. 
tril (m[, k]) 
Lower triangle of an array. 
tril_indices (*args, **kwargs) 
Return the indices for the lowertriangle of an (n, m) array. 
triu (m[, k]) 
Upper triangle of an array. 
triu_indices (*args, **kwargs) 
Return the indices for the uppertriangle of an (n, m) array. 
true_divide (x1, x2) 
Returns a true division of the inputs, elementwise. 
vander (x[, N, increasing]) 
Generate a Vandermonde matrix. 
var (a[, axis, dtype, out, ddof, keepdims]) 
Compute the variance along the specified axis. 
vdot (a, b[, precision]) 
Return the dot product of two vectors. 
vsplit (ary, indices_or_sections) 
Split an array into multiple subarrays vertically (rowwise). 
vstack (tup) 
Stack arrays in sequence vertically (row wise). 
where (condition[, x, y]) 
Return elements chosen from x or y depending on condition. 
zeros (shape[, dtype]) 
Return a new array of given shape and type, filled with zeros. 
zeros_like (x[, dtype]) 
Return an array of zeros with the same shape and type as a given array. 
jax.numpy.fftÂ¶
fft (a[, n, axis, norm]) 
Compute the onedimensional discrete Fourier Transform. 
ifft (a[, n, axis, norm]) 
Compute the onedimensional inverse discrete Fourier Transform. 
fft2 (a[, s, axes, norm]) 
Compute the 2dimensional discrete Fourier Transform 
ifft2 (a[, s, axes, norm]) 
Compute the 2dimensional inverse discrete Fourier Transform. 
fftn (a[, s, axes, norm]) 
Compute the Ndimensional discrete Fourier Transform. 
ifftn (a[, s, axes, norm]) 
Compute the Ndimensional inverse discrete Fourier Transform. 
rfft (a[, n, axis, norm]) 
Compute the onedimensional discrete Fourier Transform for real input. 
irfft (a[, n, axis, norm]) 
Compute the inverse of the npoint DFT for real input. 
rfft2 (a[, s, axes, norm]) 
Compute the 2dimensional FFT of a real array. 
irfft2 (a[, s, axes, norm]) 
Compute the 2dimensional inverse FFT of a real array. 
rfftn (a[, s, axes, norm]) 
Compute the Ndimensional discrete Fourier Transform for real input. 
irfftn (a[, s, axes, norm]) 
Compute the inverse of the Ndimensional FFT of real input. 
fftfreq (n[, d]) 
Return the Discrete Fourier Transform sample frequencies. 
rfftfreq (n[, d]) 
Return the Discrete Fourier Transform sample frequencies 
fftshift (x[, axes]) 
Shift the zerofrequency component to the center of the spectrum. 
ifftshift (x[, axes]) 
The inverse of fftshift. Although identical for evenlength x, the 
jax.numpy.linalgÂ¶
cholesky (a) 
Cholesky decomposition. 
det (a) 
Compute the determinant of an array. 
eig (a) 
Compute the eigenvalues and right eigenvectors of a square array. 
eigh (a[, UPLO, symmetrize_input]) 
Return the eigenvalues and eigenvectors of a complex Hermitian 
eigvals (a) 
Compute the eigenvalues of a general matrix. 
eigvalsh (a[, UPLO]) 
Compute the eigenvalues of a complex Hermitian or real symmetric matrix. 
inv (a) 
Compute the (multiplicative) inverse of a matrix. 
matrix_power (a, n) 
Raise a square matrix to the (integer) power n. 
matrix_rank (M[, tol]) 
Return matrix rank of array using SVD method 
norm (x[, ord, axis, keepdims]) 
Matrix or vector norm. 
pinv (a[, rcond]) 
Compute the (MoorePenrose) pseudoinverse of a matrix. 
qr (a[, mode]) 
Compute the qr factorization of a matrix. 
slogdet 
Compute the sign and (natural) logarithm of the determinant of an array. 
solve (a, b) 
Solve a linear matrix equation, or system of linear scalar equations. 
svd (a[, full_matrices, compute_uv]) 
Singular Value Decomposition. 
jax.scipy packageÂ¶
jax.scipy.linalgÂ¶
block_diag (*arrs) 
Create a block diagonal matrix from provided arrays. 
cho_factor (a[, lower, overwrite_a, check_finite]) 
Compute the Cholesky decomposition of a matrix, to use in cho_solve 
cho_solve (c_and_lower, b[, overwrite_b, â€¦]) 
Solve the linear equations A x = b, given the Cholesky factorization of A. 
cholesky (a[, lower, overwrite_a, check_finite]) 
Compute the Cholesky decomposition of a matrix. 
det (a[, overwrite_a, check_finite]) 
Compute the determinant of a matrix 
eigh (a[, b, lower, eigvals_only, â€¦]) 
Solve an ordinary or generalized eigenvalue problem for a complex 
expm (A, *[, upper_triangular]) 
Compute the matrix exponential using Pade approximation. 
inv (a[, overwrite_a, check_finite]) 
Compute the inverse of a matrix. 
lu (a[, permute_l, overwrite_a, check_finite]) 
Compute pivoted LU decomposition of a matrix. 
lu_factor (a[, overwrite_a, check_finite]) 
Compute pivoted LU decomposition of a matrix. 
lu_solve (lu_and_piv, b[, trans, â€¦]) 
Solve an equation system, a x = b, given the LU factorization of a 
qr (a[, overwrite_a, lwork, mode, pivoting, â€¦]) 
Compute QR decomposition of a matrix. 
solve (a, b[, sym_pos, lower, overwrite_a, â€¦]) 
Solves the linear equation set a * x = b for the unknown x 
solve_triangular (a, b[, trans, lower, â€¦]) 
Solve the equation a x = b for x, assuming a is a triangular matrix. 
svd (a[, full_matrices, compute_uv, â€¦]) 
Singular Value Decomposition. 
tril (m[, k]) 
Make a copy of a matrix with elements above the kth diagonal zeroed. 
triu (m[, k]) 
Make a copy of a matrix with elements below the kth diagonal zeroed. 
jax.scipy.ndimageÂ¶
map_coordinates (input, coordinates, order[, â€¦]) 
Map the input array to new coordinates by interpolation. 
jax.scipy.specialÂ¶
betainc (a, b, x) 
Incomplete beta function. 
digamma (x) 
The digamma function. 
entr (x) 
Elementwise function for computing entropy. 
erf (x) 
Returns the error function of complex argument. 
erfc (x) 
Complementary error function, 1  erf(x) . 
erfinv (x) 
Inverse of the error function erf. 
expit 
Expit (a.k.a. 
gammainc (a, x) 
Regularized lower incomplete gamma function. 
gammaincc (a, x) 
Regularized upper incomplete gamma function. 
gammaln (x) 
Logarithm of the absolute value of the Gamma function. 
i0e (x) 
Exponentially scaled modified Bessel function of order 0. 
i1e (x) 
Exponentially scaled modified Bessel function of order 1. 
log_ndtr 
Log Normal distribution function. 
logit 
Logit ufunc for ndarrays. 
logsumexp (a[, axis, b, keepdims, return_sign]) 
Compute the log of the sum of exponentials of input elements. 
multigammaln (a, d) 
Returns the log of multivariate gamma, also sometimes called the 
ndtr (x) 
Normal distribution function. 
ndtri (p) 
The inverse of the CDF of the Normal distribution function. 
xlog1py (x, y) 
Compute x*log1p(y) so that the result is 0 if x = 0 . 
xlogy (x, y) 
Compute x*log(y) so that the result is 0 if x = 0 . 
jax.scipy.statsÂ¶
jax.scipy.stats.betaÂ¶
logpdf (x, a, b[, loc, scale]) 
Log of the probability density function at x of the given RV. 
pdf (x, a, b[, loc, scale]) 
Probability density function at x of the given RV. 
jax.scipy.stats.exponÂ¶
logpdf (x[, loc, scale]) 
Log of the probability density function at x of the given RV. 
pdf (x[, loc, scale]) 
Probability density function at x of the given RV. 
jax.scipy.stats.gammaÂ¶
logpdf (x, a[, loc, scale]) 
Log of the probability density function at x of the given RV. 
pdf (x, a[, loc, scale]) 
Probability density function at x of the given RV. 
jax.scipy.stats.laplaceÂ¶
cdf (x[, loc, scale]) 
Cumulative distribution function of the given RV. 
logpdf (x[, loc, scale]) 
Log of the probability density function at x of the given RV. 
pdf (x[, loc, scale]) 
Probability density function at x of the given RV. 
jax.scipy.stats.logisticÂ¶
cdf (x) 
Cumulative distribution function of the given RV. 
isf (x) 
Inverse survival function (inverse of sf) at q of the given RV. 
logpdf (x) 
Log of the probability density function at x of the given RV. 
pdf (x) 
Probability density function at x of the given RV. 
ppf (x) 
Percent point function (inverse of cdf) at q of the given RV. 
sf (x) 
Survival function (1  cdf) at x of the given RV. 
jax.scipy.stats.normÂ¶
cdf (x[, loc, scale]) 
Cumulative distribution function of the given RV. 
logcdf (x[, loc, scale]) 
Log of the cumulative distribution function at x of the given RV. 
logpdf (x[, loc, scale]) 
Log of the probability density function at x of the given RV. 
pdf (x[, loc, scale]) 
Probability density function at x of the given RV. 
ppf (q[, loc, scale]) 
Percent point function (inverse of cdf) at q of the given RV. 
jax.experimental packageÂ¶
jax.experimental.loops moduleÂ¶
Loops is an experimental module for syntactic sugar for loops and controlflow.
The current implementation should convert loops correctly to JAX internal representation, and most transformations should work (see below), but we have not yet finetuned the performance of the resulting XLA compilation!
By default, loops and controlflow in JAX are executed and inlined during tracing. For example, in the following code the for loop is unrolled during JAX tracing:
arr = onp.zeros(5)
for i in range(arr.shape[0]):
arr[i] += 2.
if i % 2 == 0:
arr[i] += 1.
In order to capture the structured controlflow one has to use the higherorder JAX operations, which require you to express the body of the loops and conditionals as functions, and the array updates using a functional style that returns an updated array, e.g.:
arr = onp.zeros(5)
def loop_body(i, acc_arr):
arr1 = ops.index_update(acc_arr, i, acc_arr[i] + 2.)
return lax.cond(i % 2 == 0,
arr1,
lambda arr1: ops.index_update(arr1, i, arr1[i] + 1),
arr1,
lambda arr1: arr1)
arr = lax.fori_loop(0, arr.shape[0], loop_body, arr)
The default notation quickly gets unreadable with deeper nested loops. With the utilities in this module you can write loops and conditionals that look closer to plain Python, as long as you keep the loopcarried state in a special loops.scope object and use for loops over special scope.range iterators:
from jax.experimental import loops
with loops.Scope() as s:
s.arr = np.zeros(5) # Create the mutable state of the loop as `scope` fields.
for i in s.range(s.arr.shape[0]):
s.arr = ops.index_update(s.arr, i, s.arr[i] + 2.)
for _ in s.cond_range(i % 2 == 0): # Conditionals as loops with 0 or 1 iterations
s.arr = ops.index_update(s.arr, i, s.arr[i] + 1.)
Loops constructed with range must have literal constant bounds. If you need loops with dynamic bounds, you can use the more general while_range iterator. However, in that case that grad transformation is not supported:
s.idx = start
for _ in s.while_range(lambda: s.idx < end):
s.idx += 1
Notes
 Loops and conditionals to be functionalized can appear only inside scopes constructed with loops.Scope and they must use one of the Scope.range iterators. All other loops are unrolled during tracing, as usual in JAX.
 Only scope data (stored in fields of the scope object) is functionalized. All other state, e.g., in other Python variables, will not be considered as being part of the loop output. All references to the mutable state should be through the scope: s.arr.
 Conceptually, this model is still â€śfunctionalâ€ť in the sense that a loop over a Scope.range behaves as a function whose input and output is the scope data.
 Scopes should be passed down to callees that need to use loop functionalization, or they may be nested.
 The programming model is that the loop body over a scope.range is traced only once, using abstract shape values, similar to how JAX traces function bodies.
 Restrictions:
The tracing of the loop body should not exit prematurely with return, exception, break. This would be detected and reported as errors when we
encounter unnested scopes.
The loop index variable should not be used after the loop. Similarly, one should not use outside the loop data computed in the loop body, except data stored in fields of the scope object.
No new mutable state can be created inside a loop to be functionalized. All mutable state must be created outside all loops and conditionals.
For a while loop, the conditional function is not allowed to modify the scope state. This is a checked error. Also, for while loops the grad transformation does not work. An alternative that allows grad is a bounded loop (range).
 Transformations:
 All transformations are supported, except grad is not supported for Scope.while_range loops.
 vmap is very useful for such loops because it pushes more work into the innerloops, which should help performance for accelerators.
For usage example, see tests/loops_test.py.

class
jax.experimental.loops.
Scope
[source]Â¶ Bases:
object
A scope context manager to keep the state of loop bodies for functionalization.
Usage:
with Scope() as s: s.data = 0. for i in s.range(5): s.data += 1. return s.data

cond_range
(pred)[source]Â¶ Creates a conditional iterator with 0 or 1 iterations based on the boolean.
The body is converted to a lax.cond. All JAX transformations work.
Usage:
for _ in scope.cond_range(s.field < 0.): s.field =  s.field

range
(first, second=None, third=None)[source]Â¶ Creates an iterator for bounded iterations to be functionalized.
The body is converted to a lax.scan, for which all JAX transformations work. The first, second, and third arguments must be integer literals.
Usage:
range(5) # start=0, end=5, step=1 range(1, 5) # start=1, end=5, step=1 range(1, 5, 2) # start=1, end=5, step=2 s.out = 1. for i in scope.range(5): s.out += 1.

while_range
(cond_func)[source]Â¶ Creates an iterator that continues as long as cond_func returns true.
The body is converted to a lax.while_loop. The grad transformation does not work.
Usage:
for _ in scope.while_range(lambda: s.loss > 1.e5): s.loss = loss(...)
Parameters: cond_func â€“ a lambda with no arguments, the condition for the â€śwhileâ€ť.

jax.experimental.optimizers moduleÂ¶
Optimizers for use with JAX.
This module contains some convenient optimizer definitions, specifically initialization and update functions, which can be used with ndarrays or arbitrarilynested tuple/list/dicts of ndarrays.
An optimizer is modeled as an (init_fun, update_fun, get_params)
triple of
functions, where the component functions have these signatures:
init_fun(params)
Args:
params: pytree representing the initial parameters.
Returns:
A pytree representing the initial optimizer state, which includes the
initial parameters and may also include auxiliary values like initial
momentum. The optimizer state pytree structure generally differs from that
of `params`.
update_fun(step, grads, opt_state)
Args:
step: integer representing the step index.
grads: a pytree with the same structure as `get_params(opt_state)`
representing the gradients to be used in updating the optimizer state.
opt_state: a pytree representing the optimizer state to be updated.
Returns:
A pytree with the same structure as the `opt_state` argument representing
the updated optimizer state.
get_params(opt_state)
Args:
opt_state: pytree representing an optimizer state.
Returns:
A pytree representing the parameters extracted from `opt_state`, such that
the invariant `params == get_params(init_fun(params))` holds true.
Notice that an optimizer implementation has a lot of flexibility in the form of opt_state: it just has to be a pytree of JaxTypes (so that it can be passed to the JAX transforms defined in api.py) and it has to be consumable by update_fun and get_params.

class
jax.experimental.optimizers.
JoinPoint
(subtree)[source]Â¶ Bases:
object
Marks the boundary between two joined (nested) pytrees.

class
jax.experimental.optimizers.
OptimizerState
(packed_state, tree_def, subtree_defs)Â¶ Bases:
tuple

packed_state
Â¶ Alias for field number 0

subtree_defs
Â¶ Alias for field number 2

tree_def
Â¶ Alias for field number 1


jax.experimental.optimizers.
adagrad
(step_size, momentum=0.9)[source]Â¶ Construct optimizer triple for Adagrad.
Adaptive Subgradient Methods for Online Learning and Stochastic Optimization: http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf
Parameters:  step_size â€“ positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar.
 momentum â€“ optional, a positive scalar value for momentum
Returns: An (init_fun, update_fun, get_params) triple.

jax.experimental.optimizers.
adam
(step_size, b1=0.9, b2=0.999, eps=1e08)[source]Â¶ Construct optimizer triple for Adam.
Parameters:  step_size â€“ positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar.
 b1 â€“ optional, a positive scalar value for beta_1, the exponential decay rate for the first moment estimates (default 0.9).
 b2 â€“ optional, a positive scalar value for beta_2, the exponential decay rate for the second moment estimates (default 0.999).
 eps â€“ optional, a positive scalar value for epsilon, a small constant for numerical stability (default 1e8).
Returns: An (init_fun, update_fun, get_params) triple.

jax.experimental.optimizers.
clip_grads
(grad_tree, max_norm)[source]Â¶ Clip gradients stored as a pytree of arrays to maximum norm max_norm.

jax.experimental.optimizers.
inverse_time_decay
(step_size, decay_steps, decay_rate, staircase=False)[source]Â¶

jax.experimental.optimizers.
l2_norm
(tree)[source]Â¶ Compute the l2 norm of a pytree of arrays. Useful for weight decay.

jax.experimental.optimizers.
momentum
(step_size, mass)[source]Â¶ Construct optimizer triple for SGD with momentum.
Parameters:  step_size â€“ positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar.
 mass â€“ positive scalar representing the momentum coefficient.
Returns: An (init_fun, update_fun, get_params) triple.

jax.experimental.optimizers.
nesterov
(step_size, mass)[source]Â¶ Construct optimizer triple for SGD with Nesterov momentum.
Parameters:  step_size â€“ positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar.
 mass â€“ positive scalar representing the momentum coefficient.
Returns: An (init_fun, update_fun, get_params) triple.

jax.experimental.optimizers.
optimizer
(opt_maker)[source]Â¶ Decorator to make an optimizer defined for arrays generalize to containers.
With this decorator, you can write init, update, and get_params functions that each operate only on single arrays, and convert them to corresponding functions that operate on pytrees of parameters. See the optimizers defined in optimizers.py for examples.
Parameters: opt_maker â€“ a function that returns an
(init_fun, update_fun, get_params)
triple of functions that might only work with ndarrays, as perinit_fun :: ndarray > OptStatePytree ndarray update_fun :: OptStatePytree ndarray > OptStatePytree ndarray get_params :: OptStatePytree ndarray > ndarray
Returns: An (init_fun, update_fun, get_params)
triple of functions that work on arbitrary pytrees, as perinit_fun :: ParameterPytree ndarray > OptimizerState update_fun :: OptimizerState > OptimizerState get_params :: OptimizerState > ParameterPytree ndarray
The OptimizerState pytree type used by the returned functions is isomorphic to
ParameterPytree (OptStatePytree ndarray)
, but may store the state instead as e.g. a partiallyflattened data structure for performance.

jax.experimental.optimizers.
pack_optimizer_state
(marked_pytree)[source]Â¶ Converts a marked pytree to an OptimizerState.
The inverse of unpack_optimizer_state. Converts a marked pytree with the leaves of the outer pytree represented as JoinPoints back into an OptimizerState. This function is intended to be useful when deserializing optimizer states.
Parameters: marked_pytree â€“ A pytree containing JoinPoint leaves that hold more pytrees. Returns: An equivalent OptimizerState to the input argument.

jax.experimental.optimizers.
polynomial_decay
(step_size, decay_steps, final_step_size, power=1.0)[source]Â¶

jax.experimental.optimizers.
rmsprop
(step_size, gamma=0.9, eps=1e08)[source]Â¶ Construct optimizer triple for RMSProp.
Parameters: step_size â€“ positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar. gamma: Decay parameter. eps: Epsilon parameter. Returns: An (init_fun, update_fun, get_params) triple.

jax.experimental.optimizers.
rmsprop_momentum
(step_size, gamma=0.9, eps=1e08, momentum=0.9)[source]Â¶ Construct optimizer triple for RMSProp with momentum.
This optimizer is separate from the rmsprop optimizer because it needs to keep track of additional parameters.
Parameters:  step_size â€“ positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar.
 gamma â€“ Decay parameter.
 eps â€“ Epsilon parameter.
 momentum â€“ Momentum parameter.
Returns: An (init_fun, update_fun, get_params) triple.

jax.experimental.optimizers.
sgd
(step_size)[source]Â¶ Construct optimizer triple for stochastic gradient descent.
Parameters: step_size â€“ positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar. Returns: An (init_fun, update_fun, get_params) triple.

jax.experimental.optimizers.
sm3
(step_size, momentum=0.9)[source]Â¶ Construct optimizer triple for SM3.
MemoryEfficient Adaptive Optimization for LargeScale Learning. https://arxiv.org/abs/1901.11150
Parameters:  step_size â€“ positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar.
 momentum â€“ optional, a positive scalar value for momentum
Returns: An (init_fun, update_fun, get_params) triple.

jax.experimental.optimizers.
unpack_optimizer_state
(opt_state)[source]Â¶ Converts an OptimizerState to a marked pytree.
Converts an OptimizerState to a marked pytree with the leaves of the outer pytree represented as JoinPoints to avoid losing information. This function is intended to be useful when serializing optimizer states.
Parameters: opt_state â€“ An OptimizerState Returns: A pytree with JoinPoint leaves that contain a second level of pytrees.
jax.experimental.optix moduleÂ¶
A composable gradient processing and optimization library for JAX.
The optix
module implements a number of composable gradient transformations,
typically used in the context of optimizing neural nets.
Each transformation defines:
 an
init_fn
, to initialize a (possibly empty) set of statistics, orstate
.  an
update_fn
to transform an input gradient and update the state.
An (optional) chain
utility can be used to build custom optimizers by
chaining arbitrary sequences of transformations. For any sequence of
transformations chain
returns a single init_fn
and update_fn
.
An (optional) apply_updates
function can be used to eventually apply the
transformed gradients to the set of parameters of interest.
Separating gradient transformations from the parameter update allows to flexibly chain a sequence of transformations of the same gradients, as well as combine multiple updates to the same parameters (e.g. in multitask settings where the different tasks may benefit from different sets of gradient transformations).
Many popular optimizers can be implemented using optix
as oneliners, and,
for convenience, we provide aliases for some of the most popular ones.

class
jax.experimental.optix.
AddNoiseState
(count, rng_key)Â¶ Bases:
tuple

count
Â¶ Alias for field number 0

rng_key
Â¶ Alias for field number 1


class
jax.experimental.optix.
ApplyEvery
(count, grad_acc)Â¶ Bases:
tuple

count
Â¶ Alias for field number 0

grad_acc
Â¶ Alias for field number 1


class
jax.experimental.optix.
InitUpdate
(init, update)Â¶ Bases:
tuple

init
Â¶ Alias for field number 0

update
Â¶ Alias for field number 1


class
jax.experimental.optix.
ScaleByAdamState
(count, mu, nu)Â¶ Bases:
tuple

count
Â¶ Alias for field number 0

mu
Â¶ Alias for field number 1

nu
Â¶ Alias for field number 2


class
jax.experimental.optix.
ScaleByRStdDevState
(mu, nu)Â¶ Bases:
tuple

mu
Â¶ Alias for field number 0

nu
Â¶ Alias for field number 1


class
jax.experimental.optix.
ScaleByScheduleState
(count)Â¶ Bases:
tuple

count
Â¶ Alias for field number 0


jax.experimental.optix.
add_noise
(eta, gamma, seed)[source]Â¶ Add gradient noise.
References
[Neelakantan et al, 2014](https://arxiv.org/abs/1511.06807)
Parameters:  eta â€“ base variance of the gaussian noise added to the gradient.
 gamma â€“ decay exponent for annealing of the variance.
 seed â€“ seed for random number generation.
Returns: An (init_fn, update_fn) tuple.

jax.experimental.optix.
apply_every
(k=1)[source]Â¶ accumulate gradients and apply them every k steps.
Parameters: k â€“ apply the update every k steps otherwise accumulate the gradients. Returns: An (init_fn, update_fn) tuple.

jax.experimental.optix.
apply_updates
(params, updates)[source]Â¶ Applies an update to the corresponding parameters.
This is an (optional) utility functions that applies an update, and returns the updated parameters to the caller. The update itself is typically the result of applying any number of chainable transformations.
Parameters:  params â€“ a tree of parameters.
 updates â€“ a tree of updates, the tree structure and the shape of the leaf
 must match that of params. (nodes) â€“
Returns: Updated parameters, with same structure and shape as params.

jax.experimental.optix.
chain
(*args)[source]Â¶ Applies a list of chainable update transformations.
Given a sequence of chainable transforms, chain returns an init_fn that constructs a state by concatenating the states of the individual transforms, and returns an update_fn which chains the update transformations feeding the appropriate state to each.
Parameters: *args â€“ a sequence of chainable (init_fn, update_fn) tuples. Returns: A single (init_fn, update_fn) tuple.

jax.experimental.optix.
clip
(max_delta)[source]Â¶ Clip updates elementwise, to be between max_delta and +max_delta.
Parameters: max_delta â€“ the maximum absolute value for each element in the update. Returns: An (init_fn, update_fn) tuple.

jax.experimental.optix.
clip_by_global_norm
(max_norm)[source]Â¶ Clip updates using their global norm.
References
[Pascanu et al, 2012](https://arxiv.org/abs/1211.5063)
Parameters: max_norm â€“ the maximum global norm for an update. Returns: An (init_fn, update_fn) tuple.

jax.experimental.optix.
scale
(step_size)[source]Â¶ Scale updates by some fixed scalar step_size.
Parameters: step_size â€“ a scalar corresponding to a fixed scaling factor for updates. Returns: An (init_fn, update_fn) tuple.

jax.experimental.optix.
scale_by_adam
(b1=0.9, b2=0.999, eps=1e08)[source]Â¶ Rescale updates according to the Adam algorithm.
References
[Kingma et al, 2014](https://arxiv.org/abs/1412.6980)
Parameters:  b1 â€“ decay rate for the exponentially weighted average of grads.
 b2 â€“ decay rate for the exponentially weighted average of squared grads.
 eps â€“ term added to the denominator to improve numerical stability.
Returns: An (init_fn, update_fn) tuple.

jax.experimental.optix.
scale_by_rms
(decay=0.9, eps=1e08)[source]Â¶ Rescale updates by the root of the exp. moving avg of the square.
References
[Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
Parameters:  decay â€“ decay rate for the exponentially weighted average of squared grads.
 eps â€“ term added to the denominator to improve numerical stability.
Returns: An (init_fn, update_fn) tuple.

jax.experimental.optix.
scale_by_schedule
(step_size_fn)[source]Â¶ Scale updates using a custom schedule for the step_size.
Parameters: step_size_fn â€“ a function that takes an update count as input and proposes the step_size to multiply the updates by. Returns: An (init_fn, update_fn) tuple.

jax.experimental.optix.
scale_by_stddev
(decay=0.9, eps=1e08)[source]Â¶ Rescale updates by the root of the centered exp. moving average of squares.
References
[Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
Parameters:  decay â€“ decay rate for the exponentially weighted average of squared grads.
 eps â€“ term added to the denominator to improve numerical stability.
Returns: An (init_fn, update_fn) tuple.
jax.experimental.stax moduleÂ¶
Stax is a small but flexible neural net specification library from scratch.
For an example of its use, see examples/resnet50.py.

jax.experimental.stax.
AvgPool
(window_shape, strides=None, padding='VALID', spec=None)Â¶ Layer construction function for a pooling layer.

jax.experimental.stax.
BatchNorm
(axis=(0, 1, 2), epsilon=1e05, center=True, scale=True, beta_init=<function zeros>, gamma_init=<function ones>)[source]Â¶ Layer construction function for a batch normalization layer.

jax.experimental.stax.
Conv
(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)Â¶ Layer construction function for a general convolution layer.

jax.experimental.stax.
Conv1DTranspose
(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)Â¶ Layer construction function for a general transposedconvolution layer.

jax.experimental.stax.
ConvTranspose
(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)Â¶ Layer construction function for a general transposedconvolution layer.

jax.experimental.stax.
Dense
(out_dim, W_init=<function variance_scaling.<locals>.init>, b_init=<function normal.<locals>.init>)[source]Â¶ Layer constructor function for a dense (fullyconnected) layer.

jax.experimental.stax.
Dropout
(rate, mode='train')[source]Â¶ Layer construction function for a dropout layer with given rate.

jax.experimental.stax.
FanInConcat
(axis=1)[source]Â¶ Layer construction function for a fanin concatenation layer.

jax.experimental.stax.
GeneralConv
(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)[source]Â¶ Layer construction function for a general convolution layer.

jax.experimental.stax.
GeneralConvTranspose
(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)[source]Â¶ Layer construction function for a general transposedconvolution layer.

jax.experimental.stax.
MaxPool
(window_shape, strides=None, padding='VALID', spec=None)Â¶ Layer construction function for a pooling layer.

jax.experimental.stax.
SumPool
(window_shape, strides=None, padding='VALID', spec=None)Â¶ Layer construction function for a pooling layer.

jax.experimental.stax.
elementwise
(fun, **fun_kwargs)[source]Â¶ Layer that applies a scalar function elementwise on its inputs.

jax.experimental.stax.
glorot
(in_axis=2, out_axis=1, dtype=<class 'jax.numpy.lax_numpy.float32'>)Â¶

jax.experimental.stax.
glorot_normal
(in_axis=2, out_axis=1, dtype=<class 'jax.numpy.lax_numpy.float32'>)Â¶

jax.experimental.stax.
parallel
(*layers)[source]Â¶ Combinator for composing layers in parallel.
The layer resulting from this combinator is often used with the FanOut and FanInSum layers.
Parameters: *layers â€“ a sequence of layers, each an (init_fun, apply_fun) pair. Returns: A new layer, meaning an (init_fun, apply_fun) pair, representing the parallel composition of the given sequence of layers. In particular, the returned layer takes a sequence of inputs and returns a sequence of outputs with the same length as the argument layers.

jax.experimental.stax.
serial
(*layers)[source]Â¶ Combinator for composing layers in serial.
Parameters: *layers â€“ a sequence of layers, each an (init_fun, apply_fun) pair. Returns: A new layer, meaning an (init_fun, apply_fun) pair, representing the serial composition of the given sequence of layers.

jax.experimental.stax.
shape_dependent
(make_layer)[source]Â¶ Combinator to delay layer constructor pair until input shapes are known.
Parameters: make_layer â€“ a oneargument function that takes an input shape as an argument (a tuple of positive integers) and returns an (init_fun, apply_fun) pair. Returns: A new layer, meaning an (init_fun, apply_fun) pair, representing the same layer as returned by make_layer but with its construction delayed until input shapes are known.
jax.lax packageÂ¶
jax.lax
is a library of primitives operations that underpins libraries
such as jax.numpy
. Transformation rules, such as JVP and batching rules,
are typically defined as transformations on jax.lax
primitives.
Many of the primitives are thin wrappers around equivalent XLA operations, described by the XLA operation semantics documentation. In a few cases JAX diverges from XLA, usually to ensure that the set of operations is closed under the operation of JVP and transpose rules.
Where possible, prefer to use libraries such as jax.numpy
instead of
using jax.lax
directly. The jax.numpy
API follows NumPy, and is
therefore more stable and less likely to change than the jax.lax
API.
OperatorsÂ¶
abs (x) 
Elementwise absolute value: \(x\). 
add (x, y) 
Elementwise addition: \(x + y\). 
acos (x) 
Elementwise arc cosine: \(\mathrm{acos}(x)\). 
asin (x) 
Elementwise arc sine: \(\mathrm{asin}(x)\). 
atan (x) 
Elementwise arc tangent: \(\mathrm{atan}(x)\). 
atan2 (x, y) 
Elementwise arc tangent of two variables: \(\mathrm{atan}({x \over y})\). 
batch_matmul (lhs, rhs[, precision]) 
Batch matrix multiplication. 
bessel_i0e (x) 
Exponentially scaled modified Bessel function of order 0: \(\mathrm{i0e}(x) = e^{x} \mathrm{i0}(x)\) 
bessel_i1e (x) 
Exponentially scaled modified Bessel function of order 1: \(\mathrm{i1e}(x) = e^{x} \mathrm{i1}(x)\) 
betainc (a, b, x) 
Elementwise regularized incomplete beta integral. 
bitcast_convert_type (operand, new_dtype) 
Elementwise bitcast. 
bitwise_not (x) 
Elementwise NOT: \(\neg x\). 
bitwise_and (x, y) 
Elementwise AND: \(x \wedge y\). 
bitwise_or (x, y) 
Elementwise OR: \(x \vee y\). 
bitwise_xor (x, y) 
Elementwise exclusive OR: \(x \oplus y\). 
broadcast (operand, sizes) 
Broadcasts an array, adding new major dimensions. 
broadcasted_iota (dtype, shape, dimension) 
Convenience wrapper around iota . 
broadcast_in_dim (operand, shape, â€¦) 
Wraps XLAâ€™s BroadcastInDim operator. 
ceil (x) 
Elementwise ceiling: \(\left\lceil x \right\rceil\). 
clamp (min, x, max) 
Elementwise clamp. 
collapse (operand, start_dimension, â€¦) 

complex (x, y) 
Elementwise make complex number: \(x + jy\). 
concatenate (operands, dimension) 
Concatenates a sequence of arrays along dimension. 
conj (x) 
Elementwise complex conjugate function: \(\overline{x}\). 
conv (lhs, rhs, window_strides, padding[, â€¦]) 
Convenience wrapper around conv_general_dilated. 
convert_element_type (operand, new_dtype) 
Elementwise cast. 
conv_general_dilated (lhs, rhs, â€¦[, â€¦]) 
General ndimensional convolution operator, with optional dilation. 
conv_with_general_padding (lhs, rhs, â€¦[, â€¦]) 
Convenience wrapper around conv_general_dilated. 
conv_transpose (lhs, rhs, strides, padding[, â€¦]) 
Convenience wrapper for calculating the Nd convolution â€śtransposeâ€ť. 
cos (x) 
Elementwise cosine: \(\mathrm{cos}(x)\). 
cosh (x) 
Elementwise hyperbolic cosine: \(\mathrm{cosh}(x)\). 
digamma (x) 
Elementwise digamma: \(\psi(x)\). 
div (x, y) 
Elementwise division: \(x \over y\). 
dot (lhs, rhs[, precision]) 
Vector/vector, matrix/vector, and matrix/matrix multiplication. 
dot_general (lhs, rhs, dimension_numbers[, â€¦]) 
More general contraction operator. 
dynamic_index_in_dim (operand, index[, axis, â€¦]) 
Convenience wrapper around dynamic_slice to perform int indexing. 
dynamic_slice (operand, start_indices, â€¦) 
Wraps XLAâ€™s DynamicSlice operator. 
dynamic_slice_in_dim (operand, start_index, â€¦) 
Convenience wrapper around dynamic_slice applying to one dimension. 
dynamic_update_index_in_dim (operand, update, â€¦) 

dynamic_update_slice_in_dim (operand, update, â€¦) 

eq (x, y) 
Elementwise equals: \(x = y\). 
erf (x) 
Elementwise error function: \(\mathrm{erf}(x)\). 
erfc (x) 
Elementwise complementary error function: \(\mathrm{erfc}(x) = 1  \mathrm{erf}(x)\). 
erf_inv (x) 
Elementwise inverse error function: \(\mathrm{erf}^{1}(x)\). 
exp (x) 
Elementwise exponential: \(e^x\). 
expm1 (x) 
Elementwise \(e^{x  1}\). 
fft (x, fft_type, fft_lengths) 

floor (x) 
Elementwise floor: \(\left\lfloor x \right\rfloor\). 
full (shape, fill_value[, dtype]) 
Returns an array of shape filled with fill_value. 
full_like (x, fill_value[, dtype, shape]) 
Create a full array like np.full based on the example array x. 
gather (operand, start_indices, â€¦) 
Gather operator. 
ge (x, y) 
Elementwise greaterthanorequals: \(x \geq y\). 
gt (x, y) 
Elementwise greaterthan: \(x > y\). 
igamma (a, x) 
Elementwise regularized incomplete gamma function. 
igammac (a, x) 
Elementwise complementary regularized incomplete gamma function. 
imag (x) 
Elementwise extract imaginary part: \(\mathrm{Im}(x)\). 
index_in_dim (operand, index[, axis, keepdims]) 
Convenience wrapper around slice to perform int indexing. 
index_take (src, idxs, axes) 

iota (dtype, size) 
Wraps XLAâ€™s Iota operator. 
is_finite (x) 
Elementwise \(\mathrm{isfinite}\). 
le (x, y) 
Elementwise lessthanorequals: \(x \leq y\). 
lt (x, y) 
Elementwise lessthan: \(x < y\). 
lgamma (x) 
Elementwise log gamma: \(\mathrm{log}(\Gamma(x))\). 
log (x) 
Elementwise natural logarithm: \(\mathrm{log}(x)\). 
log1p (x) 
Elementwise \(\mathrm{log}(1 + x)\). 
max (x, y) 
Elementwise maximum: \(\mathrm{max}(x, y)\) 
min (x, y) 
Elementwise minimum: \(\mathrm{min}(x, y)\) 
mul (x, y) 
Elementwise multiplication: \(x \times y\). 
ne (x, y) 
Elementwise notequals: \(x \neq y\). 
neg (x) 
Elementwise negation: \(x\). 
nextafter (x1, x2) 
Returns the next representable value after x1 in the direction of x2. 
pad (operand, padding_value, padding_config) 
Wraps XLAâ€™s Pad operator. 
pow (x, y) 
Elementwise power: \(x^y\). 
real (x) 
Elementwise extract real part: \(\mathrm{Re}(x)\). 
reciprocal (x) 
Elementwise reciprocal: \(1 \over x\). 
reduce (operand, init_value, computation, â€¦) 
Wraps XLAâ€™s Reduce operator. 
reduce_window (operand, init_value, â€¦) 
Wraps XLAâ€™s ReduceWindow operator. 
reshape (operand, new_sizes[, dimensions]) 
Wraps XLAâ€™s Reshape operator. 
rem (x, y) 
Elementwise remainder: \(x \bmod y\). 
rev (operand, dimensions) 
Wraps XLAâ€™s Rev operator. 
round (x) 
Elementwise round. 
rsqrt (x) 
Elementwise reciprocal square root: :math:`1 over sqrt{x}. 
scatter (operand, scatter_indices, updates, â€¦) 
Scatterupdate operator. 
scatter_add (operand, scatter_indices, â€¦) 
Scatteradd operator. 
select (pred, on_true, on_false) 
Wraps XLAâ€™s Select operator. 
shift_left (x, y) 
Elementwise left shift: \(x \ll y\). 
shift_right_arithmetic (x, y) 
Elementwise arithmetic right shift: \(x \gg y\). 
shift_right_logical (x, y) 
Elementwise logical right shift: \(x \gg y\). 
slice (operand, start_indices, limit_indices) 
Wraps XLAâ€™s Slice operator. 
slice_in_dim (operand, start_index, limit_index) 
Convenience wrapper around slice applying to only one dimension. 
sign (x) 
Elementwise sign. 
sin (x) 
Elementwise sine: \(\mathrm{sin}(x)\). 
sinh (x) 
Elementwise hyperbolic sine: \(\mathrm{sinh}(x)\). 
sort (operand[, dimension]) 
Wraps XLAâ€™s Sort operator. 
sort_key_val (keys, values[, dimension]) 

sqrt (x) 
Elementwise square root: \(\sqrt{x}\). 
square (x) 
Elementwise square: \(x^2\). 
sub (x, y) 
Elementwise subtraction: \(x  y\). 
tan (x) 
Elementwise tangent: \(\mathrm{tan}(x)\). 
tie_in (x, y) 
Gives y a fake data dependence on x . 
transpose (operand, permutation) 
Wraps XLAâ€™s Transpose operator. 
Control flow operatorsÂ¶
cond (pred, true_operand, true_fun, â€¦) 
Conditionally apply true_fun or false_fun . 
fori_loop (lower, upper, body_fun, init_val) 
Loop from lower to upper by reduction to while_loop . 
map (f, xs) 
Map a function over leading array axes. 
scan (f, init, xs[, length]) 
Scan a function over leading array axes while carrying along state. 
while_loop (cond_fun, body_fun, init_val) 
Call body_fun repeatedly in a loop while cond_fun is True. 
Custom gradient operatorsÂ¶
stop_gradient (x) 
Stops gradient computation. 
custom_linear_solve (matvec, b, solve[, â€¦]) 
Perform a matrixfree linear solve with implicitly defined gradients. 
custom_root (f, initial_guess, solve, â€¦) 
Differentiably solve for a roots of a function. 
Parallel operatorsÂ¶
Parallelism support is experimental.
all_gather (x, axis_name) 
Gather values of x across all replicas. 
all_to_all (x, axis_name, split_axis, concat_axis) 
Materialize the mapped axis and map a different axis. 
psum (x, axis_name) 
Compute an allreduce sum on x over the pmapped axis axis_name . 
pmax (x, axis_name) 
Compute an allreduce max on x over the pmapped axis axis_name . 
pmin (x, axis_name) 
Compute an allreduce min on x over the pmapped axis axis_name . 
ppermute (x, axis_name, perm) 
Perform a collective permutation according to the permutation perm . 
pswapaxes (x, axis_name, axis) 
Swap the pmapped axis axis_name with the unmapped axis axis . 
axis_index (axis_name) 
Return the index along the pmapped axis axis_name . 
jax.nn packageÂ¶
jax.nn.initializers packageÂ¶
Common neural network layer initializers, consistent with definitions used in Keras and Sonnet.
InitializersÂ¶
This module provides common neural network layer initializers, consistent with definitions used in Keras and Sonnet.
zeros (key, shape[, dtype]) 

ones (key, shape[, dtype]) 

uniform ([scale, dtype]) 

normal ([stddev, dtype]) 

variance_scaling (scale, mode, distribution) 

glorot_uniform ([in_axis, out_axis, dtype]) 

glorot_normal ([in_axis, out_axis, dtype]) 

lecun_uniform ([in_axis, out_axis, dtype]) 

lecun_normal ([in_axis, out_axis, dtype]) 

he_uniform ([in_axis, out_axis, dtype]) 

he_normal ([in_axis, out_axis, dtype]) 
Common functions for neural network libraries.
Activation functionsÂ¶
relu 
Rectified linear unit activation function. 
sigmoid (x) 
Sigmoid activation function. 
softplus (x) 
Softplus activation function. 
soft_sign (x) 
Softsign activation function. 
swish (x) 
Swish activation function. 
log_sigmoid (x) 
Logsigmoid activation function. 
leaky_relu (x[, negative_slope]) 
Leaky rectified linear unit activation function. 
hard_tanh (x) 
Hard \(\mathrm{tanh}\) activation function. 
elu (x[, alpha]) 
Exponential linear unit activation function. 
celu (x[, alpha]) 
Continuouslydifferentiable exponential linear unit activation. 
selu (x) 
Scaled exponential linear unit activation. 
gelu (x) 
Gaussian error linear unit activation function. 
glu (x[, axis]) 
Gated linear unit activation function. 
Other functionsÂ¶
softmax (x[, axis]) 
Softmax function. 
log_softmax (x[, axis]) 
LogSoftmax function. 
normalize (x[, axis, mean, variance, epsilon]) 
Normalizes an array by subtracting mean and dividing by sqrt(var). 
jax.ops packageÂ¶
Indexed update operatorsÂ¶
JAX is intended to be used with a functional style of programming, and hence
does not support NumPystyle indexed assignment directly. Instead, JAX provides
pure alternatives, namely jax.ops.index_update()
and its relatives.
index 
Helper object for building indexes for indexed update functions. 
index_update (x, idx, y) 
Pure equivalent of x[idx] = y . 
index_add (x, idx, y) 
Pure equivalent of x[idx] += y . 
index_min (x, idx, y) 
Pure equivalent of x[idx] = minimum(x[idx], y) . 
index_max (x, idx, y) 
Pure equivalent of x[idx] = maximum(x[idx], y) . 
Other operatorsÂ¶
segment_sum (data, segment_ids[, num_segments]) 
Computes the sum within segments of an array. 
jax.random packageÂ¶
JAX pseudorandom number generators (PRNGs).
The JAX PRNG system is based on â€śParallel random numbers: as easy as 1, 2, 3â€ť (Salmon et al. 2011). For details on the design and its motivation, see:
https://github.com/google/jax/blob/master/design_notes/prng.md

jax.random.
PRNGKey
(seed)[source]Â¶ Create a pseudorandom number generator (PRNG) key given an integer seed.
Parameters: seed â€“ a 64 or 32bit integer used as the value of the key. Returns: A PRNG key, which is modeled as an array of shape (2,) and dtype uint32. The key is constructed from a 64bit seed by effectively bitcasting to a pair of uint32 values (or from a 32bit seed by first padding out with zeros).

jax.random.
bernoulli
(key, p=0.5, shape=None)[source]Â¶ Sample Bernoulli random values with given shape and mean.
Parameters:  key â€“ a PRNGKey used as the random key.
 p â€“ optional, a float or array of floats for the mean of the random
variables. Must be broadcastcompatible with
shape
. Default 0.5.  shape â€“ optional, a tuple of nonnegative integers representing the result
shape. Must be broadcastcompatible with
p.shape
. The default (None) produces a result shape equal top.shape
.
Returns: A random array with boolean dtype and shape given by
shape
ifshape
is not None, or elsep.shape
.

jax.random.
beta
(key, a, b, shape=None, dtype=<class 'numpy.float64'>)[source]Â¶ Sample Bernoulli random values with given shape and mean.
Parameters:  key â€“ a PRNGKey used as the random key.
 a â€“ a float or array of floats broadcastcompatible with
shape
representing the first parameter â€śalphaâ€ť.  b â€“ a float or array of floats broadcastcompatible with
shape
representing the second parameter â€śbetaâ€ť.  shape â€“ optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcastcompatible with
a
andb
. The default (None) produces a result shape by broadcastinga
andb
.  dtype â€“ optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified dtype and shape given by
shape
ifshape
is not None, or else by broadcastinga
andb
.

jax.random.
categorical
(key, logits, axis=1, shape=None)[source]Â¶ Sample random values from categorical distributions.
Parameters:  key â€“ a PRNGKey used as the random key.
 logits â€“ Unnormalized log probabilities of the categorical distribution(s) to sample from, so that softmax(logits, axis) gives the corresponding probabilities.
 axis â€“ Axis along which logits belong to the same categorical distribution.
 shape â€“ Optional, a tuple of nonnegative integers representing the result shape.
Must be broadcastcompatible with
onp.delete(logits.shape, axis)
. The default (None) produces a result shape equal toonp.delete(logits.shape, axis)
.
Returns: A random array with int dtype and shape given by
shape
ifshape
is not None, or elseonp.delete(logi