JAX: High-Performance Array Computing#
JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning.
If you’re looking to train neural networks, use Flax and start with its documentation. Some associated tools are Optax and Orbax. For an end-to-end transformer library built on JAX, see MaxText.
JAX provides a familiar NumPy-style API for ease of adoption by researchers and engineers.
JAX includes composable function transformations for compilation, batching, automatic differentiation, and parallelization.
The same code executes on multiple backends, including CPU, GPU, & TPU
Installing JAX#
Using JAX requires installing two packages: jax
, which is pure Python and
cross-platform, and jaxlib
which contains compiled binaries, and requires
different builds for different operating systems and accelerators.
TL;DR For most users, a typical JAX installation may look something like this:
CPU-only (Linux/macOS/Windows)
pip install -U "jax[cpu]"
GPU (NVIDIA, CUDA 12, x86_64)
pip install -U "jax[cuda12]"
GPU (NVIDIA, CUDA 12, x86_64) legacy
You should prefer jax[cuda12]
, which uses the common CPU jaxlib and adds GPU
support as a plugin. The monolithic jax[cuda12_pip]
option will be removed in
a future JAX release.
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Supported platforms#
The table below shows all supported platforms and installation options. Check if your setup is supported; and if it says “yes” or “experimental”, then click on the corresponding link to learn how to install JAX in greater detail.
Linux, x86_64 |
Linux, aarch64 |
macOS, Intel x86_64, AMD GPU |
macOS, Apple Silicon, ARM-based |
Windows, x86_64 |
Windows WSL2, x86_64 |
|
---|---|---|---|---|---|---|
CPU |
||||||
NVIDIA GPU |
no |
n/a |
no |
|||
Google Cloud TPU |
n/a |
n/a |
n/a |
n/a |
n/a |
|
AMD GPU |
no |
no |
n/a |
no |
no |
|
Apple GPU |
n/a |
no |
n/a |
n/a |
CPU#
pip installation: CPU#
Currently, the JAX team releases jaxlib
wheels for the following
operating systems and architectures:
Linux, x86_64
macOS, Intel
macOS, Apple ARM-based
Windows, x86_64 (experimental)
To install a CPU-only version of JAX, which might be useful for doing local development on a laptop, you can run:
pip install --upgrade pip
pip install --upgrade "jax[cpu]"
On Windows, you may also need to install the Microsoft Visual Studio 2019 Redistributable if it is not already installed on your machine.
Other operating systems and architectures require building from source. Trying
to pip install on other operating systems and architectures may lead to jaxlib
not being installed alongside jax
, although jax
may successfully install
(but fail at runtime).
NVIDIA GPU#
JAX supports NVIDIA GPUs that have SM version 5.2 (Maxwell) or newer. Note that Kepler-series GPUs are no longer supported by JAX since NVIDIA has dropped support for Kepler GPUs in its software.
You must first install the NVIDIA driver. You’re recommended to install the newest driver available from NVIDIA, but the driver version must be >= 525.60.13 for CUDA 12 on Linux.
If you need to use a newer CUDA toolkit with an older driver, for example on a cluster where you cannot update the NVIDIA driver easily, you may be able to use the CUDA forward compatibility packages that NVIDIA provides for this purpose.
pip installation: NVIDIA GPU (CUDA, installed via pip, easier)#
There are two ways to install JAX with NVIDIA GPU support:
Using NVIDIA CUDA and cuDNN installed from pip wheels
Using a self-installed CUDA/cuDNN
The JAX team strongly recommends installing CUDA and cuDNN using the pip wheels, since it is much easier!
This method is only supported on x86_64, because NVIDIA has not released aarch64 CUDA pip packages.
pip install --upgrade pip
# NVIDIA CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12]"
# Legacy way of NVIDIA CUDA 12 installation. You should prefer `jax[cuda12]`,
# which uses the common CPU jaxlib and adds GPU support as a plugin. The
# monolithic `jax[cuda12_pip]` option will be removed in a future JAX release.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
If JAX detects the wrong version of the NVIDIA CUDA libraries, there are several things you need to check:
Make sure that
LD_LIBRARY_PATH
is not set, sinceLD_LIBRARY_PATH
can override the NVIDIA CUDA libraries.Make sure that the NVIDIA CUDA libraries installed are those requested by JAX. Rerunning the installation command above should work.
pip installation: NVIDIA GPU (CUDA, installed locally, harder)#
If you prefer to use a preinstalled copy of NVIDIA CUDA, you must first install NVIDIA CUDA and cuDNN.
JAX provides pre-built CUDA-compatible wheels for Linux x86_64 only. Other combinations of operating system and architecture are possible, but require building from source (refer to Building from source to learn more}.
You should use an NVIDIA driver version that is at least as new as your NVIDIA CUDA toolkit’s corresponding driver version. If you need to use a newer CUDA toolkit with an older driver, for example on a cluster where you cannot update the NVIDIA driver easily, you may be able to use the CUDA forward compatibility packages that NVIDIA provides for this purpose.
JAX currently ships one CUDA wheel variant:
Built with |
Compatible with |
---|---|
CUDA 12.3 |
CUDA >=12.1 |
CUDNN 8.9 |
CUDNN >=8.9, <9.0 |
NCCL 2.19 |
NCCL >=2.18 |
JAX checks the versions of your libraries, and will report an error if they are
not sufficiently new.
Setting the JAX_SKIP_CUDA_CONSTRAINTS_CHECK
environment variable will disable
the check, but using older versions of CUDA may lead to errors, or incorrect
results.
NCCL is an optional dependency, required only if you are performing multi-GPU computations.
To install, run:
pip install --upgrade pip
# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 8.9 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
These pip
installations do not work with Windows, and may fail silently; refer to the table
above.
You can find your CUDA version with the command:
nvcc --version
JAX uses LD_LIBRARY_PATH
to find CUDA libraries and PATH
to find binaries
(ptxas
, nvlink
). Please make sure that these paths point to the correct CUDA
installation.
Please let the JAX team know on the GitHub issue tracker if you run into any errors or problems with the pre-built wheels.
NVIDIA GPU Docker containers#
NVIDIA provides the JAX Toolbox containers, which are bleeding edge containers containing nightly releases of jax and some models/frameworks.
JAX nightly installation#
Nightly releases reflect the state of the main JAX repository at the time they are built, and may not pass the full test suite.
jax
:
pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
jaxlib
CPU:
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
jaxlib
Google Cloud TPU:
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
jaxlib
NVIDIA GPU (CUDA 12):
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install -U --pre jax-cuda12-pjrt jax-cuda12-plugin -f https://storage.googleapis.com/jax-releases/jax_cuda_plugin_nightly_releases.html
jaxlib
NVIDIA GPU (CUDA 12) legacy:
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
Google Cloud TPU#
pip installation: Google Cloud TPU#
JAX provides pre-built wheels for
Google Cloud TPU.
To install JAX along with appropriate versions of jaxlib
and libtpu
, you can run
the following in your cloud TPU VM:
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
For users of Colab (https://colab.research.google.com/), be sure you are using TPU v2 and not the older, deprecated TPU runtime.
Apple Silicon GPU (ARM-based)#
pip installation: Apple ARM-based Silicon GPUs#
Apple provides an experimental Metal plugin for Apple ARM-based GPU hardware. For details, refer to Apple’s JAX on Metal documentation.
Note: There are several caveats with the Metal plugin:
The Metal plugin is new and experimental and has a number of known issues. Please report any issues on the JAX issue tracker.
The Metal plugin currently requires very specific versions of
jax
andjaxlib
. This restriction will be relaxed over time as the plugin API matures.
AMD GPU#
JAX has experimental ROCm support. There are two ways to install JAX:
Use AMD’s Docker container; or
Build from source (refer to Building from source — a section called Additional notes for building a ROCM
jaxlib
for AMD GPUs).
Conda (community-supported)#
Conda installation#
There is a community-supported Conda build of jax
. To install it using conda
,
simply run:
conda install jax -c conda-forge
To install it on a machine with an NVIDIA GPU, run:
conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
Note the cudatoolkit
distributed by conda-forge
is missing ptxas
, which
JAX requires. You must therefore either install the cuda-nvcc
package from
the nvidia
channel, or install CUDA on your machine separately so that ptxas
is in your path. The channel order above is important (conda-forge
before
nvidia
).
If you would like to override which release of CUDA is used by JAX, or to
install the CUDA build on a machine without GPUs, follow the instructions in the
Tips & tricks
section of the conda-forge
website.
Go to the conda-forge
jaxlib and
jax repositories
for more details.
Building JAX from source#
Refer to Building from source.
Installing older jaxlib
wheels#
Due to storage limitations on the Python package index, the JAX team periodically removes
older jaxlib
wheels from the releases on http://pypi.org/project/jax. These can
still be installed directly via the URLs here. For example:
# Install jaxlib on CPU via the wheel archive
pip install jax[cpu]==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# Install the jaxlib 0.3.25 CPU wheel directly
pip install jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_releases.html
For specific older GPU wheels, be sure to use the jax_cuda_releases.html
URL; for example
pip install jaxlib==0.3.25+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Quickstart#
JAX a library for array-oriented numerical computation (à la NumPy), with automatic differentiation and JIT compilation to enable high-performance machine learning research.
This document provides a quick overview of essential JAX features, so you can get started with JAX quickly:
JAX provides a unified NumPy-like interface to computations that run on CPU, GPU, or TPU, in local or distributed settings.
JAX features built-in Just-In-Time (JIT) compilation via Open XLA, an open-source machine learning compiler ecosystem.
JAX functions support efficient evaluation of gradients via its automatic differentiation transformations.
JAX functions can be automatically vectorized to efficiently map them over arrays representing batches of inputs.
Installation#
JAX can be installed for CPU on Linux, Windows, and macOS directly from the Python Package Index:
pip install "jax[cpu]"
or, for NVIDIA GPU:
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
For more detailed platform-specific installation information, check out Installing JAX.
JAX as NumPy#
Most JAX usage is through the familiar jax.numpy
API, which is typically imported under the jnp
alias:
import jax.numpy as jnp
With this import, you can immediately use JAX in a similar manner to typical NumPy programs, including using NumPy-style array creation functions, Python functions and operators, and array attributes and methods:
def selu(x, alpha=1.67, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = jnp.arange(5.0)
print(selu(x))
[0. 1.05 2.1 3.1499999 4.2 ]
You’ll find a few differences between JAX arrays and NumPy arrays once you begin digging-in; these are explored in 🔪 JAX - The Sharp Bits 🔪.
Just-in-time compilation with jax.jit()
#
JAX runs transparently on the GPU or TPU (falling back to CPU if you don’t have one). However, in the above example, JAX is dispatching kernels to the chip one operation at a time. If we have a sequence of operations, we can use the jax.jit()
function to compile this sequence of operations together using XLA.
We can use IPython’s %timeit
to quickly benchmark our selu
function, using block_until_ready()
to
account for JAX’s dynamic dispatch (See Asynchronous dispatch):
from jax import random
key = random.key(1701)
x = random.normal(key, (1_000_000,))
%timeit selu(x).block_until_ready()
2.94 ms ± 18.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
(notice we’ve used jax.random
to generate some random numbers; for details on
how to generate random numbers in JAX, check out Pseudorandom numbers).
We can speed the execution of this function with the jax.jit()
transformation,
which will jit-compile the first time selu
is called and will be cached thereafter.
from jax import jit
selu_jit = jit(selu)
_ = selu_jit(x) # compiles on first call
%timeit selu_jit(x).block_until_ready()
859 µs ± 7.13 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
The above timing represent execution on CPU, but the same code can be run on GPU or TPU, typically for an even greater speedup.
For more on JIT compilation in JAX, check out Just-in-time compilation.
Taking derivatives with jax.grad()
#
In addition to transforming functions via JIT compilation, JAX also provides other
transformations. One such transformation is jax.grad()
, which performs
automatic differentiation (autodiff):
from jax import grad
def sum_logistic(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))
[0.25 0.19661197 0.10499357]
Let’s verify with finite differences that our result is correct.
def first_finite_differences(f, x, eps=1E-3):
return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
for v in jnp.eye(len(x))])
print(first_finite_differences(sum_logistic, x_small))
[0.24998187 0.1965761 0.10502338]
The grad()
and jit()
transformations compose and can be mixed arbitrarily.
In the above example we jitted sum_logistic
and then took its derivative. We can go further:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))
-0.0353256
Beyond scalar-valued functions, the jax.jacobian()
transformation can be
used to compute the full Jacobian matrix for vector-valued functions:
from jax import jacobian
print(jacobian(jnp.exp)(x_small))
[[1. 0. 0. ]
[0. 2.7182817 0. ]
[0. 0. 7.389056 ]]
For more advanced autodiff operations, you can use jax.vjp()
for reverse-mode vector-Jacobian products,
and jax.jvp()
and jax.linearize()
for forward-mode Jacobian-vector products.
The two can be composed arbitrarily with one another, and with other JAX transformations.
For example, jax.jvp()
and jax.vjp()
are used to define the forward-mode jax.jacfwd()
and reverse-mode jax.jacrev()
for computing Jacobians in forward- and reverse-mode, respectively.
Here’s one way to compose them to make a function that efficiently computes full Hessian matrices:
from jax import jacfwd, jacrev
def hessian(fun):
return jit(jacfwd(jacrev(fun)))
print(hessian(sum_logistic)(x_small))
[[-0. -0. -0. ]
[-0. -0.09085776 -0. ]
[-0. -0. -0.07996249]]
This kind of composition produces efficient code in practice; this is more-or-less how JAX’s built-in jax.hessian()
function is implemented.
For more on automatic differentiation in JAX, check out Automatic differentiation.
Auto-vectorization with jax.vmap()
#
Another useful transformation is vmap()
, the vectorizing map.
It has the familiar semantics of mapping a function along array axes, but instead of explicitly looping
over function calls, it transforms the function into a natively vectorized version for better performance.
When composed with jit()
, it can be just as performant as manually rewriting your function
operate over an extra batch dimension.
We’re going to work with a simple example, and promote matrix-vector products into matrix-matrix products using vmap()
.
Although this is easy to do by hand in this specific case, the same technique can apply to more complicated functions.
key1, key2 = random.split(key)
mat = random.normal(key1, (150, 100))
batched_x = random.normal(key2, (10, 100))
def apply_matrix(x):
return jnp.dot(mat, x)
The apply_matrix
function maps a vector to a vector, but we may want to apply it row-wise across a matrix.
We could do this by looping over the batch dimension in Python, but this usually results in poor performance.
def naively_batched_apply_matrix(v_batched):
return jnp.stack([apply_matrix(v) for v in v_batched])
print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
Naively batched
1.07 ms ± 2.76 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
A programmer familiar with the the jnp.dot
function might recognize that apply_matrix
can
be rewritten to avoid explicit looping, using the built-in batching semantics of jnp.dot
:
import numpy as np
@jit
def batched_apply_matrix(batched_x):
return jnp.dot(batched_x, mat.T)
np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()
Manually batched
18.5 µs ± 11.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
However, as functions become more complicated, this kind of manual batching becomes more difficult and error-prone.
The vmap()
transformation is designed to automatically transform a function into a batch-aware version:
from jax import vmap
@jit
def vmap_batched_apply_matrix(batched_x):
return vmap(apply_matrix)(batched_x)
np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
vmap_batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
Auto-vectorized with vmap
26.2 µs ± 130 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
As you would expect, vmap()
can be arbitrarily composed with jit()
,
grad()
, and any other JAX transformation.
For more on automatic vectorization in JAX, check out Automatic vectorization.
This is just a taste of what JAX can do. We’re really excited to see what you do with it!
🔪 JAX - The Sharp Bits 🔪#
levskaya@ mattjj@
When walking about the countryside of Italy, the people will not hesitate to tell you that JAX has “una anima di pura programmazione funzionale”.
JAX is a language for expressing and composing transformations of numerical programs. JAX is also able to compile numerical programs for CPU or accelerators (GPU/TPU). JAX works great for many numerical and scientific programs, but only if they are written with certain constraints that we describe below.
import numpy as np
from jax import grad, jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp
🔪 Pure functions#
JAX transformation and compilation are designed to work only on Python functions that are functionally pure: all the input data is passed through the function parameters, all the results are output through the function results. A pure function will always return the same result if invoked with the same inputs.
Here are some examples of functions that are not functionally pure for which JAX behaves differently than the Python interpreter. Note that these behaviors are not guaranteed by the JAX system; the proper way to use JAX is to use it only on functionally pure Python functions.
def impure_print_side_effect(x):
print("Executing function") # This is a side-effect
return x
# The side-effects appear during the first run
print ("First call: ", jit(impure_print_side_effect)(4.))
# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print ("Second call: ", jit(impure_print_side_effect)(5.))
# JAX re-runs the Python function when the type or shape of the argument changes
print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.])))
Executing function
First call: 4.0
Second call: 5.0
Executing function
Third call, different type: [5.]
g = 0.
def impure_uses_globals(x):
return x + g
# JAX captures the value of the global during the first run
print ("First call: ", jit(impure_uses_globals)(4.))
g = 10. # Update the global
# Subsequent runs may silently use the cached value of the globals
print ("Second call: ", jit(impure_uses_globals)(5.))
# JAX re-runs the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))
First call: 4.0
Second call: 5.0
Third call, different type: [14.]
g = 0.
def impure_saves_global(x):
global g
g = x
return x
# JAX runs once the transformed function with special Traced values for arguments
print ("First call: ", jit(impure_saves_global)(4.))
print ("Saved global: ", g) # Saved global has an internal JAX value
First call: 4.0
Saved global: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
A Python function can be functionally pure even if it actually uses stateful objects internally, as long as it does not read or write external state:
def pure_uses_internal_state(x):
state = dict(even=0, odd=0)
for i in range(10):
state['even' if i % 2 == 0 else 'odd'] += x
return state['even'] + state['odd']
print(jit(pure_uses_internal_state)(5.))
50.0
It is not recommended to use iterators in any JAX function you want to jit
or in any control-flow primitive. The reason is that an iterator is a python object which introduces state to retrieve the next element. Therefore, it is incompatible with JAX functional programming model. In the code below, there are some examples of incorrect attempts to use iterators with JAX. Most of them return an error, but some give unexpected results.
import jax.numpy as jnp
import jax.lax as lax
from jax import make_jaxpr
# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0
# lax.scan
def func11(arr, extra):
ones = jnp.ones(arr.shape)
def body(carry, aelems):
ae1, ae2 = aelems
return (carry + ae1 * ae2 + extra, carry)
return lax.scan(body, 0., (arr, ones))
make_jaxpr(func11)(jnp.arange(16), 5.)
# make_jaxpr(func11)(iter(range(16)), 5.) # throws error
# lax.cond
array_operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
iter_operand = iter(range(10))
# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error
45
0
🔪 In-Place Updates#
In Numpy you’re used to doing this:
numpy_array = np.zeros((3,3), dtype=np.float32)
print("original array:")
print(numpy_array)
# In place, mutating update
numpy_array[1, :] = 1.0
print("updated array:")
print(numpy_array)
original array:
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
updated array:
[[0. 0. 0.]
[1. 1. 1.]
[0. 0. 0.]]
If we try to update a JAX device array in-place, however, we get an error! (☉_☉)
%xmode Minimal
Exception reporting mode: Minimal
jax_array = jnp.zeros((3,3), dtype=jnp.float32)
# In place update of JAX's array will yield an error!
jax_array[1, :] = 1.0
TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
Allowing mutation of variables in-place makes program analysis and transformation difficult. JAX requires that programs are pure functions.
Instead, JAX offers a functional array update using the .at
property on JAX arrays.
️⚠️ inside jit
’d code and lax.while_loop
or lax.fori_loop
the size of slices can’t be functions of argument values but only functions of argument shapes – the slice start indices have no such restriction. See the below Control Flow Section for more information on this limitation.
Array updates: x.at[idx].set(y)
#
For example, the update above can be written as:
updated_array = jax_array.at[1, :].set(1.0)
print("updated array:\n", updated_array)
updated array:
[[0. 0. 0.]
[1. 1. 1.]
[0. 0. 0.]]
JAX’s array update functions, unlike their NumPy versions, operate out-of-place. That is, the updated array is returned as a new array and the original array is not modified by the update.
print("original array unchanged:\n", jax_array)
original array unchanged:
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
However, inside jit-compiled code, if the input value x
of x.at[idx].set(y)
is not reused, the compiler will optimize the array update to occur in-place.
Array updates with other operations#
Indexed array updates are not limited simply to overwriting values. For example, we can perform indexed addition as follows:
print("original array:")
jax_array = jnp.ones((5, 6))
print(jax_array)
new_jax_array = jax_array.at[::2, 3:].add(7.)
print("new array post-addition:")
print(new_jax_array)
original array:
[[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]]
new array post-addition:
[[1. 1. 1. 8. 8. 8.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 8. 8. 8.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 8. 8. 8.]]
For more details on indexed array updates, see the documentation for the .at
property.
🔪 Out-of-Bounds Indexing#
In Numpy, you are used to errors being thrown when you index an array outside of its bounds, like this:
np.arange(10)[11]
IndexError: index 11 is out of bounds for axis 0 with size 10
However, raising an error from code running on an accelerator can be difficult or impossible. Therefore, JAX must choose some non-error behavior for out of bounds indexing (akin to how invalid floating point arithmetic results in NaN
). When the indexing operation is an array index update (e.g. index_add
or scatter
-like primitives), updates at out-of-bounds indices will be skipped; when the operation is an array index retrieval (e.g. NumPy indexing or gather
-like primitives) the index is clamped to the bounds of the array since something must be returned. For example, the last value of the array will be returned from this indexing operation:
jnp.arange(10)[11]
Array(9, dtype=int32)
If you would like finer-grained control over the behavior for out-of-bound indices, you can use the optional parameters of ndarray.at
; for example:
jnp.arange(10.0).at[11].get()
Array(9., dtype=float32)
jnp.arange(10.0).at[11].get(mode='fill', fill_value=jnp.nan)
Array(nan, dtype=float32)
Note that due to this behavior for index retrieval, functions like jnp.nanargmin
and jnp.nanargmax
return -1 for slices consisting of NaNs whereas Numpy would throw an error.
Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) will not preserve the semantics of out of bounds indexing. Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of undefined behavior.
🔪 Non-array inputs: NumPy vs. JAX#
NumPy is generally happy accepting Python lists or tuples as inputs to its API functions:
np.sum([1, 2, 3])
6
JAX departs from this, generally returning a helpful error:
jnp.sum([1, 2, 3])
TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.
This is a deliberate design choice, because passing lists or tuples to traced functions can lead to silent performance degradation that might otherwise be difficult to detect.
For example, consider the following permissive version of jnp.sum
that allows list inputs:
def permissive_sum(x):
return jnp.sum(jnp.array(x))
x = list(range(10))
permissive_sum(x)
Array(45, dtype=int32)
The output is what we would expect, but this hides potential performance issues under the hood. In JAX’s tracing and JIT compilation model, each element in a Python list or tuple is treated as a separate JAX variable, and individually processed and pushed to device. This can be seen in the jaxpr for the permissive_sum
function above:
make_jaxpr(permissive_sum)(x)
{ lambda ; a:i32[] b:i32[] c:i32[] d:i32[] e:i32[] f:i32[] g:i32[] h:i32[] i:i32[]
j:i32[]. let
k:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
l:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
m:i32[] = convert_element_type[new_dtype=int32 weak_type=False] c
n:i32[] = convert_element_type[new_dtype=int32 weak_type=False] d
o:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
p:i32[] = convert_element_type[new_dtype=int32 weak_type=False] f
q:i32[] = convert_element_type[new_dtype=int32 weak_type=False] g
r:i32[] = convert_element_type[new_dtype=int32 weak_type=False] h
s:i32[] = convert_element_type[new_dtype=int32 weak_type=False] i
t:i32[] = convert_element_type[new_dtype=int32 weak_type=False] j
u:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] k
v:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] l
w:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] m
x:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] n
y:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] o
z:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] p
ba:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] q
bb:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] r
bc:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] s
bd:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] t
be:i32[10] = concatenate[dimension=0] u v w x y z ba bb bc bd
bf:i32[] = reduce_sum[axes=(0,)] be
in (bf,) }
Each entry of the list is handled as a separate input, resulting in a tracing & compilation overhead that grows linearly with the size of the list. To prevent surprises like this, JAX avoids implicit conversions of lists and tuples to arrays.
If you would like to pass a tuple or list to a JAX function, you can do so by first explicitly converting it to an array:
jnp.sum(jnp.array(x))
Array(45, dtype=int32)
🔪 Random Numbers#
If all scientific papers whose results are in doubt because of bad
rand()
s were to disappear from library shelves, there would be a gap on each shelf about as big as your fist. - Numerical Recipes
RNGs and State#
You’re used to stateful pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness:
print(np.random.random())
print(np.random.random())
print(np.random.random())
0.4157599213412506
0.09201806380722433
0.44769303473910294
Underneath the hood, numpy uses the Mersenne Twister PRNG to power its pseudorandom functions. The PRNG has a period of \(2^{19937}-1\) and at any point can be described by 624 32-bit unsigned ints and a position indicating how much of this “entropy” has been used up.
np.random.seed(0)
rng_state = np.random.get_state()
# print(rng_state)
# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,
# 2481403966, 4042607538, 337614300, ... 614 more numbers...,
# 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)
This pseudorandom state vector is automagically updated behind the scenes every time a random number is needed, “consuming” 2 of the uint32s in the Mersenne twister state vector:
_ = np.random.uniform()
rng_state = np.random.get_state()
#print(rng_state)
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)
# Let's exhaust the entropy in this PRNG statevector
for i in range(311):
_ = np.random.uniform()
rng_state = np.random.get_state()
#print(rng_state)
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
# ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)
# Next call iterates the RNG state for a new batch of fake "entropy".
_ = np.random.uniform()
rng_state = np.random.get_state()
# print(rng_state)
# --> ('MT19937', array([1499117434, 2949980591, 2242547484,
# 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)
The problem with magic PRNG state is that it’s hard to reason about how it’s being used and updated across different threads, processes, and devices, and it’s very easy to screw up when the details of entropy production and consumption are hidden from the end user.
The Mersenne Twister PRNG is also known to have a number of problems, it has a large 2.5kB state size, which leads to problematic initialization issues. It fails modern BigCrush tests, and is generally slow.
JAX PRNG#
JAX instead implements an explicit PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern Threefry counter-based PRNG that’s splittable. That is, its design allows us to fork the PRNG state into new PRNGs for use with parallel stochastic generation.
The random state is described by a special array element that we call a key:
from jax import random
key = random.key(0)
key
Array((), dtype=key<fry>) overlaying:
[0 0]
JAX’s random functions produce pseudorandom numbers from the PRNG state, but do not change the state!
Reusing the same state will cause sadness and monotony, depriving the end user of lifegiving chaos:
print(random.normal(key, shape=(1,)))
print(key)
# No no no!
print(random.normal(key, shape=(1,)))
print(key)
[-0.20584226]
Array((), dtype=key<fry>) overlaying:
[0 0]
[-0.20584226]
Array((), dtype=key<fry>) overlaying:
[0 0]
Instead, we split the PRNG to get usable subkeys every time we need a new pseudorandom number:
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print(" \---SPLIT --> new key ", key)
print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
old key Array((), dtype=key<fry>) overlaying:
[0 0]
\---SPLIT --> new key Array((), dtype=key<fry>) overlaying:
[4146024105 967050713]
\--> new subkey Array((), dtype=key<fry>) overlaying:
[2718843009 1272950319] --> normal [-1.2515389]
We propagate the key and make new subkeys whenever we need a new random number:
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print(" \---SPLIT --> new key ", key)
print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
old key Array((), dtype=key<fry>) overlaying:
[4146024105 967050713]
\---SPLIT --> new key Array((), dtype=key<fry>) overlaying:
[2384771982 3928867769]
\--> new subkey Array((), dtype=key<fry>) overlaying:
[1278412471 2182328957] --> normal [-0.58665055]
We can generate more than one subkey at a time:
key, *subkeys = random.split(key, 4)
for subkey in subkeys:
print(random.normal(subkey, shape=(1,)))
[-0.37533438]
[0.98645043]
[0.14553197]
🔪 Control Flow#
✔ python control_flow + autodiff ✔#
If you just want to apply grad
to your python functions, you can use regular python control-flow constructs with no problems, as if you were using Autograd (or Pytorch or TF Eager).
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
print(grad(f)(2.)) # ok!
print(grad(f)(4.)) # ok!
12.0
-4.0
python control flow + JIT#
Using control flow with jit
is more complicated, and by default it has more constraints.
This works:
@jit
def f(x):
for i in range(3):
x = 2 * x
return x
print(f(3))
24
So does this:
@jit
def g(x):
y = 0.
for i in range(x.shape[0]):
y = y + x[i]
return y
print(g(jnp.array([1., 2., 3.])))
6.0
But this doesn’t, at least by default:
@jit
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
# This will fail!
f(2)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function f at /tmp/ipykernel_3575/3402096563.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
What gives!?
When we jit
-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don’t have to re-compile on each function evaluation.
For example, if we evaluate an @jit
function on the array jnp.array([1., 2., 3.], jnp.float32)
, we might want to compile code that we can reuse to evaluate the function on jnp.array([4., 5., 6.], jnp.float32)
to save on compile time.
To get a view of your Python code that is valid for many different argument values, JAX traces it on abstract values that represent sets of possible inputs. There are multiple different levels of abstraction, and different transformations use different abstraction levels.
By default, jit
traces your code on the ShapedArray
abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value ShapedArray((3,), jnp.float32)
, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.
But there’s a tradeoff here: if we trace a Python function on a ShapedArray((), jnp.float32)
that isn’t committed to a specific concrete value, when we hit a line like if x < 3
, the expression x < 3
evaluates to an abstract ShapedArray((), jnp.bool_)
that represents the set {True, False}
. When Python attempts to coerce that to a concrete True
or False
, we get an error: we don’t know which branch to take, and can’t continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace.
The good news is that you can control this tradeoff yourself. By having jit
trace on more refined abstract values, you can relax the traceability constraints. For example, using the static_argnums
argument to jit
, we can specify to trace on concrete values of some arguments. Here’s that example function again:
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
f = jit(f, static_argnums=(0,))
print(f(2.))
12.0
Here’s another example, this time involving a loop:
def f(x, n):
y = 0.
for i in range(n):
y = y + x[i]
return y
f = jit(f, static_argnums=(1,))
f(jnp.array([2., 3., 4.]), 2)
Array(5., dtype=float32)
In effect, the loop gets statically unrolled. JAX can also trace at higher levels of abstraction, like Unshaped
, but that’s not currently the default for any transformation
️⚠️ functions with argument-value dependent shapes
These control-flow issues also come up in a more subtle way: numerical functions we want to jit can’t specialize the shapes of internal arrays on argument values (specializing on argument shapes is ok). As a trivial example, let’s make a function whose output happens to depend on the input variable length
.
def example_fun(length, val):
return jnp.ones((length,)) * val
# un-jit'd works fine
print(example_fun(5, 4))
[4. 4. 4. 4. 4.]
bad_example_jit = jit(example_fun)
# this will fail:
bad_example_jit(10, 4)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function example_fun at /tmp/ipykernel_3575/1210496444.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument length.
# static_argnums tells JAX to recompile on changes at these argument positions:
good_example_jit = jit(example_fun, static_argnums=(0,))
# first compile
print(good_example_jit(10, 4))
# recompiles
print(good_example_jit(5, 4))
[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]
[4. 4. 4. 4. 4.]
static_argnums
can be handy if length
in our example rarely changes, but it would be disastrous if it changed a lot!
Lastly, if your function has global side-effects, JAX’s tracer can cause weird things to happen. A common gotcha is trying to print arrays inside jit’d functions:
@jit
def f(x):
print(x)
y = 2 * x
print(y)
return y
f(2)
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
Array(4, dtype=int32, weak_type=True)
Structured control flow primitives#
There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that’s traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives:
lax.cond
differentiablelax.while_loop
fwd-mode-differentiablelax.fori_loop
fwd-mode-differentiable in general; fwd and rev-mode differentiable if endpoints are static.lax.scan
differentiable
cond
#
python equivalent:
def cond(pred, true_fun, false_fun, operand):
if pred:
return true_fun(operand)
else:
return false_fun(operand)
from jax import lax
operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, operand)
# --> array([1.], dtype=float32)
lax.cond(False, lambda x: x+1, lambda x: x-1, operand)
# --> array([-1.], dtype=float32)
Array([-1.], dtype=float32)
jax.lax
provides two other functions that allow branching on dynamic predicates:
lax.select
is like a batched version oflax.cond
, with the choices expressed as pre-computed arrays rather than as functions.lax.switch
is likelax.cond
, but allows switching between any number of callable choices.
In addition, jax.numpy
provides several numpy-style interfaces to these functions:
jnp.where
with three arguments is the numpy-style wrapper oflax.select
.jnp.piecewise
is a numpy-style wrapper oflax.switch
, but switches on a list of boolean conditions rather than a single scalar index.jnp.select
has an API similar tojnp.piecewise
, but the choices are given as pre-computed arrays rather than as functions. It is implemented in terms of multiple calls tolax.select
.
while_loop
#
python equivalent:
def while_loop(cond_fun, body_fun, init_val):
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
init_val = 0
cond_fun = lambda x: x<10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)
# --> array(10, dtype=int32)
Array(10, dtype=int32, weak_type=True)
fori_loop
#
python equivalent:
def fori_loop(start, stop, body_fun, init_val):
val = init_val
for i in range(start, stop):
val = body_fun(i, val)
return val
init_val = 0
start = 0
stop = 10
body_fun = lambda i,x: x+i
lax.fori_loop(start, stop, body_fun, init_val)
# --> array(45, dtype=int32)
Array(45, dtype=int32, weak_type=True)
Summary#
\(\ast\) = argument-value-independent loop condition - unrolls the loop
🔪 Dynamic Shapes#
JAX code used within transforms like jax.jit
, jax.vmap
, jax.grad
, etc. requires all output arrays and intermediate arrays to have static shape: that is, the shape cannot depend on values within other arrays.
For example, if you were implementing your own version of jnp.nansum
, you might start with something like this:
def nansum(x):
mask = ~jnp.isnan(x) # boolean mask selecting non-nan values
x_without_nans = x[mask]
return x_without_nans.sum()
Outside JIT and other transforms, this works as expected:
x = jnp.array([1, 2, jnp.nan, 3, 4])
print(nansum(x))
10.0
If you attempt to apply jax.jit
or another transform to this function, it will error:
jax.jit(nansum)(x)
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[5])
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError
The problem is that the size of x_without_nans
is dependent on the values within x
, which is another way of saying its size is dynamic.
Often in JAX it is possible to work-around the need for dynamically-sized arrays via other means.
For example, here it is possible to use the three-argument form of jnp.where
to replace the NaN values with zeros, thus computing the same result while avoiding dynamic shapes:
@jax.jit
def nansum_2(x):
mask = ~jnp.isnan(x) # boolean mask selecting non-nan values
return jnp.where(mask, x, 0).sum()
print(nansum_2(x))
10.0
Similar tricks can be played in other situations where dynamically-shaped arrays occur.
🔪 NaNs#
Debugging NaNs#
If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by:
setting the
JAX_DEBUG_NANS=True
environment variable;adding
jax.config.update("jax_debug_nans", True)
near the top of your main file;adding
jax.config.parse_flags_with_absl()
to your main file, then set the option using a command-line flag like--jax_debug_nans=True
;
This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an @jit
. For code under an @jit
, the output of every @jit
function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of @jit
at a time.
There could be tricky situations that arise, like nans that only occur under a @jit
but don’t get produced in de-optimized mode. In that case you’ll see a warning message print out but your code will continue to execute.
If the nans are being produced in the backward pass of a gradient evaluation, when an exception is raised several frames up in the stack trace you will be in the backward_pass function, which is essentially a simple jaxpr interpreter that walks the sequence of primitive operations in reverse. In the example below, we started an ipython repl with the command line env JAX_DEBUG_NANS=True ipython
, then ran this:
In [1]: import jax.numpy as jnp
In [2]: jnp.divide(0., 0.)
---------------------------------------------------------------------------
FloatingPointError Traceback (most recent call last)
<ipython-input-2-f2e2c413b437> in <module>()
----> 1 jnp.divide(0., 0.)
.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
343 return floor_divide(x1, x2)
344 else:
--> 345 return true_divide(x1, x2)
346
347
.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
332 x1, x2 = _promote_shapes(x1, x2)
333 return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334 lax.convert_element_type(x2, result_dtype))
335
336
.../jax/jax/lax.pyc in div(x, y)
244 def div(x, y):
245 r"""Elementwise division: :math:`x \over y`."""
--> 246 return div_p.bind(x, y)
247
248 def rem(x, y):
... stack trace ...
.../jax/jax/interpreters/xla.pyc in handle_result(device_buffer)
103 py_val = device_buffer.to_py()
104 if np.any(np.isnan(py_val)):
--> 105 raise FloatingPointError("invalid value")
106 else:
107 return Array(device_buffer, *result_shape)
FloatingPointError: invalid value
The nan generated was caught. By running %debug
, we can get a post-mortem debugger. This also works with functions under @jit
, as the example below shows.
In [4]: from jax import jit
In [5]: @jit
...: def f(x, y):
...: a = x * y
...: b = (x + y) / (x - y)
...: c = a + 2
...: return a + b * c
...:
In [6]: x = jnp.array([2., 0.])
In [7]: y = jnp.array([3., 0.])
In [8]: f(x, y)
Invalid value encountered in the output of a jit function. Calling the de-optimized version.
---------------------------------------------------------------------------
FloatingPointError Traceback (most recent call last)
<ipython-input-8-811b7ddb3300> in <module>()
----> 1 f(x, y)
... stack trace ...
<ipython-input-5-619b39acbaac> in f(x, y)
2 def f(x, y):
3 a = x * y
----> 4 b = (x + y) / (x - y)
5 c = a + 2
6 return a + b * c
.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
343 return floor_divide(x1, x2)
344 else:
--> 345 return true_divide(x1, x2)
346
347
.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
332 x1, x2 = _promote_shapes(x1, x2)
333 return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334 lax.convert_element_type(x2, result_dtype))
335
336
.../jax/jax/lax.pyc in div(x, y)
244 def div(x, y):
245 r"""Elementwise division: :math:`x \over y`."""
--> 246 return div_p.bind(x, y)
247
248 def rem(x, y):
... stack trace ...
When this code sees a nan in the output of an @jit
function, it calls into the de-optimized code, so we still get a clear stack trace. And we can run a post-mortem debugger with %debug
to inspect all the values to figure out the error.
⚠️ You shouldn’t have the NaN-checker on if you’re not debugging, as it can introduce lots of device-host round-trips and performance regressions!
⚠️ The NaN-checker doesn’t work with pmap
. To debug nans in pmap
code, one thing to try is replacing pmap
with vmap
.
🔪 Double (64bit) precision#
At the moment, JAX by default enforces single-precision numbers to mitigate the Numpy API’s tendency to aggressively promote operands to double
. This is the desired behavior for many machine-learning applications, but it may catch you by surprise!
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype
/tmp/ipykernel_3575/1258726447.py:1: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
dtype('float32')
To use double-precision numbers, you need to set the jax_enable_x64
configuration variable at startup.
There are a few ways to do this:
You can enable 64-bit mode by setting the environment variable
JAX_ENABLE_X64=True
.You can manually set the
jax_enable_x64
configuration flag at startup:# again, this only works on startup! import jax jax.config.update("jax_enable_x64", True)
You can parse command-line flags with
absl.app.run(main)
import jax jax.config.config_with_absl()
If you want JAX to run absl parsing for you, i.e. you don’t want to do
absl.app.run(main)
, you can instead useimport jax if __name__ == '__main__': # calls jax.config.config_with_absl() *and* runs absl parsing jax.config.parse_flags_with_absl()
Note that #2-#4 work for any of JAX’s configuration options.
We can then confirm that x64
mode is enabled:
import jax.numpy as jnp
from jax import random
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype # --> dtype('float64')
/tmp/ipykernel_3575/2819792939.py:3: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
dtype('float32')
Caveats#
⚠️ XLA doesn’t support 64-bit convolutions on all backends!
🔪 Miscellaneous Divergences from NumPy#
While jax.numpy
makes every attempt to replicate the behavior of numpy’s API, there do exist corner cases where the behaviors differ.
Many such cases are discussed in detail in the sections above; here we list several other known places where the APIs diverge.
For binary operations, JAX’s type promotion rules differ somewhat from those used by NumPy. See Type Promotion Semantics for more details.
When performing unsafe type casts (i.e. casts in which the target dtype cannot represent the input value), JAX’s behavior may be backend dependent, and in general may diverge from NumPy’s behavior. Numpy allows control over the result in these scenarios via the
casting
argument (seenp.ndarray.astype
); JAX does not provide any such configuration, instead directly inheriting the behavior of XLA:ConvertElementType.Here is an example of an unsafe cast with differing results between NumPy and JAX:
>>> np.arange(254.0, 258.0).astype('uint8') array([254, 255, 0, 1], dtype=uint8) >>> jnp.arange(254.0, 258.0).astype('uint8') Array([254, 255, 255, 255], dtype=uint8)
This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.
Fin.#
If something’s not covered here that has caused you weeping and gnashing of teeth, please let us know and we’ll extend these introductory advisos!
JAX Frequently Asked Questions (FAQ)#
We are collecting answers to frequently asked questions here. Contributions welcome!
jit
changes the behavior of my function#
If you have a Python function that changes behavior after using jax.jit()
, perhaps
your function uses global state, or has side-effects. In the following code, the
impure_func
uses the global y
and has a side-effect due to print
:
y = 0
# @jit # Different behavior with jit
def impure_func(x):
print("Inside:", y)
return x + y
for y in range(3):
print("Result:", impure_func(y))
Without jit
the output is:
Inside: 0
Result: 0
Inside: 1
Result: 2
Inside: 2
Result: 4
and with jit
it is:
Inside: 0
Result: 0
Result: 1
Result: 2
For jax.jit()
, the function is executed once using the Python interpreter, at which time the
Inside
printing happens, and the first value of y
is observed. Then, the function
is compiled and cached, and executed multiple times with different values of x
, but
with the same first value of y
.
Additional reading:
jit
changes the exact numerics of outputs#
Sometimes users are surprised by the fact that wrapping a function with jit()
can change the function’s outputs. For example:
>>> from jax import jit
>>> import jax.numpy as jnp
>>> def f(x):
... return jnp.log(jnp.sqrt(x))
>>> x = jnp.pi
>>> print(f(x))
0.572365
>>> print(jit(f)(x))
0.5723649
This slight difference in output comes from optimizations within the XLA compiler: during compilation, XLA will sometimes rearrange or elide certain operations to make the overall computation more efficient.
In this case, XLA utilizes the properties of the logarithm to replace log(sqrt(x))
with 0.5 * log(x)
, which is a mathematically identical expression that can be
computed more efficiently than the original. The difference in output comes from
the fact that floating point arithmetic is only a close approximation of real math,
so different ways of computing the same expression may have subtly different results.
Other times, XLA’s optimizations may lead to even more drastic differences. Consider the following example:
>>> def f(x):
... return jnp.log(jnp.exp(x))
>>> x = 100.0
>>> print(f(x))
inf
>>> print(jit(f)(x))
100.0
In non-JIT-compiled op-by-op mode, the result is inf
because jnp.exp(x)
overflows and returns inf
. Under JIT, however, XLA recognizes that log
is
the inverse of exp
, and removes the operations from the compiled function,
simply returning the input. In this case, JIT compilation produces a more accurate
floating point approximation of the real result.
Unfortunately the full list of XLA’s algebraic simplifications is not well documented, but if you’re familiar with C++ and curious about what types of optimizations the XLA compiler makes, you can see them in the source code: algebraic_simplifier.cc.
jit
decorated function is very slow to compile#
If your jit
decorated function takes tens of seconds (or more!) to run the
first time you call it, but executes quickly when called again, JAX is taking a
long time to trace or compile your code.
This is usually a sign that calling your function generates a large amount of
code in JAX’s internal representation, typically because it makes heavy use of
Python control flow such as for
loops. For a handful of loop iterations,
Python is OK, but if you need many loop iterations, you should rewrite your
code to make use of JAX’s
structured control flow primitives
(such as lax.scan()
) or avoid wrapping the loop with jit
(you can
still use jit
decorated functions inside the loop).
If you’re not sure if this is the problem, you can try running
jax.make_jaxpr()
on your function. You can expect slow compilation if the
output is many hundreds or thousands of lines long.
Sometimes it isn’t obvious how to rewrite your code to avoid Python loops
because your code makes use of many arrays with different shapes. The
recommended solution in this case is to make use of functions like
jax.numpy.where()
to do your computation on padded arrays with fixed
shape.
If your functions are slow to compile for another reason, please open an issue on GitHub.
How to use jit
with methods?#
Most examples of jax.jit()
concern decorating stand-alone Python functions,
but decorating a method within a class introduces some complication. For example,
consider the following simple class, where we’ve used a standard jit()
annotation on a method:
>>> import jax.numpy as jnp
>>> from jax import jit
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... @jit # <---- How to do this correctly?
... def calc(self, y):
... if self.mul:
... return self.x * y
... return y
However, this approach will result in an error when you attempt to call this method:
>>> c = CustomClass(2, True)
>>> c.calc(3)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File "<stdin>", line 1, in <module
TypeError: Argument '<CustomClass object at 0x7f7dd4125890>' of type <class 'CustomClass'> is not a valid JAX type.
The problem is that the first argument to the function is self
, which has type
CustomClass
, and JAX does not know how to handle this type.
There are three basic strategies we might use in this case, and we’ll discuss
them below.
Strategy 1: JIT-compiled helper function#
The most straightforward approach is to create a helper function external to the class that can be JIT-decorated in the normal way. For example:
>>> from functools import partial
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... def calc(self, y):
... return _calc(self.mul, self.x, y)
>>> @partial(jit, static_argnums=0)
... def _calc(mul, x, y):
... if mul:
... return x * y
... return y
The result will work as expected:
>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6
The benefit of such an approach is that it is simple, explicit, and it avoids the need
to teach JAX how to handle objects of type CustomClass
. However, you may wish to
keep all the method logic in the same place.
Strategy 2: Marking self
as static#
Another common pattern is to use static_argnums
to mark the self
argument as static.
But this must be done with care to avoid unexpected results.
You may be tempted to simply do this:
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... # WARNING: this example is broken, as we'll see below. Don't copy & paste!
... @partial(jit, static_argnums=0)
... def calc(self, y):
... if self.mul:
... return self.x * y
... return y
If you call the method, it will no longer raise an error:
>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6
However, there is a catch: if you mutate the object after the first method call, the subsequent method call may return an incorrect result:
>>> c.mul = False
>>> print(c.calc(3)) # Should print 3
6
Why is this? When you mark an object as static, it will effectively be used as a dictionary
key in JIT’s internal compilation cache, meaning its hash (i.e. hash(obj)
) equality
(i.e. obj1 == obj2
) and object identity (i.e. obj1 is obj2
) will be assumed to have
consistent behavior. The default __hash__
for a custom object is its object ID, and so
JAX has no way of knowing that a mutated object should trigger a re-compilation.
You can partially address this by defining an appropriate __hash__
and __eq__
methods
for your object; for example:
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... @partial(jit, static_argnums=0)
... def calc(self, y):
... if self.mul:
... return self.x * y
... return y
...
... def __hash__(self):
... return hash((self.x, self.mul))
...
... def __eq__(self, other):
... return (isinstance(other, CustomClass) and
... (self.x, self.mul) == (other.x, other.mul))
(see the object.__hash__()
documentation for more discussion of the requirements
when overriding __hash__
).
This should work correctly with JIT and other transforms so long as you never mutate
your object. Mutations of objects used as hash keys lead to several subtle problems,
which is why for example mutable Python containers (e.g. dict
, list
)
don’t define __hash__
, while their immutable counterparts (e.g. tuple
) do.
If your class relies on in-place mutations (such as setting self.attr = ...
within its
methods), then your object is not really “static” and marking it as such may lead to problems.
Fortunately, there’s another option for this case.
Strategy 3: Making CustomClass
a PyTree#
The most flexible approach to correctly JIT-compiling a class method is to register the type as a custom PyTree object; see Extending pytrees. This lets you specify exactly which components of the class should be treated as static and which should be treated as dynamic. Here’s how it might look:
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... @jit
... def calc(self, y):
... if self.mul:
... return self.x * y
... return y
...
... def _tree_flatten(self):
... children = (self.x,) # arrays / dynamic values
... aux_data = {'mul': self.mul} # static values
... return (children, aux_data)
...
... @classmethod
... def _tree_unflatten(cls, aux_data, children):
... return cls(*children, **aux_data)
>>> from jax import tree_util
>>> tree_util.register_pytree_node(CustomClass,
... CustomClass._tree_flatten,
... CustomClass._tree_unflatten)
This is certainly more involved, but it solves all the issues associated with the simpler approaches used above:
>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6
>>> c.mul = False # mutation is detected
>>> print(c.calc(3))
3
>>> c = CustomClass(jnp.array(2), True) # non-hashable x is supported
>>> print(c.calc(3))
6
So long as your tree_flatten
and tree_unflatten
functions correctly handle all
relevant attributes in the class, you should be able to use objects of this type directly
as arguments to JIT-compiled functions, without any special annotations.
Controlling data and computation placement on devices#
Let’s first look at the principles of data and computation placement in JAX.
In JAX, the computation follows data placement. JAX arrays have two placement properties: 1) the device where the data resides; and 2) whether it is committed to the device or not (the data is sometimes referred to as being sticky to the device).
By default, JAX arrays are placed uncommitted on the default device
(jax.devices()[0]
), which is the first GPU or TPU by default. If no GPU or
TPU is present, jax.devices()[0]
is the CPU. The default device can
be temporarily overridden with the jax.default_device()
context manager, or
set for the whole process by setting the environment variable JAX_PLATFORMS
or the absl flag --jax_platforms
to “cpu”, “gpu”, or “tpu”
(JAX_PLATFORMS
can also be a list of platforms, which determines which
platforms are available in priority order).
>>> from jax import numpy as jnp
>>> print(jnp.ones(3).devices())
{CudaDevice(id=0)}
Computations involving uncommitted data are performed on the default device and the results are uncommitted on the default device.
Data can also be placed explicitly on a device using jax.device_put()
with a device
parameter, in which case the data becomes committed to the device:
>>> import jax
>>> from jax import device_put
>>> arr = device_put(1, jax.devices()[2])
>>> print(arr.devices())
{CudaDevice(id=2)}
Computations involving some committed inputs will happen on the committed device and the result will be committed on the same device. Invoking an operation on arguments that are committed to more than one device will raise an error.
You can also use jax.device_put()
without a device
parameter. If the data
is already on a device (committed or not), it’s left as-is. If the data isn’t on any
device—that is, it’s a regular Python or NumPy value—it’s placed uncommitted on the default
device.
Jitted functions behave like any other primitive operations—they will follow the data and will show errors if invoked on data committed on more than one device.
(Before PR #6002 in March 2021
there was some laziness in creation of array constants, so that
jax.device_put(jnp.zeros(...), jax.devices()[1])
or similar would actually
create the array of zeros on jax.devices()[1]
, instead of creating the
array on the default device then moving it. But this optimization was removed
so as to simplify the implementation.)
(As of April 2020, jax.jit()
has a device parameter that affects the device
placement. That parameter is experimental, is likely to be removed or changed,
and its use is not recommended.)
For a worked-out example, we recommend reading through
test_computation_follows_data
in
multi_device_test.py.
Benchmarking JAX code#
You just ported a tricky function from NumPy/SciPy to JAX. Did that actually speed things up?
Keep in mind these important differences from NumPy when measuring the speed of code using JAX:
JAX code is Just-In-Time (JIT) compiled. Most code written in JAX can be written in such a way that it supports JIT compilation, which can make it run much faster (see To JIT or not to JIT). To get maximum performance from JAX, you should apply
jax.jit()
on your outer-most function calls.Keep in mind that the first time you run JAX code, it will be slower because it is being compiled. This is true even if you don’t use
jit
in your own code, because JAX’s builtin functions are also JIT compiled.JAX has asynchronous dispatch. This means that you need to call
.block_until_ready()
to ensure that computation has actually happened (see Asynchronous dispatch).JAX by default only uses 32-bit dtypes. You may want to either explicitly use 32-bit dtypes in NumPy or enable 64-bit dtypes in JAX (see Double (64 bit) precision) for a fair comparison.
Transferring data between CPUs and accelerators takes time. If you only want to measure how long it takes to evaluate a function, you may want to transfer data to the device on which you want to run it first (see Controlling data and computation placement on devices).
Here’s an example of how to put together all these tricks into a microbenchmark for comparing JAX versus NumPy, making using of IPython’s convenient %time and %timeit magics:
import numpy as np
import jax.numpy as jnp
import jax
def f(x): # function we're benchmarking (works in both NumPy & JAX)
return x.T @ (x - x.mean(axis=0))
x_np = np.ones((1000, 1000), dtype=np.float32) # same as JAX default dtype
%timeit f(x_np) # measure NumPy runtime
%time x_jax = jax.device_put(x_np) # measure JAX device transfer time
f_jit = jax.jit(f)
%time f_jit(x_jax).block_until_ready() # measure JAX compilation time
%timeit f_jit(x_jax).block_until_ready() # measure JAX runtime
When run with a GPU in Colab, we see:
NumPy takes 16.2 ms per evaluation on the CPU
JAX takes 1.26 ms to copy the NumPy arrays onto the GPU
JAX takes 193 ms to compile the function
JAX takes 485 µs per evaluation on the GPU
In this case, we see that once the data is transferred and the function is compiled, JAX on the GPU is about 30x faster for repeated evaluations.
Is this a fair comparison? Maybe. The performance that ultimately matters is for
running full applications, which inevitably include some amount of both data
transfer and compilation. Also, we were careful to pick large enough arrays
(1000x1000) and an intensive enough computation (the @
operator is
performing matrix-matrix multiplication) to amortize the increased overhead of
JAX/accelerators vs NumPy/CPU. For example, if we switch this example to use
10x10 input instead, JAX/GPU runs 10x slower than NumPy/CPU (100 µs vs 10 µs).
Is JAX faster than NumPy?#
One question users frequently attempt to answer with such benchmarks is whether JAX is faster than NumPy; due to the difference in the two packages, there is not a simple answer.
Broadly speaking:
NumPy operations are executed eagerly, synchronously, and only on CPU.
JAX operations may be executed eagerly or after compilation (if inside
jit()
); they are dispatched asynchronously (see Asynchronous dispatch); and they can be executed on CPU, GPU, or TPU, each of which have vastly different and continuously evolving performance characteristics.
These architectural differences make meaningful direct benchmark comparisons between NumPy and JAX difficult.
Additionally, these differences have led to different engineering focus between the packages: for example, NumPy has put significant effort into decreasing the per-call dispatch overhead for individual array operations, because in NumPy’s computational model that overhead cannot be avoided. JAX, on the other hand, has several ways to avoid dispatch overhead (e.g. JIT compilation, asynchronous dispatch, batching transforms, etc.), and so reducing per-call overhead has been less of a priority.
Keeping all that in mind, in summary: if you’re doing microbenchmarks of individual array operations on CPU, you can generally expect NumPy to outperform JAX due to its lower per-operation dispatch overhead. If you’re running your code on GPU or TPU, or are benchmarking more complicated JIT-compiled sequences of operations on CPU, you can generally expect JAX to outperform NumPy.
Different kinds of JAX values#
In the process of transforming functions, JAX replaces some function arguments with special tracer values.
You could see this if you use a print
statement:
def func(x):
print(x)
return jnp.cos(x)
res = jax.jit(func)(0.)
The above code does return the correct value 1.
but it also prints
Traced<ShapedArray(float32[])>
for the value of x
. Normally, JAX
handles these tracer values internally in a transparent way, e.g.,
in the numeric JAX primitives that are used to implement the
jax.numpy
functions. This is why jnp.cos
works in the example above.
More precisely, a tracer value is introduced for the argument of
a JAX-transformed function, except the arguments identified by special
parameters such as static_argnums
for jax.jit()
or
static_broadcasted_argnums
for jax.pmap()
. Typically, computations
that involve at least a tracer value will produce a tracer value. Besides tracer
values, there are regular Python values: values that are computed outside JAX
transformations, or arise from above-mentioned static arguments of certain JAX
transformations, or computed solely from other regular Python values.
These are the values that are used everywhere in absence of JAX transformations.
A tracer value carries an abstract value, e.g., ShapedArray
with information
about the shape and dtype of an array. We will refer here to such tracers as
abstract tracers. Some tracers, e.g., those that are
introduced for arguments of autodiff transformations, carry ConcreteArray
abstract values that actually include the regular array data, and are used,
e.g., for resolving conditionals. We will refer here to such tracers
as concrete tracers. Tracer values computed from these concrete tracers,
perhaps in combination with regular values, result in concrete tracers.
A concrete value is either a regular value or a concrete tracer.
Most often values computed from tracer values are themselves tracer values.
There are very few exceptions, when a computation can be entirely done
using the abstract value carried by a tracer, in which case the result
can be a regular value. For example, getting the shape of a tracer
with ShapedArray
abstract value. Another example is when explicitly
casting a concrete tracer value to a regular type, e.g., int(x)
or
x.astype(float)
.
Another such situation is for bool(x)
, which produces a Python bool when
concreteness makes it possible. That case is especially salient because
of how often it arises in control flow.
Here is how the transformations introduce abstract or concrete tracers:
jax.jit()
: introduces abstract tracers for all positional arguments except those denoted bystatic_argnums
, which remain regular values.jax.pmap()
: introduces abstract tracers for all positional arguments except those denoted bystatic_broadcasted_argnums
.jax.vmap()
,jax.make_jaxpr()
,xla_computation()
: introduce abstract tracers for all positional arguments.jax.jvp()
andjax.grad()
introduce concrete tracers for all positional arguments. An exception is when these transformations are within an outer transformation and the actual arguments are themselves abstract tracers; in that case, the tracers introduced by the autodiff transformations are also abstract tracers.All higher-order control-flow primitives (
lax.cond()
,lax.while_loop()
,lax.fori_loop()
,lax.scan()
) when they process the functionals introduce abstract tracers, whether or not there is a JAX transformation in progress.
All of this is relevant when you have code that can operate only on regular Python values, such as code that has conditional control-flow based on data:
def divide(x, y):
return x / y if y >= 1. else 0.
If we want to apply jax.jit()
, we must ensure to specify static_argnums=1
to ensure y
stays a regular value. This is due to the boolean expression
y >= 1.
, which requires concrete values (regular or tracers). The
same would happen if we write explicitly bool(y >= 1.)
, or int(y)
,
or float(y)
.
Interestingly, jax.grad(divide)(3., 2.)
, works because jax.grad()
uses concrete tracers, and resolves the conditional using the concrete
value of y
.
Buffer donation#
When JAX executes a computation it uses buffers on the device for all inputs and outputs. If you know that one of the inputs is not needed after the computation, and if it matches the shape and element type of one of the outputs, you can specify that you want the corresponding input buffer to be donated to hold an output. This will reduce the memory required for the execution by the size of the donated buffer.
If you have something like the following pattern, you can use buffer donation:
params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params, state)
You can think of this as a way to do a memory-efficient functional update on your immutable JAX arrays. Within the boundaries of a computation XLA can make this optimization for you, but at the jit/pmap boundary you need to guarantee to XLA that you will not use the donated input buffer after calling the donating function.
You achieve this by using the donate_argnums parameter to the functions jax.jit()
,
jax.pjit()
, and jax.pmap()
. This parameter is a sequence of indices (0 based) into
the positional argument list:
def add(x, y):
return x + y
x = jax.device_put(np.ones((2, 3)))
y = jax.device_put(np.ones((2, 3)))
# Execute `add` with donation of the buffer for `y`. The result has
# the same shape and type as `y`, so it will share its buffer.
z = jax.jit(add, donate_argnums=(1,))(x, y)
Note that this currently does not work when calling your function with key-word arguments! The following code will not donate any buffers:
params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params=params, state=state)
If an argument whose buffer is donated is a pytree, then all the buffers for its components are donated:
def add_ones(xs: List[Array]):
return [x + 1 for x in xs]
xs = [jax.device_put(np.ones((2, 3))), jax.device_put(np.ones((3, 4)))]
# Execute `add_ones` with donation of all the buffers for `xs`.
# The outputs have the same shape and type as the elements of `xs`,
# so they will share those buffers.
z = jax.jit(add_ones, donate_argnums=0)(xs)
It is not allowed to donate a buffer that is used subsequently in the computation, and JAX will give an error because the buffer for y has become invalid after it was donated:
# Donate the buffer for `y`
z = jax.jit(add, donate_argnums=(1,))(x, y)
w = y + 1 # Reuses `y` whose buffer was donated above
# >> RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer
You will get a warning if the donated buffer is not used, e.g., because there are more donated buffers than can be used for the outputs:
# Execute `add` with donation of the buffers for both `x` and `y`.
# One of those buffers will be used for the result, but the other will
# not be used.
z = jax.jit(add, donate_argnums=(0, 1))(x, y)
# >> UserWarning: Some donated buffers were not usable: f32[2,3]{1,0}
The donation may also be unused if there is no output whose shape matches the donation:
y = jax.device_put(np.ones((1, 3))) # `y` has different shape than the output
# Execute `add` with donation of the buffer for `y`.
z = jax.jit(add, donate_argnums=(1,))(x, y)
# >> UserWarning: Some donated buffers were not usable: f32[1,3]{1,0}
Gradients contain NaN where using where
#
If you define a function using where
to avoid an undefined value, if you
are not careful you may obtain a NaN
for reverse differentiation:
def my_log(x):
return jnp.where(x > 0., jnp.log(x), 0.)
my_log(0.) ==> 0. # Ok
jax.grad(my_log)(0.) ==> NaN
A short explanation is that during grad
computation the adjoint corresponding
to the undefined jnp.log(x)
is a NaN
and it gets accumulated to the
adjoint of the jnp.where
. The correct way to write such functions is to ensure
that there is a jnp.where
inside the partially-defined function, to ensure
that the adjoint is always finite:
def safe_for_grad_log(x):
return jnp.log(jnp.where(x > 0., x, 1.))
safe_for_grad_log(0.) ==> 0. # Ok
jax.grad(safe_for_grad_log)(0.) ==> 0. # Ok
The inner jnp.where
may be needed in addition to the original one, e.g.:
def my_log_or_y(x, y):
"""Return log(x) if x > 0 or y"""
return jnp.where(x > 0., jnp.log(jnp.where(x > 0., x, 1.), y)
Additional reading:
Why are gradients zero for functions based on sort order?#
If you define a function that processes the input using operations that depend on
the relative ordering of inputs (e.g. max
, greater
, argsort
, etc.) then
you may be surprised to find that the gradient is everywhere zero.
Here is an example, where we define f(x) to be a step function that returns
0 when x is negative, and 1 when x is positive:
import jax
import numpy as np
import jax.numpy as jnp
def f(x):
return (x > 0).astype(float)
df = jax.vmap(jax.grad(f))
x = jnp.array([-1.0, -0.5, 0.0, 0.5, 1.0])
print(f"f(x) = {f(x)}")
# f(x) = [0. 0. 0. 1. 1.]
print(f"df(x) = {df(x)}")
# df(x) = [0. 0. 0. 0. 0.]
The fact that the gradient is everywhere zero may be confusing at first glance: after all, the output does change in response to the input, so how can the gradient be zero? However, zero turns out to be the correct result in this case.
Why is this? Remember that what differentiation is measuring the change in f
given an infinitesimal change in x
. For x=1.0
, f
returns 1.0
.
If we perturb x
to make it slightly larger or smaller, this does not change
the output, so by definition, grad(f)(1.0)
should be zero.
This same logic holds for all values of f
greater than zero: infinitesimally
perturbing the input does not change the output, so the gradient is zero.
Similarly, for all values of x
less than zero, the output is zero.
Perturbing x
does not change this output, so the gradient is zero.
That leaves us with the tricky case of x=0
. Surely, if you perturb x
upward,
it will change the output, but this is problematic: an infinitesimal change in x
produces a finite change in the function value, which implies the gradient is
undefined.
Fortunately, there’s another way for us to measure the gradient in this case: we
perturb the function downward, in which case the output does not change, and so the
gradient is zero.
JAX and other autodiff systems tend to handle discontinuities in this way: if the
positive gradient and negative gradient disagree, but one is defined and the other is
not, we use the one that is defined.
Under this definition of the gradient, mathematically and numerically the gradient of
this function is everywhere zero.
The problem stems from the fact that our function has a discontinuity at x = 0
.
Our f
here is essentially a Heaviside Step Function, and we can use a
Sigmoid Function as a smoothed replacement.
The sigmoid is approximately equal to the heaviside function when x is far from zero,
but replaces the discontinuity at x = 0
with a smooth, differentiable curve.
As a result of using jax.nn.sigmoid()
, we get a similar computation with
well-defined gradients:
def g(x):
return jax.nn.sigmoid(x)
dg = jax.vmap(jax.grad(g))
x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])
with np.printoptions(suppress=True, precision=2):
print(f"g(x) = {g(x)}")
# g(x) = [0. 0.27 0.5 0.73 1. ]
print(f"dg(x) = {dg(x)}")
# dg(x) = [0. 0.2 0.25 0.2 0. ]
The jax.nn
submodule also has smooth versions of other common rank-based
functions, for example jax.nn.softmax()
can replace uses of
jax.numpy.argmax()
, jax.nn.soft_sign()
can replace uses of
jax.numpy.sign()
, jax.nn.softplus()
or jax.nn.squareplus()
can replace uses of jax.nn.relu()
, etc.
How can I convert a JAX Tracer to a NumPy array?#
When inspecting a transformed JAX function at runtime, you’ll find that array
values are replaced by Tracer
objects:
@jax.jit
def f(x):
print(type(x))
return x
f(jnp.arange(5))
This prints the following:
<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>
A frequent question is how such a tracer can be converted back to a normal NumPy array. In short, it is impossible to convert a Tracer to a NumPy array, because a tracer is an abstract representation of every possible value with a given shape and dtype, while a numpy array is a concrete member of that abstract class. For more discussion of how tracers work within the context of JAX transformations, see JIT mechanics.
The question of converting Tracers back to arrays usually comes up within the context of another goal, related to accessing intermediate values in a computation at runtime. For example:
If you wish to print a traced value at runtime for debugging purposes, you might consider using
jax.debug.print()
.If you wish to call non-JAX code within a transformed JAX function, you might consider using
jax.pure_callback()
, an example of which is available at Pure callback example.If you wish to input or output array buffers at runtime (for example, load data from file, or log the contents of the array to disk), you might consider using
jax.experimental.io_callback()
, an example of which can be found at IO callback example.
For more information on runtime callbacks and examples of their use, see External callbacks in JAX.
Why do some CUDA libraries fail to load/initialize?#
When resolving dynamic libraries, JAX uses the usual dynamic linker search pattern.
JAX sets RPATH
to point to the JAX-relative location of the
pip-installed NVIDIA CUDA packages, preferring them if installed. If ld.so
cannot find your CUDA runtime libraries along its usual search path, then you
must include the paths to those libraries explicitly in LD_LIBRARY_PATH
.
The easiest way to ensure your CUDA files are discoverable is to simply install
the nvidia-*-cu12
pip packages, which are included in the standard
jax[cuda_12]
install option.
Occasionally, even when you have ensured that your runtime libraries are discoverable, there may still be some issues with loading or initializing them. A common cause of such issues is simply having insufficient memory for CUDA library initialization at runtime. This sometimes occurs because JAX will pre-allocate too large of a chunk of currently available device memory for faster execution, occasionally resulting in insufficient memory being left available for runtime CUDA library initialization.
This is especially likely when running multiple JAX instances, running JAX in
tandem with TensorFlow which performs its own pre-allocation, or when running
JAX on a system where the GPU is being heavily utilized by other processes. When
in doubt, try running the program again with reduced pre-allocation, either by
reducing XLA_PYTHON_CLIENT_MEM_FRACTION
from the default of .75
,
or setting XLA_PYTHON_CLIENT_PREALLOCATE=false
. For more details, please
see the page on JAX GPU memory allocation.
JAX tutorials#
Key Concepts#
This section briefly introduces some key concepts of the JAX package.
JAX arrays (jax.Array
)#
The default array implementation in JAX is jax.Array
. In many ways it is similar to
the numpy.ndarray
type that you may be familar with from the NumPy package, but it
has some important differences.
Array creation#
We typically don’t call the jax.Array
constructor directly, but rather create arrays via JAX API functions.
For example, jax.numpy
provides familar NumPy-style array construction functionality
such as jax.numpy.zeros()
, jax.numpy.linspace()
, jax.numpy.arange()
, etc.
import jax
import jax.numpy as jnp
x = jnp.arange(5)
isinstance(x, jax.Array)
True
If you use Python type annotations in your code, jax.Array
is the appropriate
annotation for jax array objects (see jax.typing
for more discussion).
Array devices and sharding#
JAX Array objects have a devices
method that lets you inspect where the contents of the array are stored. In the simplest cases, this will be a single CPU device:
x.devices()
{CpuDevice(id=0)}
In general, an array may be sharded across multiple devices, in a manner that can be inspected via the sharding
attribute:
x.sharding
SingleDeviceSharding(device=CpuDevice(id=0))
Here the array is on a single device, but in general a JAX array can be sharded across multiple devices, or even multiple hosts. To read more about sharded arrays and parallel computation, refer to Introduction to sharded computation
Transformations#
Along with functions to operate on arrays, JAX includes a number of transformations which operate on JAX functions. These include
jax.jit()
: Just-in-time (JIT) compilation; see Just-in-time compilationjax.vmap()
: Vectorizing transform; see Automatic vectorizationjax.grad()
: Gradient transform; see Automatic differentiation
as well as several others. Transformations accept a function as an argument, and return a new transformed function. For example, here’s how you might JIT-compile a simple SELU function:
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
selu_jit = jax.jit(selu)
print(selu_jit(1.0))
1.05
Often you’ll see transformations applied using Python’s decorator syntax for convenience:
@jax.jit
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
Transformations like jit()
, vmap()
, grad()
, and others are
key to using JAX effectively, and we’ll cover them in detail in later sections.
Tracing#
The magic behind transformations is the notion of a Tracer. Tracers are abstract stand-ins for array objects, and are passed to JAX functions in order to extract the sequence of operations that the function encodes.
You can see this by printing any array value within transformed JAX code; for example:
@jax.jit
def f(x):
print(x)
return x + 1
x = jnp.arange(5)
result = f(x)
Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace(level=1/0)>
The value printed is not the array x
, but a Tracer
instance that
represents essential attributes of x
, such as its shape
and dtype
. By executing
the function with traced values, JAX can determine the sequence of operations encoded
by the function before those operations are actually executed: transformations like
jit()
, vmap()
, and grad()
can then map this sequence
of input operations to a transformed sequence of operations.
Jaxprs#
JAX has its own intermediate representation for sequences of operations, known as a jaxpr. A jaxpr (short for JAX exPRession) is a simple representation of a functional program, comprising a sequence of primitive operations.
For example, consider the selu
function we defined above:
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
We can use the jax.make_jaxpr()
utility to convert this function into a jaxpr
given a particular input:
x = jnp.arange(5.0)
jax.make_jaxpr(selu)(x)
{ lambda ; a:f32[5]. let
b:bool[5] = gt a 0.0
c:f32[5] = exp a
d:f32[5] = mul 1.6699999570846558 c
e:f32[5] = sub d 1.6699999570846558
f:f32[5] = pjit[
name=_where
jaxpr={ lambda ; g:bool[5] h:f32[5] i:f32[5]. let
j:f32[5] = select_n g i h
in (j,) }
] b a e
k:f32[5] = mul 1.0499999523162842 f
in (k,) }
Comparing this to the Python function definition, we see that it encodes the precise sequence of operations that the function represents. We’ll go into more depth about jaxprs later in JAX internals: The jaxpr language.
Pytrees#
JAX functions and transformations fundamentally operate on arrays, but in practice it is convenient to write code that work with collections of arrays: for example, a neural network might organize its parameters in a dictionary of arrays with meaningful keys. Rather than handle such structures on a case-by-case basis, JAX relies on the pytree abstraction to treat such collections in a uniform matter.
Here are some examples of objects that can be treated as pytrees:
# (nested) list of parameters
params = [1, 2, (jnp.arange(3), jnp.ones(2))]
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef([*, *, (*, *)])
[1, 2, Array([0, 1, 2], dtype=int32), Array([1., 1.], dtype=float32)]
# Dictionary of parameters
params = {'n': 5, 'W': jnp.ones((2, 2)), 'b': jnp.zeros(2)}
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef({'W': *, 'b': *, 'n': *})
[Array([[1., 1.],
[1., 1.]], dtype=float32), Array([0., 0.], dtype=float32), 5]
# Named tuple of parameters
from typing import NamedTuple
class Params(NamedTuple):
a: int
b: float
params = Params(1, 5.0)
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef(CustomNode(namedtuple[Params], [*, *]))
[1, 5.0]
JAX has a number of general-purpose utilities for working with PyTrees; for example
the functions jax.tree.map()
can be used to map a function to every leaf in a
tree, and jax.tree.reduce()
can be used to apply a reduction across the leaves
in a tree.
You can learn more in the Working with pytrees tutorial.
Just-in-time compilation#
In this section, we will further explore how JAX works, and how we can make it performant.
We will discuss the jax.jit()
transformation, which will perform Just In Time (JIT)
compilation of a JAX Python function so it can be executed efficiently in XLA.
How JAX transformations work#
In the previous section, we discussed that JAX allows us to transform Python functions. JAX accomplishes this by reducing each function into a sequence of primitive operations, each representing one fundamental unit of computation.
One way to see the sequence of primitives behind a function is using jax.make_jaxpr()
:
import jax
import jax.numpy as jnp
global_list = []
def log2(x):
global_list.append(x)
ln_x = jnp.log(x)
ln_2 = jnp.log(2.0)
return ln_x / ln_2
print(jax.make_jaxpr(log2)(3.0))
{ lambda ; a:f32[]. let
b:f32[] = log a
c:f32[] = log 2.0
d:f32[] = div b c
in (d,) }
The Understanding Jaxprs section of the documentation provides more information on the meaning of the above output.
Importantly, notice that the jaxpr does not capture the side-effect present in the function: there is nothing in it corresponding to global_list.append(x)
.
This is a feature, not a bug: JAX transformations are designed to understand side-effect-free (a.k.a. functionally pure) code.
If pure function and side-effect are unfamiliar terms, this is explained in a little more detail in 🔪 JAX - The Sharp Bits 🔪: Pure Functions.
Impure functions are dangerous because under JAX transformations they are likely not to behave as intended; they might fail silently, or produce surprising downstream errors like leaked Tracers.
Moreover, JAX often can’t detect when side effects are present.
(If you want debug printing, use jax.debug.print()
. To express general side-effects at the cost of performance, see jax.experimental.io_callback()
.
To check for tracer leaks at the cost of performance, use with jax.check_tracer_leaks()
).
When tracing, JAX wraps each argument by a tracer object. These tracers then record all JAX operations performed on them during the function call (which happens in regular Python). Then, JAX uses the tracer records to reconstruct the entire function. The output of that reconstruction is the jaxpr. Since the tracers do not record the Python side-effects, they do not appear in the jaxpr. However, the side-effects still happen during the trace itself.
Note: the Python print()
function is not pure: the text output is a side-effect of the function. Therefore, any print()
calls will only happen during tracing, and will not appear in the jaxpr:
def log2_with_print(x):
print("printed x:", x)
ln_x = jnp.log(x)
ln_2 = jnp.log(2.0)
return ln_x / ln_2
print(jax.make_jaxpr(log2_with_print)(3.))
printed x: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
{ lambda ; a:f32[]. let
b:f32[] = log a
c:f32[] = log 2.0
d:f32[] = div b c
in (d,) }
See how the printed x
is a Traced
object? That’s the JAX internals at work.
The fact that the Python code runs at least once is strictly an implementation detail, and so shouldn’t be relied upon. However, it’s useful to understand as you can use it when debugging to print out intermediate values of a computation.
A key thing to understand is that a jaxpr captures the function as executed on the parameters given to it. For example, if we have a Python conditional, the jaxpr will only know about the branch we take:
def log2_if_rank_2(x):
if x.ndim == 2:
ln_x = jnp.log(x)
ln_2 = jnp.log(2.0)
return ln_x / ln_2
else:
return x
print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1, 2, 3])))
{ lambda ; a:i32[3]. let in (a,) }
JIT compiling a function#
As explained before, JAX enables operations to execute on CPU/GPU/TPU using the same code. Let’s look at an example of computing a Scaled Exponential Linear Unit (SELU), an operation commonly used in deep learning:
import jax
import jax.numpy as jnp
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = jnp.arange(1000000)
%timeit selu(x).block_until_ready()
2.88 ms ± 15.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
The code above is sending one operation at a time to the accelerator. This limits the ability of the XLA compiler to optimize our functions.
Naturally, what we want to do is give the XLA compiler as much code as possible, so it can fully optimize it. For this purpose, JAX provides the jax.jit()
transformation, which will JIT compile a JAX-compatible function. The example below shows how to use JIT to speed up the previous function.
selu_jit = jax.jit(selu)
# Pre-compile the function before timing...
selu_jit(x).block_until_ready()
%timeit selu_jit(x).block_until_ready()
1.02 ms ± 3.34 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Here’s what just happened:
We defined
selu_jit
as the compiled version ofselu
.We called
selu_jit
once onx
. This is where JAX does its tracing – it needs to have some inputs to wrap in tracers, after all. The jaxpr is then compiled using XLA into very efficient code optimized for your GPU or TPU. Finally, the compiled code is executed to satisfy the call. Subsequent calls toselu_jit
will use the compiled code directly, skipping the python implementation entirely. (If we didn’t include the warm-up call separately, everything would still work, but then the compilation time would be included in the benchmark. It would still be faster, because we run many loops in the benchmark, but it wouldn’t be a fair comparison.)We timed the execution speed of the compiled version. (Note the use of
block_until_ready()
, which is required due to JAX’s Asynchronous dispatch).
Why can’t we just JIT everything?#
After going through the example above, you might be wondering whether we should simply apply jax.jit()
to every function. To understand why this is not the case, and when we should/shouldn’t apply jit
, let’s first check some cases where JIT doesn’t work.
# Condition on value of x.
def f(x):
if x > 0:
return x
else:
return 2 * x
jax.jit(f)(10) # Raises an error
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function f at /tmp/ipykernel_2831/2956679937.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
# While loop conditioned on x and n.
def g(x, n):
i = 0
while i < n:
i += 1
return x + i
jax.jit(g)(10, 20) # Raises an error
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function g at /tmp/ipykernel_2831/722961019.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument n.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
The problem in both cases is that we tried to condition the trace-time flow of the program using runtime values.
Traced values within JIT, like x
and n
here, can only affect control flow via their static attributes: such as
shape
or dtype
, and not via their values.
For more detail on the interaction between Python control flow and JAX, see 🔪 JAX - The Sharp Bits 🔪: Control Flow.
One way to deal with this problem is to rewrite the code to avoid conditionals on value. Another is to use special Control flow operators like jax.lax.cond()
. However, sometimes that is not possible or practical.
In that case, you can consider JIT-compiling only part of the function.
For example, if the most computationally expensive part of the function is inside the loop, we can JIT-compile just that inner part (though make sure to check the next section on caching to avoid shooting yourself in the foot):
# While loop conditioned on x and n with a jitted body.
@jax.jit
def loop_body(prev_i):
return prev_i + 1
def g_inner_jitted(x, n):
i = 0
while i < n:
i = loop_body(i)
return x + i
g_inner_jitted(10, 20)
Array(30, dtype=int32, weak_type=True)
Marking arguments as static#
If we really need to JIT-compile a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifying static_argnums
or static_argnames
.
The cost of this is that the resulting jaxpr and compiled artifact depends on the particular value passed, and so JAX will have to re-compile the function for every new value of the specified static input.
It is only a good strategy if the function is guaranteed to see a limited set of static values.
f_jit_correct = jax.jit(f, static_argnums=0)
print(f_jit_correct(10))
10
g_jit_correct = jax.jit(g, static_argnames=['n'])
print(g_jit_correct(10, 20))
30
To specify such arguments when using jit
as a decorator, a common pattern is to use python’s functools.partial()
:
from functools import partial
@partial(jax.jit, static_argnames=['n'])
def g_jit_decorated(x, n):
i = 0
while i < n:
i += 1
return x + i
print(g_jit_decorated(10, 20))
30
JIT and caching#
With the compilation overhead of the first JIT call, understanding how and when jax.jit()
caches previous compilations is key to using it effectively.
Suppose we define f = jax.jit(g)
. When we first invoke f
, it will get compiled, and the resulting XLA code will get cached. Subsequent calls of f
will reuse the cached code.
This is how jax.jit
makes up for the up-front cost of compilation.
If we specify static_argnums
, then the cached code will be used only for the same values of arguments labelled as static. If any of them change, recompilation occurs.
If there are many values, then your program might spend more time compiling than it would have executing ops one-by-one.
Avoid calling jax.jit()
on temporary functions defined inside loops or other Python scopes.
For most cases, JAX will be able to use the compiled, cached function in subsequent calls to jax.jit()
.
However, because the cache relies on the hash of the function, it becomes problematic when equivalent functions are redefined.
This will cause unnecessary compilation each time in the loop:
from functools import partial
def unjitted_loop_body(prev_i):
return prev_i + 1
def g_inner_jitted_partial(x, n):
i = 0
while i < n:
# Don't do this! each time the partial returns
# a function with different hash
i = jax.jit(partial(unjitted_loop_body))(i)
return x + i
def g_inner_jitted_lambda(x, n):
i = 0
while i < n:
# Don't do this!, lambda will also return
# a function with a different hash
i = jax.jit(lambda x: unjitted_loop_body(x))(i)
return x + i
def g_inner_jitted_normal(x, n):
i = 0
while i < n:
# this is OK, since JAX can find the
# cached, compiled function
i = jax.jit(unjitted_loop_body)(i)
return x + i
print("jit called in a loop with partials:")
%timeit g_inner_jitted_partial(10, 20).block_until_ready()
print("jit called in a loop with lambdas:")
%timeit g_inner_jitted_lambda(10, 20).block_until_ready()
print("jit called in a loop with caching:")
%timeit g_inner_jitted_normal(10, 20).block_until_ready()
jit called in a loop with partials:
222 ms ± 3.29 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with lambdas:
224 ms ± 6.07 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with caching:
2.61 ms ± 21 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Automatic vectorization#
In the previous section we discussed JIT compilation via the jax.jit()
function.
This notebook discusses another of JAX’s transforms: vectorization via jax.vmap()
.
Manual vectorization#
Consider the following simple code that computes the convolution of two one-dimensional vectors:
import jax
import jax.numpy as jnp
x = jnp.arange(5)
w = jnp.array([2., 3., 4.])
def convolve(x, w):
output = []
for i in range(1, len(x)-1):
output.append(jnp.dot(x[i-1:i+2], w))
return jnp.array(output)
convolve(x, w)
Array([11., 20., 29.], dtype=float32)
Suppose we would like to apply this function to a batch of weights w
to a batch of vectors x
.
xs = jnp.stack([x, x])
ws = jnp.stack([w, w])
The most naive option would be to simply loop over the batch in Python:
def manually_batched_convolve(xs, ws):
output = []
for i in range(xs.shape[0]):
output.append(convolve(xs[i], ws[i]))
return jnp.stack(output)
manually_batched_convolve(xs, ws)
Array([[11., 20., 29.],
[11., 20., 29.]], dtype=float32)
This produces the correct result, however it is not very efficient.
In order to batch the computation efficiently, you would normally have to rewrite the function manually to ensure it is done in vectorized form. This is not particularly difficult to implement, but does involve changing how the function treats indices, axes, and other parts of the input.
For example, we could manually rewrite convolve()
to support vectorized computation across the batch dimension as follows:
def manually_vectorized_convolve(xs, ws):
output = []
for i in range(1, xs.shape[-1] -1):
output.append(jnp.sum(xs[:, i-1:i+2] * ws, axis=1))
return jnp.stack(output, axis=1)
manually_vectorized_convolve(xs, ws)
Array([[11., 20., 29.],
[11., 20., 29.]], dtype=float32)
Such re-implementation can be messy and error-prone as the complexity of a function increases; fortunately JAX provides another way.
Automatic vectorization#
In JAX, the jax.vmap()
transformation is designed to generate such a vectorized implementation of a function automatically:
auto_batch_convolve = jax.vmap(convolve)
auto_batch_convolve(xs, ws)
Array([[11., 20., 29.],
[11., 20., 29.]], dtype=float32)
It does this by tracing the function similarly to jax.jit()
, and automatically adding batch axes at the beginning of each input.
If the batch dimension is not the first, you may use the in_axes
and out_axes
arguments to specify the location of the batch dimension in inputs and outputs. These may be an integer if the batch axis is the same for all inputs and outputs, or lists, otherwise.
auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1)
xst = jnp.transpose(xs)
wst = jnp.transpose(ws)
auto_batch_convolve_v2(xst, wst)
Array([[11., 11.],
[20., 20.],
[29., 29.]], dtype=float32)
jax.vmap()
also supports the case where only one of the arguments is batched: for example, if you would like to convolve to a single set of weights w
with a batch of vectors x
; in this case the in_axes
argument can be set to None
:
batch_convolve_v3 = jax.vmap(convolve, in_axes=[0, None])
batch_convolve_v3(xs, w)
Array([[11., 20., 29.],
[11., 20., 29.]], dtype=float32)
Combining transformations#
As with all JAX transformations, jax.jit()
and jax.vmap()
are designed to be composable, which means you can wrap a vmapped function with jit
, or a jitted function with vmap
, and everything will work correctly:
jitted_batch_convolve = jax.jit(auto_batch_convolve)
jitted_batch_convolve(xs, ws)
Array([[11., 20., 29.],
[11., 20., 29.]], dtype=float32)
Automatic differentiation#
In this section, you will learn about fundamental applications of automatic differentiation (autodiff) in JAX. JAX has a pretty general autodiff system. Computing gradients is a critical part of modern machine learning methods, and this tutorial will walk you through a few introductory autodiff topics, such as:
3. Differentiating with respect to nested lists, tuples, and dicts
4. Evaluating a function and its gradient using jax.value_and_grad
Make sure to also check out the Advanced automatic differentiation tutorial for more advanced topics.
While understanding how automatic differentiation works “under the hood” isn’t crucial for using JAX in most contexts, you are encouraged to check out this quite accessible video to get a deeper sense of what’s going on.
1. Taking gradients with jax.grad
#
In JAX, you can differentiate a scalar-valued function with the jax.grad()
transformation:
import jax
import jax.numpy as jnp
from jax import grad
grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))
0.070650816
jax.grad()
takes a function and returns a function. If you have a Python function f
that evaluates the mathematical function \(f\), then jax.grad(f)
is a Python function that evaluates the mathematical function \(\nabla f\). That means grad(f)(x)
represents the value \(\nabla f(x)\).
Since jax.grad()
operates on functions, you can apply it to its own output to differentiate as many times as you like:
print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0))
-0.13621868
0.25265405
JAX’s autodiff makes it easy to compute higher-order derivatives, because the functions that compute derivatives are themselves differentiable. Thus, higher-order derivatives are as easy as stacking transformations. This can be illustrated in the single-variable case:
The derivative of \(f(x) = x^3 + 2x^2 - 3x + 1\) can be computed as:
f = lambda x: x**3 + 2*x**2 - 3*x + 1
dfdx = jax.grad(f)
The higher-order derivatives of \(f\) are:
Computing any of these in JAX is as easy as chaining the jax.grad()
function:
d2fdx = jax.grad(dfdx)
d3fdx = jax.grad(d2fdx)
d4fdx = jax.grad(d3fdx)
Evaluating the above in \(x=1\) would give you:
Using JAX:
print(dfdx(1.))
print(d2fdx(1.))
print(d3fdx(1.))
print(d4fdx(1.))
4.0
10.0
6.0
0.0
2. Computing gradients in a linear logistic regression#
The next example shows how to compute gradients with jax.grad()
in a linear logistic regression model. First, the setup:
key = jax.random.key(0)
def sigmoid(x):
return 0.5 * (jnp.tanh(x / 2) + 1)
# Outputs probability of a label being true.
def predict(W, b, inputs):
return sigmoid(jnp.dot(inputs, W) + b)
# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12, 0.77],
[0.88, -1.08, 0.15],
[0.52, 0.06, -1.30],
[0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])
# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
preds = predict(W, b, inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))
# Initialize random model coefficients
key, W_key, b_key = jax.random.split(key, 3)
W = jax.random.normal(W_key, (3,))
b = jax.random.normal(b_key, ())
Use the jax.grad()
function with its argnums
argument to differentiate a function with respect to positional arguments.
# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print(f'{W_grad=}')
# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print(f'{W_grad=}')
# But you can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print(f'{b_grad=}')
# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print(f'{W_grad=}')
print(f'{b_grad=}')
W_grad=Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32)
W_grad=Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32)
b_grad=Array(-0.29227245, dtype=float32)
W_grad=Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32)
b_grad=Array(-0.29227245, dtype=float32)
The jax.grad()
API has a direct correspondence to the excellent notation in Spivak’s classic Calculus on Manifolds (1965), also used in Sussman and Wisdom’s Structure and Interpretation of Classical Mechanics (2015) and their Functional Differential Geometry (2013). Both books are open-access. See in particular the “Prologue” section of Functional Differential Geometry for a defense of this notation.
Essentially, when using the argnums
argument, if f
is a Python function for evaluating the mathematical function \(f\), then the Python expression jax.grad(f, i)
evaluates to a Python function for evaluating \(\partial_i f\).
3. Differentiating with respect to nested lists, tuples, and dicts#
Due to JAX’s PyTree abstraction (see Working with pytrees), differentiating with respect to standard Python containers just works, so use tuples, lists, and dicts (and arbitrary nesting) however you like.
Continuing the previous example:
def loss2(params_dict):
preds = predict(params_dict['W'], params_dict['b'], inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))
print(grad(loss2)({'W': W, 'b': b}))
{'W': Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32), 'b': Array(-0.29227245, dtype=float32)}
You can create Custom pytree nodes to work with not just jax.grad()
but other JAX transformations (jax.jit()
, jax.vmap()
, and so on).
4. Evaluating a function and its gradient using jax.value_and_grad
#
Another convenient function is jax.value_and_grad()
for efficiently computing both a function’s value as well as its gradient’s value in one pass.
Continuing the previous examples:
loss_value, Wb_grad = jax.value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b))
loss value 3.0519385
loss value 3.0519385
5. Checking against numerical differences#
A great thing about derivatives is that they’re straightforward to check with finite differences.
Continuing the previous examples:
# Set a step size for finite differences calculations
eps = 1e-4
# Check b_grad with scalar finite differences
b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', grad(loss, 1)(W, b))
# Check W_grad with finite differences in a random direction
key, subkey = jax.random.split(key)
vec = jax.random.normal(subkey, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(W, b), unitvec))
b_grad_numerical -0.29325485
b_grad_autodiff -0.29227245
W_dirderiv_numerical -0.2002716
W_dirderiv_autodiff -0.19909117
JAX provides a simple convenience function that does essentially the same thing, but checks up to any order of differentiation that you like:
from jax.test_util import check_grads
check_grads(loss, (W, b), order=2) # check up to 2nd order derivatives
Next steps#
The Advanced automatic differentiation tutorial provides more advanced and detailed explanations of how the ideas covered in this document are implemented in the JAX backend. Some features, such as Custom derivative rules for JAX-transformable Python functions, depend on understanding advanced automatic differentiation, so do check out that section in the Advanced automatic differentiation tutorial if you are interested.
Introduction to debugging#
This section introduces you to a set of built-in JAX debugging methods — jax.debug.print()
, jax.debug.breakpoint()
, and jax.debug.callback()
— that you can use with various JAX transformations.
Let’s begin with jax.debug.print()
.
JAX debug.print
for high-level#
TL;DR Here is a rule of thumb:
Use
jax.debug.print()
for traced (dynamic) array values withjax.jit()
,jax.vmap()
and others.Use Python
print()
for static values, such as dtypes and array shapes.
Recall from Just-in-time compilation that when transforming a function with jax.jit()
,
the Python code is executed with abstract tracers in place of your arrays. Because of this,
the Python print()
function will only print this tracer value:
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
print("print(x) ->", x)
y = jnp.sin(x)
print("print(y) ->", y)
return y
result = f(2.)
print(x) -> Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
print(y) -> Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
Python’s print
executes at trace-time, before the runtime values exist.
If you want to print the actual runtime values, you can use jax.debug.print()
:
@jax.jit
def f(x):
jax.debug.print("jax.debug.print(x) -> {x}", x=x)
y = jnp.sin(x)
jax.debug.print("jax.debug.print(y) -> {y}", y=y)
return y
result = f(2.)
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.9092974066734314
Similarly, within jax.vmap()
, using Python’s print
will only print the tracer;
to print the values being mapped over, use jax.debug.print()
:
def f(x):
jax.debug.print("jax.debug.print(x) -> {}", x)
y = jnp.sin(x)
jax.debug.print("jax.debug.print(y) -> {}", y)
return y
xs = jnp.arange(3.)
result = jax.vmap(f)(xs)
jax.debug.print(x) -> 0.0
jax.debug.print(x) -> 1.0
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.0
jax.debug.print(y) -> 0.8414709568023682
jax.debug.print(y) -> 0.9092974066734314
Here’s the result with jax.lax.map()
, which is a sequential map rather than a
vectorization:
result = jax.lax.map(f, xs)
jax.debug.print(x) -> 0.0
jax.debug.print(y) -> 0.0
jax.debug.print(x) -> 1.0
jax.debug.print(y) -> 0.8414709568023682
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.9092974066734314
Notice the order is different, as jax.vmap()
and jax.lax.map()
compute the same results in different ways. When debugging, the evaluation order details are exactly what you may need to inspect.
Below is an example with jax.grad()
, where jax.debug.print()
only prints the forward pass. In this case, the behavior is similar to Python’s print()
, but it’s consistent if you apply jax.jit()
during the call.
def f(x):
jax.debug.print("jax.debug.print(x) -> {}", x)
return x ** 2
result = jax.grad(f)(1.)
jax.debug.print(x) -> 1.0
Sometimes, when the arguments don’t depend on one another, calls to jax.debug.print()
may print them in a different order when staged out with a JAX transformation. If you need the original order, such as x: ...
first and then y: ...
second, add the ordered=True
parameter.
For example:
@jax.jit
def f(x, y):
jax.debug.print("jax.debug.print(x) -> {}", x, ordered=True)
jax.debug.print("jax.debug.print(y) -> {}", y, ordered=True)
return x + y
f(1, 2)
jax.debug.print(x) -> 1
jax.debug.print(y) -> 2
Array(3, dtype=int32, weak_type=True)
To learn more about jax.debug.print()
and its Sharp Bits, refer to Advanced debugging.
JAX debug.breakpoint
for pdb
-like debugging#
TL;DR Use jax.debug.breakpoint()
to pause the execution of your JAX program to inspect values.
To pause your compiled JAX program during certain points during debugging, you can use jax.debug.breakpoint()
. The prompt is similar to Python pdb
, and it allows you to inspect the values in the call stack. In fact, jax.debug.breakpoint()
is an application of jax.debug.callback()
that captures information about the call stack.
To print all available commands during a breakpoint
debugging session, use the help
command. (Full debugger commands, the Sharp Bits, its strengths and limitations are covered in Advanced debugging.)
Here is an example of what a debugger session might look like:
@jax.jit
def f(x):
y, z = jnp.sin(x, jnp.cos(x))
jax.debug.breakpoint()
return y * z
f(2.) # ==> Pauses during execution
For value-dependent breakpointing, you can use runtime conditionals like jax.lax.cond()
:
def breakpoint_if_nonfinite(x):
is_finite = jnp.isfinite(x).all()
def true_fn(x):
pass
def false_fn(x):
jax.debug.breakpoint()
jax.lax.cond(is_finite, true_fn, false_fn, x)
@jax.jit
def f(x, y):
z = x / y
breakpoint_if_nonfinite(z)
return z
f(2., 1.) # ==> No breakpoint
Array(2., dtype=float32, weak_type=True)
f(2., 0.) # ==> Pauses during execution
JAX debug.callback
for more control during debugging#
Both jax.debug.print()
and jax.debug.breakpoint()
are implemented using
the more flexible jax.debug.callback()
, which gives greater control over the
host-side logic executed via a Python callback.
It is compatible with jax.jit()
, jax.vmap()
, jax.grad()
and other
transformations (refer to the Flavors of callback table in
External callbacks for more information).
For example:
import logging
def log_value(x):
logging.warning(f'Logged value: {x}')
@jax.jit
def f(x):
jax.debug.callback(log_value, x)
return x
f(1.0);
WARNING:root:Logged value: 1.0
This callback is compatible with other transformations, including jax.vmap()
and jax.grad()
:
x = jnp.arange(5.0)
jax.vmap(f)(x);
WARNING:root:Logged value: 0.0
WARNING:root:Logged value: 1.0
WARNING:root:Logged value: 2.0
WARNING:root:Logged value: 3.0
WARNING:root:Logged value: 4.0
jax.grad(f)(1.0);
WARNING:root:Logged value: 1.0
This can make jax.debug.callback()
useful for general-purpose debugging.
You can learn more about jax.debug.callback()
and other kinds of JAX callbacks in External callbacks.
Next steps#
Check out the Advanced debugging to learn more about debugging in JAX.
Pseudorandom numbers#
In this section we focus on jax.random
and pseudo random number generation (PRNG); that is, the process of algorithmically generating sequences of numbers whose properties approximate the properties of sequences of random numbers sampled from an appropriate distribution.
PRNG-generated sequences are not truly random because they are actually determined by their initial value, which is typically referred to as the seed
, and each step of random sampling is a deterministic function of some state
that is carried over from a sample to the next.
Pseudo random number generation is an essential component of any machine learning or scientific computing framework. Generally, JAX strives to be compatible with NumPy, but pseudo random number generation is a notable exception.
To better understand the difference between the approaches taken by JAX and NumPy when it comes to random number generation we will discuss both approaches in this section.
Random numbers in NumPy#
Pseudo random number generation is natively supported in NumPy by the numpy.random
module.
In NumPy, pseudo random number generation is based on a global state
, which can be set to a deterministic initial condition using numpy.random.seed()
.
import numpy as np
np.random.seed(0)
You can inspect the content of the state using the following command.
def print_truncated_random_state():
"""To avoid spamming the outputs, print only part of the state."""
full_random_state = np.random.get_state()
print(str(full_random_state)[:460], '...')
print_truncated_random_state()
('MT19937', array([ 0, 1, 1812433255, 1900727105, 1208447044,
2481403966, 4042607538, 337614300, 3232553940, 1018809052,
3202401494, 1775180719, 3192392114, 594215549, 184016991,
829906058, 610491522, 3879932251, 3139825610, 297902587,
4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
2891506774, 1066338622, 135451537, 933040465, 2759011858,
2273819758, 3545703099, 2516396728, 127 ...
The state
is updated by each call to a random function:
np.random.seed(0)
print_truncated_random_state()
('MT19937', array([ 0, 1, 1812433255, 1900727105, 1208447044,
2481403966, 4042607538, 337614300, 3232553940, 1018809052,
3202401494, 1775180719, 3192392114, 594215549, 184016991,
829906058, 610491522, 3879932251, 3139825610, 297902587,
4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
2891506774, 1066338622, 135451537, 933040465, 2759011858,
2273819758, 3545703099, 2516396728, 127 ...
_ = np.random.uniform()
print_truncated_random_state()
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
3904844661, 676747479, 2085143622, 1056793272, 3812477442,
2168787041, 275552121, 2696932952, 3432054210, 1657102335,
3518946594, 962584079, 1051271004, 3806145045, 1414436097,
2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
696824676, 2399811678, 3992505346, 569184356, 2626558620,
136797809, 4273176064, 296167901, 343 ...
NumPy allows you to sample both individual numbers, or entire vectors of numbers in a single function call. For instance, you may sample a vector of 3 scalars from a uniform distribution by doing:
np.random.seed(0)
print(np.random.uniform(size=3))
[0.5488135 0.71518937 0.60276338]
NumPy provides a sequential equivalent guarantee, meaning that sampling N numbers in a row individually or sampling a vector of N numbers results in the same pseudo-random sequences:
np.random.seed(0)
print("individually:", np.stack([np.random.uniform() for _ in range(3)]))
np.random.seed(0)
print("all at once: ", np.random.uniform(size=3))
individually: [0.5488135 0.71518937 0.60276338]
all at once: [0.5488135 0.71518937 0.60276338]
Random numbers in JAX#
JAX’s random number generation differs from NumPy’s in important ways, because NumPy’s PRNG design makes it hard to simultaneously guarantee a number of desirable properties. Specifically, in JAX we want PRNG generation to be:
reproducible,
parallelizable,
vectorisable.
We will discuss why in the following. First, we will focus on the implications of a PRNG design based on a global state. Consider the code:
import numpy as np
np.random.seed(0)
def bar(): return np.random.uniform()
def baz(): return np.random.uniform()
def foo(): return bar() + 2 * baz()
print(foo())
1.9791922366721637
The function foo
sums two scalars sampled from a uniform distribution.
The output of this code can only satisfy requirement #1 if we assume a predictable order of execution for bar()
and baz()
.
This is not a problem in NumPy, which always evaluates code in the order defined by the Python interpreter.
In JAX, however, this is more problematic: for efficient execution, we want the JIT compiler to be free to reorder, elide, and fuse various operations in the function we define.
Further, when executing in multi-device environments, execution efficiency would be hampered by the need for each process to synchronize a global state.
Explicit random state#
To avoid this issue, JAX avoids implicit global random state, and instead tracks state explicitly via a random key
:
from jax import random
key = random.key(42)
print(key)
Array((), dtype=key<fry>) overlaying:
[ 0 42]
Note
This section uses the new-style typed PRNG keys produced by jax.random.key()
, rather than the
old-style raw PRNG keys produced by jax.random.PRNGKey()
. For details, see JEP 9263: Typed keys & pluggable RNGs.
A key is an array with a special dtype corresponding to the particular PRNG implementation being used; in the default implementation each key is backed by a pair of uint32
values.
The key is effectively a stand-in for NumPy’s hidden state object, but we pass it explicitly to jax.random()
functions.
Importantly, random functions consume the key, but do not modify it: feeding the same key object to a random function will always result in the same sample being generated.
print(random.normal(key))
print(random.normal(key))
-0.18471177
-0.18471177
Re-using the same key, even with different random
APIs, can result in correlated outputs, which is generally undesirable.
The rule of thumb is: never reuse keys (unless you want identical outputs).
In order to generate different and independent samples, you must split()
the key explicitly before passing it to a random function:
for i in range(3):
new_key, subkey = random.split(key)
del key # The old key is consumed by split() -- we must never use it again.
val = random.normal(subkey)
del subkey # The subkey is consumed by normal().
print(f"draw {i}: {val}")
key = new_key # new_key is safe to use in the next iteration.
draw 0: 1.369469404220581
draw 1: -0.19947023689746857
draw 2: -2.298278331756592
(Calling del
here is not required, but we do so to emphasize that the key should not be reused once consumed.)
jax.random.split()
is a deterministic function that converts one key
into several independent (in the pseudorandomness sense) keys.
We keep one of the outputs as the new_key
, and can safely use the unique extra key (called subkey
) as input into a random function, and then discard it forever.
If you wanted to get another sample from the normal distribution, you would split key
again, and so on: the crucial point is that you never use the same key twice.
It doesn’t matter which part of the output of split(key)
we call key
, and which we call subkey
.
They are all independent keys with equal status.
The key/subkey naming convention is a typical usage pattern that helps track how keys are consumed:
subkeys are destined for immediate consumption by random functions, while the key is retained to generate more randomness later.
Usually, the above example would be written concisely as
key, subkey = random.split(key)
which discards the old key automatically.
It’s worth noting that split()
can create as many keys as you need, not just 2:
key, *forty_two_subkeys = random.split(key, num=43)
Lack of sequential equivalence#
Another difference between NumPy’s and JAX’s random modules relates to the sequential equivalence guarantee mentioned above.
As in NumPy, JAX’s random module also allows sampling of vectors of numbers. However, JAX does not provide a sequential equivalence guarantee, because doing so would interfere with the vectorization on SIMD hardware (requirement #3 above).
In the example below, sampling 3 values out of a normal distribution individually using three subkeys gives a different result to using giving a single key and specifying shape=(3,)
:
key = random.key(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)
key = random.key(42)
print("all at once: ", random.normal(key, shape=(3,)))
individually: [-0.04838832 0.10796154 -1.2226542 ]
all at once: [ 0.18693547 -1.2806505 -1.5593132 ]
The lack of sequential equivalence gives us freedom to write code more efficiently; for example,
instead of generating sequence
above via a sequential loop, we can use jax.vmap()
to
compute the same result in a vectorized manner:
import jax
print("vectorized:", jax.vmap(random.normal)(subkeys))
vectorized: [-0.04838832 0.10796154 -1.2226542 ]
Next Steps#
For more information on JAX random numbers, refer to the documentation of the jax.random
module. If you’re interested in the details of the design of JAX’s random number generator,
see JAX PRNG Design.
Working with pytrees#
JAX has built-in support for objects that look like dictionaries (dicts) of arrays, or lists of lists of dicts, or other nested structures — in JAX these are called pytrees. This section will explain how to use them, provide useful code examples, and point out common “gotchas” and patterns.
What is a pytree?#
A pytree is a container-like structure built out of container-like Python objects — “leaf” pytrees and/or more pytrees. A pytree can include lists, tuples, and dicts. A leaf is anything that’s not a pytree, such as an array, but a single leaf is also a pytree.
In the context of machine learning (ML), a pytree can contain:
Model parameters
Dataset entries
Reinforcement learning agent observations
When working with datasets, you can often come across pytrees (such as lists of lists of dicts).
Below is an example of a simple pytree. In JAX, you can use jax.tree.leaves()
, to extract the flattened leaves from the trees, as demonstrated here:
import jax
import jax.numpy as jnp
example_trees = [
[1, 'a', object()],
(1, (2, 3), ()),
[1, {'k1': 2, 'k2': (3, 4)}, 5],
{'a': 2, 'b': (2, 3)},
jnp.array([1, 2, 3]),
]
# Print how many leaves the pytrees have.
for pytree in example_trees:
# This `jax.tree.leaves()` method extracts the flattened leaves from the pytrees.
leaves = jax.tree.leaves(pytree)
print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")
[1, 'a', <object object at 0x7f62b3463730>] has 3 leaves: [1, 'a', <object object at 0x7f62b3463730>]
(1, (2, 3), ()) has 3 leaves: [1, 2, 3]
[1, {'k1': 2, 'k2': (3, 4)}, 5] has 5 leaves: [1, 2, 3, 4, 5]
{'a': 2, 'b': (2, 3)} has 3 leaves: [2, 2, 3]
Array([1, 2, 3], dtype=int32) has 1 leaves: [Array([1, 2, 3], dtype=int32)]
Any tree-like structure built out of container-like Python objects can be treated as a pytree in JAX. Classes are considered container-like if they are in the pytree registry, which by default includes lists, tuples, and dicts. Any object whose type is not in the pytree container registry will be treated as a leaf node in the tree.
The pytree registry can be extended to include user-defined container classes by registering the class with functions that specify how to flatten the tree; see Custom pytree nodes below.
Common pytree functions#
JAX provides a number of utilities to operate over pytrees. These can be found in the jax.tree_util
subpackage;
for convenience many of these have aliases in the jax.tree
module.
Common function: jax.tree.map
#
The most commonly used pytree function is jax.tree.map()
. It works analogously to Python’s native map
, but transparently operates over entire pytrees.
Here’s an example:
list_of_lists = [
[1, 2, 3],
[1, 2],
[1, 2, 3, 4]
]
jax.tree.map(lambda x: x*2, list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]
jax.tree.map()
also allows mapping a N-ary function over multiple arguments. For example:
another_list_of_lists = list_of_lists
jax.tree.map(lambda x, y: x+y, list_of_lists, another_list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]
When using multiple arguments with jax.tree.map()
, the structure of the inputs must exactly match. That is, lists must have the same number of elements, dicts must have the same keys, etc.
Example of jax.tree.map
with ML model parameters#
This example demonstrates how pytree operations can be useful when training a simple multi-layer perceptron (MLP).
Begin with defining the initial model parameters:
import numpy as np
def init_mlp_params(layer_widths):
params = []
for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
params.append(
dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),
biases=np.ones(shape=(n_out,))
)
)
return params
params = init_mlp_params([1, 128, 128, 1])
Use jax.tree.map()
to check the shapes of the initial parameters:
jax.tree.map(lambda x: x.shape, params)
[{'biases': (128,), 'weights': (1, 128)},
{'biases': (128,), 'weights': (128, 128)},
{'biases': (1,), 'weights': (128, 1)}]
Next, define the functions for training the MLP model:
# Define the forward pass.
def forward(params, x):
*hidden, last = params
for layer in hidden:
x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
return x @ last['weights'] + last['biases']
# Define the loss function.
def loss_fn(params, x, y):
return jnp.mean((forward(params, x) - y) ** 2)
# Set the learning rate.
LEARNING_RATE = 0.0001
# Using the stochastic gradient descent, define the parameter update function.
# Apply `@jax.jit` for JIT compilation (speed).
@jax.jit
def update(params, x, y):
# Calculate the gradients with `jax.grad`.
grads = jax.grad(loss_fn)(params, x, y)
# Note that `grads` is a pytree with the same structure as `params`.
# `jax.grad` is one of many JAX functions that has
# built-in support for pytrees.
# This is useful - you can apply the SGD update using JAX pytree utilities.
return jax.tree.map(
lambda p, g: p - LEARNING_RATE * g, params, grads
)
Custom pytree nodes#
This section explains how in JAX you can extend the set of Python types that will be considered internal nodes in pytrees (pytree nodes) by using jax.tree_util.register_pytree_node()
with jax.tree.map()
.
Why would you need this? In the previous examples, pytrees were shown as lists, tuples, and dicts, with everything else as pytree leaves. This is because if you define your own container class, it will be considered to be a pytree leaf unless you register it with JAX. This is also the case even if your container class has trees inside it. For example:
class Special(object):
def __init__(self, x, y):
self.x = x
self.y = y
jax.tree.leaves([
Special(0, 1),
Special(2, 4),
])
[<__main__.Special at 0x7f62b3d55060>, <__main__.Special at 0x7f62b3d57100>]
Accordingly, if you try to use a jax.tree.map()
expecting the leaves to be elements inside the container, you will get an error:
jax.tree.map(lambda x: x + 1,
[
Special(0, 1),
Special(2, 4)
])
TypeError: unsupported operand type(s) for +: 'Special' and 'int'
As a solution, JAX allows to extend the set of types to be considered internal pytree nodes through a global registry of types. Additionally, the values of registered types are traversed recursively.
First, register a new type using jax.tree_util.register_pytree_node()
:
from jax.tree_util import register_pytree_node
class RegisteredSpecial(Special):
def __repr__(self):
return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)
def special_flatten(v):
"""Specifies a flattening recipe.
Params:
v: The value of the registered type to flatten.
Returns:
A pair of an iterable with the children to be flattened recursively,
and some opaque auxiliary data to pass back to the unflattening recipe.
The auxiliary data is stored in the treedef for use during unflattening.
The auxiliary data could be used, for example, for dictionary keys.
"""
children = (v.x, v.y)
aux_data = None
return (children, aux_data)
def special_unflatten(aux_data, children):
"""Specifies an unflattening recipe.
Params:
aux_data: The opaque data that was specified during flattening of the
current tree definition.
children: The unflattened children
Returns:
A reconstructed object of the registered type, using the specified
children and auxiliary data.
"""
return RegisteredSpecial(*children)
# Global registration
register_pytree_node(
RegisteredSpecial,
special_flatten, # Instruct JAX what are the children nodes.
special_unflatten # Instruct JAX how to pack back into a `RegisteredSpecial`.
)
Now you can traverse the special container structure:
jax.tree.map(lambda x: x + 1,
[
RegisteredSpecial(0, 1),
RegisteredSpecial(2, 4),
])
[RegisteredSpecial(x=1, y=2), RegisteredSpecial(x=3, y=5)]
Modern Python comes equipped with helpful tools to make defining containers easier. Some will work with JAX out-of-the-box, but others require more care.
For instance, a Python NamedTuple
subclass doesn’t need to be registered to be considered a pytree node type:
from typing import NamedTuple, Any
class MyOtherContainer(NamedTuple):
name: str
a: Any
b: Any
c: Any
# NamedTuple subclasses are handled as pytree nodes, so
# this will work out-of-the-box.
jax.tree.leaves([
MyOtherContainer('Alice', 1, 2, 3),
MyOtherContainer('Bob', 4, 5, 6)
])
['Alice', 1, 2, 3, 'Bob', 4, 5, 6]
Notice that the name
field now appears as a leaf, because all tuple elements are children. This is what happens when you don’t have to register the class the hard way.
Pytrees and JAX transformations#
Many JAX functions, like jax.lax.scan()
, operate over pytrees of arrays. In addition, all JAX function transformations can be applied to functions that accept as input and produce as output pytrees of arrays.
Some JAX function transformations take optional parameters that specify how certain input or output values should be treated (such as the in_axes
and out_axes
arguments to jax.vmap()
). These parameters can also be pytrees, and their structure must correspond to the pytree structure of the corresponding arguments. In particular, to be able to “match up” leaves in these parameter pytrees with values in the argument pytrees, the parameter pytrees are often constrained to be tree prefixes of the argument pytrees.
For example, if you pass the following input to jax.vmap()
(note that the input arguments to a function are considered a tuple):
vmap(f, in_axes=(a1, {"k1": a2, "k2": a3}))
then you can use the following in_axes
pytree to specify that only the k2
argument is mapped (axis=0
), and the rest aren’t mapped over (axis=None
):
vmap(f, in_axes=(None, {"k1": None, "k2": 0}))
The optional parameter pytree structure must match that of the main input pytree. However, the optional parameters can optionally be specified as a “prefix” pytree, meaning that a single leaf value can be applied to an entire sub-pytree.
For example, if you have the same jax.vmap()
input as above, but wish to only map over the dictionary argument, you can use:
vmap(f, in_axes=(None, 0)) # equivalent to (None, {"k1": 0, "k2": 0})
Alternatively, if you want every argument to be mapped, you can write a single leaf value that is applied over the entire argument tuple pytree:
vmap(f, in_axes=0) # equivalent to (0, {"k1": 0, "k2": 0})
This happens to be the default in_axes
value for jax.vmap()
.
The same logic applies to other optional parameters that refer to specific input or output values of a transformed function, such as out_axes
in jax.vmap()
.
Explicit key paths#
In a pytree each leaf has a key path. A key path for a leaf is a list
of keys, where the length of the list is equal to the depth of the leaf in the pytree . Each key is a hashable object that represents an index into the corresponding pytree node type. The type of the key depends on the pytree node type; for example, the type of keys for dict
s is different from the type of keys for tuple
s.
For built-in pytree node types, the set of keys for any pytree node instance is unique. For a pytree comprising nodes with this property, the key path for each leaf is unique.
JAX has the following jax.tree_util.*
methods for working with key paths:
jax.tree_util.tree_flatten_with_path()
: Works similarly tojax.tree.flatten()
, but returns key paths.jax.tree_util.tree_map_with_path()
: Works similarly tojax.tree.map()
, but the function also takes key paths as arguments.jax.tree_util.keystr()
: Given a general key path, returns a reader-friendly string expression.
For example, one use case is to print debugging information related to a certain leaf value:
import collections
ATuple = collections.namedtuple("ATuple", ('name'))
tree = [1, {'k1': 2, 'k2': (3, 4)}, ATuple('foo')]
flattened, _ = jax.tree_util.tree_flatten_with_path(tree)
for key_path, value in flattened:
print(f'Value of tree{jax.tree_util.keystr(key_path)}: {value}')
Value of tree[0]: 1
Value of tree[1]['k1']: 2
Value of tree[1]['k2'][0]: 3
Value of tree[1]['k2'][1]: 4
Value of tree[2].name: foo
To express key paths, JAX provides a few default key types for the built-in pytree node types, namely:
SequenceKey(idx: int)
: For lists and tuples.DictKey(key: Hashable)
: For dictionaries.GetAttrKey(name: str)
: Fornamedtuple
s and preferably custom pytree nodes (more in the next section)
You are free to define your own key types for your custom nodes. They will work with jax.tree_util.keystr()
as long as their __str__()
method is also overridden with a reader-friendly expression.
for key_path, _ in flattened:
print(f'Key path of tree{jax.tree_util.keystr(key_path)}: {repr(key_path)}')
Key path of tree[0]: (SequenceKey(idx=0),)
Key path of tree[1]['k1']: (SequenceKey(idx=1), DictKey(key='k1'))
Key path of tree[1]['k2'][0]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=0))
Key path of tree[1]['k2'][1]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=1))
Key path of tree[2].name: (SequenceKey(idx=2), GetAttrKey(name='name'))
Common pytree gotchas#
This section covers some of the most common problems (“gotchas”) encountered when using JAX pytrees.
Mistaking pytree nodes for leaves#
A common gotcha to look out for is accidentally introducing tree nodes instead of leaves:
a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))]
# Try to make another pytree with ones instead of zeros.
shapes = jax.tree.map(lambda x: x.shape, a_tree)
jax.tree.map(jnp.ones, shapes)
[(Array([1., 1.], dtype=float32), Array([1., 1., 1.], dtype=float32)),
(Array([1., 1., 1.], dtype=float32), Array([1., 1., 1., 1.], dtype=float32))]
What happened here is that the shape
of an array is a tuple, which is a pytree node, with its elements as leaves. Thus, in the map, instead of calling jnp.ones
on e.g. (2, 3)
, it’s called on 2
and 3
.
The solution will depend on the specifics, but there are two broadly applicable options:
Rewrite the code to avoid the intermediate
jax.tree.map()
.Convert the tuple into a NumPy array (
np.array
) or a JAX NumPy array (jnp.array
), which makes the entire sequence a leaf.
Handling of None
by jax.tree_util
#
jax.tree_util
functions treat None
as the absence of a pytree node, not as a leaf:
jax.tree.leaves([None, None, None])
[]
To treat None
as a leaf, you can use the is_leaf
argument:
jax.tree.leaves([None, None, None], is_leaf=lambda x: x is None)
[None, None, None]
Custom pytrees and initialization with unexpected values#
Another common gotcha with user-defined pytree objects is that JAX transformations occasionally initialize them with unexpected values, so that any input validation done at initialization may fail. For example:
class MyTree:
def __init__(self, a):
self.a = jnp.asarray(a)
register_pytree_node(MyTree, lambda tree: ((tree.a,), None),
lambda _, args: MyTree(*args))
tree = MyTree(jnp.arange(5.0))
jax.vmap(lambda x: x)(tree) # Error because object() is passed to `MyTree`.
TypeError: Cannot interpret '<object object at 0x7f62b3463df0>' as a data type
The above exception was the direct cause of the following exception:
TypeError: Cannot determine dtype of <object object at 0x7f62b3463df0>
During handling of the above exception, another exception occurred:
TypeError: Value '<object object at 0x7f62b3463df0>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
jax.jacobian(lambda x: x)(tree) # Error because MyTree(...) is passed to `MyTree`.
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:2662: FutureWarning: None encountered in jnp.array(); this is currently treated as NaN. In the future this will result in an error.
return array(a, dtype=dtype, copy=bool(copy), order=order) # type: ignore
TypeError: Cannot interpret '<object object at 0x7f629c32c360>' as a data type
The above exception was the direct cause of the following exception:
TypeError: Cannot determine dtype of <object object at 0x7f629c32c360>
During handling of the above exception, another exception occurred:
TypeError: Value '<object object at 0x7f629c32c360>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
In the first case with
jax.vmap(...)(tree)
, JAX’s internals use arrays ofobject()
values to infer the structure of the treeIn the second case with
jax.jacobian(...)(tree)
, the Jacobian of a function mapping a tree to a tree is defined as a tree of trees.
Potential solution 1:
The
__init__
and__new__
methods of custom pytree classes should generally avoid doing any array conversion or other input validation, or else anticipate and handle these special cases. For example:
class MyTree:
def __init__(self, a):
if not (type(a) is object or a is None or isinstance(a, MyTree)):
a = jnp.asarray(a)
self.a = a
Potential solution 2:
Structure your custom
tree_unflatten
function so that it avoids calling__init__
. If you choose this route, make sure that yourtree_unflatten
function stays in sync with__init__
if and when the code is updated. Example:
def tree_unflatten(aux_data, children):
del aux_data # Unused in this class.
obj = object.__new__(MyTree)
obj.a = a
return obj
Common pytree patterns#
This section covers some of the most common patterns with JAX pytrees.
Transposing pytrees with jax.tree.map
and jax.tree.transpose
#
To transpose a pytree (turn a list of trees into a tree of lists), JAX has two functions: {func} jax.tree.map
(more basic) and jax.tree.transpose()
(more flexible, complex and verbose).
Option 1: Use jax.tree.map()
. Here’s an example:
def tree_transpose(list_of_trees):
"""
Converts a list of trees of identical structure into a single tree of lists.
"""
return jax.tree.map(lambda *xs: list(xs), *list_of_trees)
# Convert a dataset from row-major to column-major.
episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)]
tree_transpose(episode_steps)
{'obs': [3, 4], 't': [1, 2]}
Option 2: For more complex transposes, use jax.tree.transpose()
, which is more verbose, but allows you specify the structure of the inner and outer pytree for more flexibility. For example:
jax.tree.transpose(
outer_treedef = jax.tree.structure([0 for e in episode_steps]),
inner_treedef = jax.tree.structure(episode_steps[0]),
pytree_to_transpose = episode_steps
)
{'obs': [3, 4], 't': [1, 2]}
Introduction to sharded computation#
JAX’s jax.Array
object is designed with distributed data and computation in mind.
This section will cover three modes of parallel computation:
Automatic parallelism via
jax.jit()
, in which we let the compiler choose the optimal computation strategySemi-automatic parallelism using
jax.jit()
andjax.lax.with_sharding_constraint()
Fully manual parallelism using
jax.experimental.shard_map.shard_map()
These examples will be run on Colab’s free TPU runtime, which provides eight devices to work with:
import jax
jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
Key concept: data sharding#
Key to all of the distributed computation approaches below is the concept of data sharding, which describes how data is laid out on the available devices.
Each concrete jax.Array
object has a sharding
attribute and a devices()
method that can give you insight into how the underlying data are stored. In the simplest cases, arrays are sharded on a single device:
import jax.numpy as jnp
arr = jnp.arange(32.0).reshape(4, 8)
arr.devices()
{TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}
arr.sharding
SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0))
For a more visual representation of the storage layout, the jax.debug
module provides some helpers to visualize the sharding of an array:
jax.debug.visualize_array_sharding(arr)
TPU 0
To create an array with a non-trivial sharding, we can define a sharding
specification for the array and pass this to jax.device_put()
.
Here we’ll define a NamedSharding
, which specifies an N-dimensional grid of devices with named axes:
# Pardon the boilerplate; constructing a sharding will become easier soon!
from jax.sharding import Mesh
from jax.sharding import PartitionSpec
from jax.sharding import NamedSharding
from jax.experimental import mesh_utils
P = jax.sharding.PartitionSpec
devices = mesh_utils.create_device_mesh((2, 4))
mesh = jax.sharding.Mesh(devices, ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
print(sharding)
NamedSharding(mesh=Mesh('x': 2, 'y': 4), spec=PartitionSpec('x', 'y'))
Passing this sharding
to jax.device_put()
, we obtain a sharded array:
arr_sharded = jax.device_put(arr, sharding)
print(arr_sharded)
jax.debug.visualize_array_sharding(arr_sharded)
[[ 0. 1. 2. 3. 4. 5. 6. 7.]
[ 8. 9. 10. 11. 12. 13. 14. 15.]
[16. 17. 18. 19. 20. 21. 22. 23.]
[24. 25. 26. 27. 28. 29. 30. 31.]]
TPU 0 TPU 1 TPU 2 TPU 3 TPU 6 TPU 7 TPU 4 TPU 5
The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.
Automatic parallelism via jit
#
Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a JIT-compiled function!
The XLA compiler behind jit
includes heuristics for optimizing computations across multiple devices.
In the simplest of cases, those heuristics boil down to computation follows data.
For example, here’s a simple element-wise function: the computation for each shard will be performed on the device associated with that shard, and the output is sharded in the same way:
@jax.jit
def f_elementwise(x):
return 2 * jnp.sin(x) + 1
result = f_elementwise(arr_sharded)
print("shardings match:", result.sharding == arr_sharded.sharding)
shardings match: True
As computations get more complex, the compiler makes decisions about how to best propagate the sharding of the data.
Here we sum along the leading axis of x
:
@jax.jit
def f_contract(x):
return x.sum(axis=0)
result = f_contract(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)
TPU 0,6 TPU 1,7 TPU 2,4 TPU 3,5
[48. 52. 56. 60. 64. 68. 72. 76.]
The result is partially replicated: that is, the first two elements of the array are replicated on devices 0
and 6
, the second on 1
and 7
, and so on.
Semi-automated sharding with constraints#
If you’d like to have some control over the sharding used within a particular computation, JAX offers the with_sharding_constraint()
function.
For example, suppose that within f_contract
above, you’d prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices:
@jax.jit
def f_contract_2(x):
out = x.sum(axis=0)
# mesh = jax.create_mesh((8,), 'x')
devices = mesh_utils.create_device_mesh(8)
mesh = jax.sharding.Mesh(devices, 'x')
sharding = jax.sharding.NamedSharding(mesh, P('x'))
return jax.lax.with_sharding_constraint(out, sharding)
result = f_contract_2(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)
TPU 0 TPU 1 TPU 2 TPU 3 TPU 6 TPU 7 TPU 4 TPU 5
[48. 52. 56. 60. 64. 68. 72. 76.]
This gives you a function with the particular output sharding you’d like.
Manual parallelism with shard_map
#
In the automatic parallelism methods explored above, you can write a function as if you’re operating on the full dataset, and jit
will split that computation across multiple devices.
By contrast, with shard_map
you write the function that will handle a single shard of data, and shard_map
will construct the full function.
shard_map
works by mapping a function across a particular mesh of devices:
from jax.experimental.shard_map import shard_map
P = jax.sharding.PartitionSpec
mesh = jax.sharding.Mesh(jax.devices(), 'x')
f_elementwise_sharded = shard_map(
f_elementwise,
mesh=mesh,
in_specs=P('x'),
out_specs=P('x'))
arr = jnp.arange(32)
f_elementwise_sharded(arr)
Array([ 1. , 2.682942 , 2.818595 , 1.28224 , -0.513605 ,
-0.9178486 , 0.44116896, 2.3139732 , 2.9787164 , 1.824237 ,
-0.08804226, -0.99998045, -0.07314599, 1.8403342 , 2.9812148 ,
2.3005757 , 0.42419332, -0.92279506, -0.50197446, 1.2997544 ,
2.8258905 , 2.6733112 , 0.98229736, -0.69244075, -0.81115675,
0.7352965 , 2.525117 , 2.912752 , 1.5418116 , -0.32726777,
-0.97606325, 0.19192469], dtype=float32)
The function you write only “sees” a single batch of the data, which we can see by printing the device local shape:
x = jnp.arange(32)
print(f"global shape: {x.shape=}")
def f(x):
print(f"device local shape: {x.shape=}")
return x * 2
y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
global shape: x.shape=(32,)
device local shape: x.shape=(4,)
Because each of your functions only sees the device-local part of the data, it means that aggregation-like functions require some extra thought.
For example, here’s what a shard_map
of a sum
looks like:
def f(x):
return jnp.sum(x, keepdims=True)
shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
Array([ 6, 22, 38, 54, 70, 86, 102, 118], dtype=int32)
Our function f
operates separately on each shard, and the resulting summation reflects this.
If we want to sum across shards, we need to explicitly request it using collective operations like jax.lax.psum()
:
def f(x):
sum_in_shard = x.sum()
return jax.lax.psum(sum_in_shard, 'x')
shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x)
Array(496, dtype=int32)
Because the output no longer has a sharded dimension, we set out_specs=P()
.
Comparing the three approaches#
With these concepts fresh in our mind, let’s compare the three approaches for a simple neural network layer. We’ll define our canonical function like this:
@jax.jit
def layer(x, weights, bias):
return jax.nn.sigmoid(x @ weights + bias)
import numpy as np
rng = np.random.default_rng(0)
x = rng.normal(size=(32,))
weights = rng.normal(size=(32, 4))
bias = rng.normal(size=(4,))
layer(x, weights, bias)
Array([0.02138912, 0.893112 , 0.59892005, 0.97742504], dtype=float32)
We can automatically run this in a distributed manner using jax.jit()
and passing appropriately sharded data.
If we shard the leading axis of both x
and weights
in the same way, then the matrix multiplication will autoatically happen in parallel:
P = jax.sharding.PartitionSpec
mesh = jax.sharding.Mesh(jax.devices(), 'x')
sharding = jax.sharding.NamedSharding(mesh, P('x'))
x_sharded = jax.device_put(x, sharding)
weights_sharded = jax.device_put(weights, sharding)
layer(x_sharded, weights_sharded, bias)
Array([0.02138912, 0.893112 , 0.59892005, 0.97742504], dtype=float32)
Alternatively, we can use jax.lax.with_sharding_constraint()
in the function to automatically distribute unsharded inputs:
@jax.jit
def layer_auto(x, weights, bias):
x = jax.lax.with_sharding_constraint(x, sharding)
weights = jax.lax.with_sharding_constraint(weights, sharding)
return layer(x, weights, bias)
layer_auto(x, weights, bias) # pass in unsharded inputs
Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32)
Finally, we can do the same thing with shard_map
, using psum
to indicate the cross-shard collective required for the matrix product:
from functools import partial
@jax.jit
@partial(shard_map, mesh=mesh,
in_specs=(P('x'), P('x', None), P(None)),
out_specs=P(None))
def layer_sharded(x, weights, bias):
return jax.nn.sigmoid(jax.lax.psum(x @ weights, 'x') + bias)
layer_sharded(x, weights, bias)
Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32)
This section has been a brief introduction of sharded and parallel computation;
for more discussion of shard_map
, see SPMD multi-device parallelism with shard_map.
Stateful Computations#
JAX transformations like jit()
, vmap()
, grad()
, require the functions
they wrap to be pure: that is, functions whose outputs depend solely on the inputs, and which have
no side effects such as updating of global state.
You can find a discussion of this in JAX sharp bits: Pure functions.
This constraint can pose some challenges in the context of machine learning, where state may exist in many forms. For example:
model parameters,
optimizer state, and
stateful layers, such as BatchNorm.
This section offers some advice of how to properly handle state in a JAX program.
A simple example: Counter#
Let’s start by looking at a simple stateful program: a counter.
import jax
import jax.numpy as jnp
class Counter:
"""A simple counter."""
def __init__(self):
self.n = 0
def count(self) -> int:
"""Increments the counter and returns the new value."""
self.n += 1
return self.n
def reset(self):
"""Resets the counter to zero."""
self.n = 0
counter = Counter()
for _ in range(3):
print(counter.count())
1
2
3
The counter’s n
attribute maintains the counter’s state between successive calls of count
. It is modified as a side effect of calling count
.
Let’s say we want to count fast, so we JIT-compile the count
method.
(In this example, this wouldn’t actually help speed anyway, for many reasons, but treat this as a toy model of JIT-compiling the update of model parameters, where jit()
makes an enormous difference).
counter.reset()
fast_count = jax.jit(counter.count)
for _ in range(3):
print(fast_count())
1
1
1
Oh no! Our counter isn’t working. This is because the line
self.n += 1
in count
involves a side effect: it modifies the input counter in-place, and so this function is not supported by jit
.
Such side effects are executed only once when the function is first traced, and subsequent calls will not repeat the side effect.
So, how do we fix it?
The solution: explicit state#
Part of the problem with our counter was that the returned value didn’t depend on the arguments, meaning a constant was “baked into” the compiled output. But it shouldn’t be a constant – it should depend on the state. Well, then why don’t we make the state into an argument?
CounterState = int
class CounterV2:
def count(self, n: CounterState) -> tuple[int, CounterState]:
# You could just return n+1, but here we separate its role as
# the output and as the counter state for didactic purposes.
return n+1, n+1
def reset(self) -> CounterState:
return 0
counter = CounterV2()
state = counter.reset()
for _ in range(3):
value, state = counter.count(state)
print(value)
1
2
3
In this new version of Counter
, we moved n
to be an argument of count
, and added another return value that represents the new, updated, state. To use this counter, we now need to keep track of the state explicitly. But in return, we can now safely jax.jit
this counter:
state = counter.reset()
fast_count = jax.jit(counter.count)
for _ in range(3):
value, state = fast_count(state)
print(value)
1
2
3
A general strategy#
We can apply the same process to any stateful method to convert it into a stateless one. We took a class of the form
class StatefulClass
state: State
def stateful_method(*args, **kwargs) -> Output:
and turned it into a class of the form
class StatelessClass
def stateless_method(state: State, *args, **kwargs) -> (Output, State):
This is a common functional programming pattern, and, essentially, is the way that state is handled in all JAX programs.
Notice that the need for a class becomes less clear once we have rewritten it this way. We could just keep stateless_method
, since the class is no longer doing any work.
This is because, like the strategy we just applied, object-oriented programming (OOP) is a way to help programmers understand program state.
In our case, the CounterV2
class is nothing more than a namespace bringing all the functions that use CounterState
into one location. Exercise for the reader: do you think it makes sense to keep it as a class?
Incidentally, you’ve already seen an example of this strategy in the JAX pseudo-randomness API, jax.random
, shown in the :ref:pseudorandom-numbers
section.
Unlike Numpy, which manages random state using implicitly updated stateful classes, JAX requires the programmer to work directly with the random generator state – the PRNG key.
Simple worked example: Linear Regression#
Let’s apply this strategy to a simple machine learning model: linear regression via gradient descent.
Here, we only deal with one kind of state: the model parameters. But generally, you’ll see many kinds of state being threaded in and out of JAX functions, like optimizer state, layer statistics for batchnorm, and others.
The function to look at carefully is update
.
from typing import NamedTuple
class Params(NamedTuple):
weight: jnp.ndarray
bias: jnp.ndarray
def init(rng) -> Params:
"""Returns the initial model params."""
weights_key, bias_key = jax.random.split(rng)
weight = jax.random.normal(weights_key, ())
bias = jax.random.normal(bias_key, ())
return Params(weight, bias)
def loss(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""Computes the least squares error of the model's predictions on x against y."""
pred = params.weight * x + params.bias
return jnp.mean((pred - y) ** 2)
LEARNING_RATE = 0.005
@jax.jit
def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params:
"""Performs one SGD update step on params using the given data."""
grad = jax.grad(loss)(params, x, y)
# If we were using Adam or another stateful optimizer,
# we would also do something like
#
# updates, new_optimizer_state = optimizer(grad, optimizer_state)
#
# and then use `updates` instead of `grad` to actually update the params.
# (And we'd include `new_optimizer_state` in the output, naturally.)
new_params = jax.tree_map(
lambda param, g: param - g * LEARNING_RATE, params, grad)
return new_params
Notice that we manually pipe the params in and out of the update function.
import matplotlib.pyplot as plt
rng = jax.random.key(42)
# Generate true data from y = w*x + b + noise
true_w, true_b = 2, -1
x_rng, noise_rng = jax.random.split(rng)
xs = jax.random.normal(x_rng, (128, 1))
noise = jax.random.normal(noise_rng, (128, 1)) * 0.5
ys = xs * true_w + true_b + noise
# Fit regression
params = init(rng)
for _ in range(1000):
params = update(params, xs, ys)
plt.scatter(xs, ys)
plt.plot(xs, params.weight * xs + params.bias, c='red', label='Model Prediction')
plt.legend();
/tmp/ipykernel_8678/721844192.py:37: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
new_params = jax.tree_map(

Taking it further#
The strategy described above is how any JAX program must handle state when using transformations like jit
, vmap
, grad
, etc.
Handling parameters manually seems fine if you’re dealing with two parameters, but what if it’s a neural net with dozens of layers? You might already be getting worried about two things:
Are we supposed to initialize them all manually, essentially repeating what we already write in the forward pass definition?
Are we supposed to pipe all these things around manually?
The details can be tricky to handle, but there are examples of libraries that take care of this for you. See JAX Neural Network Libraries for some examples.
User Guides#
User guides are deeper dives into particular topics within JAX that become relevant as your JAX project matures into larger or deployed codebases.
How to Think in JAX#
JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively.
JAX vs. NumPy#
Key Concepts:
JAX provides a NumPy-inspired interface for convenience.
Through duck-typing, JAX arrays can often be used as drop-in replacements of NumPy arrays.
Unlike NumPy arrays, JAX arrays are always immutable.
NumPy provides a well-known, powerful API for working with numerical data. For convenience, JAX provides jax.numpy
which closely mirrors the numpy API and provides easy entry into JAX. Almost anything that can be done with numpy
can be done with jax.numpy
:
import matplotlib.pyplot as plt
import numpy as np
x_np = np.linspace(0, 10, 1000)
y_np = 2 * np.sin(x_np) * np.cos(x_np)
plt.plot(x_np, y_np);

import jax.numpy as jnp
x_jnp = jnp.linspace(0, 10, 1000)
y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)
plt.plot(x_jnp, y_jnp);

The code blocks are identical aside from replacing np
with jnp
, and the results are the same. As we can see, JAX arrays can often be used directly in place of NumPy arrays for things like plotting.
The arrays themselves are implemented as different Python types:
type(x_np)
numpy.ndarray
type(x_jnp)
jaxlib.xla_extension.ArrayImpl
Python’s duck-typing allows JAX arrays and NumPy arrays to be used interchangeably in many places.
However, there is one important difference between JAX and NumPy arrays: JAX arrays are immutable, meaning that once created their contents cannot be changed.
Here is an example of mutating an array in NumPy:
# NumPy: mutable arrays
x = np.arange(10)
x[0] = 10
print(x)
[10 1 2 3 4 5 6 7 8 9]
The equivalent in JAX results in an error, as JAX arrays are immutable:
%xmode minimal
Exception reporting mode: Minimal
# JAX: immutable arrays
x = jnp.arange(10)
x[0] = 10
TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
For updating individual elements, JAX provides an indexed update syntax that returns an updated copy:
y = x.at[0].set(10)
print(x)
print(y)
[0 1 2 3 4 5 6 7 8 9]
[10 1 2 3 4 5 6 7 8 9]
NumPy, lax & XLA: JAX API layering#
Key Concepts:
jax.numpy
is a high-level wrapper that provides a familiar interface.jax.lax
is a lower-level API that is stricter and often more powerful.All JAX operations are implemented in terms of operations in XLA – the Accelerated Linear Algebra compiler.
If you look at the source of jax.numpy
, you’ll see that all the operations are eventually expressed in terms of functions defined in jax.lax
. You can think of jax.lax
as a stricter, but often more powerful, API for working with multi-dimensional arrays.
For example, while jax.numpy
will implicitly promote arguments to allow operations between mixed data types, jax.lax
will not:
import jax.numpy as jnp
jnp.add(1, 1.0) # jax.numpy API implicitly promotes mixed types.
Array(2., dtype=float32, weak_type=True)
from jax import lax
lax.add(1, 1.0) # jax.lax API requires explicit type promotion.
MLIRError: Verification failed:
error: "jit(add)/jit(main)/add"(callsite("<module>"("/tmp/ipykernel_8133/3435837498.py":2:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at callsite("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0) at callsite("_run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3130:0) at callsite("run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3075:0) at callsite("run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/zmqshell.py":549:0) at callsite("do_execute"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/ipkernel.py":449:0) at "execute_request"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/kernelbase.py":778:0))))))))))): op requires the same element type for all operands and results
The above exception was the direct cause of the following exception:
ValueError: Cannot lower jaxpr with verifier errors:
op requires the same element type for all operands and results
at loc("jit(add)/jit(main)/add"(callsite("<module>"("/tmp/ipykernel_8133/3435837498.py":2:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at callsite("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0) at callsite("_run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3130:0) at callsite("run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3075:0) at callsite("run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/zmqshell.py":549:0) at callsite("do_execute"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/ipkernel.py":449:0) at "execute_request"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/kernelbase.py":778:0))))))))))))Define JAX_DUMP_IR_TO to dump the module.
If using jax.lax
directly, you’ll have to do type promotion explicitly in such cases:
lax.add(jnp.float32(1), 1.0)
Array(2., dtype=float32)
Along with this strictness, jax.lax
also provides efficient APIs for some more general operations than are supported by NumPy.
For example, consider a 1D convolution, which can be expressed in NumPy this way:
x = jnp.array([1, 2, 1])
y = jnp.ones(10)
jnp.convolve(x, y)
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)
Under the hood, this NumPy operation is translated to a much more general convolution implemented by lax.conv_general_dilated
:
from jax import lax
result = lax.conv_general_dilated(
x.reshape(1, 1, 3).astype(float), # note: explicit promotion
y.reshape(1, 1, 10),
window_strides=(1,),
padding=[(len(y) - 1, len(y) - 1)]) # equivalent of padding='full' in NumPy
result[0, 0]
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)
This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (See Convolutions in JAX for more detail on JAX convolutions).
At their heart, all jax.lax
operations are Python wrappers for operations in XLA; here, for example, the convolution implementation is provided by XLA:ConvWithGeneralPadding.
Every JAX operation is eventually expressed in terms of these fundamental XLA operations, which is what enables just-in-time (JIT) compilation.
To JIT or not to JIT#
Key Concepts:
By default JAX executes operations one at a time, in sequence.
Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once.
Not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time.
The fact that all JAX operations are expressed in terms of XLA allows JAX to use the XLA compiler to execute blocks of code very efficiently.
For example, consider this function that normalizes the rows of a 2D matrix, expressed in terms of jax.numpy
operations:
import jax.numpy as jnp
def norm(X):
X = X - X.mean(0)
return X / X.std(0)
A just-in-time compiled version of the function can be created using the jax.jit
transform:
from jax import jit
norm_compiled = jit(norm)
This function returns the same results as the original, up to standard floating-point accuracy:
np.random.seed(1701)
X = jnp.array(np.random.rand(10000, 10))
np.allclose(norm(X), norm_compiled(X), atol=1E-6)
True
But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case (note the use of block_until_ready()
to account for JAX’s asynchronous dispatch):
%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready()
347 µs ± 2.11 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
278 µs ± 1.77 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
That said, jax.jit
does have limitations: in particular, it requires all arrays to have static shapes. That means that some JAX operations are incompatible with JIT compilation.
For example, this operation can be executed in op-by-op mode:
def get_negatives(x):
return x[x < 0]
x = jnp.array(np.random.randn(10))
get_negatives(x)
Array([-0.10570311, -0.59403396, -0.8680282 , -0.23489487], dtype=float32)
But it returns an error if you attempt to execute it in jit mode:
jit(get_negatives)(x)
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[10])
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError
This is because the function generates an array whose shape is not known at compile time: the size of the output depends on the values of the input array, and so it is not compatible with JIT.
JIT mechanics: tracing and static variables#
Key Concepts:
JIT and other JAX transforms work by tracing a function to determine its effect on inputs of a specific shape and type.
Variables that you don’t want to be traced can be marked as static
To use jax.jit
effectively, it is useful to understand how it works. Let’s put a few print()
statements within a JIT-compiled function and then call the function:
@jit
def f(x, y):
print("Running f():")
print(f" x = {x}")
print(f" y = {y}")
result = jnp.dot(x + 1, y + 1)
print(f" result = {result}")
return result
x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)
Running f():
x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=1/0)>
y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>
Array([0.25773212, 5.3623195 , 5.403243 ], dtype=float32)
Notice that the print statements execute, but rather than printing the data we passed to the function, though, it prints tracer objects that stand-in for them.
These tracer objects are what jax.jit
uses to extract the sequence of operations specified by the function. Basic tracers are stand-ins that encode the shape and dtype of the arrays, but are agnostic to the values. This recorded sequence of computations can then be efficiently applied within XLA to new inputs with the same shape and dtype, without having to re-execute the Python code.
When we call the compiled function again on matching inputs, no re-compilation is required and nothing is printed because the result is computed in compiled XLA rather than in Python:
x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
f(x2, y2)
Array([1.4344584, 4.3004413, 7.9897013], dtype=float32)
The extracted sequence of operations is encoded in a JAX expression, or jaxpr for short. You can view the jaxpr using the jax.make_jaxpr
transformation:
from jax import make_jaxpr
def f(x, y):
return jnp.dot(x + 1, y + 1)
make_jaxpr(f)(x, y)
{ lambda ; a:f32[3,4] b:f32[4]. let
c:f32[3,4] = add a 1.0
d:f32[4] = add b 1.0
e:f32[3] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] c d
in (e,) }
Note one consequence of this: because JIT compilation is done without information on the content of the array, control flow statements in the function cannot depend on traced values. For example, this fails:
@jit
def f(x, neg):
return -x if neg else x
f(1, True)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function f at /tmp/ipykernel_8133/2422663986.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
If there are variables that you would not like to be traced, they can be marked as static for the purposes of JIT compilation:
from functools import partial
@partial(jit, static_argnums=(1,))
def f(x, neg):
return -x if neg else x
f(1, True)
Array(-1, dtype=int32, weak_type=True)
Note that calling a JIT-compiled function with a different static argument results in re-compilation, so the function still works as expected:
f(1, False)
Array(1, dtype=int32, weak_type=True)
Understanding which values and operations will be static and which will be traced is a key part of using jax.jit
effectively.
Static vs Traced Operations#
Key Concepts:
Just as values can be either static or traced, operations can be static or traced.
Static operations are evaluated at compile-time in Python; traced operations are compiled & evaluated at run-time in XLA.
Use
numpy
for operations that you want to be static; usejax.numpy
for operations that you want to be traced.
This distinction between static and traced values makes it important to think about how to keep a static value static. Consider this function:
import jax.numpy as jnp
from jax import jit
@jit
def f(x):
return x.reshape(jnp.array(x.shape).prod())
x = jnp.ones((2, 3))
f(x)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function f at /tmp/ipykernel_8133/1983583872.py:4 for jit. This value became a tracer due to JAX operations on these lines:
operation a:i32[2] = convert_element_type[new_dtype=int32 weak_type=False] b
from line /tmp/ipykernel_8133/1983583872.py:6 (f)
This fails with an error specifying that a tracer was found instead of a 1D sequence of concrete values of integer type. Let’s add some print statements to the function to understand why this is happening:
@jit
def f(x):
print(f"x = {x}")
print(f"x.shape = {x.shape}")
print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}")
# comment this out to avoid the error:
# return x.reshape(jnp.array(x.shape).prod())
f(x)
x = Traced<ShapedArray(float32[2,3])>with<DynamicJaxprTrace(level=1/0)>
x.shape = (2, 3)
jnp.array(x.shape).prod() = Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Notice that although x
is traced, x.shape
is a static value. However, when we use jnp.array
and jnp.prod
on this static value, it becomes a traced value, at which point it cannot be used in a function like reshape()
that requires a static input (recall: array shapes must be static).
A useful pattern is to use numpy
for operations that should be static (i.e. done at compile-time), and use jax.numpy
for operations that should be traced (i.e. compiled and executed at run-time). For this function, it might look like this:
from jax import jit
import jax.numpy as jnp
import numpy as np
@jit
def f(x):
return x.reshape((np.prod(x.shape),))
f(x)
Array([1., 1., 1., 1., 1., 1.], dtype=float32)
For this reason, a standard convention in JAX programs is to import numpy as np
and import jax.numpy as jnp
so that both interfaces are available for finer control over whether operations are performed in a static matter (with numpy
, once at compile-time) or a traced manner (with jax.numpy
, optimized at run-time).
Profiling JAX programs#
Viewing program traces with Perfetto#
We can use the JAX profiler to generate traces of a JAX program that can be visualized using the Perfetto visualizer. Currently, this method blocks the program until a link is clicked and the Perfetto UI loads the trace. If you wish to get profiling information without any interaction, check out the Tensorboard profiler below.
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
After this computation is done, the program will prompt you to open a link to
ui.perfetto.dev
. When you open the link, the Perfetto UI will load the trace
file and open a visualizer.
Program execution will continue after loading the link. The link is no longer valid after opening once, but it will redirect to a new URL that remains valid. You can then click the “Share” button in the Perfetto UI to create a permalink to the trace that can be shared with others.
Remote profiling#
When profiling code that is running remotely (for example on a hosted VM), you need to establish an SSH tunnel on port 9001 for the link to work. You can do that with this command:
$ ssh -L 9001:127.0.0.1:9001 <user>@<host>
or if you’re using Google Cloud:
$ gcloud compute ssh <machine-name> -- -L 9001:127.0.0.1:9001
Manual capture#
Instead of capturing traces programmatically using jax.profiler.trace
, you can
instead start a profiling server in the script of interest by calling
jax.profiler.start_server(<port>)
. If you only need the profiler server to be
active for a portion of your script, you can shut it down by calling
jax.profiler.stop_server()
.
Once the script is running and after the profiler server has started, we can manually capture and trace by running:
$ python -m jax.collect_profile <port> <duration_in_ms>
By default, the resulting trace information is dumped into a temporary directory
but this can be overridden by passing in --log_dir=<directory of choice>
.
Also, by default, the program will prompt you to open a link to
ui.perfetto.dev
. When you open the link, the Perfetto UI will load the trace
file and open a visualizer. This feature is disabled by passing in
--no_perfetto_link
into the command. Alternatively, you can also point
Tensorboard to the log_dir
to analyze the trace (see the
“Tensorboard Profiling” section below).
TensorBoard profiling#
TensorBoard’s profiler can be used to profile JAX programs. Tensorboard is a great way to acquire and visualize performance traces and profiles of your program, including activity on GPU and TPU. The end result looks something like this:
Installation#
The TensorBoard profiler is only available with the version of TensorBoard bundled with TensorFlow.
pip install tensorflow tensorboard-plugin-profile
If you already have TensorFlow installed, you only need to install the
tensorboard-plugin-profile
pip package. Be careful to only install one version
of TensorFlow or TensorBoard, otherwise you may encounter the “duplicate
plugins” error described below. See
https://www.tensorflow.org/guide/profiler for more information on installing
TensorBoard.
Programmatic capture#
You can instrument your code to capture a profiler trace via the
jax.profiler.start_trace()
and jax.profiler.stop_trace()
methods. Call start_trace()
with the directory to write
trace files to. This should be the same --logdir
directory used to start
TensorBoard. Then, you can use TensorBoard to view the traces.
For example, to take a profiler trace:
import jax
jax.profiler.start_trace("/tmp/tensorboard")
# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
jax.profiler.stop_trace()
Note the block_until_ready()
call. We use this to make sure on-device
execution is captured by the trace. See Asynchronous dispatch for details on why
this is necessary.
You can also use the jax.profiler.trace()
context manager as an
alternative to start_trace
and stop_trace
:
import jax
with jax.profiler.trace("/tmp/tensorboard"):
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
To view the trace, first start TensorBoard if you haven’t already:
$ tensorboard --logdir=/tmp/tensorboard
[...]
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.5.0 at http://localhost:6006/ (Press CTRL+C to quit)
You should be able to load TensorBoard at http://localhost:6006/ in this
example. You can specify a different port with the --port
flag. See
Profiling on a remote machine below if running JAX on a remote server.
Then, either select “Profile” in the upper-right dropdown menu, or go directly
to http://localhost:6006/#profile. Available traces appear in the “Runs”
dropdown menu on the left. Select the run you’re interested in, and then under
“Tools”, select trace_viewer
. You should now see a timeline of the
execution. You can use the WASD keys to navigate the trace, and click or drag to
select events to see more details at the bottom. See these TensorFlow
docs
for more details on using the trace viewer.
You can also use the memory_viewer
, op_profile
, and graph_viewer
tools.
Manual capture via TensorBoard#
The following are instructions for capturing a manually-triggered N-second trace from a running program.
Start a TensorBoard server:
tensorboard --logdir /tmp/tensorboard/
You should be able to load TensorBoard at http://localhost:6006/. You can specify a different port with the
--port
flag. See Profiling on a remote machine below if running JAX on a remote server.In the Python program or process you’d like to profile, add the following somewhere near the beginning:
import jax.profiler jax.profiler.start_server(9999)
This starts the profiler server that TensorBoard connects to. The profiler server must be running before you move on to the next step. When you’re done using the server, you can call
jax.profiler.stop_server()
to shut it down.If you’d like to profile a snippet of a long-running program (e.g. a long training loop), you can put this at the beginning of the program and start your program as usual. If you’d like to profile a short program (e.g. a microbenchmark), one option is to start the profiler server in an IPython shell, and run the short program with
%run
after starting the capture in the next step. Another option is to start the profiler server at the beginning of the program and usetime.sleep()
to give you enough time to start the capture.Open http://localhost:6006/#profile, and click the “CAPTURE PROFILE” button in the upper left. Enter “localhost:9999” as the profile service URL (this is the address of the profiler server you started in the previous step). Enter the number of milliseconds you’d like to profile for, and click “CAPTURE”.
If the code you’d like to profile isn’t already running (e.g. if you started the profiler server in a Python shell), run it while the capture is running.
After the capture finishes, TensorBoard should automatically refresh. (Not all of the TensorBoard profiling features are hooked up with JAX, so it may initially look like nothing was captured.) On the left under “Tools”, select
trace_viewer
.You should now see a timeline of the execution. You can use the WASD keys to navigate the trace, and click or drag to select events to see more details at the bottom. See these TensorFlow docs for more details on using the trace viewer.
You can also use the
memory_viewer
,op_profile
, andgraph_viewer
tools.
Adding custom trace events#
By default, the events in the trace viewer are mostly low-level internal JAX
functions. You can add your own events and functions by using
jax.profiler.TraceAnnotation
and jax.profiler.annotate_function()
in
your code.
Troubleshooting#
GPU profiling#
Programs running on GPU should produce traces for the GPU streams near the top of the trace viewer. If you’re only seeing the host traces, check your program logs and/or output for the following error messages.
If you get an error like: Could not load dynamic library 'libcupti.so.10.1'
Full error:
W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libcupti.so.10.1'; dlerror: libcupti.so.10.1: cannot open shared object file: No such file or directory
2020-06-12 13:19:59.822799: E external/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1422] function cupti_interface_->Subscribe( &subscriber_, (CUpti_CallbackFunc)ApiCallback, this)failed with error CUPTI could not be loaded or symbol could not be found.
Add the path to libcupti.so
to the environment variable LD_LIBRARY_PATH
.
(Try locate libcupti.so
to find the path.) For example:
export LD_LIBRARY_PATH=/usr/local/cuda-10.1/extras/CUPTI/lib64/:$LD_LIBRARY_PATH
If you still get the Could not load dynamic library
message after doing this,
check if the GPU trace shows up in the trace viewer anyway. This message
sometimes occurs even when everything is working, since it looks for the
libcupti
library in multiple places.
If you get an error like: failed with error CUPTI_ERROR_INSUFFICIENT_PRIVILEGES
Full error:
E external/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1445] function cupti_interface_->EnableCallback( 0 , subscriber_, CUPTI_CB_DOMAIN_DRIVER_API, cbid)failed with error CUPTI_ERROR_INSUFFICIENT_PRIVILEGES
2020-06-12 14:31:54.097791: E external/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1487] function cupti_interface_->ActivityDisable(activity)failed with error CUPTI_ERROR_NOT_INITIALIZED
Run the following commands (note this requires a reboot):
echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"' | sudo tee -a /etc/modprobe.d/nvidia-kernel-common.conf
sudo update-initramfs -u
sudo reboot now
See NVIDIA’s documentation on this error for more information.
Profiling on a remote machine#
If the JAX program you’d like to profile is running on a remote machine, one option is to run all the instructions above on the remote machine (in particular, start the TensorBoard server on the remote machine), then use SSH local port forwarding to access the TensorBoard web UI from your local machine. Use the following SSH command to forward the default TensorBoard port 6006 from the local to the remote machine:
ssh -L 6006:localhost:6006 <remote server address>
or if you’re using Google Cloud:
$ gcloud compute ssh <machine-name> -- -L 6006:localhost:6006
Multiple TensorBoard installs#
If starting TensorBoard fails with an error like: ValueError: Duplicate plugins for name projector
It’s often because there are two versions of TensorBoard and/or TensorFlow
installed (e.g. the tensorflow
, tf-nightly
, tensorboard
, and tb-nightly
pip packages all include TensorBoard). Uninstalling a single pip package can
result in the tensorboard
executable being removed which is then hard to
replace, so it may be necessary to uninstall everything and reinstall a single
version:
pip uninstall tensorflow tf-nightly tensorboard tb-nightly
pip install tensorflow
Nsight#
NVIDIA’s Nsight
tools can be used to trace and profile JAX code on GPU. For
details, see the Nsight
documentation.
Device Memory Profiling#
Note
May 2023 update: we recommend using Tensorboard
profiling for device memory analysis. After taking a
profile, open the memory_viewer
tab of the Tensorboard profiler for more
detailed and understandable device memory usage.
The JAX Device Memory Profiler allows us to explore how and why JAX programs are using GPU or TPU memory. For example, it can be used to:
Figure out which arrays and executables are in GPU memory at a given time, or
Track down memory leaks.
Installation#
The JAX device memory profiler emits output that can be interpreted using
pprof (google/pprof). Start by installing pprof
,
by following its
installation instructions.
At the time of writing, installing pprof
requires first installing
Go of version 1.16+,
Graphviz, and then running
go install github.com/google/pprof@latest
which installs pprof
as $GOPATH/bin/pprof
, where GOPATH
defaults to
~/go
.
Note
The version of pprof
from google/pprof is not the same as
the older tool of the same name distributed as part of the gperftools
package.
The gperftools
version of pprof
will not work with JAX.
Understanding how a JAX program is using GPU or TPU memory#
A common use of the device memory profiler is to figure out why a JAX program is using a large amount of GPU or TPU memory, for example if trying to debug an out-of-memory problem.
To capture a device memory profile to disk, use
jax.profiler.save_device_memory_profile()
. For example, consider the
following Python program:
import jax
import jax.numpy as jnp
import jax.profiler
def func1(x):
return jnp.tile(x, 10) * 0.5
def func2(x):
y = func1(x)
return y, jnp.tile(x, 10) + 1
x = jax.random.normal(jax.random.key(42), (1000, 1000))
y, z = func2(x)
z.block_until_ready()
jax.profiler.save_device_memory_profile("memory.prof")
If we first run the program above and then execute
pprof --web memory.prof
pprof
opens a web browser containing the following visualization of the device
memory profile in callgraph format:
The callgraph is a visualization of
the Python stack at the point the allocation of each live buffer was made.
For example, in this specific case, the visualization shows that
func2
and its callees were responsible for allocating 76.30MB, of which
38.15MB was allocated inside the call from func1
to func2
.
For more information about how to interpret callgraph visualizations, see the
pprof documentation.
Functions compiled with jax.jit()
are opaque to the device memory profiler.
That is, any memory allocated inside a jit
-compiled function will be
attributed to the function as a whole.
In the example, the call to block_until_ready()
is to ensure that func2
completes before the device memory profile is collected. See
Asynchronous dispatch for more details.
Debugging memory leaks#
We can also use the JAX device memory profiler to track down memory leaks by using
pprof
to visualize the change in memory usage between two device memory profiles
taken at different times. For example, consider the following program which
accumulates JAX arrays into a constantly-growing Python list.
import jax
import jax.numpy as jnp
import jax.profiler
def afunction():
return jax.random.normal(jax.random.key(77), (1000000,))
z = afunction()
def anotherfunc():
arrays = []
for i in range(1, 10):
x = jax.random.normal(jax.random.key(42), (i, 10000))
arrays.append(x)
x.block_until_ready()
jax.profiler.save_device_memory_profile(f"memory{i}.prof")
anotherfunc()
If we simply visualize the device memory profile at the end of execution
(memory9.prof
), it may not be obvious that each iteration of the loop in
anotherfunc
accumulates more device memory allocations:
pprof --web memory9.prof
The large but fixed allocation inside afunction
dominates the profile but does
not grow over time.
By using pprof
’s
--diff_base
feature to visualize the change in memory usage
across loop iterations, we can identify why the memory usage of the
program increases over time:
pprof --web --diff_base memory1.prof memory9.prof
The visualization shows that the memory growth can be attributed to the call to
normal
inside anotherfunc
.
Runtime value debugging in JAX#
Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has TL;DR summaries and you can click the “Read more” links at the bottom to learn more.
Table of contents:
Interactive inspection with jax.debug
#
TL;DR Use jax.debug.print()
to print values to stdout in jax.jit
-,jax.pmap
-, and pjit
-decorated functions,
and jax.debug.breakpoint()
to pause execution of your compiled function to inspect values in the call stack:
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
jax.debug.print("🤯 {x} 🤯", x=x)
y = jnp.sin(x)
jax.debug.breakpoint()
jax.debug.print("🤯 {y} 🤯", y=y)
return y
f(2.)
# Prints:
# 🤯 2.0 🤯
# Enters breakpoint to inspect values!
# 🤯 0.9092974662780762 🤯
Click here to learn more!
Functional error checks with jax.experimental.checkify
#
TL;DR Checkify lets you add jit
-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the checkify.checkify
transformation together with the assert-like checkify.check
function to add runtime checks to JAX code:
from jax.experimental import checkify
import jax
import jax.numpy as jnp
def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative!")
y = x[i]
z = jnp.sin(y)
return z
jittable_f = checkify.checkify(f)
err, z = jax.jit(jittable_f)(jnp.ones((5,)), -1)
print(err.get())
# >> index needs to be non-negative! (check failed at <...>:6 (f))
You can also use checkify to automatically add common checks:
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)
err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)
err, z = checked_f(jnp.ones((5,)), -1)
err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)
Click here to learn more!
Throwing Python errors with JAX’s debug flags#
TL;DR Enable the jax_debug_nans
flag to automatically detect when NaNs are produced in jax.jit
-compiled code (but not in jax.pmap
or jax.pjit
-compiled code) and enable the jax_disable_jit
flag to disable JIT-compilation, enabling use of traditional Python debugging tools like print
and pdb
.
import jax
jax.config.update("jax_debug_nans", True)
def f(x, y):
return x / y
jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
Click here to learn more!
jax.debug.print
and jax.debug.breakpoint
#
The jax.debug
package offers some useful tools for inspecting values
inside of JIT-ted functions.
Debugging with jax.debug.print
and other debugging callbacks#
TL;DR Use jax.debug.print()
to print traced array values to stdout in jit
- and pmap
-decorated functions:
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
jax.debug.print("🤯 {x} 🤯", x=x)
y = jnp.sin(x)
jax.debug.print("🤯 {y} 🤯", y=y)
return y
f(2.)
# Prints:
# 🤯 2.0 🤯
# 🤯 0.9092974662780762 🤯
With some transformations, like jax.grad
and jax.vmap
, you can use Python’s builtin print
function to print out numerical values. But print
won’t work with jax.jit
or jax.pmap
because those transformations delay numerical evaluation. So use jax.debug.print
instead!
Semantically, jax.debug.print
is roughly equivalent to the following Python function
def debug.print(fmt: str, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
print(fmt.format(*args, **kwargs))
except that it can be staged out and transformed by JAX. See the API reference
for more details.
Note that fmt
cannot be an f-string because f-strings are formatted immediately, whereas for jax.debug.print
, we’d like to delay formatting until later.
When to use “debug” print?#
You should use jax.debug.print
for dynamic (i.e. traced) array values within JAX transformations
like jit
, vmap
, and others.
For printing of static values (like array shapes or dtypes), you can use a normal Python print
statement.
Why “debug” print?#
In the name of debugging, jax.debug.print
can reveal information about how computations are evaluated:
xs = jnp.arange(3.)
def f(x):
jax.debug.print("x: {}", x)
y = jnp.sin(x)
jax.debug.print("y: {}", y)
return y
jax.vmap(f)(xs)
# Prints: x: 0.0
# x: 1.0
# x: 2.0
# y: 0.0
# y: 0.841471
# y: 0.9092974
jax.lax.map(f, xs)
# Prints: x: 0.0
# y: 0.0
# x: 1.0
# y: 0.841471
# x: 2.0
# y: 0.9092974
Notice that the printed results are in different orders!
By revealing these inner-workings, the output of jax.debug.print
doesn’t respect JAX’s usual semantics guarantees, like that jax.vmap(f)(xs)
and jax.lax.map(f, xs)
compute the same thing (in different ways). Yet these evaluation order details are exactly what we might want to see when debugging!
So use jax.debug.print
for debugging, and not when semantics guarantees are important.
More examples of jax.debug.print
#
In addition to the above examples using jit
and vmap
, here are a few more to have in mind.
Printing under jax.pmap
#
When jax.pmap
-ed, jax.debug.print
s might be reordered!
xs = jnp.arange(2.)
def f(x):
jax.debug.print("x: {}", x)
return x
jax.pmap(f)(xs)
# Prints: x: 1.0
# x: 0.0
# OR
# Prints: x: 1.0
# x: 0.0
Printing under jax.grad
#
Under a jax.grad
, jax.debug.print
s will only print on the forward pass:
def f(x):
jax.debug.print("x: {}", x)
return x * 2.
jax.grad(f)(1.)
# Prints: x: 1.0
This behavior is similar to how Python’s builtin print
works under a jax.grad
. But by using jax.debug.print
here, the behavior is the same even if the caller applies a jax.jit
.
To print on the backward pass, just use a jax.custom_vjp
:
@jax.custom_vjp
def print_grad(x):
return x
def print_grad_fwd(x):
return x, None
def print_grad_bwd(_, x_grad):
jax.debug.print("x_grad: {}", x_grad)
return (x_grad,)
print_grad.defvjp(print_grad_fwd, print_grad_bwd)
def f(x):
x = print_grad(x)
return x * 2.
jax.grad(f)(1.)
# Prints: x_grad: 2.0
Printing in other transformations#
jax.debug.print
also works in other transformations like xmap
and pjit
.
More control with jax.debug.callback
#
In fact, jax.debug.print
is a thin convenience wrapper around jax.debug.callback
, which can be used directly for greater control over string formatting, or even the kind of output.
Semantically, jax.debug.callback
is roughly equivalent to the following Python function
def callback(fun: Callable, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
fun(*args, **kwargs)
return None
As with jax.debug.print
, these callbacks should only be used for debugging output, like printing or plotting. Printing and plotting are pretty harmless, but if you use it for anything else its behavior might surprise you under transformations. For example, it’s not safe to use jax.debug.callback
for timing operations, since callbacks might be reordered and asynchronous (see below).
Strengths and limitations of jax.debug.print
#
Strengths#
Print debugging is simple and intuitive
jax.debug.callback
can be used for other innocuous side-effects
Limitations#
Adding print statements is a manual process
Can have performance impacts
Interactive inspection with jax.debug.breakpoint()
#
TL;DR Use jax.debug.breakpoint()
to pause the execution of your JAX program to inspect values:
@jax.jit
def f(x):
y, z = jnp.sin(x), jnp.cos(x)
jax.debug.breakpoint()
return y * z
f(2.) # ==> Pauses during execution!
jax.debug.breakpoint()
is actually just an application of jax.debug.callback(...)
that captures information about the call stack. It has the same transformation behaviors as jax.debug.print
as a result (e.g. vmap
-ing jax.debug.breakpoint()
unrolls it across the mapped axis).
Usage#
Calling jax.debug.breakpoint()
in a compiled JAX function will pause your program when it hits the breakpoint. You’ll be presented with a pdb
-like prompt that allows you to inspect the values in the call stack. Unlike pdb
, you will not be able to step through the execution, but you are allowed to resume it.
Debugger commands:
help
- prints out available commandsp
- evaluates an expression and prints its resultpp
- evaluates an expression and pretty-prints its resultu(p)
- go up a stack framed(own)
- go down a stack framew(here)/bt
- print out a backtracel(ist)
- print out code contextc(ont(inue))
- resumes the execution of the programq(uit)/exit
- exits the program (does not work on TPU)
Examples#
Usage with jax.lax.cond
#
When combined with jax.lax.cond
, the debugger can become a useful tool for detecting nan
s or inf
s.
def breakpoint_if_nonfinite(x):
is_finite = jnp.isfinite(x).all()
def true_fn(x):
pass
def false_fn(x):
jax.debug.breakpoint()
lax.cond(is_finite, true_fn, false_fn, x)
@jax.jit
def f(x, y):
z = x / y
breakpoint_if_nonfinite(z)
return z
f(2., 0.) # ==> Pauses during execution!
Sharp bits#
Because jax.debug.breakpoint
is a just an application of jax.debug.callback
, it has the same sharp bits as jax.debug.print
, with a few more caveats:
jax.debug.breakpoint
materializes even more intermediates thanjax.debug.print
because it forces materialization of all values in the call stackjax.debug.breakpoint
has more runtime overhead than ajax.debug.print
because it has to potentially copy all the intermediate values in a JAX program from device to host.
Strengths and limitations of jax.debug.breakpoint()
#
Strengths#
Simple, intuitive and (somewhat) standard
Can inspect many values at the same time, up and down the call stack
Limitations#
Need to potentially use many breakpoints to pinpoint the source of an error
Materializes many intermediates
The checkify
transformation#
TL;DR Checkify lets you add jit
-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the checkify.checkify
transformation together with the assert-like checkify.check
function to add runtime checks to JAX code:
from jax.experimental import checkify
import jax
import jax.numpy as jnp
def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative, got {i}", i=i)
y = x[i]
z = jnp.sin(y)
return z
jittable_f = checkify.checkify(f)
err, z = jax.jit(jittable_f)(jnp.ones((5,)), -2)
print(err.get())
# >> index needs to be non-negative, got -2! (check failed at <...>:6 (f))
You can also use checkify to automatically add common checks:
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)
err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)
err, z = checked_f(jnp.ones((5,)), -1)
err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)
err, z = checked_f(jnp.array([5, 1]), 0)
err.throw() # if no error occurred, throw does nothing!
Functionalizing checks#
The assert-like check API by itself is not functionally pure: it can raise a Python Exception as a side-effect, just like assert. So it can’t be staged out with jit
, pmap
, pjit
, or scan
:
jax.jit(f)(jnp.ones((5,)), -1) # checkify transformation not used
# ValueError: Cannot abstractly evaluate a checkify.check which was not functionalized.
But the checkify transformation functionalizes (or discharges) these effects. A checkify-transformed function returns an error value as a new output and remains functionally pure. That functionalization means checkify-transformed functions can be composed with staging/transforms however we like:
err, z = jax.pmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
err.throw()
"""
ValueError:
.. at mapped index 0: index needs to be non-negative! (check failed at :6 (f))
.. at mapped index 2: out-of-bounds indexing at <..>:7 (f)
"""
Why does JAX need checkify?#
Under some JAX transformations you can express runtime error checks with ordinary Python assertions, for example when only using jax.grad
and jax.numpy
:
def f(x):
assert x > 0., "must be positive!"
return jnp.log(x)
jax.grad(f)(0.)
# ValueError: "must be positive!"
But ordinary assertions don’t work inside jit
, pmap
, pjit
, or scan
. In those cases, numeric computations are staged out rather than evaluated eagerly during Python execution, and as a result numeric values aren’t available:
jax.jit(f)(0.)
# ConcretizationTypeError: "Abstract tracer value encountered ..."
JAX transformation semantics rely on functional purity, especially when composing multiple transformations, so how can we provide an error mechanism without disrupting all that? Beyond needing a new API, the situation is trickier still: XLA HLO doesn’t support assertions or throwing errors, so even if we had a JAX API which was able to stage out assertions, how would we lower these assertions to XLA?
You could imagine manually adding run-time checks to your function and plumbing out values representing errors:
def f_checked(x):
error = x <= 0.
result = jnp.log(x)
return error, result
err, y = jax.jit(f_checked)(0.)
if err:
raise ValueError("must be positive!")
# ValueError: "must be positive!"
The error is a regular value computed by the function, and the error is raised outside of f_checked
. f_checked
is functionally pure, so we know by construction that it’ll already work with jit
, pmap, pjit, scan, and all of JAX’s transformations. The only problem is that this plumbing can be a pain!
checkify
does this rewrite for you: that includes plumbing the error value through the function, rewriting checks to boolean operations and merging the result with the tracked error value, and returning the final error value as an output to the checkified function:
def f(x):
checkify.check(x > 0., "{} must be positive!", x) # convenient but effectful API
return jnp.log(x)
f_checked = checkify(f)
err, x = jax.jit(f_checked)(-1.)
err.throw()
# ValueError: -1. must be positive! (check failed at <...>:2 (f))
We call this functionalizing or discharging the effect introduced by calling check. (In the “manual” example above the error value is just a boolean. checkify’s error values are conceptually similar but also track error messages and expose throw and get methods; see jax.experimental.checkify
). checkify.check
also allows you to add run-time values to your error message by providing them as format arguments to the error message.
You could now manually instrument your code with run-time checks, but checkify
can also automatically add checks for common errors!
Consider these error cases:
jnp.arange(3)[5] # out of bounds
jnp.sin(jnp.inf) # NaN generated
jnp.ones((5,)) / jnp.arange(5) # division by zero
By default checkify
only discharges checkify.check
s, and won’t do anything to catch errors like the above. But if you ask it to, checkify
will also instrument your code with checks automatically.
def f(x, i):
y = x[i] # i could be out of bounds.
z = jnp.sin(y) # z could become NaN
return z
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)
err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)
The API for selecting which automatic checks to enable is based on Sets. See jax.experimental.checkify
for more details.
checkify
under JAX transformations.#
As demonstrated in the examples above, a checkified function can be happily
jitted. Here’s a few more examples of checkify
with other JAX
transformations. Note that checkified functions are functionally pure, and
should trivially compose with all JAX transformations!
jit
#
You can safely add jax.jit
to a checkified function, or checkify
a jitted
function, both will work.
def f(x, i):
return x[i]
checkify_of_jit = checkify.checkify(jax.jit(f))
jit_of_checkify = jax.jit(checkify.checkify(f))
err, _ = checkify_of_jit(jnp.ones((5,)), 100)
err.get()
# out-of-bounds indexing at <..>:2 (f)
err, _ = jit_of_checkify(jnp.ones((5,)), 100)
# out-of-bounds indexing at <..>:2 (f)
vmap
/pmap
#
You can vmap
and pmap
checkified functions (or checkify
mapped functions).
Mapping a checkified function will give you a mapped error, which can contain
different errors for every element of the mapped dimension.
def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative!")
return x[i]
checked_f = checkify.checkify(f, errors=checkify.all_checks)
errs, out = jax.vmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
errs.throw()
"""
ValueError:
at mapped index 0: index needs to be non-negative! (check failed at <...>:2 (f))
at mapped index 2: out-of-bounds indexing at <...>:3 (f)
"""
However, a checkify-of-vmap will produce a single (unmapped) error!
@jax.vmap
def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative!")
return x[i]
checked_f = checkify.checkify(f, errors=checkify.all_checks)
err, out = checked_f(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
err.throw()
# ValueError: index needs to be non-negative! (check failed at <...>:2 (f))
pjit
#
pjit
of a checkified function just works, you only need to specify an
additional out_axis_resources
of None
for the error value output.
def f(x):
return x / x
f = checkify.checkify(f, errors=checkify.float_checks)
f = pjit(
f,
in_shardings=PartitionSpec('x', None),
out_shardings=(None, PartitionSpec('x', None)))
with jax.sharding.Mesh(mesh.devices, mesh.axis_names):
err, data = f(input_data)
err.throw()
# ValueError: divided by zero at <...>:4 (f)
grad
#
Your gradient computation will also be instrumented if you checkify-of-grad:
def f(x):
return x / (1 + jnp.sqrt(x))
grad_f = jax.grad(f)
err, _ = checkify.checkify(grad_f, errors=checkify.nan_checks)(0.)
print(err.get())
>> nan generated by primitive mul at <...>:3 (f)
Note that there’s no multiply in f
, but there is a multiply in its gradient computation (and this is where the NaN is generated!). So use checkify-of-grad to add automatic checks to both forward and backward pass operations.
checkify.check
s will only be applied to the primal value of your function. If
you want to use a check
on a gradient value, use a custom_vjp
:
@jax.custom_vjp
def assert_gradient_negative(x):
return x
def fwd(x):
return assert_gradient_negative(x), None
def bwd(_, grad):
checkify.check(grad < 0, "gradient needs to be negative!")
return (grad,)
assert_gradient_negative.defvjp(fwd, bwd)
jax.grad(assert_gradient_negative)(-1.)
# ValueError: gradient needs to be negative!
Strengths and limitations of jax.experimental.checkify
#
Strengths#
You can use it everywhere (errors are “just values” and behave intuitively under transformations like other values)
Automatic instrumentation: you don’t need to make local modifications to your code. Instead,
checkify
can instrument all of it!
Limitations#
Adding a lot of runtime checks can be expensive (eg. adding a NaN check to every primitive will add a lot of operations to your computation)
Requires threading error values out of functions and manually throwing the error. If the error is not explicitly thrown, you might miss out on errors!
Throwing an error value will materialize that error value on the host, meaning it’s a blocking operation which defeats JAX’s async run-ahead.
JAX debugging flags#
JAX offers flags and context managers that enable catching errors more easily.
jax_debug_nans
configuration option and context manager#
TL;DR Enable the jax_debug_nans
flag to automatically detect when NaNs are produced in jax.jit
-compiled code (but not in jax.pmap
or jax.pjit
-compiled code).
jax_debug_nans
is a JAX flag that when enabled, automatically raises an error when a NaN is detected. It has special handling for JIT-compiled – when a NaN output is detected from a JIT-ted function, the function is re-run eagerly (i.e. without compilation) and will throw an error at the specific primitive that produced the NaN.
Usage#
If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by:
setting the
JAX_DEBUG_NANS=True
environment variable;adding
jax.config.update("jax_debug_nans", True)
near the top of your main file;adding
jax.config.parse_flags_with_absl()
to your main file, then set the option using a command-line flag like--jax_debug_nans=True
;
Example(s)#
import jax
jax.config.update("jax_debug_nans", True)
def f(x, y):
return x / y
jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
Strengths and limitations of jax_debug_nans
#
Strengths#
Easy to apply
Precisely detects where NaNs were produced
Throws a standard Python exception and is compatible with PDB postmortem
Limitations#
Not compatible with
jax.pmap
orjax.pjit
Re-running functions eagerly can be slow
Errors on false positives (e.g. intentionally created NaNs)
jax_disable_jit
configuration option and context manager#
TL;DR Enable the jax_disable_jit
flag to disable JIT-compilation, enabling use of traditional Python debugging tools like print
and pdb
jax_disable_jit
is a JAX flag that when enabled, disables JIT-compilation throughout JAX (including in control flow functions like jax.lax.cond
and jax.lax.scan
).
Usage#
You can disable JIT-compilation by:
setting the
JAX_DISABLE_JIT=True
environment variable;adding
jax.config.update("jax_disable_jit", True)
near the top of your main file;adding
jax.config.parse_flags_with_absl()
to your main file, then set the option using a command-line flag like--jax_disable_jit=True
;
Examples#
import jax
jax.config.update("jax_disable_jit", True)
def f(x):
y = jnp.log(x)
if jnp.isnan(y):
breakpoint()
return y
jax.jit(f)(-2.) # ==> Enters PDB breakpoint!
Strengths and limitations of jax_disable_jit
#
Strengths#
Easy to apply
Enables use of Python’s built-in
breakpoint
andprint
Throws standard Python exceptions and is compatible with PDB postmortem
Limitations#
Not compatible with
jax.pmap
orjax.pjit
Running functions without JIT-compilation can be slow
GPU performance tips#
This document focuses on performance tips for neural network workloads
Matmul precision#
On recent GPU generations, such as the Nvidia A100 generation or later, it can
be a good idea to perform most computations in bfloat16
precision. For
example, if using Flax, instantiate Dense
layers using flax.linen.Dense(..., dtype=jax.numpy.bfloat16)
. Here are some
code examples:
In the Flax LM1B example,
Dense
modules are instantiated with a configurable dtype which defaults to bfloat16.In MaxText,
DenseGeneral
modules are also instantiated with a configurable dtype that defaults to bfloat16.
XLA performance flags#
The existence and exact behavior of XLA flags may be jaxlib
-version dependent.
As of jaxlib==0.4.18
(released Oct 6
2023), setting these XLA flags can
improve performance. Some are related to communication between GPUs, and so are
only relevant when running computations on multiple devices, while others are
related to code generation on each device.
Some of these may be set by default in future releases.
These flags can be set via the XLA_FLAGS
shell environment variable. For
example, we can add this to the top of a Python file:
import os
os.environ['XLA_FLAGS'] = (
'--xla_gpu_enable_triton_softmax_fusion=true '
'--xla_gpu_triton_gemm_any=True '
'--xla_gpu_enable_async_collectives=true '
'--xla_gpu_enable_latency_hiding_scheduler=true '
'--xla_gpu_enable_highest_priority_async_stream=true '
)
For more examples, see also XLA Flags recommended for Pax training on Nvidia GPUs.
Code generation flags#
–xla_gpu_enable_triton_softmax_fusion This flag enables an automatic softmax fusion, based on pattern-matching backed by Triton code generation. The default value is False.
–xla_gpu_triton_gemm_any Use the Triton-based GEMM (matmul) emitter for any GEMM that it supports. The default value is False.
Communication flags#
–xla_gpu_enable_async_collectives This flag enables the collective ops such as
AllReduce
,AllGather
,ReduceScatter
andCollectivePermute
to be asynchronous. Asynchronous communication can overlap cross-core communication with computation. The default value is False.–xla_gpu_enable_latency_hiding_scheduler This flag enables latency hiding schedulers to overlap asynchronous communication with computation efficiently. The default value is False.
–xla_gpu_enable_pipelined_collectives When using pipeline parallelism, this flag enables overlapping the (i+1)-th layer weight
AllGather
with the i-th layer computation. It also enables overlapping (i+1)-th layer weightReduce
/ReduceScatter
with i-th layer’s computation. The default value is False. There are some bugs when this flag is turned on.–xla_gpu_collective_permute_decomposer_threshold This flag is useful when performing GSPMD pipelining. Setting a nonzero threshold decomposes
CollectivePermute
s intoCollectivePermuteReceiveDone
andCollectivePermuteSendDone
pairs, so that computation can be performed between each correspondingReceiveDone
/SendDone
pair and hence achieve more overlap. By default the threshold is 0 and there is no decomposition. Setting it to threshold > 0 such as--xla_gpu_collective_permute_decomposer_threshold=1024
can enable this feature.–xla_gpu_all_gather_combine_threshold_bytes –xla_gpu_reduce_scatter_combine_threshold_bytes –xla_gpu_all_reduce_combine_threshold_bytes These flags tune when to combine multiple small
AllGather
/ReduceScatter
/AllReduce
into one bigAllGather
/ReduceScatter
/AllReduce
to reduce time spent on cross-device communication. For example, for theAllGather
/ReduceScatter
thresholds on a Transformer-based workload, consider tuning them high enough so as to combine at least a Transformer Layer’s weightAllGather
/ReduceScatter
. By default, thecombine_threshold_bytes
is set to 256.
NCCL flags#
These Nvidia NCCL flag values may be useful for single-host multi-device computations on Nvidia GPUs:
os.environ.update({
"NCCL_LL128_BUFFSIZE": "-2",
"NCCL_LL_BUFFSIZE": "-2",
"NCCL_PROTO": "SIMPLE,LL,LL128",
})
These NCCL flags could improve single-host communication speed. These flags don’t seem useful for multi-host communication yet.
Persistent Compilation Cache#
JAX has an optional disk cache for compiled programs. If enabled, JAX will store copies of compiled programs on disk, which can save recompilation time when running the same or similar tasks repeatedly.
Usage#
The compilation cache is enabled when the cache-location is set. This should be done prior to the first compilation. Set the location as follows:
import jax
# Make sure this is called before jax runs any operations!
jax.config.update("jax_compilation_cache_dir", "cache-location")
See the sections below for more detail on cache-location
.
set_cache_dir()
is an alternate way of setting cache-location
.
Local filesystem#
cache-location
can be a directory on the local filesystem. For example:
import jax
jax.config.update("jax_compilation_cache_dir", "/tmp/jax-cache")
Note: the cache does not have an eviction mechanism implemented. If the cache-location is a directory in the local filesystem, its size will continue to grow unless files are manually deleted.
Google Cloud#
When running on Google Cloud, the compilation cache can be placed on a Google Cloud Storage (GCS) bucket. We recommend the following configuration:
Create the bucket in the same region as where the workload will run.
Create the bucket in the same project as the workload’s VM(s). Ensure that permissions are set so that the VM(s) can write to the bucket.
There is no need for replication for smaller workloads. Larger workloads could benefit from replication.
Use “Standard” for the default storage class for the bucket.
Set the soft delete policy to its shortest: 7 days.
Set the object lifecycle to the expected duration of the workload run. For example, if the workload is expected to run for 10 days, set the object lifecycle to 10 days. That should cover restarts that occur during the entire run. Use
age
for the lifecycle condition andDelete
for the action. See Object Lifecycle Management for details. If the object lifecycle is not set, the cache will continue to grow since there is no eviction mechanism implemented.All encryption policies are supported.
Assuming that gs://jax-cache
is the GCS bucket, set cache-location
as
follows:
import jax
jax.config.update("jax_compilation_cache_dir", "gs://jax-cache")
Understanding Jaxprs#
Updated: May 3, 2020 (for commit f1a46fe).
Conceptually, one can think of JAX transformations as first trace-specializing the Python function to be transformed into a small and well-behaved intermediate form that is then interpreted with transformation-specific interpretation rules. One of the reasons JAX can pack so much power into such a small software package is that it starts with a familiar and flexible programming interface (Python with NumPy) and it uses the actual Python interpreter to do most of the heavy lifting to distill the essence of the computation into a simple statically-typed expression language with limited higher-order features. That language is the jaxpr language.
Not all Python programs can be processed this way, but it turns out that many scientific computing and machine learning programs can.
Before we proceed, it is important to point out that not all JAX transformations literally materialize a jaxpr as described above; some, e.g., differentiation or batching, will apply transformations incrementally during tracing. Nevertheless, if one wants to understand how JAX works internally, or to make use of the result of JAX tracing, it is useful to understand jaxprs.
A jaxpr instance represents a function with one or more typed parameters (input
variables) and one or more typed results. The results depend only on the input
variables; there are no free variables captured from enclosing scopes. The
inputs and outputs have types, which in JAX are represented as abstract values.
There are two related representations in the code for jaxprs,
jax.core.Jaxpr
and jax.core.ClosedJaxpr
. A
jax.core.ClosedJaxpr
represents a partially-applied
jax.core.Jaxpr
, and is what you obtain when you use
jax.make_jaxpr()
to inspect jaxprs. It has the following fields:
jaxpr
is ajax.core.Jaxpr
representing the actual computation content of the function (described below).
consts
is a list of constants.
The most interesting part of the ClosedJaxpr is the actual execution content,
represented as a jax.core.Jaxpr
as printed using the following
grammar:
Jaxpr ::= { lambda Var* ; Var+. let
Eqn*
in [Expr+] }
- where:
The parameters of the jaxpr are shown as two lists of variables separated by
;
. The first set of variables are the ones that have been introduced to stand for constants that have been hoisted out. These are called theconstvars
, and in ajax.core.ClosedJaxpr
theconsts
field holds corresponding values. The second list of variables, calledinvars
, correspond to the inputs of the traced Python function.Eqn*
is a list of equations, defining intermediate variables referring to intermediate expressions. Each equation defines one or more variables as the result of applying a primitive on some atomic expressions. Each equation uses only input variables and intermediate variables defined by previous equations.Expr+
: is a list of output atomic expressions (literals or variables) for the jaxpr.
Equations are printed as follows:
Eqn ::= Var+ = Primitive [ Param* ] Expr+
- where:
Var+
are one or more intermediate variables to be defined as the output of a primitive invocation (some primitives can return multiple values).Expr+
are one or more atomic expressions, each either a variable or a literal constant. A special variableunitvar
or literalunit
, printed as*
, represents a value that is not needed in the rest of the computation and has been elided. That is, units are just placeholders.Param*
are zero or more named parameters to the primitive, printed in square brackets. Each parameter is shown asName = Value
.
Most jaxpr primitives are first-order (they take just one or more Expr
as arguments):
Primitive := add | sub | sin | mul | ...
The jaxpr primitives are documented in the jax.lax
module.
For example, here is the jaxpr produced for the function func1
below
>>> from jax import make_jaxpr
>>> import jax.numpy as jnp
>>> def func1(first, second):
... temp = first + jnp.sin(second) * 3.
... return jnp.sum(temp)
...
>>> print(make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8)))
{ lambda ; a:f32[8] b:f32[8]. let
c:f32[8] = sin b
d:f32[8] = mul c 3.0
e:f32[8] = add a d
f:f32[] = reduce_sum[axes=(0,)] e
in (f,) }
Here there are no constvars, a
and b
are the input variables
and they correspond respectively to
first
and second
function parameters. The scalar literal 3.0
is kept
inline.
The reduce_sum
primitive has named parameter axes
, in addition to the
operand e
.
Note that even though execution of a program that calls into JAX builds a jaxpr, Python-level control-flow and Python-level functions execute normally. This means that just because a Python program contains functions and control-flow, the resulting jaxpr does not have to contain control-flow or higher-order features.
For example, when tracing the function func3
JAX will inline the call to
inner
and the conditional if second.shape[0] > 4
, and will produce the same
jaxpr as before
>>> def func2(inner, first, second):
... temp = first + inner(second) * 3.
... return jnp.sum(temp)
...
>>> def inner(second):
... if second.shape[0] > 4:
... return jnp.sin(second)
... else:
... assert False
...
>>> def func3(first, second):
... return func2(inner, first, second)
...
>>> print(make_jaxpr(func3)(jnp.zeros(8), jnp.ones(8)))
{ lambda ; a:f32[8] b:f32[8]. let
c:f32[8] = sin b
d:f32[8] = mul c 3.0
e:f32[8] = add a d
f:f32[] = reduce_sum[axes=(0,)] e
in (f,) }
Handling PyTrees#
In jaxpr there are no tuple types; instead primitives take multiple inputs and produce multiple outputs. When processing a function that has structured inputs or outputs, JAX will flatten those and in jaxpr they will appear as lists of inputs and outputs. For more details, please see the documentation for PyTrees (Pytrees).
For example, the following code produces an identical jaxpr to what we saw before (with two input vars, one for each element of the input tuple)
>>> def func4(arg): # Arg is a pair
... temp = arg[0] + jnp.sin(arg[1]) * 3.
... return jnp.sum(temp)
...
>>> print(make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8))))
{ lambda ; a:f32[8] b:f32[8]. let
c:f32[8] = sin b
d:f32[8] = mul c 3.0
e:f32[8] = add a d
f:f32[] = reduce_sum[axes=(0,)] e
in (f,) }
Constant Vars#
Some values in jaxprs are constants, in that their value does not depend on the jaxpr’s arguments. When these values are scalars they are represented directly in the jaxpr equations; non-scalar array constants are instead hoisted out to the top-level jaxpr, where they correspond to constant variables (“constvars”). These constvars differ from the other jaxpr parameters (“invars”) only as a bookkeeping convention.
Higher-order primitives#
jaxpr includes several higher-order primitives. They are more complicated because they include sub-jaxprs.
Conditionals#
JAX traces through normal Python conditionals. To capture a
conditional expression for dynamic execution, one must use the
jax.lax.switch()
and jax.lax.cond()
constructors,
which have the signatures:
lax.switch(index: int, branches: Sequence[A -> B], operand: A) -> B
lax.cond(pred: bool, true_body: A -> B, false_body: A -> B, operand: A) -> B
Both of these will bind a primitive called cond
internally. The
cond
primitive in jaxprs reflects the more general signature of
lax.switch()
: it takes an integer denoting the index of the branch
to execute (clamped into valid indexing range).
For example:
>>> from jax import lax
>>>
>>> def one_of_three(index, arg):
... return lax.switch(index, [lambda x: x + 1.,
... lambda x: x - 2.,
... lambda x: x + 3.],
... arg)
...
>>> print(make_jaxpr(one_of_three)(1, 5.))
{ lambda ; a:i32[] b:f32[]. let
c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
d:i32[] = clamp 0 c 2
e:f32[] = cond[
branches=(
{ lambda ; f:f32[]. let g:f32[] = add f 1.0 in (g,) }
{ lambda ; h:f32[]. let i:f32[] = sub h 2.0 in (i,) }
{ lambda ; j:f32[]. let k:f32[] = add j 3.0 in (k,) }
)
linear=(False,)
] d b
in (e,) }
The cond primitive has a number of parameters:
branches are jaxprs that correspond to the branch functionals. In this example, those functionals each take one input variable, corresponding to
x
.linear is a tuple of booleans that is used internally by the auto-differentiation machinery to encode which of the input parameters are used linearly in the conditional.
The above instance of the cond primitive takes two operands. The first
one (d
) is the branch index, then b
is the operand (arg
) to
be passed to whichever jaxpr in branches
is selected by the branch
index.
Another example, using lax.cond()
:
>>> from jax import lax
>>>
>>> def func7(arg):
... return lax.cond(arg >= 0.,
... lambda xtrue: xtrue + 3.,
... lambda xfalse: xfalse - 3.,
... arg)
...
>>> print(make_jaxpr(func7)(5.))
{ lambda ; a:f32[]. let
b:bool[] = ge a 0.0
c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
d:f32[] = cond[
branches=(
{ lambda ; e:f32[]. let f:f32[] = sub e 3.0 in (f,) }
{ lambda ; g:f32[]. let h:f32[] = add g 3.0 in (h,) }
)
linear=(False,)
] c a
in (d,) }
In this case, the boolean predicate is converted to an integer index
(0 or 1), and branches
are jaxprs that correspond to the false and
true branch functionals, in that order. Again, each functional takes
one input variable, corresponding to xfalse
and xtrue
respectively.
The following example shows a more complicated situation when the input
to the branch functionals is a tuple, and the false branch functional
contains a constant jnp.ones(1)
that is hoisted as a constvar
>>> def func8(arg1, arg2): # arg2 is a pair
... return lax.cond(arg1 >= 0.,
... lambda xtrue: xtrue[0],
... lambda xfalse: jnp.array([1]) + xfalse[1],
... arg2)
...
>>> print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.)))
{ lambda a:i32[1]; b:f32[] c:f32[1] d:f32[]. let
e:bool[] = ge b 0.0
f:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
g:f32[1] = cond[
branches=(
{ lambda ; h:i32[1] i:f32[1] j:f32[]. let
k:f32[1] = convert_element_type[new_dtype=float32 weak_type=True] h
l:f32[1] = add k j
in (l,) }
{ lambda ; m_:i32[1] n:f32[1] o:f32[]. let in (n,) }
)
linear=(False, False, False)
] f a c d
in (g,) }
While#
Just like for conditionals, Python loops are inlined during tracing.
If you want to capture a loop for dynamic execution, you must use one of several
special operations, jax.lax.while_loop()
(a primitive)
and jax.lax.fori_loop()
(a helper that generates a while_loop primitive):
lax.while_loop(cond_fun: (C -> bool), body_fun: (C -> C), init: C) -> C
lax.fori_loop(start: int, end: int, body: (int -> C -> C), init: C) -> C
In the above signature, “C” stands for the type of the loop “carry” value. For example, here is an example fori loop
>>> import numpy as np
>>>
>>> def func10(arg, n):
... ones = jnp.ones(arg.shape) # A constant
... return lax.fori_loop(0, n,
... lambda i, carry: carry + ones * 3. + arg,
... arg + ones)
...
>>> print(make_jaxpr(func10)(np.ones(16), 5))
{ lambda ; a:f32[16] b:i32[]. let
c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
d:f32[16] = add a c
_:i32[] _:i32[] e:f32[16] = while[
body_jaxpr={ lambda ; f:f32[16] g:f32[16] h:i32[] i:i32[] j:f32[16]. let
k:i32[] = add h 1
l:f32[16] = mul f 3.0
m:f32[16] = add j l
n:f32[16] = add m g
in (k, i, n) }
body_nconsts=2
cond_jaxpr={ lambda ; o:i32[] p:i32[] q:f32[16]. let
r:bool[] = lt o p
in (r,) }
cond_nconsts=0
] c a 0 b d
in (e,) }
The while primitive takes 5 arguments: c a 0 b d
, as follows:
0 constants for
cond_jaxpr
(sincecond_nconsts
is 0)2 constants for
body_jaxpr
(c
, anda
)3 parameters for the initial value of carry
Scan#
JAX supports a special form of loop over the elements of an array (with
statically known shape). The fact that there are a fixed number of iterations
makes this form of looping easily reverse-differentiable. Such loops are
constructed with the jax.lax.scan()
function:
lax.scan(body_fun: (C -> A -> (C, B)), init_carry: C, in_arr: Array[A]) -> (C, Array[B])
This is written in terms of a Haskell Type Signature:
C
is the type of the scan carry, A
is the element type of the
input array(s), and B
is the element type of the output array(s).
For the example consider the function func11
below
>>> def func11(arr, extra):
... ones = jnp.ones(arr.shape) # A constant
... def body(carry, aelems):
... # carry: running dot-product of the two arrays
... # aelems: a pair with corresponding elements from the two arrays
... ae1, ae2 = aelems
... return (carry + ae1 * ae2 + extra, carry)
... return lax.scan(body, 0., (arr, ones))
...
>>> print(make_jaxpr(func11)(np.ones(16), 5.))
{ lambda ; a:f32[16] b:f32[]. let
c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
d:f32[] e:f32[16] = scan[
_split_transpose=False
jaxpr={ lambda ; f:f32[] g:f32[] h:f32[] i:f32[]. let
j:f32[] = mul h i
k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] g
l:f32[] = add k j
m:f32[] = convert_element_type[new_dtype=float32 weak_type=False] f
n:f32[] = add l m
in (n, g) }
length=16
linear=(False, False, False, False)
num_carry=1
num_consts=1
reverse=False
unroll=1
] b 0.0 a c
in (d, e) }
The linear
parameter describes for each of the input variables whether they
are guaranteed to be used linearly in the body. Once the scan goes through
linearization, more arguments will be linear.
The scan primitive takes 4 arguments: b 0.0 a c
, of which:
one is the free variable for the body
one is the initial value of the carry
The next 2 are the arrays over which the scan operates.
XLA_call#
The call primitive arises from JIT compilation, and it encapsulates a sub-jaxpr along with parameters that specify the backend and the device on which the computation should run. For example
>>> from jax import jit
>>>
>>> def func12(arg):
... @jit
... def inner(x):
... return x + arg * jnp.ones(1) # Include a constant in the inner function
... return arg + inner(arg - 2.)
...
>>> print(make_jaxpr(func12)(1.))
{ lambda ; a:f32[]. let
b:f32[] = sub a 2.0
c:f32[1] = pjit[
name=inner
jaxpr={ lambda ; d:f32[] e:f32[]. let
f:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0
g:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
h:f32[1] = mul g f
i:f32[] = convert_element_type[new_dtype=float32 weak_type=False] e
j:f32[1] = add i h
in (j,) }
] a b
k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a
l:f32[1] = add k c
in (l,) }
XLA_pmap#
If you use the jax.pmap()
transformation, the function to be mapped is
captured using the xla_pmap
primitive. Consider this example
>>> from jax import pmap
>>>
>>> def func13(arr, extra):
... def inner(x):
... # use a free variable "extra" and a constant jnp.ones(1)
... return (x + extra + jnp.ones(1)) / lax.psum(x, axis_name='rows')
... return pmap(inner, axis_name='rows')(arr)
...
>>> print(make_jaxpr(func13)(jnp.ones((1, 3)), 5.))
{ lambda ; a:f32[1,3] b:f32[]. let
c:f32[1,3] = xla_pmap[
axis_name=rows
axis_size=1
backend=None
call_jaxpr={ lambda ; d:f32[] e:f32[3]. let
f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
g:f32[3] = add e f
h:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0
i:f32[3] = add g h
j:f32[3] = psum[axes=('rows',) axis_index_groups=None] e
k:f32[3] = div i j
in (k,) }
devices=None
donated_invars=(False, False)
global_axis_size=1
in_axes=(None, 0)
is_explicit_global_axis_size=False
name=inner
out_axes=(0,)
] b a
in (c,) }
The xla_pmap
primitive specifies the name of the axis (parameter
axis_name
) and the body of the function to be mapped as the call_jaxpr
parameter. The value of this parameter is a Jaxpr with 2 input variables.
The parameter in_axes
specifies which of the input variables should be
mapped and which should be broadcast. In our example, the value of extra
is broadcast and the value of arr
is mapped.
External Callbacks in JAX#
This guide outlines the uses of various callback functions, which allow JAX runtimes to execute Python code on the host, even while running under jit
, vmap
, grad
, or another transformation.
Why callbacks?#
A callback routine is a way to perform host-side execution of code at runtime.
As a simple example, suppose you’d like to print the value of some variable during the course of a computation.
Using a simple Python print
statement, it looks like this:
import jax
@jax.jit
def f(x):
y = x + 1
print("intermediate value: {}".format(y))
return y * 2
result = f(2)
intermediate value: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
What is printed is not the runtime value, but the trace-time abstract value (if you’re not famililar with tracing in JAX, a good primer can be found in How To Think In JAX).
To print the value at runtime we need a callback, for example jax.debug.print
:
@jax.jit
def f(x):
y = x + 1
jax.debug.print("intermediate value: {}", y)
return y * 2
result = f(2)
intermediate value: 3
This works by passing the runtime value represented by y
back to the host process, where the host can print the value.
Flavors of Callback#
In earlier versions of JAX, there was only one kind of callback available, implemented in jax.experimental.host_callback
. The host_callback
routines had some deficiencies, and are now deprecated in favor of several callbacks designed for different situations:
jax.pure_callback()
: appropriate for pure functions: i.e. functions with no side effect.jax.experimental.io_callback()
: appropriate for impure functions: e.g. functions which read or write data to disk.jax.debug.callback()
: appropriate for functions that should reflect the execution behavior of the compiler.
(The jax.debug.print()
function we used above is a wrapper around jax.debug.callback()
).
From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow.
callback function |
supports return value |
|
|
|
|
guaranteed execution |
---|---|---|---|---|---|---|
|
✅ |
✅ |
✅ |
❌¹ |
✅ |
❌ |
|
✅ |
✅ |
✅/❌² |
❌ |
✅³ |
✅ |
|
❌ |
✅ |
✅ |
✅ |
✅ |
❌ |
¹ jax.pure_callback
can be used with custom_jvp
to make it compatible with autodiff
² jax.experimental.io_callback
is compatible with vmap
only if ordered=False
.
³ Note that vmap
of scan
/while_loop
of io_callback
has complicated semantics, and its behavior may change in future releases.
Exploring jax.pure_callback
#
jax.pure_callback
is generally the callback function you should reach for when you want host-side execution of a pure function: i.e. a function that has no side-effects (such as printing values, reading data from disk, updating a global state, etc.).
The function you pass to jax.pure_callback
need not actually be pure, but it will be assumed pure by JAX’s transformations and higher-order functions, which means that it may be silently elided or called multiple times.
import jax
import jax.numpy as jnp
import numpy as np
def f_host(x):
# call a numpy (not jax.numpy) operation:
return np.sin(x).astype(x.dtype)
def f(x):
result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
return jax.pure_callback(f_host, result_shape, x)
x = jnp.arange(5.0)
f(x)
Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)
Because pure_callback
can be elided or duplicated, it is compatible out-of-the-box with transformations like jit
and vmap
, as well as higher-order primitives like scan
and while_loop
:”
jax.jit(f)(x)
Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)
jax.vmap(f)(x)
Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)
def body_fun(_, x):
return _, f(x)
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]
Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)
However, because there is no way for JAX to introspect the content of the callback, pure_callback
has undefined autodiff semantics:
%xmode minimal
Exception reporting mode: Minimal
jax.grad(f)(x)
ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.
For an example of using pure_callback
with jax.custom_jvp
, see Example: pure_callback
with custom_jvp
below.
By design functions passed to pure_callback
are treated as if they have no side-effects: one consequence of this is that if the output of the function is not used, the compiler may eliminate the callback entirely:
def print_something():
print('printing something')
return np.int32(0)
@jax.jit
def f1():
return jax.pure_callback(print_something, np.int32(0))
f1();
printing something
@jax.jit
def f2():
jax.pure_callback(print_something, np.int32(0))
return 1.0
f2();
In f1
, the output of the callback is used in the return value of the function, so the callback is executed and we see the printed output.
In f2
on the other hand, the output of the callback is unused, and so the compiler notices this and eliminates the function call. These are the correct semantics for a callback to a function with no side-effects.
Exploring jax.experimental.io_callback
#
In contrast to jax.pure_callback()
, jax.experimental.io_callback()
is explicitly meant to be used with impure functions, i.e. functions that do have side-effects.
As an example, here is a callback to a global host-side numpy random generator. This is an impure operation because a side-effect of generating a random number in numpy is that the random state is updated (Please note that this is meant as a toy example of io_callback
and not necessarily a recommended way of generating random numbers in JAX!).
from jax.experimental import io_callback
from functools import partial
global_rng = np.random.default_rng(0)
def host_side_random_like(x):
"""Generate a random array like x using the global_rng state"""
# We have two side-effects here:
# - printing the shape and dtype
# - calling global_rng, thus updating its state
print(f'generating {x.dtype}{list(x.shape)}')
return global_rng.uniform(size=x.shape).astype(x.dtype)
@jax.jit
def numpy_random_like(x):
return io_callback(host_side_random_like, x, x)
x = jnp.zeros(5)
numpy_random_like(x)
generating float32[5]
Array([0.6369617 , 0.26978672, 0.04097353, 0.01652764, 0.8132702 ], dtype=float32)
The io_callback
is compatible with vmap
by default:
jax.vmap(numpy_random_like)(x)
generating float32[]
generating float32[]
generating float32[]
generating float32[]
generating float32[]
Array([0.91275555, 0.60663575, 0.72949654, 0.543625 , 0.9350724 ], dtype=float32)
Note, however, that this may execute the mapped callbacks in any order. So, for example, if you ran this on a GPU, the order of the mapped outputs might differ from run to run.
If it is important that the order of callbacks be preserved, you can set ordered=True
, in which case attempting to vmap
will raise an error:
@jax.jit
def numpy_random_like_ordered(x):
return io_callback(host_side_random_like, x, x, ordered=True)
jax.vmap(numpy_random_like_ordered)(x)
JaxStackTraceBeforeTransformation: ValueError: Cannot `vmap` ordered IO callback.
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
ValueError: Cannot `vmap` ordered IO callback.
On the other hand, scan
and while_loop
work with io_callback
regardless of whether ordering is enforced:
def body_fun(_, x):
return _, numpy_random_like_ordered(x)
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]
generating float32[]
generating float32[]
generating float32[]
generating float32[]
generating float32[]
Array([0.81585354, 0.0027385 , 0.8574043 , 0.03358557, 0.72965544], dtype=float32)
Like pure_callback
, io_callback
fails under automatic differentiation if it is passed a differentiated variable:
jax.grad(numpy_random_like)(x)
JaxStackTraceBeforeTransformation: ValueError: IO callbacks do not support JVP.
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
ValueError: IO callbacks do not support JVP.
However, if the callback is not dependent on a differentiated variable, it will execute:
@jax.jit
def f(x):
io_callback(lambda: print('hello'), None)
return x
jax.grad(f)(1.0);
hello
Unlike pure_callback
, the compiler will not remove the callback execution in this case, even though the output of the callback is unused in the subsequent computation.
Exploring debug.callback
#
Both pure_callback
and io_callback
enforce some assumptions about the purity of the function they’re calling, and limit in various ways what JAX transforms and compilation machinery may do. debug.callback
essentially assumes nothing about the callback function, such that the action of the callback reflects exactly what JAX is doing during the course of a program. Further, debug.callback
cannot return any value to the program.
from jax import debug
def log_value(x):
# This could be an actual logging call; we'll use
# print() for demonstration
print("log:", x)
@jax.jit
def f(x):
debug.callback(log_value, x)
return x
f(1.0);
log: 1.0
The debug callback is compatible with vmap
:
x = jnp.arange(5.0)
jax.vmap(f)(x);
log: 0.0
log: 1.0
log: 2.0
log: 3.0
log: 4.0
And is also compatible with grad
and other autodiff transformations
jax.grad(f)(1.0);
log: 1.0
This can make debug.callback
more useful for general-purpose debugging than either pure_callback
or io_callback
.
Example: pure_callback
with custom_jvp
#
One powerful way to take advantage of jax.pure_callback()
is to combine it with jax.custom_jvp
(see Custom derivative rules for more details on custom_jvp
).
Suppose we want to create a JAX-compatible wrapper for a scipy or numpy function that is not yet available in the jax.scipy
or jax.numpy
wrappers.
Here, we’ll consider creating a wrapper for the Bessel function of the first kind, implemented in scipy.special.jv
.
We can start by defining a straightforward pure_callback
:
import jax
import jax.numpy as jnp
import scipy.special
def jv(v, z):
v, z = jnp.asarray(v), jnp.asarray(z)
# Require the order v to be integer type: this simplifies
# the JVP rule below.
assert jnp.issubdtype(v.dtype, jnp.integer)
# Promote the input to inexact (float/complex).
# Note that jnp.result_type() accounts for the enable_x64 flag.
z = z.astype(jnp.result_type(float, z.dtype))
# Wrap scipy function to return the expected dtype.
_scipy_jv = lambda v, z: scipy.special.jv(v, z).astype(z.dtype)
# Define the expected shape & dtype of output.
result_shape_dtype = jax.ShapeDtypeStruct(
shape=jnp.broadcast_shapes(v.shape, z.shape),
dtype=z.dtype)
# We use vectorize=True because scipy.special.jv handles broadcasted inputs.
return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
This lets us call into scipy.special.jv
from transformed JAX code, including when transformed by jit
and vmap
:
from functools import partial
j1 = partial(jv, 1)
z = jnp.arange(5.0)
print(j1(z))
[ 0. 0.44005057 0.5767248 0.33905897 -0.06604332]
Here is the same result with jit
:
print(jax.jit(j1)(z))
[ 0. 0.44005057 0.5767248 0.33905897 -0.06604332]
And here is the same result again with vmap
:
print(jax.vmap(j1)(z))
[ 0. 0.44005057 0.5767248 0.33905897 -0.06604332]
However, if we call jax.grad
, we see an error because there is no autodiff rule defined for this function:
jax.grad(j1)(z)
ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.
Let’s define a custom gradient rule for this. Looking at the definition of the Bessel Function of the First Kind, we find that there is a relatively straightforward recurrence relationship for the derivative with respect to the argument z
:
The gradient with respect to \(\nu\) is more complicated, but since we’ve restricted the v
argument to integer types we don’t need to worry about its gradient for the sake of this example.
We can use jax.custom_jvp
to define this automatic differentiation rule for our callback function:
jv = jax.custom_jvp(jv)
@jv.defjvp
def _jv_jvp(primals, tangents):
v, z = primals
_, z_dot = tangents # Note: v_dot is always 0 because v is integer.
jv_minus_1, jv_plus_1 = jv(v - 1, z), jv(v + 1, z)
djv_dz = jnp.where(v == 0, -jv_plus_1, 0.5 * (jv_minus_1 - jv_plus_1))
return jv(v, z), z_dot * djv_dz
Now computing the gradient of our function will work correctly:
j1 = partial(jv, 1)
print(jax.grad(j1)(2.0))
-0.06447162
Further, since we’ve defined our gradient in terms of jv
itself, JAX’s architecture means that we get second-order and higher derivatives for free:
jax.hessian(j1)(2.0)
Array(-0.4003078, dtype=float32, weak_type=True)
Keep in mind that although this all works correctly with JAX, each call to our callback-based jv
function will result in passing the input data from the device to the host, and passing the output of scipy.special.jv
from the host back to the device.
When running on accelerators like GPU or TPU, this data movement and host synchronization can lead to significant overhead each time jv
is called.
However, if you are running JAX on a single CPU (where the “host” and “device” are on the same hardware), JAX will generally do this data transfer in a fast, zero-copy fashion, making this pattern is a relatively straightforward way extend JAX’s capabilities.
Type promotion semantics#
This document describes JAX’s type promotion rules–i.e., the result of jax.numpy.promote_types()
for each pair of types.
For some background on the considerations that went into the design of what is described below, see Design of Type Promotion Semantics for JAX.
JAX’s type promotion behavior is determined via the following type promotion lattice:
where, for example:
b1
meansnp.bool_
,i2
meansnp.int16
,u4
meansnp.uint32
,bf
meansnp.bfloat16
,f2
meansnp.float16
,c8
meansnp.complex64
,i*
means Pythonint
or weakly-typedint
,f*
means Pythonfloat
or weakly-typedfloat
, andc*
means Pythoncomplex
or weakly-typedcomplex
.
(for more about weak types, see Weakly-typed values in JAX below).
Promotion between any two types is given by their join on this lattice, which generates the following binary promotion table:
b1 | u1 | u2 | u4 | u8 | i1 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i* | f* | c* | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
b1 | b1 | u1 | u2 | u4 | u8 | i1 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i* | f* | c* |
u1 | u1 | u1 | u2 | u4 | u8 | i2 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | u1 | f* | c* |
u2 | u2 | u2 | u2 | u4 | u8 | i4 | i4 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | u2 | f* | c* |
u4 | u4 | u4 | u4 | u4 | u8 | i8 | i8 | i8 | i8 | bf | f2 | f4 | f8 | c8 | c16 | u4 | f* | c* |
u8 | u8 | u8 | u8 | u8 | u8 | f* | f* | f* | f* | bf | f2 | f4 | f8 | c8 | c16 | u8 | f* | c* |
i1 | i1 | i2 | i4 | i8 | f* | i1 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i1 | f* | c* |
i2 | i2 | i2 | i4 | i8 | f* | i2 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i2 | f* | c* |
i4 | i4 | i4 | i4 | i8 | f* | i4 | i4 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i4 | f* | c* |
i8 | i8 | i8 | i8 | i8 | f* | i8 | i8 | i8 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i8 | f* | c* |
bf | bf | bf | bf | bf | bf | bf | bf | bf | bf | bf | f4 | f4 | f8 | c8 | c16 | bf | bf | c8 |
f2 | f2 | f2 | f2 | f2 | f2 | f2 | f2 | f2 | f2 | f4 | f2 | f4 | f8 | c8 | c16 | f2 | f2 | c8 |
f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f8 | c8 | c16 | f4 | f4 | c8 |
f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | c16 | c16 | f8 | f8 | c16 |
c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c16 | c8 | c16 | c8 | c8 | c8 |
c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 |
i* | i* | u1 | u2 | u4 | u8 | i1 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i* | f* | c* |
f* | f* | f* | f* | f* | f* | f* | f* | f* | f* | bf | f2 | f4 | f8 | c8 | c16 | f* | f* | c* |
c* | c* | c* | c* | c* | c* | c* | c* | c* | c* | c8 | c8 | c8 | c16 | c8 | c16 | c* | c* | c* |
Jax’s type promotion rules differ from those of NumPy, as given by
numpy.promote_types()
, in those cells highlighted with a green background
in the table above. There are three key classes of differences:
When promoting a weakly typed value against a typed JAX value of the same category, JAX always prefers the precision of the JAX value. For example,
jnp.int16(1) + 1
will returnint16
rather than promoting toint64
as in NumPy. Note that this applies only to Python scalar values; if the constant is a NumPy array then the above lattice is used for type promotion. For example,jnp.int16(1) + np.array(1)
will returnint64
.When promoting an integer or boolean type against a floating-point or complex type, JAX always prefers the type of the floating-point or complex type.
JAX supports the bfloat16 non-standard 16-bit floating point type (
jax.numpy.bfloat16
), which is useful for neural network training. The only notable promotion behavior is with respect to IEEE-754float16
, with whichbfloat16
promotes to afloat32
.
The differences between NumPy and JAX are motivated by the fact that accelerator devices, such as GPUs and TPUs, either pay a significant performance penalty to use 64-bit floating point types (GPUs) or do not support 64-bit floating point types at all (TPUs). Classic NumPy’s promotion rules are too willing to overpromote to 64-bit types, which is problematic for a system designed to run on accelerators.
JAX uses floating point promotion rules that are more suited to modern accelerator devices and are less aggressive about promoting floating point types. The promotion rules used by JAX for floating-point types are similar to those used by PyTorch.
Effects of Python operator dispatch#
Keep in mind that Python operators like + will dispatch based on the Python type of
the two values being added. This means that, for example, np.int16(1) + 1
will
promote using NumPy rules, whereas jnp.int16(1) + 1
will promote using JAX rules.
This can lead to potentially confusing non-associative promotion semantics when
the two types of promotion are combined;
for example with np.int16(1) + 1 + jnp.int16(1)
.
Weakly-typed values in JAX#
Weakly-typed values in JAX can in most cases be thought of as having promotion behavior
equivalent to that of Python scalars, such as the integer scalar 2
in the following:
>>> x = jnp.arange(5, dtype='int8')
>>> 2 * x
Array([0, 2, 4, 6, 8], dtype=int8)
JAX’s weak type framework is designed to prevent unwanted type promotion within
binary operations between JAX values and values with no explicitly user-specified type,
such as Python scalar literals. For example, if 2
were not treated as weakly-typed,
the expression above would lead to an implicit type promotion:
>>> jnp.int32(2) * x
Array([0, 2, 4, 6, 8], dtype=int32)
When used in JAX, Python scalars are sometimes promoted to DeviceArray
objects, for example during JIT compilation. To maintain the desired promotion
semantics in this case, DeviceArray
objects carry a weak_type
flag
that can be seen in an array’s string representation:
>>> jnp.asarray(2)
Array(2, dtype=int32, weak_type=True)
If the dtype
is specified explicitly, it will instead result in a standard
strongly-typed array value:
>>> jnp.asarray(2, dtype='int32')
Array(2, dtype=int32)
Strict dtype promotion#
In some contexts it can be useful to disable implicit type promotion behavior, and
instead require all promotions to be explicit. This can be done in JAX by setting the
jax_numpy_dtype_promtion
flag to 'strict'
. Locally, it can be done with acontext manager:
>>> x = jnp.float32(1)
>>> y = jnp.int32(1)
>>> with jax.numpy_dtype_promotion('strict'):
... z = x + y
...
Traceback (most recent call last):
TypePromotionError: Input dtypes ('float32', 'int32') have no available implicit
dtype promotion path when jax_numpy_dtype_promotion=strict. Try explicitly casting
inputs to the desired output type, or set jax_numpy_dtype_promotion=standard.
For convenience, strict promotion mode will still allow safe weakly-typed promotions, so you can still write code code that mixes JAX arrays and Python scalars:
>>> with jax.numpy_dtype_promotion('strict'):
... z = x + 1
>>> print(z)
2.0
If you would prefer to set the configuration globally, you can do so using the standard configuration update:
jax.config.update('jax_numpy_dtype_promotion', 'strict')
To restore the default standard type promotion, set this configuration to 'standard'
:
jax.config.update('jax_numpy_dtype_promotion', 'standard')
Pytrees#
What is a pytree?#
In JAX, we use the term pytree to refer to a tree-like structure built out of container-like Python objects. Classes are considered container-like if they are in the pytree registry, which by default includes lists, tuples, and dicts. That is:
any object whose type is not in the pytree container registry is considered a leaf pytree;
any object whose type is in the pytree container registry, and which contains pytrees, is considered a pytree.
For each entry in the pytree container registry, a container-like type is
registered with a pair of functions that specify how to convert an instance of
the container type to a (children, metadata)
pair and how to convert such a
pair back to an instance of the container type. Using these functions, JAX can
canonicalize any tree of registered container objects into tuples.
Example pytrees:
[1, "a", object()] # 3 leaves
(1, (2, 3), ()) # 3 leaves
[1, {"k1": 2, "k2": (3, 4)}, 5] # 5 leaves
JAX can be extended to consider other container types as pytrees; see Extending pytrees below.
Pytrees and JAX functions#
Many JAX functions, like jax.lax.scan()
, operate over pytrees of arrays.
JAX function transformations can be applied to functions that accept as input
and produce as output pytrees of arrays.
Applying optional parameters to pytrees#
Some JAX function transformations take optional parameters that specify how
certain input or output values should be treated (e.g. the in_axes
and
out_axes
arguments to vmap()
). These parameters can also be pytrees,
and their structure must correspond to the pytree structure of the corresponding
arguments. In particular, to be able to “match up” leaves in these parameter
pytrees with values in the argument pytrees, the parameter pytrees are often
constrained to be tree prefixes of the argument pytrees.
For example, if we pass the following input to vmap()
(note that the input
arguments to a function are considered a tuple):
(a1, {"k1": a2, "k2": a3})
We can use the following in_axes
pytree to specify that only the k2
argument is mapped (axis=0
) and the rest aren’t mapped over
(axis=None
):
(None, {"k1": None, "k2": 0})
The optional parameter pytree structure must match that of the main input
pytree. However, the optional parameters can optionally be specified as a
“prefix” pytree, meaning that a single leaf value can be applied to an entire
sub-pytree. For example, if we have the same vmap()
input as above,
but wish to only map over the dictionary argument, we can use:
(None, 0) # equivalent to (None, {"k1": 0, "k2": 0})
Or, if we want every argument to be mapped, we can simply write a single leaf value that is applied over the entire argument tuple pytree:
0
This happens to be the default in_axes
value for vmap()
!
The same logic applies to other optional parameters that refer to specific input
or output values of a transformed function, e.g. vmap
’s out_axes
.
Viewing the pytree definition of an object#
To view the pytree definition of an arbitrary object
for debugging purposes, you can use:
from jax.tree_util import tree_structure
print(tree_structure(object))
Developer information#
This is primarily JAX internal documentation, end-users are not supposed to need to understand this to use JAX, except when registering new user-defined container types with JAX. Some of these details may change.
Internal pytree handling#
JAX flattens pytrees into lists of leaves at the api.py
boundary (and also
in control flow primitives). This keeps downstream JAX internals simpler:
transformations like grad()
, jit()
, and vmap()
can handle user functions that accept and return the myriad different Python
containers, while all the other parts of the system can operate on functions
that only take (multiple) array arguments and always return a flat list of arrays.
When JAX flattens a pytree it will produce a list of leaves and a treedef
object that encodes the structure of the original value. The treedef
can
then be used to construct a matching structured value after transforming the
leaves. Pytrees are tree-like, rather than DAG-like or graph-like, in that we
handle them assuming referential transparency and that they can’t contain
reference cycles.
Here is a simple example:
from jax.tree_util import tree_flatten, tree_unflatten
import jax.numpy as jnp
# The structured value to be transformed
value_structured = [1., (2., 3.)]
# The leaves in value_flat correspond to the `*` markers in value_tree
value_flat, value_tree = tree_flatten(value_structured)
print(f"{value_flat=}\n{value_tree=}")
# Transform the flat value list using an element-wise numeric transformer
transformed_flat = list(map(lambda v: v * 2., value_flat))
print(f"{transformed_flat=}")
# Reconstruct the structured output, using the original
transformed_structured = tree_unflatten(value_tree, transformed_flat)
print(f"{transformed_structured=}")
value_flat=[1.0, 2.0, 3.0]
value_tree=PyTreeDef([*, (*, *)])
transformed_flat=[2.0, 4.0, 6.0]
transformed_structured=[2.0, (4.0, 6.0)]
By default, pytree containers can be lists, tuples, dicts, namedtuple, None, OrderedDict. Other types of values, including numeric and ndarray values, are treated as leaves:
from collections import namedtuple
Point = namedtuple('Point', ['x', 'y'])
example_containers = [
(1., [2., 3.]),
(1., {'b': 2., 'a': 3.}),
1.,
None,
jnp.zeros(2),
Point(1., 2.)
]
def show_example(structured):
flat, tree = tree_flatten(structured)
unflattened = tree_unflatten(tree, flat)
print(f"{structured=}\n {flat=}\n {tree=}\n {unflattened=}")
for structured in example_containers:
show_example(structured)
structured=(1.0, [2.0, 3.0])
flat=[1.0, 2.0, 3.0]
tree=PyTreeDef((*, [*, *]))
unflattened=(1.0, [2.0, 3.0])
structured=(1.0, {'b': 2.0, 'a': 3.0})
flat=[1.0, 3.0, 2.0]
tree=PyTreeDef((*, {'a': *, 'b': *}))
unflattened=(1.0, {'a': 3.0, 'b': 2.0})
structured=1.0
flat=[1.0]
tree=PyTreeDef(*)
unflattened=1.0
structured=None
flat=[]
tree=PyTreeDef(None)
unflattened=None
structured=Array([0., 0.], dtype=float32)
flat=[Array([0., 0.], dtype=float32)]
tree=PyTreeDef(*)
unflattened=Array([0., 0.], dtype=float32)
structured=Point(x=1.0, y=2.0)
flat=[1.0, 2.0]
tree=PyTreeDef(CustomNode(namedtuple[Point], [*, *]))
unflattened=Point(x=1.0, y=2.0)
Extending pytrees#
By default, any part of a structured value that is not recognized as an internal pytree node (i.e. container-like) is treated as a leaf:
class Special(object):
def __init__(self, x, y):
self.x = x
self.y = y
def __repr__(self):
return "Special(x={}, y={})".format(self.x, self.y)
show_example(Special(1., 2.))
structured=Special(x=1.0, y=2.0)
flat=[Special(x=1.0, y=2.0)]
tree=PyTreeDef(*)
unflattened=Special(x=1.0, y=2.0)
The set of Python types that are considered internal pytree nodes is extensible,
through a global registry of types, and values of registered types are traversed
recursively. To register a new type, you can use
register_pytree_node()
:
from jax.tree_util import register_pytree_node
class RegisteredSpecial(Special):
def __repr__(self):
return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)
def special_flatten(v):
"""Specifies a flattening recipe.
Params:
v: the value of registered type to flatten.
Returns:
a pair of an iterable with the children to be flattened recursively,
and some opaque auxiliary data to pass back to the unflattening recipe.
The auxiliary data is stored in the treedef for use during unflattening.
The auxiliary data could be used, e.g., for dictionary keys.
"""
children = (v.x, v.y)
aux_data = None
return (children, aux_data)
def special_unflatten(aux_data, children):
"""Specifies an unflattening recipe.
Params:
aux_data: the opaque data that was specified during flattening of the
current treedef.
children: the unflattened children
Returns:
a re-constructed object of the registered type, using the specified
children and auxiliary data.
"""
return RegisteredSpecial(*children)
# Global registration
register_pytree_node(
RegisteredSpecial,
special_flatten, # tell JAX what are the children nodes
special_unflatten # tell JAX how to pack back into a RegisteredSpecial
)
show_example(RegisteredSpecial(1., 2.))
structured=RegisteredSpecial(x=1.0, y=2.0)
flat=[1.0, 2.0]
tree=PyTreeDef(CustomNode(RegisteredSpecial[None], [*, *]))
unflattened=RegisteredSpecial(x=1.0, y=2.0)
Alternatively, you can define appropriate tree_flatten
and tree_unflatten
methods
on your class and decorate it with register_pytree_node_class()
:
from jax.tree_util import register_pytree_node_class
@register_pytree_node_class
class RegisteredSpecial2(Special):
def __repr__(self):
return "RegisteredSpecial2(x={}, y={})".format(self.x, self.y)
def tree_flatten(self):
children = (self.x, self.y)
aux_data = None
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
show_example(RegisteredSpecial2(1., 2.))
structured=RegisteredSpecial2(x=1.0, y=2.0)
flat=[1.0, 2.0]
tree=PyTreeDef(CustomNode(RegisteredSpecial2[None], [*, *]))
unflattened=RegisteredSpecial2(x=1.0, y=2.0)
When defining unflattening functions, in general children
should contain all the
dynamic elements of the data structure (arrays, dynamic scalars, and pytrees), while
aux_data
should contain all the static elements that will be rolled into the treedef
structure. JAX sometimes needs to compare treedef
for equality, or compute its hash
for use in the JIT cache, and so care must be taken to ensure that the auxiliary data
specified in the flattening recipe supports meaningful hashing and equality comparisons.
The whole set of functions for operating on pytrees are in jax.tree_util
.
Custom PyTrees and Initialization#
One common gotcha with user-defined PyTree objects is that JAX transformations occasionally initialize them with unexpected values, so that any input validation done at initialization may fail. For example:
class MyTree:
def __init__(self, a):
self.a = jnp.asarray(a)
register_pytree_node(MyTree, lambda tree: ((tree.a,), None),
lambda _, args: MyTree(*args))
tree = MyTree(jnp.arange(5.0))
jax.vmap(lambda x: x)(tree) # Error because object() is passed to MyTree.
jax.jacobian(lambda x: x)(tree) # Error because MyTree(...) is passed to MyTree
In the first case, JAX’s internals use arrays of object()
values to infer the structure
of the tree; in the second case, the jacobian of a function mapping a tree to a tree
is defined as a tree of trees.
For this reason, the __init__
and __new__
methods of custom PyTree classes should
generally avoid doing any array conversion or other input validation, or else
anticipate and handle these special cases. For example:
class MyTree:
def __init__(self, a):
if not (type(a) is object or a is None or isinstance(a, MyTree)):
a = jnp.asarray(a)
self.a = a
Another possibility is to structure your tree_unflatten
function so that it avoids
calling __init__
; for example:
def tree_unflatten(aux_data, children):
del aux_data # unused in this class
obj = object.__new__(MyTree)
obj.a = a
return obj
If you go this route, make sure that your tree_unflatten
function stays in-sync with
__init__
if and when the code is updated.
Ahead-of-time lowering and compilation#
JAX offers several transformations, such as jax.jit
and jax.pmap
, returning
a function that is compiled and runs on accelerators or the CPU. As the JIT
acronym indicates, all compilation happens just-in-time for execution.
Some situations call for ahead-of-time (AOT) compilation instead. When you want to fully compile prior to execution time, or you want control over when different parts of the compilation process take place, JAX has some options for you.
First, let’s review the stages of compilation. Suppose that f
is a
function/callable output by jax.jit()
, say f = jax.jit(F)
for some input
callable F
. When it is invoked with arguments, say f(x, y)
where x
and y
are arrays, JAX does the following in order:
Stage out a specialized version of the original Python callable
F
to an internal representation. The specialization reflects a restriction ofF
to input types inferred from properties of the argumentsx
andy
(usually their shape and element type).Lower this specialized, staged-out computation to the XLA compiler’s input language, StableHLO.
Compile the lowered HLO program to produce an optimized executable for the target device (CPU, GPU, or TPU).
Execute the compiled executable with the arrays
x
andy
as arguments.
JAX’s AOT API gives you direct control over steps #2, #3, and #4 (but not #1), plus some other features along the way. An example:
>>> import jax
>>> import jax.numpy as jnp
>>> import numpy as np
>>> def f(x, y): return 2 * x + y
>>> x, y = 3, 4
>>> lowered = jax.jit(f).lower(x, y)
>>> # Print lowered HLO
>>> print(lowered.as_text())
module @jit_f.0 {
func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
%0 = stablehlo.constant dense<2> : tensor<i32>
%1 = stablehlo.multiply %0, %arg0 : tensor<i32>
%2 = stablehlo.add %1, %arg1 : tensor<i32>
return %2 : tensor<i32>
}
}
>>> compiled = lowered.compile()
>>> # Query for cost analysis, print FLOP estimate
>>> compiled.cost_analysis()[0]['flops']
2.0
>>> # Execute the compiled function!
>>> compiled(x, y)
DeviceArray(10, dtype=int32)
See the jax.stages
documentation for more details on what functionality
the lowering and compiled functions provide.
In place of jax.jit
above, you can also lower(...)
the result of
jax.pmap()
, as well as pjit
and xmap
(from
jax.experimental.pjit
and jax.experimental.maps
respectively). In
each case, you can compile()
the result similarly.
All optional arguments to jit
—such as static_argnums
—are respected in
the corresponding lowering, compilation, and execution. Again the same goes for
pmap
, pjit
, and xmap
.
In the example above, we can replace the arguments to lower
with any objects
that have shape
and dtype
attributes:
>>> i32_scalar = jax.ShapeDtypeStruct((), jnp.dtype('int32'))
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x, y)
DeviceArray(10, dtype=int32)
More generally, lower
only needs its arguments to structurally supply what JAX
must know for specialization and lowering. For typical array arguments like the
ones above, this means shape
and dtype
fields. For static arguments, by
contrast, JAX needs actual array values (more on this
below).
Invoking an AOT-compiled function with arguments that are incompatible with its lowering raises an error:
>>> x_1d = y_1d = jnp.arange(3)
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_1d, y_1d)
...
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with int32[3]
Argument 'y' compiled with int32[] and called with int32[3]
>>> x_f = y_f = jnp.float32(72.)
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_f, y_f)
...
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with float32[]
Argument 'y' compiled with int32[] and called with float32[]
Relatedly, AOT-compiled functions cannot be transformed by JAX’s just-in-time
transformations such as
jax.jit
, jax.grad()
, and jax.vmap()
.
Lowering with static arguments#
Lowering with static arguments underscores the interaction between options
passed to jax.jit
, the arguments passed to lower
, and the arguments needed
to invoke the resulting compiled function. Continuing with our example above:
>>> lowered_with_x = jax.jit(f, static_argnums=0).lower(7, 8)
>>> # Lowered HLO, specialized to the *value* of the first argument (7)
>>> print(lowered_with_x.as_text())
module @jit_f.1 {
func.func public @main(%arg0: tensor<i32>) -> tensor<i32> {
%0 = stablehlo.constant dense<14> : tensor<i32>
%1 = stablehlo.add %0, %arg0 : tensor<i32>
return %1 : tensor<i32>
}
}
>>> lowered_with_x.compile()(5)
DeviceArray(19, dtype=int32)
Note that lower
here takes two arguments as usual, but the subsequent compiled
function accepts only the remaining non-static second argument. The static first
argument (value 7) is taken as a constant at lowering time and built into the
lowered computation, where it is possibly folded in with other constants. In
this case, its multiplication by 2 is simplified, resulting in the constant 14.
Although the second argument to lower
above can be replaced by a hollow
shape/dtype structure, it is necessary that the static first argument be a
concrete value. Otherwise, lowering would err:
>>> jax.jit(f, static_argnums=0).lower(i32_scalar, i32_scalar)
TypeError: unsupported operand type(s) for *: 'int' and 'ShapeDtypeStruct'
>>> jax.jit(f, static_argnums=0).lower(10, i32_scalar).compile()(5)
DeviceArray(25, dtype=int32)
AOT-compiled functions cannot be transformed#
Compiled functions are specialized to a particular set of argument “types,” such
as arrays with a specific shape and element type in our running example. From
JAX’s internal point of view, transformations such as jax.vmap()
alter the
type signature of functions in a way that invalidates the compiled-for type
signature. As a policy, JAX simply disallows compiled functions to be involved
in transformations. Example:
>>> def g(x):
... assert x.shape == (3, 2)
... return x @ jnp.ones(2)
>>> def make_z(*shape):
... return jnp.arange(np.prod(shape)).reshape(shape)
>>> z, zs = make_z(3, 2), make_z(4, 3, 2)
>>> g_jit = jax.jit(g)
>>> g_aot = jax.jit(g).lower(z).compile()
>>> jax.vmap(g_jit)(zs)
DeviceArray([[ 1., 5., 9.],
[13., 17., 21.],
[25., 29., 33.],
[37., 41., 45.]], dtype=float32)
>>> jax.vmap(g_aot)(zs)
TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type <class 'jax.interpreters.batching.BatchTracer'>.
A similar error is raised when g_aot
is involved in autodiff
(e.g. jax.grad()
). For consistency, transformation by jax.jit
is
disallowed as well, even though jit
does not meaningfully modify its
argument’s type signature.
Debug information and analyses, when available#
In addition to the primary AOT functionality (separate and explicit lowering, compilation, and execution), JAX’s various AOT stages also offer some additional features to help with debugging and gathering compiler feedback.
For instance, as the initial example above shows, lowered functions often offer
a text representation. Compiled functions do the same, and also offer cost and
memory analyses from the compiler. All of these are provided via methods on the
jax.stages.Lowered
and jax.stages.Compiled
objects (e.g.,
lowered.as_text()
and compiled.cost_analysis()
above).
These methods are meant as an aid for manual inspection and debugging, not as a reliably programmable API. Their availability and output vary by compiler, platform, and runtime. This makes for two important caveats:
If some functionality is unavailable on JAX’s current backend, then the method for it returns something trivial (and
False
-like). For example, if the compiler underlying JAX does not provide a cost analysis, thencompiled.cost_analysis()
will beNone
.If some functionality is available, there are still very limited guarantees on what the corresponding method provides. The return value is not required to be consistent—in type, structure, or value—across JAX configurations, backends/platforms, versions, or even invocations of the method. JAX cannot guarantee that the output of
compiled.cost_analysis()
on one day will remain the same on the following day.
When in doubt, see the package API documentation for jax.stages
.
Inspecting staged-out computations#
Stage #1 in the list at the top of this note mentions specialization and
staging, prior to lowering. JAX’s internal notion of a function specialized to
the types of its arguments is not always a reified data structure in memory. To
explicitly construct a view of JAX’s specialization of a function in the
internal Jaxpr intermediate
language, see
jax.make_jaxpr()
.
JAX Errors#
This page lists a few of the errors you might encounter when using JAX, along with representative examples of how one might fix them.
- class jax.errors.ConcretizationTypeError(tracer, context='')#
This error occurs when a JAX Tracer object is used in a context where a concrete value is required (see Different kinds of JAX values for more on what a Tracer is). In some situations, it can be easily fixed by marking problematic values as static; in others, it may indicate that your program is doing operations that are not directly supported by JAX’s JIT compilation model.
Examples:
- Traced value where static value is expected
One common cause of this error is using a traced value where a static value is required. For example:
>>> from functools import partial >>> from jax import jit >>> import jax.numpy as jnp >>> @jit ... def func(x, axis): ... return x.min(axis)
>>> func(jnp.arange(4), 0) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: axis argument to jnp.min().
This can often be fixed by marking the problematic argument as static:
>>> @partial(jit, static_argnums=1) ... def func(x, axis): ... return x.min(axis) >>> func(jnp.arange(4), 0) Array(0, dtype=int32)
- Shape depends on Traced Value
Such an error may also arise when a shape in your JIT-compiled computation depends on the values within a traced quantity. For example:
>>> @jit ... def func(x): ... return jnp.where(x < 0) >>> func(jnp.arange(4)) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: The error arose in jnp.nonzero.
This is an example of an operation that is incompatible with JAX’s JIT compilation model, which requires array sizes to be known at compile-time. Here the size of the returned array depends on the contents of x, and such code cannot be JIT compiled.
In many cases it is possible to work around this by modifying the logic used in the function; for example here is code with a similar issue:
>>> @jit ... def func(x): ... indices = jnp.where(x > 1) ... return x[indices].sum() >>> func(jnp.arange(4)) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: The error arose in jnp.nonzero.
And here is how you might express the same operation in a way that avoids creation of a dynamically-sized index array:
>>> @jit ... def func(x): ... return jnp.where(x > 1, x, 0).sum() >>> func(jnp.arange(4)) Array(5, dtype=int32)
To understand more subtleties having to do with tracers vs. regular values, and concrete vs. abstract values, you may want to read Different kinds of JAX values.
- Parameters:
tracer (core.Tracer)
context (str)
- class jax.errors.KeyReuseError(message)#
This error occurs when a PRNG key is reused in an unsafe manner. Key reuse is checked only when jax_debug_key_reuse is set to True.
Here is a simple example of code that would lead to such an error:
>>> with jax.debug_key_reuse(True): ... key = jax.random.key(0) ... value = jax.random.uniform(key) ... new_value = jax.random.uniform(key) ... --------------------------------------------------------------------------- KeyReuseError Traceback (most recent call last) ... KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0
This sort of key reuse is problematic because the JAX PRNG is stateless, and keys must be manually split; For more information on this see Sharp Bits: Random Numbers.
- Parameters:
message (str)
- class jax.errors.NonConcreteBooleanIndexError(tracer)#
This error occurs when a program attempts to use non-concrete boolean indices in a traced indexing operation. Under JIT compilation, JAX arrays must have static shapes (i.e. shapes that are known at compile-time) and so boolean masks must be used carefully. Some logic implemented via boolean masking is simply not possible in a
jax.jit()
function; in other cases, the logic can be re-expressed in a JIT-compatible way, often using the three-argument version ofwhere()
.Following are a few examples of when this error might arise.
- Constructing arrays via boolean masking
This most commonly arises when attempting to create an array via a boolean mask within a JIT context. For example:
>>> import jax >>> import jax.numpy as jnp >>> @jax.jit ... def positive_values(x): ... return x[x > 0] >>> positive_values(jnp.arange(-5, 5)) Traceback (most recent call last): ... NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])
This function is attempting to return only the positive values in the input array; the size of this returned array cannot be determined at compile-time unless x is marked as static, and so operations like this cannot be performed under JIT compilation.
- Reexpressible Boolean Logic
Although creating dynamically sized arrays is not supported directly, in many cases it is possible to re-express the logic of the computation in terms of a JIT-compatible operation. For example, here is another function that fails under JIT for the same reason:
>>> @jax.jit ... def sum_of_positive(x): ... return x[x > 0].sum() >>> sum_of_positive(jnp.arange(-5, 5)) Traceback (most recent call last): ... NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])
In this case, however, the problematic array is only an intermediate value, and we can instead express the same logic in terms of the JIT-compatible three-argument version of
jax.numpy.where()
:>>> @jax.jit ... def sum_of_positive(x): ... return jnp.where(x > 0, x, 0).sum() >>> sum_of_positive(jnp.arange(-5, 5)) Array(10, dtype=int32)
This pattern of replacing boolean masking with three-argument
where()
is a common solution to this sort of problem.- Boolean indexing into JAX arrays
The other situation where this error often arises is when using boolean indices, such as with
.at[...].set(...)
. Here is a simple example:>>> @jax.jit ... def manual_clip(x): ... return x.at[x < 0].set(0) >>> manual_clip(jnp.arange(-2, 2)) Traceback (most recent call last): ... NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[4])
This function is attempting to set values smaller than zero to a scalar fill value. As above, this can be addressed by re-expressing the logic in terms of
where()
:>>> @jax.jit ... def manual_clip(x): ... return jnp.where(x < 0, 0, x) >>> manual_clip(jnp.arange(-2, 2)) Array([0, 0, 0, 1], dtype=int32)
- Parameters:
tracer (core.Tracer)
- class jax.errors.TracerArrayConversionError(tracer)#
This error occurs when a program attempts to convert a JAX Tracer object into a standard NumPy array (see Different kinds of JAX values for more on what a Tracer is). It typically occurs in one of a few situations.
- Using non-JAX functions in JAX transformations
This error can occur if you attempt to use a non-JAX library like
numpy
orscipy
inside a JAX transformation (jit()
,grad()
,jax.vmap()
, etc.). For example:>>> from jax import jit >>> import numpy as np >>> @jit ... def func(x): ... return np.sin(x) >>> func(np.arange(4)) Traceback (most recent call last): ... TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int32[4]
In this case, you can fix the issue by using
jax.numpy.sin()
in place ofnumpy.sin()
:>>> import jax.numpy as jnp >>> @jit ... def func(x): ... return jnp.sin(x) >>> func(jnp.arange(4)) Array([0. , 0.84147096, 0.9092974 , 0.14112 ], dtype=float32)
See also External Callbacks for options for calling back to host-side computations from transformed JAX code.
- Indexing a numpy array with a tracer
If this error arises on a line that involves array indexing, it may be that the array being indexed
x
is a standard numpy.ndarray while the indicesidx
are traced JAX arrays. For example:>>> x = np.arange(10) >>> @jit ... def func(i): ... return x[i] >>> func(0) Traceback (most recent call last): ... TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int32[0]
Depending on the context, you may fix this by converting the numpy array into a JAX array:
>>> @jit ... def func(i): ... return jnp.asarray(x)[i] >>> func(0) Array(0, dtype=int32)
or by declaring the index as a static argument:
>>> from functools import partial >>> @partial(jit, static_argnums=(0,)) ... def func(i): ... return x[i] >>> func(0) Array(0, dtype=int32)
To understand more subtleties having to do with tracers vs. regular values, and concrete vs. abstract values, you may want to read Different kinds of JAX values.
- Parameters:
tracer (core.Tracer)
- class jax.errors.TracerBoolConversionError(tracer)#
This error occurs when a traced value in JAX is used in a context where a boolean value is expected (see Different kinds of JAX values for more on what a Tracer is).
The boolean cast may be an explicit (e.g.
bool(x)
) or implicit, through use of control flow (e.g.if x > 0
orwhile x
), use of Python boolean operators (e.g.z = x and y
,z = x or y
,z = not x
) or functions that use them (e.g.z = max(x, y)
,z = min(x, y)
etc.).In some situations, this problem can be easily fixed by marking traced values as static; in others, it may indicate that your program is doing operations that are not directly supported by JAX’s JIT compilation model.
Examples:
- Traced value used in control flow
One case where this often arises is when a traced value is used in Python control flow. For example:
>>> from jax import jit >>> import jax.numpy as jnp >>> @jit ... def func(x, y): ... return x if x.sum() < y.sum() else y >>> func(jnp.ones(4), jnp.zeros(4)) Traceback (most recent call last): ... TracerBoolConversionError: Attempted boolean conversion of JAX Tracer [...]
We could mark both inputs
x
andy
as static, but that would defeat the purpose of usingjax.jit()
here. Another option is to re-express the if statement in terms of the three-termjax.numpy.where()
:>>> @jit ... def func(x, y): ... return jnp.where(x.sum() < y.sum(), x, y) >>> func(jnp.ones(4), jnp.zeros(4)) Array([0., 0., 0., 0.], dtype=float32)
For more complicated control flow including loops, see Control flow operators.
- Control flow on traced values
Another common cause of this error is if you inadvertently trace over a boolean flag. For example:
>>> @jit ... def func(x, normalize=True): ... if normalize: ... return x / x.sum() ... return x >>> func(jnp.arange(5), True) Traceback (most recent call last): ... TracerBoolConversionError: Attempted boolean conversion of JAX Tracer ...
Here because the flag
normalize
is traced, it cannot be used in Python control flow. In this situation, the best solution is probably to mark this value as static:>>> from functools import partial >>> @partial(jit, static_argnames=['normalize']) ... def func(x, normalize=True): ... if normalize: ... return x / x.sum() ... return x >>> func(jnp.arange(5), True) Array([0. , 0.1, 0.2, 0.3, 0.4], dtype=float32)
For more on
static_argnums
, see the documentation ofjax.jit()
.- Using non-JAX aware functions
Another common cause of this error is using non-JAX aware functions within JAX code. For example:
>>> @jit ... def func(x): ... return min(x, 0)
>>> func(2) Traceback (most recent call last): ... TracerBoolConversionError: Attempted boolean conversion of JAX Tracer ...
In this case, the error occurs because Python’s built-in
min
function is not compatible with JAX transforms. This can be fixed by replacing it withjnp.minumum
:>>> @jit ... def func(x): ... return jnp.minimum(x, 0)
>>> print(func(2)) 0
To understand more subtleties having to do with tracers vs. regular values, and concrete vs. abstract values, you may want to read Different kinds of JAX values.
- Parameters:
tracer (core.Tracer)
- class jax.errors.TracerIntegerConversionError(tracer)#
This error can occur when a JAX Tracer object is used in a context where a Python integer is expected (see Different kinds of JAX values for more on what a Tracer is). It typically occurs in a few situations.
- Passing a tracer in place of an integer
This error can occur if you attempt to pass a traced value to a function that requires a static integer argument; for example:
>>> from jax import jit >>> import numpy as np >>> @jit ... def func(x, axis): ... return np.split(x, 2, axis) >>> func(np.arange(4), 0) Traceback (most recent call last): ... TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[0]
When this happens, the solution is often to mark the problematic argument as static:
>>> from functools import partial >>> @partial(jit, static_argnums=1) ... def func(x, axis): ... return np.split(x, 2, axis) >>> func(np.arange(10), 0) [Array([0, 1, 2, 3, 4], dtype=int32), Array([5, 6, 7, 8, 9], dtype=int32)]
An alternative is to apply the transformation to a closure that encapsulates the arguments to be protected, either manually as below or by using
functools.partial()
:>>> jit(lambda arr: np.split(arr, 2, 0))(np.arange(4)) [Array([0, 1], dtype=int32), Array([2, 3], dtype=int32)]
Note a new closure is created at every invocation, which defeats the compilation caching mechanism, which is why static_argnums is preferred.
- Indexing a list with a Tracer
This error can occur if you attempt to index a Python list with a traced quantity. For example:
>>> import jax.numpy as jnp >>> from jax import jit >>> L = [1, 2, 3] >>> @jit ... def func(i): ... return L[i] >>> func(0) Traceback (most recent call last): ... TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[0]
Depending on the context, you can generally fix this either by converting the list to a JAX array:
>>> @jit ... def func(i): ... return jnp.array(L)[i] >>> func(0) Array(1, dtype=int32)
or by declaring the index as a static argument:
>>> from functools import partial >>> @partial(jit, static_argnums=0) ... def func(i): ... return L[i] >>> func(0) Array(1, dtype=int32, weak_type=True)
To understand more subtleties having to do with tracers vs. regular values, and concrete vs. abstract values, you may want to read Different kinds of JAX values.
- Parameters:
tracer (core.Tracer)
- class jax.errors.UnexpectedTracerError(msg)#
This error occurs when you use a JAX value that has leaked out of a function. What does it mean to leak a value? If you use a JAX transformation on a function
f
that stores, in some scope outside off
, a reference to an intermediate value, that value is considered to have been leaked. Leaking values is a side effect. (Read more about avoiding side effects in Pure Functions)JAX detects leaks when you then use the leaked value in another operation later on, at which point it raises an
UnexpectedTracerError
. To fix this, avoid side effects: if a function computes a value needed in an outer scope, return that value from the transformed function explicitly.Specifically, a
Tracer
is JAX’s internal representation of a function’s intermediate values during transformations, e.g. withinjit()
,pmap()
,vmap()
, etc. Encountering aTracer
outside of a transformation implies a leak.- Life-cycle of a leaked value
Consider the following example of a transformed function which leaks a value to an outer scope:
>>> from jax import jit >>> import jax.numpy as jnp >>> outs = [] >>> @jit # 1 ... def side_effecting(x): ... y = x + 1 # 3 ... outs.append(y) # 4 >>> x = 1 >>> side_effecting(x) # 2 >>> outs[0] + 1 # 5 Traceback (most recent call last): ... UnexpectedTracerError: Encountered an unexpected tracer.
In this example we leak a Traced value from an inner transformed scope to an outer scope. We get an
UnexpectedTracerError
when the leaked value is used, not when the value is leaked.This example also demonstrates the life-cycle of a leaked value:
A function is transformed (in this case, by
jit()
)The transformed function is called (initiating an abstract trace of the function and turning
x
into aTracer
)The intermediate value
y
, which will later be leaked, is created (an intermediate value of a traced function is also aTracer
)The value is leaked (appended to a list in an outer scope, escaping the function through a side-channel)
The leaked value is used, and an UnexpectedTracerError is raised.
The UnexpectedTracerError message tries to point to these locations in your code by including information about each stage. Respectively:
The name of the transformed function (
side_effecting
) and which transform kicked off the tracejit()
).A reconstructed stack trace of where the leaked Tracer was created, which includes where the transformed function was called. (
When the Tracer was created, the final 5 stack frames were...
).From the reconstructed stack trace, the line of code that created the leaked Tracer.
The leak location is not included in the error message because it is difficult to pin down! JAX can only tell you what the leaked value looks like (what shape it has and where it was created) and what boundary it was leaked over (the name of the transformation and the name of the transformed function).
The current error’s stack trace points to where the value is used.
The error can be fixed by the returning the value out of the transformed function:
>>> from jax import jit >>> import jax.numpy as jnp >>> outs = [] >>> @jit ... def not_side_effecting(x): ... y = x+1 ... return y >>> x = 1 >>> y = not_side_effecting(x) >>> outs.append(y) >>> outs[0] + 1 # all good! no longer a leaked value. Array(3, dtype=int32, weak_type=True)
- Leak checker
As discussed in point 2 and 3 above, JAX shows a reconstructed stack trace which points to where the leaked value was created. This is because JAX only raises an error when the leaked value is used, not when the value is leaked. This is not the most useful place to raise this error, because you need to know the location where the Tracer was leaked to fix the error.
To make this location easier to track down, you can use the leak checker. When the leak checker is enabled, an error is raised as soon as a
Tracer
is leaked. (To be more exact, it will raise an error when the transformed function from which theTracer
is leaked returns)To enable the leak checker you can use the
JAX_CHECK_TRACER_LEAKS
environment variable or thewith jax.checking_leaks()
context manager.Note
Note that this tool is experimental and may report false positives. It works by disabling some JAX caches, so it will have a negative effect on performance and should only be used when debugging.
Example usage:
>>> from jax import jit >>> import jax.numpy as jnp >>> outs = [] >>> @jit ... def side_effecting(x): ... y = x+1 ... outs.append(y) >>> x = 1 >>> with jax.checking_leaks(): ... y = side_effecting(x) Traceback (most recent call last): ... Exception: Leaked Trace
- Parameters:
msg (str)
Transfer guard#
JAX may transfer data between the host and devices and between devices during type conversion and input sharding. To log or disallow any unintended transfers, the user may configure a JAX transfer guard.
JAX transfer guards distinguish between two types of transfers:
Explicit transfers:
jax.device_put*()
andjax.device_get()
calls.Implicit transfers: Other transfers (e.g., printing a
DeviceArray
).
A transfer guard can take an action based on its guard level:
"allow"
: Silently allow all transfers (default)."log"
: Log and allow implicit transfers. Silently allow explicit transfers."disallow"
: Disallow implicit transfers. Silently allow explicit transfers."log_explicit"
: Log and allow all transfers."disallow_explicit"
: Disallow all transfers.
JAX will raise a RuntimeError
when disallowing a transfer.
The transfer guards use the standard JAX configuration system:
A
--jax_transfer_guard=GUARD_LEVEL
command-line flag andjax.config.update("jax_transfer_guard", GUARD_LEVEL)
will set the global option.A
with jax.transfer_guard(GUARD_LEVEL): ...
context manager will set the thread-local option within the scope of the context manager.
Note that similar to other JAX configuration options, a newly spawned thread will use the global option instead of any active thread-local option of the scope where the thread was spawned.
The transfer guards can also be applied more selectively, based on the
direction of transfer. The flag and context manager name is suffixed with a
corresponding transfer direction (e.g., --jax_transfer_guard_host_to_device
and jax.config.transfer_guard_host_to_device
):
"host_to_device"
: Converting a Python value or NumPy array into a JAX on-device buffer."device_to_device"
: Copying a JAX on-device buffer to a different device."device_to_host"
: Fetching a JAX on-device buffer.
Fetching a buffer on a CPU device is always allowed regardless of the transfer guard level.
The following shows an example of using the transfer guard.
>>> jax.config.update("jax_transfer_guard", "allow") # This is default.
>>>
>>> x = jnp.array(1)
>>> y = jnp.array(2)
>>> z = jnp.array(3)
>>>
>>> print("x", x) # All transfers are allowed.
x 1
>>> with jax.transfer_guard("disallow"):
... print("x", x) # x has already been fetched into the host.
... print("y", jax.device_get(y)) # Explicit transfers are allowed.
... try:
... print("z", z) # Implicit transfers are disallowed.
... assert False, "This line is expected to be unreachable."
... except:
... print("z could not be fetched")
x 1
y 2
z could not be fetched
Pallas: a JAX kernel language#
Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU. This section contains tutorials, guides and examples for using Pallas.
Pallas Design#
In this document, we explain the initial Pallas design. This is a snapshot of some of the earlier design decisions made and Pallas’s specific APIs might have changed since.
Introduction#
JAX is being used for a diverse set of workloads, from large scale machine learning to scientific computing. JAX’s success story is as much a success story for XLA, the primary compiler that JAX targets – XLA compiles JAX programs for accelerators and has enabled JAX to scale to the largest ML models. JAX describes logical computations in XLA’s representation, HLO. HLO describes how computations happen logically but not physically. Given a logical HLO computation, XLA decides how that computation is to be executed physically. For a wide variety of ML applications, XLA does a good job of compiling user programs but inevitably some users hit XLA’s limitations. In these cases, we need to provide an “escape hatch” to allow experts to write hand-tuned kernels that outperform XLA at that point in time. Furthermore, advances in ML systems research take some time to be incorporated into XLA and users often want to run ahead with them. Over time, the compiler can incorporate the optimizations that were proven out experimentally through hand-tuned kernels.
XLA does offer the CustomCall
mechanism as an escape hatch, but it requires users to write C++ and on GPU it requires users to learn the CUDA programming model. The CUDA programming model is arguably too low-level for many machine learning GPU kernels, like matrix multiplication, and even expert users will have trouble using CUDA to implement efficient matrix multiplication or multi-headed attention. Not only this, JAX users are usually familiar with Python and NumPy-style array programming which doesn’t involve writing any C++ or thinking about GPU parallelism. All popular machine learning frameworks share this idea: manipulating (usually) arrays with high level operations like matmul
or convolution
. Unfortunately, this means implementing a custom operation via CustomCall
is a big investment, involving potentially learning C++ and/or GPU programming.
Triton, a GPU compiler built and maintained by OpenAI, has taken the ML compiler world by storm. Triton offers the best of both worlds: an array-based programming model for GPU kernels. Triton is the primary code generation route for torch.compile
in PyTorch 2.0, via the Torch Inductor library. Triton actively hides some aspects of GPU programming in the name of a more accessible programming model that can be used from Python and to generate optimized code from a higher-level representation. While GPUs are more flexible than what Triton offers, in the ML domain, Triton seems to be expressive enough for many applications.
In this document, we describe Pallas, an extension to JAX that enables kernel programming for both GPUs and TPUs using a Triton-like model. A JAX-based kernel language offers several advantages:
Although Triton exposes a TPU-like programming model to users, i.e. writing programs for tiles of arrays in L1-cache, it is specialized enough to GPU that we cannot directly compile Triton for TPU. For example, Triton offers atomic operations specifically meant to handle parallel writes that don’t necessarily make sense on TPU. A higher level front end can abstract away details of the platform while surfacing just that tile-based programming model. The kernels will thus be portable across different hardware platforms.
JAX as a tracing-based frontend for numerical computing is both mature and well-used. By embedding the kernel programming language in JAX itself, we can re-use JAX’s tracing infrastructure and provide a NumPy-like frontend that’s already familiar to users.
JAX transformations are key to its success, allowing users to express simple programs but transform them to achieve complex functionality. We can leverage the same transformations (vmap, jvp, etc.) to transform user-written kernels.
The open question is: is JAX a good fit for a kernel language at all? We think so. Triton demonstrates that an array programming language can be practical for writing GPU kernels and JAX is just that. JAX has also proven to be a flexible front-end for compilers and for program transformations.
We describe Pallas as follows: we first describe the ways in which we extend JAX to support writing custom kernels. We then show how we can lower Pallas to both Triton and Mosaic. We conclude by describing existing and potential ways to transform Pallas kernels via JAX transformations.
Visualization of Pallas lowering paths
Pallas: Extending JAX for kernels#
The key point we’d like to make is that Pallas is just JAX, with some extensions:
Users now use reference types called
Ref
s in their JAX code. This gives users more precise control over memory access and layout in JAX will more closely resemble physical layout.Users write their JAX programs using a subset of JAX primitives, along with a set of Pallas-specific primitives.
Users embed their Pallas kernels in an outer JAX program via a special
pallas_call
higher-order function, that executes the kernel in a map. It is analogous topmap
orshard_map
, except with references to shared memory.
We’ll go over these three extensions one at a time, by example.
Note that these APIs are still experimental and subject to change.
Reference types#
Let’s look at an example Pallas program for adding two vectors:
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
def add_kernel(x_ref, y_ref, o_ref):
# In this code, `x_ref`, `y_ref` and `o_ref` are (8,)-shaped `Ref`s
x = x_ref[:]
y = y_ref[:]
o_ref[:] = x + y
x, y = jnp.arange(8), jnp.arange(8, 16)
add = pl.pallas_call(add_kernel, out_shape=jax.ShapeDtypeStruct((8,), jnp.int32))
add(x, y)
Unlike a regular JAX program, add_kernel
does not receive immutable array arguments. Instead, it’s provided with references that can be read from and updated in-place using NumPy-like syntax. Ref
s are not a Pallas-specific concept – they were introduced to JAX to represent stateful computations. However, we can leverage them when writing kernels that operate on mutable memory too.
Pallas kernels not only receive Ref
s corresponding to the inputs to the kernel, but also receive Ref
s for the outputs as well (specified in pallas_call
via out_shape
). Ref
s are special types that cannot be passed into the usual set of JAX primitives without being read from first. When you read from a Ref
you get a JAX Array
type out, and you must write an Array
into a Ref
.
Reading from/writing into Refs#
Reading from a Ref
corresponds to loading an array into the lowest level of the memory hierarchy (L1-cache on GPU and vector registers on TPU). Writing into a Ref
is analogous.
def f(x_ref, o_ref):
# Using vanilla Python indexing
x = x_ref[0, 2:5, :]
# Or via Numpy advanced int indexing
o_ref[jnp.arange(3), :] = x
# Note that in order to use NumPy advanced int indexing, you need to broadcast the indices against each other into the desired multidimensional shape:
def f(x_ref):
# Assume x_ref is (8, 4) and we want to read out a (2, 3) slice
x = x_ref[jnp.arange(2)[..., None], jnp.arange(3)[None, ...]]
Writing to Ref
s can be done via analogous __setitem__
style indexing.
Other forms of indexing (for example, dynamic slicing) can be done via pallas.load
and pallas.store
, new JAX primitives designed to make loading from/storing into memory easier. We’ll discuss these new primitives later.
Extending JAX with new Pallas primitives#
Because JAX was designed with HLO in mind, the set of JAX primitives closely mirrors the set of HLO operations. Targeting a new compiler (e.g. Triton or Mosaic) means we might need to supplement JAX’s primitives with new ones specific to the new compiler. At the same time, we may not be able to lower all JAX primitives, so we need to restrict it to a subset.
Because Pallas was initially designed with Triton in mind, we offer a set of new primitives targeting the Triton programming model. As we’ll show later, we can lower these primitives to Mosaic as well.
pallas.load
and pallas.store
#
pallas.load
and pallas.store
are primitives that allow loading from memory and storing into memory. Unlike __getitem__
and __setitem__
they are more flexible at the cost of being more verbose. Specifically, you can use the pallas.dynamic_slice
(pallas.ds
for short) construct (which should maybe be upstreamed into JAX to be used with Ref __getitem__
and __setitem__
).
def f(x_ref, o_ref):
# Reading from memory via pallas.load
x = pl.load(x_ref, (0, slice(2, 5), slice(None)))
# Using integer indexing automatically broadcasts
x = pl.load(x_ref, (0, 2 + jnp.arange(3), slice(None)))
# You can also use `pl.dynamic_slice` (`pl.ds` for short) objects as well
pl.store(o_ref, (0, pl.ds(start=2, size=3), slice(None)), x)
pallas.load
and pallas.store
also support masking via the mask argument.
def f(x_ref, o_ref):
# Reading from memory via pallas.load
idx = jnp.arange(8)
mask = idx < 5
x = pl.load(x_ref, (idx,), mask=mask, other=float('-inf'))
Masking is important when doing out-of-bounds loads/stores. The operational semantics of masking can be compiler-determined (if we understand the documentation properly, Triton avoids the read from/write to memory if it’s masked).
pallas.program_id
and pallas.num_programs
#
As we’ll soon see, we’ll be executing the same Pallas kernels many times (either in parallel or in a pipeline depending on the backend). These new primitives tell us “where” we are in the execution of the kernel.
pallas.program_id
takes in an axis argument, which tells us which index in an axis of a multidimensional grid this kernel is currently executing in (analogous to threadId
from CUDA programming or lax.axis_index
in jax.pmap
). Note that we are currently borrowing the “program” terminology from Triton and in the future we might want to change it to something more familiar to JAX users.
def f(x_ref, o_ref):
i = pl.program_id(axis=0) # execution index in the first axis of the grid
o_ref[i] = jnp.exp(x_ref[i])
pallas.num_programs
also takes in an axis and returns the grid size for that axis.
Note that while program_id
and num_programs
are Triton-specific terminology they are easily generalized to make sense on TPU as well.
Using a subset of JAX primitives in Pallas#
Because we’re writing kernels, not high-level HLO programs, some JAX primitives may not be able to be represented in our underlying substrate efficiently. However, we know we can support most elementwise operations, simple dot products, and JAX control flow.
While we haven’t yet mapped out exactly all the JAX primitives that we can support in Pallas kernels, we can certainly identify some that are not easy to lower or are unlikely to be useful:
conv_general
- convolution usually isn’t offered as primitive in the underlying hardware.gather/scatter
- the underlying compiler may not support noncontiguous memory reads and writes
Executing Pallas kernels with pallas_call
#
Now that we’ve written our Pallas kernels (a.k.a. JAX with Ref
s and the extra Pallas primitives), how do we execute them on a GPU or TPU? We use pallas_call
, a higher order function (akin to jax.jit
and jax.pmap
) that executes the kernel.
The signature of pallas_call
is as follows:
def pallas_call(
kernel: Callable,
in_specs: Sequence[Spec],
out_specs: Sequence[Spec],
out_shapes: Sequence[jax.ShapeDtypeStruct],
grid: Optional[Tuple[int, ...]] = None) -> Callable:
...
When we provide a kernel to pallas_call
we provide additional information. The first is out_shape
which tells the kernel what the outputs look like (pallas_call
will pass a Ref
corresponding to these into the kernel to be written to). The rest of the information (in_specs
, out_specs
, and grid
) are information about how the kernel will be scheduled on the accelerator.
The (rough) semantics for pallas_call
are as follows:
def pallas_call(kernel, in_specs, out_specs, out_shapes, grid):
def execute(*args):
outputs = map(empty_ref, out_shapes)
grid_indices = map(range, grid)
for indices in itertools.product(*grid_indices): # Could run in parallel!
local_inputs = [in_spec.transform(arg, indices) for arg, in_spec in
zip(args, in_specs)]
local_outputs = [out_spec.transform(arg, indices) for arg, out_spec in
zip(outputs, out_specs)]
kernel(*local_inputs, *local_outputs) # writes to outputs
return execute
Specifically, pallas_call
will “loop” over grid iteration space, applying a transformation to the inputs and outputs specified via the in_specs
and out_specs
. In each iteration, the kernel will be called on the transformed inputs and outputs. Note that the “loop” over the iteration space could be executed in parallel (e.g. on GPU). pallas_call
also provides no guarantees on the order of loop iterations over the iteration space, just that every member of the iteration space will be looped over. Compilers like Triton and Mosaic will have more specific operational semantics associated with the grid.
Transformation functions#
The in_specs
and out_specs
arguments to pallas_call
allow inputs and outputs to be transformed in some way. The two options that Pallas offers right now are an identity transformation (where inputs and outputs are left unchanged), and BlockSpec
s, take fixed-size slices of Ref
s determined by the loop index.
A BlockSpec
takes an index_map
function and a block_shape
. Logically, it takes an array and slices it along each axis into block_shape
sizes blocks. The index_map
function takes loop indices (from the grid index set) and maps them to block indices. The transform function converts Ref
s into logical views of the Ref
at the corresponding block. When we specify None
in an entry in block_shape, that corresponds to “mapping” over that dimension, removing it from the block within the kernel.
class BlockSpec:
index_map: Callable[[Tuple[Int, ...]], Tuple[Int, ...]]
block_shape: Tuple[Optional[int], ...]
def transform(self, ref, *loop_indices):
block_indices = self.transform_function(loop_indices)
# Returns a view of `ref` starting at `block_indices` of shape self.block_shape
...
We could also imagine other Spec
s that are used with pallas_call
, for example a Spec
that corresponds to overlapping windows to, say, implement convolutions.
Immediate benefits of Pallas as a front-end#
By offering a JAX front-end for kernel writing, we can immediately reap some benefits.
More flexible front end#
The first is that JAX users are already accustomed to the benefits (and limitations) of programming with JAX and its tracing-based transformations. This means users can use closures and other familiar Python constructs when writing Pallas kernels. This is unlike the existing AST-parsing-based Triton front end or the MLIR builders for Mosaic. For example, this makes Pallas far more amenable to templating than Triton.
See this example of how we can use higher-order functions in Python to template a kernel.
def make_kernel(eltwise_kernel):
def add(x_ref, y_ref, o_ref):
x = pl.load(x_ref, ())
y = pl.load(y_ref, ())
pl.store(o_ref, (), eltwise_kernel(x + y))
return add
kernel1 = make_kernel(lambda x: x * 2)
kernel2 = make_kernel(jnp.exp)
pl.pallas_call(kernel1, out_shape=x, grid=1)(1., 1.)
pl.pallas_call(kernel2, out_shape=x, grid=1)(1., 1.)
Emulation mode#
By representing kernels as programs with JAX primitives and some new Pallas primitives, we can also lower Pallas programs to StableHLO directly and compile/execute them with XLA. Specifically, a pallas_call
can be implemented as a lax.scan
over the grid. This enables us to develop GPU or TPU kernels on any XLA-supported platform (even CPU!) and debug them using JAX/XLA debugging tools (like jax.debug.print
). We can also use the more reliable and better tested XLA numerics to verify the correctness of the Triton and Mosaic compilers. One could also imagine perturbing the scan
ordering to simulate the parallel reads and writes that happen on GPU.
Examples#
add
#
We modify our add_kernel
example to operate over (2,)-sized blocks using BlockSpec
s.
def add_kernel(x_ref, y_ref, o_ref):
# In this code, `x_ref`, `y_ref` and `o_ref` are (2,)-shaped `Ref`s
x = x_ref[:]
y = y_ref[:]
o_ref[:] = x + y
x, y = jnp.arange(8), jnp.arange(8, 16)
add = pl.pallas_call(
add_kernel,
out_shape=jax.ShapeDtypeStruct((8,), jnp.int32),
in_specs=[
pl.BlockSpec(lambda i: i, (2,)),
pl.BlockSpec(lambda i: i, (2,))
],
out_specs=pl.BlockSpec(lambda i: i, (2,)),
grid=(4,))
add(x, y)
Templated matmul#
In this example, we compute tiles of the output by doing an unrolled accumulation over blocks of rows and columns from our input arrays. We inline an activation function into the body of the kernel using a higher order function so we can emit a fused kernel.
def matmul_kernel(x_ref, y_ref, o_ref, *, activation, block_k):
acc = jnp.zeros((x_ref.shape[0], y_ref.shape[1]), jnp.float32)
for k in range(x_ref.shape[1] // block_k):
x = x_ref[:, k*block_k:(k+1)*block_k]
y = y_ref[k*block_k:(k+1)*block_k, :]
acc += x @ y
o_ref[:, :] = activation(acc).astype(o_ref.dtype)
x, y = jnp.ones((512, 256)), jnp.ones((256, 1024))
block_shape = 128, 256, 128
@partial(jax.jit, static_argnames=["block_shape", "activation"])
def matmul(x, y, *, block_shape, activation):
block_m, block_n, block_k = block_shape
fused_matmul = pl.pallas_call(
partial(matmul_kernel, block_k=block_k, activation=activation),
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1],), jnp.float32),
in_specs=[
pl.BlockSpec(lambda i, j: (i, 0), (block_m, x.shape[1])),
pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], block_n))
],
out_specs=pl.BlockSpec(lambda i, j: (i, j), (block_m, block_n)),
grid=(4, 4),
)
return fused_matmul(x, y)
z = matmul(x, y, block_shape=block_shape, activation=jax.nn.gelu)
Lowering Pallas#
After users express their Pallas kernels, we lower them to different representations depending on the target backend. On GPUs, we lower Pallas to Triton IR, and on TPU we lower Pallas to Mosaic.
Lowering Pallas to Triton for GPU#
Lowering Pallas to Triton is easy because Pallas was designed with Triton as a target language in mind. The main differences between Pallas and Triton is that Triton doesn’t have a notion of BlockSpec
s and also uses pointers when doing memory loads and stores as opposed to indices.
Triton supports pointers as an array element type in its language and in Triton you can load from and store to arrays of pointers. In Pallas, when given a (4, 5)
-shaped Ref
, x_ref
, and then do like x_ref[3, 2]
, we need to lower this to computing a Triton pointer to the appropriate row-major position in x_ref
(that is, doing 5 * 3 + 2 * 1). Similarly, when we lower slices to Triton, e.g. x_ref[4, :]
we need to produce an array of pointers 5 * 4 + jnp.arange(3)
.
Other than that, lowering to Triton is fairly straightforward. JAX dot products can be lowered to Triton dot products and JAX unary primitives are lowered to their Triton equivalents. Triton’s atomic operations are lowered via new Pallas atomic primitives.
Lowering Pallas to Mosaic for TPU#
Mosaic consumes (mostly) standard dialect MLIR and emits LLO to be compiled for TPU. Pallas can be lowered to Mosaic via translating JAX primitives to MLIR (mostly the vector
and arith
dialects). The BlockSpec
s can be converted into pipeline schedules (i.e. the transform_func
s in Mosaic).
Transforming Pallas#
A natural question is how do JAX transformations interact with Pallas kernels? There are two main ways: transformations inside Pallas kernels and transformations outside Pallas kernels.
Transformation inside Pallas kernels should actually “just work”, so long as we are able to lower the transformed code. For example, we could use jax.grad(jnp.sin)(...)
inside of a JAX kernel because we can lower a cos
to both Triton and Mosaic. However, we might not be able to lower a jax.vmap(lax.dynamic_slice)
because it could turn into a gather that we cannot lower.
Transformations of Pallas kernels from the outer JAX programs is perhaps the more interesting case. How do we handle things like vmap(pallas_call)
and grad(pallas_call)
?
vmap-of-pallas_call
#
vmap automatically vectorizes JAX programs. While kernel writers might want precise control over how a batched kernel will behave differently from its unbatched variant, we can offer a reasonable default vmap
rule for pallas_call
while offering the jax.custom_vmap
customization mechanism. When pallas_call
is vmap
-ed, we augment the pallas_call
to have an extra grid dimension corresponding to the new batch dimension and transform the BlockSpec
s to handle indexing along that dimension.
grad-of-pallas_call
#
grad
of pallas_call
enables automatic differentiation of kernels. jax.grad
breaks down into applications of three distinct transforms: jvp
, partial_eval
and transpose
. In principle, we can re-use most of JAX’s infrastructure when implementing these rules for pallas_call
(since it behaves much like existing JAX higher order primitives).
However, automatic differentiation of kernels can result in a performance hit due to how memory access is transposed. If we write a GPU kernel with overlapping-and-parallel reads and disjoint-but-parallel writes, we automatically transpose it into a kernel that has overlapping-but-parallel writes (which are slow when done atomically) and disjoint-and-parallel reads. To emit a kernel that better uses parallelism with shared memory, we would need to reorder loops and change how the kernel is vectorized. Unfortunately, we do not have a program representation amenable to that in Pallas. A potential direction to automatically differentiating kernels efficiently is to explore a different representation, perhaps one like that in Dex. We could also look at how Enzyme approaches this problem. However, AD of Pallas kernels may still be useful for a class of kernels that does transpose efficiently (for example elementwise kernels).
In general, though, jax.custom_vjp
is a viable escape hatch to express Pallas kernels that work with jax.grad
.
Other transformations#
We could imagine other JAX transformations applying to Pallas kernels that we haven’t explicitly explored yet. For example, checkify
is a JAX transformation that does functional error handling. We could imagine using checkify
with pallas_call to allow plumbing out error codes from GPU kernels that indicate if OOB access or NaNs were produced.
Another potential transformation to integrate with is custom_partitioning to enable automatically partitionable kernels to be used with pjit.
Pallas Quickstart#
Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU. Pallas allows you to use the same JAX functions and APIs but operates at a lower level of abstraction.
Specifically, Pallas requires users to think about memory access and how to divide up computations across multiple compute units in a hardware accelerator. On GPUs, Pallas lowers to Triton and on TPUs, Pallas lowers to Mosaic.
Let’s dive into some examples.
Note: Pallas is still an experimental API and you may be broken by changes!
Hello world in Pallas#
from functools import partial
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np
We’ll first write the “hello world” in Pallas, a kernel that adds two vectors.
def add_vectors_kernel(x_ref, y_ref, o_ref):
x, y = x_ref[...], y_ref[...]
o_ref[...] = x + y
Ref
types
Let’s dissect this function a bit. Unlike most JAX functions you’ve probably written, it does not take in jax.Array
s as inputs and doesn’t return any values. Instead it takes in Ref
objects as inputs. Note that we also don’t have any outputs but we are given an o_ref
, which corresponds to the desired output.
Reading from Ref
s
In the body, we are first reading from x_ref
and y_ref
, indicated by the [...]
(the ellipsis means we are reading the whole Ref
; alternatively we also could have used x_ref[:]
). Reading from a Ref
like this returns a jax.Array
.
Writing to Ref
s
We then write x + y
to o_ref
. Mutation has not historically been supported in JAX – jax.Array
s are immutable! Ref
s are new (experimental) types that allow mutation under certain circumstances. We can interpret writing to a Ref
as mutating its underlying buffer.
So we’ve written what we call a “kernel”, which we define as a program that will run as an atomic unit of execution on an accelerator, without any interaction with the host. How do we invoke it from a JAX computation? We use the pallas_call
higher-order function.
@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
return pl.pallas_call(add_vectors_kernel,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
)(x, y)
add_vectors(jnp.arange(8), jnp.arange(8))
Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)
pallas_call
lifts the Pallas kernel function into an operation that can be called as part of a larger JAX program. But, to do so, it needs a few more details. Here we specify out_shape
, an object that has a .shape
and .dtype
(or a list thereof).
out_shape
determines the shape/dtype of o_ref
in our add_vector_kernel
.
pallas_call
returns a function that takes in and returns jax.Array
s.
What’s actually happening here?
Thus far we’ve described how to think about Pallas kernels but what we’ve actually accomplished is we’re writing a function that’s executed very close to the compute units.
On GPU, x_ref
corresponds to a value in high-bandwidth memory (HBM) and when we do x_ref[...]
we are copying the value from HBM into static RAM (SRAM) (this is a costly operation generally speaking!). We then use GPU vector compute to execute the addition, then copy the resulting value in SRAM back to HBM.
On TPU, we do something slightly different. Before the kernel is ever executed, we fetch the value from HBM into SRAM. x_ref
therefore corresponds to a value in SRAM and when we do x_ref[...]
we are copying the value from SRAM into a register. We then use TPU vector compute to execute the addition, then copy the resulting value back into SRAM. After the kernel is executed, the SRAM value is copied back into HBM.
We are in the process of writing backend-specific Pallas guides. Coming soon!
Pallas programming model#
In our “hello world” example, we wrote a very simple kernel. It takes advantage of the fact that our 8-sized arrays can comfortably fit inside the SRAM of hardware accelerators. In most real-world applications, this will not be the case!
Part of writing Pallas kernels is thinking about how to take big arrays that live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations that operate on “blocks” of those arrays that can fit in SRAM.
Grids#
To automatically “carve” up the inputs and outputs, you provide a grid
and BlockSpec
s to pallas_call
.
A grid
is a tuple of integers (e.g. ()
, (2, 3, 4)
, or (8,)
) that specifies an iteration space.
For example, a grid (4, 5)
would have 20 elements: (0, 0), (0, 1), ..., (0, 4), (1, 0), ..., (3, 4)
.
We run the kernel function once for each element, a style of single-program multiple-data (SPMD) programming.
A 2D grid
When we provide a grid
to pallas_call
, the kernel is executed as many times as prod(grid)
. Each of these invocations is referred to as a “program”, To access which program (i.e. which element of the grid) the kernel is currently executing, we use program_id(axis=...)
. For example, for invocation (1, 2)
, program_id(axis=0)
returns 1
and program_id(axis=1)
returns 2
.
Here’s an example kernel that uses a grid
and program_id
.
def iota_kernel(o_ref):
i = pl.program_id(0)
o_ref[i] = i
We now execute it using pallas_call
with an additional grid
argument.
def iota(len: int):
return pl.pallas_call(iota_kernel,
out_shape=jax.ShapeDtypeStruct((len,), jnp.int32),
grid=(len,))()
iota(8)
Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)
On GPUs, each program is executed in parallel on separate threads. Thus, we need to think about race conditions on writes to HBM. A reasonable approach is to write our kernels in such a way that different programs write to disjoint places in HBM to avoid these parallel writes. On the other hand, parallelizing the computation is how we can execute operations like matrix multiplications really quickly.
On TPUs, programs are executed in a combination of parallel and sequential (depending on the architecture) so there are slightly different considerations.
Block specs#
With grid
and program_id
in mind, Pallas provides an abstraction that takes care of some common indexing patterns seen in a lot of kernels.
To build intuition, let’s try to implement a matrix multiplication.
A simple strategy for implementing a matrix multiplication in Pallas is to implement it recursively. We know our underlying hardware has support for small matrix multiplications (using GPU and TPU tensorcores), so we just express a big matrix multiplication in terms of smaller ones.
Suppose we have input matrices \(X\) and \(Y\) and are computing \(Z = XY\). We first express \(X\) and \(Y\) as block matrices. \(X\) will have “row” blocks and \(Y\) will have “column” blocks.
Our strategy is that because \(Z\) is also a block matrix, we can assign each of the programs in our Pallas kernel one of the output blocks. Computing each output block corresponds to doing a smaller matrix multiply between a “row” block of \(X\) and a “column” block of \(Y\).
To express this pattern, we use BlockSpec
s. A BlockSpec
specifies a block shape for each input and output, and an “index map” function, that maps a set of program indices to a block index.
A visualization of a BlockSpec
For a concrete example, let’s say we’d like to multiply two (1024, 1024)
matrices x
and y
together to produce z
, and would like to parallelize the computation 4 ways. We split up z
into 4 (512, 512)
blocks where each block is computed with a (512, 1024) x (1024, 512)
matrix multiplication. To express this, we’d first use a (2, 2)
grid (one block for each program).
For x
, we use BlockSpec(lambda i, j: (i, 0), (512, 1024))
– this carves x
up into “row” blocks. To see this see how both program instances (1, 0)
and (1, 1)
pick the (1, 0)
block in x
. For y
, we use a transposed version BlockSpec(lambda i, j: (0, j), (1024, 512))
. Finally, for z
we use BlockSpec(lambda i, j: (i, j), (512, 512))
.
These BlockSpec
s are passed into pallas_call
via in_specs
and out_specs
.
Underneath the hood, pallas_call
will automatically carve up your inputs and outputs into Ref
s for each block that will be passed into the kernel.
def matmul_kernel(x_ref, y_ref, z_ref):
z_ref[...] = x_ref[...] @ y_ref[...]
def matmul(x: jax.Array, y: jax.Array):
return pl.pallas_call(
matmul_kernel,
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),
grid=(2, 2),
in_specs=[
pl.BlockSpec(lambda i, j: (i, 0), (x.shape[0] // 2, x.shape[1])),
pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], y.shape[1] // 2))
],
out_specs=pl.BlockSpec(
lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)
)
)(x, y)
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (1024, 1024))
y = jax.random.normal(k2, (1024, 1024))
z = matmul(x, y)
np.testing.assert_allclose(z, x @ y)
Note that this is a very naive implementation of a matrix multiplication but consider it a starting point for various types of optimizations. Let’s add an additional feature to our matrix multiply: fused activation. It’s actually really easy! Just pass a higher-order activation function into the kernel.
def matmul_kernel(x_ref, y_ref, z_ref, *, activation):
z_ref[...] = activation(x_ref[...] @ y_ref[...])
def matmul(x: jax.Array, y: jax.Array, *, activation):
return pl.pallas_call(
partial(matmul_kernel, activation=activation),
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),
grid=(2, 2),
in_specs=[
pl.BlockSpec(lambda i, j: (i, 0), (x.shape[0] // 2, x.shape[1])),
pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], y.shape[1] // 2))
],
out_specs=pl.BlockSpec(
lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)
),
)(x, y)
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (1024, 1024))
y = jax.random.normal(k2, (1024, 1024))
z = matmul(x, y, activation=jax.nn.relu)
np.testing.assert_allclose(z, jax.nn.relu(x @ y))
To conclude, let’s highlight a cool feature of Pallas: it composes with jax.vmap
! To turn this matrix multiplication into a batched version, we just need to vmap
it.
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (4, 1024, 1024))
y = jax.random.normal(k2, (4, 1024, 1024))
z = jax.vmap(partial(matmul, activation=jax.nn.relu))(x, y)
np.testing.assert_allclose(z, jax.nn.relu(jax.vmap(jnp.matmul)(x, y)))
Pallas TPU#
TPU specific documentation.
Writing TPU kernels with Pallas#
This page focuses on the details that are important when attempting to run Pallas kernels on Google TPUs. For one, the TPU backend is still in an experimental phase, and only a subset of JAX NumPy will be accepted. Furthermore, writing performant code for TPUs might require thinking carefully about the native capabilities of the hardware. While many patterns that are unnatural to the hardware will be accepted, they might end up requiring software emulation, and can slow down the computation.
Warning
This feature should still be considered experimental as work is still in progress (in particular on improving the error messages).
Note
While all the features described here are experimental, we remain very serious about maintaining their correctness. As such, it might not be uncommon to see a “not implemented” error while attempting to write TPU kernels. But, if a kernel is accepted by the compiler, it must return the expected results.
If you see unexpected outputs, please compare them against a kernel run with
interpret=True
passed in to pallas_call
. If the results diverge,
please file a bug report.
What is a TPU?#
TPU is a hardware accelerator developed at Google. You can think of TPUs as GPUs, but specialized for machine learning workloads specifically. As such, their architecture differs quite significantly. However, we believe that Pallas can make it easy to start writing TPU kernels, even without having a full understanding of the underlying hardware. Having said that, understanding the hardware well will certainly make it easier to write performant kernels.
In a nutshell, the main difference between TPUs and GPUs is that TPUs are sequential machines with a very wide vector register (kind of like a CPU!). At the same time, they allow the software to schedule certain operations in the background, making them execute asynchronously with respect to the main instruction stream. This includes things like HBM memory accesses (which cannot be issued directly, but instead have to be prefetched to lower levels of the memory hierarchy by the DMA subunits), matrix multiplies (supported by the MXU unit) or matrix transpositions and permutes (supported by the XLU unit).
If you’re interested in learning more about the TPU architecture in detail, we recommend reading a collection of papers published over the years. While many of them talk about specific TPU generations, many of the ideas described transfer to later generations as well.
Noteworthy properties and restrictions#
BlockSpec
s and grid iteration#
BlockSpec
s generally behave as expected in Pallas — every invocation of
the kernel body gets access to slices of the inputs and is meant to initialize a slice
of the output.
Warning
Not all window shapes are supported. If the last two dimensions of your input are larger than 8 and 128 respectively, the window shape in those dimensions must be a multiple of the respective factor. If the input dimension is smaller, the window should span the full dimension.
One interesting aspect of Pallas TPU kernels is the way they handle memory spaces:
While the inputs to pallas_call
will often reside in HBM (the main TPU
memory), the references passed in to the kernel body will point to buffers in
lower levels of memory hierarchy (VMEM or SMEM). This enables the kernel body
to write and read them at very high speeds, while all the communication with
HBM (which has very high latency) is handled by the compiler and overlapped
with compute.
What’s more, compared to GPUs, TPUs are actually highly sequential machines. Ergo, the grid is generally not processed in parallel, but sequentially, in lexicographic order (though see the Multicore TPU configurations section for exceptions). This unlocks some interesting capabilities:
When two (lexicographically) consecutive grid indices use the same slice of an input, the HBM transfer for the second iteration is skipped, as the data is already available.
Multiple invocations of the kernel body can write to the same slice of the output, without any risk of race conditions. However, we do require that all invocations that write to a particular slice are consecutive.
The “consecutive” restriction on the output usually means that the some prefix of the grid dimensions always vary the slice of the output an invocation needs to access, while the output window remains constant for the remaining suffix.
For example, when implementing a Pallas TPU kernel for matrix multiplication, one would generally use a 3 dimensional grid: the first two dimensions would correspond to slicing along the first axis of the left operand and the second axis of the second operand. The third and last grid axis would tile the reduction dimension. The grid axis corresponding to the reduction dimension has to be the last one, since the output window does not vary along this axis. The output reference can be then used as an accumulator for partial results.
Note
VMEM is fairly large for such a low-level memory hierarchy (16MB+), making it possible to use large window sizes. And, oftentimes, the larger the window size, the better the eventual hardware utilization will be. However, it is possible to specify a window size that (together with space necessary to hold spilled vector registers) exceeds the size of VMEM. In this case, you will likely see a low-level compiler error message complaining about an out-of-memory error.
Dimension ordering is meaningful#
In JAX programs, the ordering of intermediate arrays inside jax.jit
usually
has no impact on performance, as the compiler is free to rearrange them.
However, as Pallas is meant to expose lower-level capabilities, the dimension
order can have great impact on the quality of generated code.
Recall that the TPUs perform bulk of the computation on 2D vector registers.
Pallas TPU will only ever consider mapping the last two dimensions of
intermediate arrays to those vector register dimensions (sublanes and lanes
respectively). An array of shape (n, 1, 1)
is guaranteed to require at least
n
vector registers to represent. If n
becomes too large, this can lead
to spills, and potential VMEM OOM errors due to an overly large memory footprint.
But it also might not — the low-level compiler is free to rearrange the
instructions to lower the register pressure, and is in fact very good at it.
Still, it is a good rule of thumb to keep the last two dimensions large
(especially the last dimension), while keeping the leading dimensions small.
Multicore TPU configurations#
In newer TPU generations, the two cores on a chip are often abstracted as a
single device. To take advantage of multiple cores, Pallas has to break the
sequential grid execution guarantees, and will need to parallelize one of the
grid axes over cores. This is an opt-in procedure. To allow that,
pallas_call
requires an extra parameter named dimension_semantics
:
That parameter is a list, with as many entries as many axes there are in the
grid. Only parallel
dimensions can be partitioned over cores. As a rule of
thumb, the dimensions are parallel, unless the output window does not vary.
As such, dimension_semantics
is always a number of parallel
axes
followed by a number of arbitrary
axes.
While partitioning a kernel over a 2-core TPU device often leads to a 2x speedup, it can be in fact significantly smaller. This is especially true if different instances of the body have highly varying cost. If all of the expensive steps get mapped to one core, but all cheap steps are assigned to the other, the second core will be sitting idle until the first one completes its tasks.
Pallas TPU generally favors partitioning axes of a size that is a multiple of the number of TPU cores, and prefers to partition leading grid axes.
Placing operands in SMEM#
Most of the compute on the TPU will happen on the vector unit. Still, there are many cases where it is useful to perform a number of scalar operations, e.g., to carry out control-flow. For that reason, TPUs come with a separate scalar unit, and a separate scalar memory (SMEM) attached to it. As a rule of thumb, any data used to perform control-flow decisions should be placed in SMEM.
SMEM is a low-latency memory that supports random access, but lets you only read and write 32-bit values with a single instruction (very small compared to the 4KBi granularity of VMEM transactions, but much more flexible due to lack of alignment requirements!).
The scalar memory is also very useful when implementing kernels that do not
access the tiles of inputs in a regular pattern, such as when writing
block-sparse kernels. In Pallas, this can be achieved by replacing the
grid
argument to pallas_call
with a grid_spec
of
PrefetchScalarGridSpec
with a non-zero num_scalar_prefetch
argument.
If num_scalar_prefetch
is n
, then the first n
arguments to
pallas_call
will be placed in SMEM. No BlockSpec
s should be specified
for those arguments. But, the BlockSpec
s for all subsequent arguments will
receive not only the grid indices, but also the SMEM references to the leading
operands.
Note
We are working on implementing examples for this feature. Stay tuned!
Supported data types#
At the moment Pallas TPU only supports the following data types:
jnp.float32
jnp.bfloat16
jnp.int*
(all precisions, except forjnp.int4
)jnp.uint*
(all precisions)
Computation placement#
All scalar (i.e. 0D) arrays will be stored in scalar registers, and operations on then will be executed on the scalar core. All other operations (even on single-element, but 1D+ arrays) will be executed on the vector core.
Supported operations#
Matrix multiplication#
Matrix multiplication always produces results in the float32 format.
If your inputs are not float32, we recommend using lax.dot
with
preferred_element_type
set to jnp.float32
.
When using lax.dot_general
, it is possible to fuse transpositions of
the last two dimensions of matrix multiplication operands into the operation,
which can improve overall kernel performance.
Precision control#
Pallas TPU lowering is aware of jax.default_matmul_precision
. For best
performance (and lowest precision), use bfloat16
. If you care about
numerical accuracy, you might want to set the precision to float32
.
Warning
Even if you pass in 32-bit operands to a matrix multiplication, they will be
rounded to bfloat16
unless float32
precision is requested.
Transposition#
If the value has at least 4 dimensions, arbitrary transpositions of all but the last two axes are free. Otherwise, only the transposition of the last two axes is implemented. Note that some transpositions of the last two dimensions can be fused into matrix multiplication.
Accessing memory#
Arbitrary slices of references can be read or updated, subject to implementation constraints. Currently, no restrictions are placed on inputs that are 32-bit wide, but only some slicing patterns are supported for narrower types. Reads and writes that are aligned to multiples of, and have a length that is a multiple of 8 and 128 respectively in the last two dimensions are always supported.
Reads and writes to vector memory generally happen on tiles of shape (8, 128)
.
As such, when reading or writing to references that have at least two dimensions,
the best performance is achieved when the base offset of the memory access
has indices divisible by the tiling, and the size of the read region is a
multiple of the tile size.
Elementwise operations#
Many elementwise operations are supported. It is worth noting that the hardware generally only supports elementwise computation using 32-bit types. When loading operands that use lower-precision types, they should generally be upcast to a 32-bit type before applying elementwise ops.
It is worth noting that they can vary significantly in their cost. As such, we outline three categories of supported operations: cheap (🟢), medium (🌕) and expensive (🔴).
Operation |
Cost |
---|---|
|
🟢 |
|
🟢 |
|
🟢 |
|
🌕 |
|
🟢 |
|
🟢 |
|
🟢 |
|
🟢 |
|
🟢 |
Comparisons ( |
🟢 |
Type casts ( |
🟢 |
|
🌕 |
|
🌕 |
|
🌕 |
|
🔴 |
|
🔴 |
Many JAX functions are implemented in terms of other JAX primitives, so this
list might not be comprehensive. For example, jax.nn.relu
is implemented
in terms of comparisons and jnp.where
will work in Pallas kernels too.
Array constructors#
All constant array constructors are supported (jnp.ones
, jnp.zeros
,
jnp.full
). Notably, the jax.random
module is not compatible with
Pallas as of today.
Reductions#
Sum, maximum and minimum reductions are supported, but only on a single array axis at a time.
Reductions over the last array dimension are generally the slowest. Reductions over the second last dimension are faster, but still slower than over the leading dimensions.
Broadcasting#
The performance characteristics of broadcasting are very similar to those of reductions. Broadcasting along all but the two trailing dimensions is always supported and free. Broadcasting along the second to last dimension is slower, while broadcasting along the last dimension is the slowest.
Reshapes#
As usual, reshapes in all dimensions but the last two dimensions are supported and free.
The only two supported cases when a reshape can modify the last two dimensions of an array is when (1) some leading dimensions are flattened onto the second to last dimension, or (2) it adds a dimension that was just removed by a reduction.
Control flow#
The TPU backend features limited support for control flow at the moment. The
currently supported functions are cond
, fori_loop
and for_loop
.
However, loop primitives get fully unrolled during the compilation at the
moment, so try to keep the loop trip count reasonably small.
Overusing control flow can lead to significant regressions in low-level code generation, and it is recommended to try to squeeze as many computationally expensive operations into a single basic block as possible.
Pipelining and BlockSpec
s#
In this guide we’ll cover how memory spaces in TPU work and how to write pipelines in Pallas that overlap memory I/O with compute.
#@title Imports
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np
TPU and its memory spaces#
A TPU and its TensorCore consist of memory spaces (where arrays can reside), registers (which temporarily store scalar and array values) and compute units (that do computation with values in registers). Below is a diagram of a TPU in which x
and y
are arrays that live in high-bandwidth memory (HBM):
Let’s talk about the components of this diagram in more detail:
Memory spaces: A TPU has high-bandwidth memory (HBM) which is what we often think of as “device memory”. There is also vector memory (VMEM), a cache meant for storing vector and array values, and scalar memory (SMEM), a cache designed to store scalar values.
Registers: A TensorCore has two main types of registers: vector registers (VREGs) store array values, and scalar registers (SREGs) store scalar values. Values can be loaded into memory from their respective caches (VMEM for VREGs and SMEM for SREGs).
Compute units: A TensorCore has a scalar unit, vector unit (VPU) and matrix unit (MXU) that can do numerical computation. Compute units operate on values that live in SREGs and VREGs and output values into those registers as well.
In order to do a vectorized computation on our values x
and y
that live in HBM, we need to:
Copy the values
x
andy
into VMEM.Load the values from VMEM into VREGs.
Execute the computation using the VPU or MXU, storing the output in VREGs.
Store the values in the output VREGs into VMEM.
Copy the output values in VMEM back to HBM.
Let’s implement a Pallas function that does just that!
def add_matrices_kernel(x_vmem_ref, y_vmem_ref, z_vmem_ref):
# Load x and y from VMEM into VREGs
x_vregs = x_vmem_ref[:, :]
y_vregs = y_vmem_ref[:, :]
# Execute a vectorized add
z_vregs = x_vregs + y_vregs
# Store the output values in VREGs back into VMEM
z_vmem_ref[:, :] = z_vregs
def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array:
# pallas_call will first allocate scratch buffers for `x` and `y` in VMEM.
# It will then copy `x` and `y` from HBM into VMEM.
z = pl.pallas_call(
add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
)(x, y)
# pallas_call will also copy the output from VMEM back into HBM.
return z
x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
...,
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.]], dtype=float32)
We’ve written two functions: add_matrices_kernel
and add_matrices
.
add_matrices_kernel
operates using Ref
s that live in VMEM. Loading from a VMEM Ref
produces a value that lives in VREGs. Values in VREGs behave like jax.Array
s in that we can use jnp
and jax.lax
operations on them to produce new values that live in VREGs. When we produce the values we’d like to return, we store them in the output VMEM Ref
.
The add_matrices
function acts on jax.Array
s and returns a jax.Array
. Inside it, we pass x
and y
into pallas_call
. pallas_call
is responsible for copying x
and y
into VMEM and for allocating the VMEM buffers that the kernel operates on (including allocating z_vmem_ref
, the output VMEM buffer). After the kernel function is finished running, pallas_call
will also copy the value in z_vmem_ref
to HBM, resulting in an output jax.Array
.
Constraints of using VMEM/SMEM#
Pallas exposes access to lower level memory spaces like VMEM and SMEM but writing kernels utilizing them adds some considerations.
Memory capacity. VMEM and SMEM are small! VMEM on v4 TPUs is only 16MiB and SMEM ranges in the tens to hundreds of KiB. If our arrays are too big, we won’t even be able to fit them into VMEM at all. For reference, a
f32[2048, 2048]
array is 16MiB, so our above kernel won’t scale beyond moderately sized arrays.Memory bandwidth. Copying to/from HBM and VMEM takes a long time, at least compared to most compute instructions. The
add_matrices
function above will likely spend more time copying between HBM and VMEM than actually performing the addition itself.
With these two constraints in mind, we’ll have to rethink our strategy for getting performance out of our TPUs.
Primer: Pipelining#
Pipelining our computation offers a way of dealing with both the memory capacity and bandwidth constraints in one fell swoop. What do we mean by pipelining?
The goal is: in parallel copy to/from HBM and VMEM while utilizing our compute units. Naively this is difficult because in our program above we copy all of x
and y
before we start doing any compute with them, creating a dependence between the copy and the compute.
However, if we can chunk up our computation into several subcomputations (e.g. when we add two matrices, we can express that as addition of “blocks” of the original matrices together), we can now overlap the copies of one of those subcomputations with the compute of the other. Let’s walk through a simple example:
Let’s say we split our arrays x
and y
into x1, x2
and y1, y2
(for example, split along the leading axis, resulting in two (256, 512)
arrays for each input. We can now execute the following pipelined computation.
Copy
x1
andy1
into VMEM.Start copying
x2
andy2
into VMEMLoad
x1, y1
from VMEM into VREGs.Execute the
z1 = x1 + y1
using the compute units.Store
z1
into VMEM.Start copying
z1
from VMEM back into HBM.Wait until
x2, y2
have been copied into VMEM.Load
x2, y2
from VMEM into VREGs.Execute the
z2 = x2 + y2
using the compute units.Store
z2
into VMEM.Wait until
z1
is copied into HBM.Start copying
z2
from VMEM back into HBM.Wait until
z2
is copied into HBM.
Any time we are doing compute here, we are asynchronously copying something. This means that some of the time spent copying is not wasted.
The two most important numbers for determining how efficient a pipelined computation are a) how many floating point operations (FLOPs) we need to execute and b) how many bytes we need to copy to execute that computation. The ratio of these two (FLOPs/memory usage) is called the arithmetic intensity of an operation and determines if our pipeline will be compute bound or memory bound.
Pipelining in Pallas#
How do we implement a pipeline like the one above in Pallas? It seems like a complex sequence of asynchronous data operations and executing kernels that would be a pain to implement manually. Fear not! Pallas offers an API for expressing pipelines without too much boilerplate, namely through grid
s and BlockSpec
s.
grid
, a.k.a. kernels in a loop#
See how in the above pipelined example, we are executing the same logic multiple times: steps 3-5 and 8-10 both execute the same operations, only on different inputs. The generalized version of this is a loop in which the same kernel is executed multiple times. pallas_call
provides an option to do exactly that.
The number of iterations in the loop is specified via the grid
argument to pallas_call
. Conceptually:
pl.pallas_call(some_kernel, grid=n)(...)
maps to
for i in range(n):
# do HBM -> VMEM copies
some_kernel(...)
# do VMEM -> HBM copies
Grids can be generalized to be multi-dimensional, corresponding to nested loops. For example,
pl.pallas_call(some_kernel, grid=(n, m))(...)
is equivalent to
for i in range(n):
for j in range(m):
# do HBM -> VMEM copies
some_kernel(...)
# do VMEM -> HBM copies
This generalizes to any tuple of integers (a length d
grid will correspond to d
nested loops).
BlockSpec
, a.k.a. how to chunk up inputs#
The next piece of information we need to provide Pallas in order to automatically pipeline our computation is information on how to chunk it up. Specifically, we need to provide a mapping between the iteration of the loop to which block of our inputs and outputs to be operated on. A BlockSpec
is exactly these two pieces of information.
First we pick a block_shape
for our inputs. In the pipelining example above, we had (512, 512)
-shaped arrays and split them along the leading dimension into two (256, 512)
-shaped arrays. In this pipeline, our block_shape
would be (256, 512)
.
We then provide an index_map
function that maps the iteration space to the blocks. Specifically, in the aforementioned pipeline, on the 1st iteration we’d like to select x1
and on the second iteration we’d like to use x2
. This can be expressed with the following index_map
:
def x_index_map(i):
return (i, 0)
We’d then construct the BlockSpec
:
block_spec = pl.BlockSpec(x_index_map, (256, 512))
The BlockSpec
s for y
and z
will be the same as the one for x
.
Putting it together#
We provide these arguments to pallas_call
via grid
, in_specs
and out_specs
(in_specs
corresponds to the tuple of positional arguments, and out_specs
corresponds to the output).
def add_matrices_pipelined(x: jax.Array, y: jax.Array) -> jax.Array:
block_spec = pl.BlockSpec(lambda i: (i, 0), (256, 512))
return pl.pallas_call(
add_matrices_kernel,
out_shape=x,
in_specs=[block_spec, block_spec],
out_specs=block_spec,
grid=(2,))(x, y)
add_matrices_pipelined(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
...,
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.]], dtype=float32)
We’ve only added a little bit of code to our original function to add automatic pipelining but the BlockSpec
s and grid
do a lot of heavy lifting!
How does it work? Well, the BlockSpec
s provide enough information to start prefetching blocks of our input from HBM into VMEM. For example, if we are starting iteration i
of our grid
, we can pass i + 1
into the index_map
functions to obtain the blocks needed for the next iteration. We can then start an asynchronous copy for those blocks. Similarly for outputs, we can wait for the outputs of the previous iteration to be copied before starting the copy for the current iteration’s outputs.
Parameterizing a pipeline#
It’s common to parameterize the block shapes in our kernel. Block sizes are perhaps the most important parameter to tune when optimizing the performance of Pallas kernels! They give us control over the pipeline (for example, picking smaller blocks adds more iterations to our pipelined loop where each iteration has less work to do).
Furthermore, we could also carve up the inputs and outputs along the 2nd dimension (we are only splitting along the first right now). Let’s write a more general kernel that handles both of these features.
def add_matrices_pipelined_2d(
x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256
) -> jax.Array:
m, n = x.shape
block_spec = pl.BlockSpec(lambda i, j: (i, j), (bm, bn))
return pl.pallas_call(
add_matrices_kernel,
out_shape=x,
in_specs=[block_spec, block_spec],
out_specs=block_spec,
grid=(m // bm, n // bn),
)(x, y)
np.testing.assert_array_equal(
add_matrices_pipelined_2d(x, y, bm=256, bn=256), x + y
)
np.testing.assert_array_equal(
add_matrices_pipelined_2d(x, y, bm=128, bn=128), x + y
)
np.testing.assert_array_equal(
add_matrices_pipelined_2d(x, y, bm=512, bn=512), x + y
)
Handling reductions#
How would you implement something like jnp.sum
using pallas_call
? Specifically, we’d like to pipeline across the reduction dimension.
Take the example of reducing a (8, 512, 512)
-shaped array to a (512, 512)
-shaped one.
x = jnp.ones((8, 512, 512))
jnp.sum(x, axis=0)
Array([[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
...,
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.]], dtype=float32)
To do this using pallas_call
, we could use a grid of size (8,)
and in each iteration i
load x[i]
into VMEM. Then we could add x[i]
to an output VMEM buffer. Let’s implement this naively first.
# Warning: this implementation is incorrect!
def naive_sum_kernel(x_ref, o_ref):
o_ref[...] += x_ref[...]
def naive_sum(x: jax.Array) -> jax.Array:
grid, *out_shape = x.shape
return pl.pallas_call(
naive_sum_kernel,
grid=grid,
# None in `block_shape` means we pick a size of 1 and squeeze it away
in_specs=[pl.BlockSpec(lambda i: (i, 0, 0), (None, *out_shape))],
out_specs=pl.BlockSpec(lambda i: (0, 0), out_shape),
out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)
)(x)
naive_sum(x)
Array([[9., 9., 9., ..., 9., 9., 9.],
[9., 9., 9., ..., 9., 9., 9.],
[9., 9., 9., ..., 9., 9., 9.],
...,
[9., 9., 9., ..., 9., 9., 9.],
[9., 9., 9., ..., 9., 9., 9.],
[9., 9., 9., ..., 9., 9., 9.]], dtype=float32)
Notice how we’ve set up the BlockSpec
s: we’re loading the entirety of the (512, 512)
dimension into VMEM (no pipelining there) but selecting the i
-th dimension of x
each iteration in the index_map
. We are using a None
for that dimension in the block shape, which indicates that we are selecting a singleton dimension from x
that we would like to squeeze away in the kernel. Therefore, x_ref
is (512, 512)
-shaped in VMEM as well.
out_spec
uses lambda i: (0, 0)
as its index_map
, indicating that o_ref
is unchanged over the course of the pipeline. This means that we can update its value each iteration by reading from and writing to it. Or can it? Actually there is one catch: o_ref
is initially garbage, meaning we’ll be accumulating into garbage. This will result in the overall function outputting the incorrect value!
Therefore, whenever we do a reduction in a kernel, we need to make sure to initialize the Ref
that is storing the reduced value. We can accomplish this by conditionally writing a value to out_ref
when we’re on iteration 0. We can do this with the helper function pl.when
, a convenience wrapper around jax.lax.cond
, and pl.program_id
, which queries which iteration in a grid axis we are in.
def sum_kernel(x_ref, o_ref):
@pl.when(pl.program_id(axis=0) == 0)
def _():
o_ref[...] = jnp.zeros_like(o_ref)
o_ref[...] += x_ref[...]
def sum(x: jax.Array) -> jax.Array:
grid, *out_shape = x.shape
return pl.pallas_call(
sum_kernel,
grid=grid,
# None in `block_shape` means we pick a size of 1 and squeeze it away
in_specs=[pl.BlockSpec(lambda i: (i, 0, 0), (None, *out_shape))],
out_specs=pl.BlockSpec(lambda i: (0, 0), out_shape),
out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)
)(x)
sum(x)
Array([[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
...,
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.]], dtype=float32)
This sum
function now outputs the correct values!
One last thing to note about reductions in Pallas are that they must be done in the minormost (rightmost) dimensions of our grid (our grid is 1-dimensional in the above example so we are reducing over its minormost dimension). This is because the pipeline that Pallas generates using the BlockSpec
s, grid
and kernel function does not read outputs back from HBM. Once you’ve written an output value back to HBM you cannot revisit it. Therefore, you cannot do a reduction across a grid dimension that has any revisiting and therefore all reductions need to happen in the rightmost dimensions.
TPUs in Megacore configuration#
Some TPU chips have two TensorCores but appear as one device to JAX users. This is called “megacore”. The separate TensorCores have their own separate VMEM, VREGs, SMEM, SREGs and compute units but share HBM.
Conceptually, TPUs in Megacore behave like very simple GPUs, i.e. they have only two threads. How do we modify our kernels to utilize both TensorCores simultaneously?
The basic idea is that if we have embarrassingly parallel dimensions in our computation, we can split up those dimensions across the TensorCores. We can indicate which dimensions are parallelizable by providing an annotation to pallas_call
called dimension_semantics
.
def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:
block_spec = pl.BlockSpec(lambda i: (i, 0), (256, 512))
return pl.pallas_call(
add_matrices_kernel,
out_shape=x,
in_specs=[block_spec, block_spec],
out_specs=block_spec,
grid=(2,),
compiler_params=dict(mosaic=dict(dimension_semantics=("parallel",))))(
x, y)
x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices_pipelined_megacore(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
...,
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.]], dtype=float32)
dimension_semantics
should be a tuple of same length as grid
where each entry is either "parallel"
or "arbitrary"
. "parallel"
indicates to Pallas that the iterations of the for loop corresponding to that dimension can be executed independently without affecting the correctness of the program. "arbitrary"
indicates to Pallas that there can be no assumptions made about this grid dimension and it therefore cannot be parallelized.
By specifying dimension_semantics
, we now execute the kernel simultaneously on each TensorCore. Pallas will handle splitting up the grid automatically.
Note that Megacore is only currently available on TPU
v4
and TPUv5p
. Supplyingdimension_semantics
annotations is a no-op on other platforms, but not specifying it will result in only one TensorCore being used (even if there are more than one available).
Conclusion#
In this guide we covered how to express TPU pipelines using pallas_call
, grid
and BlockSpec
s. We covered how to express nested loops via a multi-dimensional grid and how to handle reductions by initialize our accumulators at the beginning of the reduction. We also learned how to handle Megacore by adding annotations to the kernel.
Exercises left to the reader:
Try implementing a
sum
kernel that pipelines the other dimensions as wellAdd megacore support to the
add
kernel and thesum
kernel as well.
Advanced Tutorials#
This section contains examples and tutorials on more advanced topics, such as Multi Core computation, Custom operations, and more in depth applications
Copyright 2018 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the “License”);
Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
Training a Simple Neural Network, with tensorflow/datasets Data Loading#
Forked from neural_network_and_data_loading.ipynb
Let’s combine everything we showed in the quickstart to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use tensorflow/datasets
data loading API to load images and labels (because it’s pretty great, and the world doesn’t need yet another data loading library :P).
Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won’t use any neural network libraries or special APIs for building our model.
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
Hyperparameters#
Let’s get a few bookkeeping items out of the way.
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
w_key, b_key = random.split(key)
return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))
# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
keys = random.split(key, len(sizes))
return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.key(0))
Auto-batching predictions#
Let us first define our prediction function. Note that we’re defining this for a single image example. We’re going to use JAX’s vmap
function to automatically handle mini-batches, with no performance penalty.
from jax.scipy.special import logsumexp
def relu(x):
return jnp.maximum(0, x)
def predict(params, image):
# per-example predictions
activations = image
for w, b in params[:-1]:
outputs = jnp.dot(w, activations) + b
activations = relu(outputs)
final_w, final_b = params[-1]
logits = jnp.dot(final_w, activations) + final_b
return logits - logsumexp(logits)
Let’s check that our prediction function only works on single images.
# This works on single examples
random_flattened_image = random.normal(random.key(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)
(10,)
# Doesn't work with a batch
random_flattened_images = random.normal(random.key(1), (10, 28 * 28))
try:
preds = predict(params, random_flattened_images)
except TypeError:
print('Invalid shapes!')
Invalid shapes!
# Let's upgrade it to handle batches using `vmap`
# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))
# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)
(10, 10)
At this point, we have all the ingredients we need to define our neural network and train it. We’ve built an auto-batched version of predict
, which we should be able to use in a loss function. We should be able to use grad
to take the derivative of the loss with respect to the neural network parameters. Last, we should be able to use jit
to speed up everything.
Utility and loss functions#
def one_hot(x, k, dtype=jnp.float32):
"""Create a one-hot encoding of x of size k."""
return jnp.array(x[:, None] == jnp.arange(k), dtype)
def accuracy(params, images, targets):
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
return jnp.mean(predicted_class == target_class)
def loss(params, images, targets):
preds = batched_predict(params, images)
return -jnp.mean(preds * targets)
@jit
def update(params, x, y):
grads = grad(loss)(params, x, y)
return [(w - step_size * dw, b - step_size * db)
for (w, b), (dw, db) in zip(params, grads)]
Data Loading with tensorflow/datasets
#
JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don’t include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let’s just use them instead of reinventing anything. We’ll use the tensorflow/datasets
data loader.
import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type='GPU')
import tensorflow_datasets as tfds
data_dir = '/tmp/tfds'
# Fetch full datasets for evaluation
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c
# Full train set
train_images, train_labels = train_data['image'], train_data['label']
train_images = jnp.reshape(train_images, (len(train_images), num_pixels))
train_labels = one_hot(train_labels, num_labels)
# Full test set
test_images, test_labels = test_data['image'], test_data['label']
test_images = jnp.reshape(test_images, (len(test_images), num_pixels))
test_labels = one_hot(test_labels, num_labels)
print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape)
Train: (60000, 784) (60000, 10)
Test: (10000, 784) (10000, 10)
Training Loop#
import time
def get_train_batches():
# as_supervised=True gives us the (image, label) as a tuple instead of a dict
ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)
# You can build up an arbitrary tf.data input pipeline
ds = ds.batch(batch_size).prefetch(1)
# tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
return tfds.as_numpy(ds)
for epoch in range(num_epochs):
start_time = time.time()
for x, y in get_train_batches():
x = jnp.reshape(x, (len(x), num_pixels))
y = one_hot(y, num_labels)
params = update(params, x, y)
epoch_time = time.time() - start_time
train_acc = accuracy(params, train_images, train_labels)
test_acc = accuracy(params, test_images, test_labels)
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
Epoch 0 in 28.30 sec
Training set accuracy 0.8400499820709229
Test set accuracy 0.8469000458717346
Epoch 1 in 14.74 sec
Training set accuracy 0.8743667006492615
Test set accuracy 0.8803000450134277
Epoch 2 in 14.57 sec
Training set accuracy 0.8901500105857849
Test set accuracy 0.8957000374794006
Epoch 3 in 14.36 sec
Training set accuracy 0.8991333246231079
Test set accuracy 0.903700053691864
Epoch 4 in 14.20 sec
Training set accuracy 0.9061833620071411
Test set accuracy 0.9087000489234924
Epoch 5 in 14.89 sec
Training set accuracy 0.9113333225250244
Test set accuracy 0.912600040435791
Epoch 6 in 13.95 sec
Training set accuracy 0.9156833291053772
Test set accuracy 0.9176000356674194
Epoch 7 in 13.32 sec
Training set accuracy 0.9192000031471252
Test set accuracy 0.9214000701904297
Epoch 8 in 13.55 sec
Training set accuracy 0.9222500324249268
Test set accuracy 0.9241000413894653
Epoch 9 in 13.40 sec
Training set accuracy 0.9253666996955872
Test set accuracy 0.9269000291824341
We’ve now used most of the JAX API: grad
for derivatives, jit
for speedups and vmap
for auto-vectorization.
We used NumPy to specify all of our computation, and borrowed the great data loaders from tensorflow/datasets
, and ran the whole thing on the GPU.
Training a Simple Neural Network, with PyTorch Data Loading#
Copyright 2018 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
Let’s combine everything we showed in the quickstart to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch’s data loading API to load images and labels (because it’s pretty great, and the world doesn’t need yet another data loading library).
Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won’t use any neural network libraries or special APIs for building our model.
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
Hyperparameters#
Let’s get a few bookkeeping items out of the way.
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
w_key, b_key = random.split(key)
return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))
# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
keys = random.split(key, len(sizes))
return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 8
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.key(0))
Auto-batching predictions#
Let us first define our prediction function. Note that we’re defining this for a single image example. We’re going to use JAX’s vmap
function to automatically handle mini-batches, with no performance penalty.
from jax.scipy.special import logsumexp
def relu(x):
return jnp.maximum(0, x)
def predict(params, image):
# per-example predictions
activations = image
for w, b in params[:-1]:
outputs = jnp.dot(w, activations) + b
activations = relu(outputs)
final_w, final_b = params[-1]
logits = jnp.dot(final_w, activations) + final_b
return logits - logsumexp(logits)
Let’s check that our prediction function only works on single images.
# This works on single examples
random_flattened_image = random.normal(random.key(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)
(10,)
# Doesn't work with a batch
random_flattened_images = random.normal(random.key(1), (10, 28 * 28))
try:
preds = predict(params, random_flattened_images)
except TypeError:
print('Invalid shapes!')
Invalid shapes!
# Let's upgrade it to handle batches using `vmap`
# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))
# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)
(10, 10)
At this point, we have all the ingredients we need to define our neural network and train it. We’ve built an auto-batched version of predict
, which we should be able to use in a loss function. We should be able to use grad
to take the derivative of the loss with respect to the neural network parameters. Last, we should be able to use jit
to speed up everything.
Utility and loss functions#
def one_hot(x, k, dtype=jnp.float32):
"""Create a one-hot encoding of x of size k."""
return jnp.array(x[:, None] == jnp.arange(k), dtype)
def accuracy(params, images, targets):
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
return jnp.mean(predicted_class == target_class)
def loss(params, images, targets):
preds = batched_predict(params, images)
return -jnp.mean(preds * targets)
@jit
def update(params, x, y):
grads = grad(loss)(params, x, y)
return [(w - step_size * dw, b - step_size * db)
for (w, b), (dw, db) in zip(params, grads)]
Data Loading with PyTorch#
JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don’t include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let’s just use them instead of reinventing anything. We’ll grab PyTorch’s data loader, and make a tiny shim to make it work with NumPy arrays.
!pip install torch torchvision
Requirement already satisfied: torch in /opt/anaconda3/lib/python3.7/site-packages (1.4.0)
Requirement already satisfied: torchvision in /opt/anaconda3/lib/python3.7/site-packages (0.5.0)
Requirement already satisfied: numpy in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (1.17.2)
Requirement already satisfied: six in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (1.12.0)
Requirement already satisfied: pillow>=4.1.1 in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (6.2.0)
import numpy as np
from jax.tree_util import tree_map
from torch.utils import data
from torchvision.datasets import MNIST
def numpy_collate(batch):
return tree_map(np.asarray, data.default_collate(batch))
class NumpyLoader(data.DataLoader):
def __init__(self, dataset, batch_size=1,
shuffle=False, sampler=None,
batch_sampler=None, num_workers=0,
pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
super(self.__class__, self).__init__(dataset,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=numpy_collate,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
worker_init_fn=worker_init_fn)
class FlattenAndCast(object):
def __call__(self, pic):
return np.ravel(np.array(pic, dtype=jnp.float32))
# Define our dataset, using torch datasets
mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw
Processing...
Done!
# Get the full train dataset (for checking accuracy while training)
train_images = np.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)
train_labels = one_hot(np.array(mnist_dataset.train_labels), n_targets)
# Get full test dataset
mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)
test_images = jnp.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=jnp.float32)
test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets)
/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:55: UserWarning: train_data has been renamed data
warnings.warn("train_data has been renamed data")
/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:45: UserWarning: train_labels has been renamed targets
warnings.warn("train_labels has been renamed targets")
/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:60: UserWarning: test_data has been renamed data
warnings.warn("test_data has been renamed data")
/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:50: UserWarning: test_labels has been renamed targets
warnings.warn("test_labels has been renamed targets")
Training Loop#
import time
for epoch in range(num_epochs):
start_time = time.time()
for x, y in training_generator:
y = one_hot(y, n_targets)
params = update(params, x, y)
epoch_time = time.time() - start_time
train_acc = accuracy(params, train_images, train_labels)
test_acc = accuracy(params, test_images, test_labels)
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
Epoch 0 in 55.15 sec
Training set accuracy 0.9157500267028809
Test set accuracy 0.9195000529289246
Epoch 1 in 42.26 sec
Training set accuracy 0.9372166991233826
Test set accuracy 0.9384000301361084
Epoch 2 in 44.37 sec
Training set accuracy 0.9491666555404663
Test set accuracy 0.9469000697135925
Epoch 3 in 41.75 sec
Training set accuracy 0.9568166732788086
Test set accuracy 0.9534000158309937
Epoch 4 in 41.16 sec
Training set accuracy 0.9631333351135254
Test set accuracy 0.9577000737190247
Epoch 5 in 38.89 sec
Training set accuracy 0.9675000309944153
Test set accuracy 0.9616000652313232
Epoch 6 in 40.68 sec
Training set accuracy 0.9708333611488342
Test set accuracy 0.9650000333786011
Epoch 7 in 41.50 sec
Training set accuracy 0.973716676235199
Test set accuracy 0.9672000408172607
We’ve now used the whole of the JAX API: grad
for derivatives, jit
for speedups and vmap
for auto-vectorization.
We used NumPy to specify all of our computation, and borrowed the great data loaders from PyTorch, and ran the whole thing on the GPU.
Autobatching for Bayesian Inference#
This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.
Inspired by a notebook by @davmre.
import functools
import itertools
import re
import sys
import time
from matplotlib.pyplot import *
import jax
from jax import lax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import random
import numpy as np
import scipy as sp
Generate a fake binary classification dataset#
np.random.seed(10009)
num_features = 10
num_points = 100
true_beta = np.random.randn(num_features).astype(jnp.float32)
all_x = np.random.randn(num_points, num_features).astype(jnp.float32)
y = (np.random.rand(num_points) < sp.special.expit(all_x.dot(true_beta))).astype(jnp.int32)
y
array([0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0,
1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0,
1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1,
1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0], dtype=int32)
Write the log-joint function for the model#
We’ll write a non-batched version, a manually batched version, and an autobatched version.
Non-batched#
def log_joint(beta):
result = 0.
# Note that no `axis` parameter is provided to `jnp.sum`.
result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=1.))
result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta))))
return result
log_joint(np.random.randn(num_features))
Array(-213.2356, dtype=float32)
# This doesn't work, because we didn't write `log_prob()` to handle batching.
try:
batch_size = 10
batched_test_beta = np.random.randn(batch_size, num_features)
log_joint(np.random.randn(batch_size, num_features))
except ValueError as e:
print("Caught expected exception " + str(e))
Caught expected exception Incompatible shapes for broadcasting: shapes=[(100,), (100, 10)]
Manually batched#
def batched_log_joint(beta):
result = 0.
# Here (and below) `sum` needs an `axis` parameter. At best, forgetting to set axis
# or setting it incorrectly yields an error; at worst, it silently changes the
# semantics of the model.
result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=1.),
axis=-1)
# Note the multiple transposes. Getting this right is not rocket science,
# but it's also not totally mindless. (I didn't get it right on the first
# try.)
result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta.T).T)),
axis=-1)
return result
batch_size = 10
batched_test_beta = np.random.randn(batch_size, num_features)
batched_log_joint(batched_test_beta)
Array([-147.84033 , -207.02205 , -109.26075 , -243.80833 , -163.0291 ,
-143.84848 , -160.28773 , -113.771706, -126.60544 , -190.81992 ], dtype=float32)
Autobatched with vmap#
It just works.
vmap_batched_log_joint = jax.vmap(log_joint)
vmap_batched_log_joint(batched_test_beta)
Array([-147.84033 , -207.02205 , -109.26075 , -243.80833 , -163.0291 ,
-143.84848 , -160.28773 , -113.771706, -126.60544 , -190.81992 ], dtype=float32)
Self-contained variational inference example#
A little code is copied from above.
Set up the (batched) log-joint function#
@jax.jit
def log_joint(beta):
result = 0.
# Note that no `axis` parameter is provided to `jnp.sum`.
result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=10.))
result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta))))
return result
batched_log_joint = jax.jit(jax.vmap(log_joint))
Define the ELBO and its gradient#
def elbo(beta_loc, beta_log_scale, epsilon):
beta_sample = beta_loc + jnp.exp(beta_log_scale) * epsilon
return jnp.mean(batched_log_joint(beta_sample), 0) + jnp.sum(beta_log_scale - 0.5 * np.log(2*np.pi))
elbo = jax.jit(elbo)
elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))
Optimize the ELBO using SGD#
def normal_sample(key, shape):
"""Convenience function for quasi-stateful RNG."""
new_key, sub_key = random.split(key)
return new_key, random.normal(sub_key, shape)
normal_sample = jax.jit(normal_sample, static_argnums=(1,))
key = random.key(10003)
beta_loc = jnp.zeros(num_features, jnp.float32)
beta_log_scale = jnp.zeros(num_features, jnp.float32)
step_size = 0.01
batch_size = 128
epsilon_shape = (batch_size, num_features)
for i in range(1000):
key, epsilon = normal_sample(key, epsilon_shape)
elbo_val, (beta_loc_grad, beta_log_scale_grad) = elbo_val_and_grad(
beta_loc, beta_log_scale, epsilon)
beta_loc += step_size * beta_loc_grad
beta_log_scale += step_size * beta_log_scale_grad
if i % 10 == 0:
print('{}\t{}'.format(i, elbo_val))
0 -180.8538818359375
10 -113.06045532226562
20 -102.73727416992188
30 -99.787353515625
40 -98.90898132324219
50 -98.29745483398438
60 -98.18632507324219
70 -97.57972717285156
80 -97.28599548339844
90 -97.46996307373047
100 -97.4771728515625
110 -97.5806655883789
120 -97.4943618774414
130 -97.50271606445312
140 -96.86396026611328
150 -97.44197845458984
160 -97.06941223144531
170 -96.84028625488281
180 -97.21336364746094
190 -97.56503295898438
200 -97.26397705078125
210 -97.11979675292969
220 -97.39595031738281
230 -97.16831970214844
240 -97.118408203125
250 -97.24345397949219
260 -97.29788970947266
270 -96.69286346435547
280 -96.96438598632812
290 -97.30055236816406
300 -96.63591766357422
310 -97.0351791381836
320 -97.52909088134766
330 -97.28811645507812
340 -97.07321166992188
350 -97.15619659423828
360 -97.25881958007812
370 -97.19515228271484
380 -97.13092041015625
390 -97.11726379394531
400 -96.938720703125
410 -97.26676940917969
420 -97.35322570800781
430 -97.21007537841797
440 -97.28434753417969
450 -97.1630859375
460 -97.2612533569336
470 -97.21343994140625
480 -97.23997497558594
490 -97.14913940429688
500 -97.23527526855469
510 -96.93419647216797
520 -97.21209716796875
530 -96.82575988769531
540 -97.01284790039062
550 -96.94175720214844
560 -97.16520690917969
570 -97.29165649414062
580 -97.42941284179688
590 -97.24370574951172
600 -97.15222930908203
610 -97.49844360351562
620 -96.9906997680664
630 -96.88956451416016
640 -96.89968872070312
650 -97.13793182373047
660 -97.43705749511719
670 -96.99235534667969
680 -97.15623474121094
690 -97.1869125366211
700 -97.11160278320312
710 -97.78105163574219
720 -97.23226165771484
730 -97.16206359863281
740 -96.99581909179688
750 -96.6672134399414
760 -97.16795349121094
770 -97.51435089111328
780 -97.28900146484375
790 -96.91226196289062
800 -97.17100524902344
810 -97.29047393798828
820 -97.16242980957031
830 -97.19107055664062
840 -97.56382751464844
850 -97.00194549560547
860 -96.86555480957031
870 -96.76338195800781
880 -96.83660888671875
890 -97.12178039550781
900 -97.09554290771484
910 -97.0682373046875
920 -97.11947631835938
930 -96.87930297851562
940 -97.45624542236328
950 -96.69279479980469
960 -97.29376220703125
970 -97.3353042602539
980 -97.34962463378906
990 -97.09675598144531
Display the results#
Coverage isn’t quite as good as we might like, but it’s not bad, and nobody said variational inference was exact.
figure(figsize=(7, 7))
plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')
plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\sigma$ Error Bars')
plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')
plot_scale = 3
plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')
xlabel('True beta')
ylabel('Estimated beta')
legend(loc='best')
<matplotlib.legend.Legend at 0x7fd32443e410>

Using JAX in multi-host and multi-process environments#
Introduction#
This guide explains how to use JAX in environments such as GPU clusters and Cloud TPU pods where accelerators are spread across multiple CPU hosts or JAX processes. We’ll refer to these as “multi-process” environments.
This guide specifically focuses on how to use collective communication
operations (e.g. jax.lax.psum()
) in multi-process settings, although
other communication methods may be useful too depending on your use case (e.g.
RPC, mpi4jax). If you’re not already
familiar with JAX’s collective operations, we recommend starting with the
Introduction to sharded computation section. An important requirement of
multi-process environments in JAX is direct communication links between
accelerators, e.g. the high-speed interconnects for Cloud TPUs or
NCCL for GPUs. These links allow
collective operations to run across multiple processes’ worth of accelerators
with high performance.
Multi-process programming model#
Key concepts:
You must run at least one JAX process per host.
You should initialize the cluster with
jax.distributed.initialize()
.Each process has a distinct set of local devices it can address. The global devices are the set of all devices across all processes.
Use standard JAX parallelism APIs like
pmap()
andxmap()
. Each process “sees” local input and output to parallelized functions, but communication inside the computations is global.Make sure all processes run the same parallel computations in the same order.
Launching JAX processes#
Unlike other distributed systems where a single controller node manages many worker nodes, JAX uses a “multi-controller” programming model where each JAX Python process runs independently, sometimes referred to as a Single Program, Multiple Data (SPMD) model. Generally, the same JAX Python program is run in each process, with only slight differences between each process’s execution (e.g. different processes will load different input data). Furthermore, you must manually run your JAX program on each host! JAX doesn’t automatically start multiple processes from a single program invocation.
(The requirement for multiple processes is why this guide isn’t offered as a notebook – we don’t currently have a good way to manage multiple Python processes from a single notebook.)
Initializing the cluster#
To initialize the cluster, you should call jax.distributed.initialize()
at
the start of each process. jax.distributed.initialize()
must be called
early in the program, before any JAX computations are executed.
The API jax.distributed.initialize()
takes several arguments, namely:
coordinator_address
: the IP address of process 0 in your cluster, together with a port available on that process. Process 0 will start a JAX service exposed via that IP address and port, to which the other processes in the cluster will connect.coordinator_bind_address
: the IP address and port to which the JAX service on process 0 in your cluster will bind. By default, it will bind to all available interfaces using the same port ascoordinator_address
.num_processes
: the number of processes in the clusterprocess_id
: the ID number of this process, in the range[0 .. num_processes)
.local_device_ids
: Restricts the visible devices of the current process tolocal_device_ids
.
For example on GPU, a typical usage is:
import jax
jax.distributed.initialize(coordinator_address="192.168.0.1:1234",
num_processes=2,
process_id=0)
On Cloud TPU, Slurm and Open MPI environments, you can simply call jax.distributed.initialize()
with no
arguments. Default values for the arguments will be chosen automatically.
When running on GPUs with Slurm and Open MPI, it is assumed that one process is started per GPU, i.e. each process will
be assigned only one visible local device. Otherwise it is assumed that one process is started per host,
i.e. each process will be assigned all local devices.
The Open MPI auto-initialization is only used when the JAX processes are launched via mpirun
/mpiexec
.
import jax
jax.distributed.initialize()
On TPU at present calling jax.distributed.initialize()
is optional, but
recommended since it enables additional checkpointing and health checking features.
Local vs. global devices#
Before we get to running multi-process computations from your program, it’s important to understand the distinction between local and global devices.
A process’s local devices are those that it can directly address and launch
computations on. For example, on a GPU cluster, each host can only launch
computations on the directly attached GPUs. On a Cloud TPU pod, each host can
only launch computations on the 8 TPU cores attached directly to that host (see
the
Cloud TPU System Architecture
documentation for more details). You can see a process’s local devices via
jax.local_devices()
.
The global devices are the devices across all processes. A computation can
span devices across processes and perform collective operations via the direct
communication links between devices, as long as each process launches the
computation on its local devices. You can see all available global devices via
jax.devices()
. A process’s local devices are always a subset of the
global devices.
Running multi-process computations#
So how do you actually run a computation involving cross-process communication? Use the same parallel evaluation APIs that you would in a single process!
For example, shard_map()
can be used to
run a parallel computation across
multiple processes. (If you’re not already familiar with how to use
shard_map
to run across multiple devices within a single process, check
out the Introduction to sharded computation tutorial.) Each process should call the
same pmapped function and pass in arguments to be mapped across its local
devices (i.e., the pmapped axis size is equal to the number of local devices).
Similarly, the function will return outputs sharded across local devices only.
Inside the function, however, collective communication operations are run across
all global devices, across all processes. Conceptually, this can be thought of
as running a pmap over a single array sharded across hosts, where each host
“sees” only its local shard of the input and output.
Here’s an example of multi-process pmap in action:
# The following is run in parallel on each host on a GPU cluster or TPU pod slice.
>>> import jax
>>> jax.distributed.initialize() # On GPU, see above for the necessary arguments.
>>> jax.device_count() # total number of accelerator devices in the cluster
32
>>> jax.local_device_count() # number of accelerator devices attached to this host
8
# The psum is performed over all mapped devices across the pod slice
>>> xs = jax.numpy.ones(jax.local_device_count())
>>> jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
ShardedDeviceArray([32., 32., 32., 32., 32., 32., 32., 32.], dtype=float32)
xmap()
works similarly when using a physical
hardware mesh (see the xmap tutorial if you’re
not familiar with the single-process version). Like pmap()
, the
inputs and outputs are local and any parallel communication inside the xmapped
function is global. The mesh is also global.
It’s very important that all processes run the same cross-process computations in the same order. Running the same JAX Python program in each process is usually sufficient. Some common pitfalls to look out for that may cause differently-ordered computations despite running the same program:
Processes passing differently-shaped inputs to the same parallel function can cause hangs or incorrect return values. Differently-shaped inputs are safe so long as they result in identically-shaped per-device data shards across processes; e.g. passing in different leading batch sizes in order to run on different numbers of local devices per process is ok, but having each process pad its batch to a different max example length is not.
“Last batch” issues where a parallel function is called in a (training) loop, and one or more processes exit the loop earlier than the rest. This will cause the rest to hang waiting for the already-finished processes to start the computation.
Conditions based on non-deterministic ordering of collections can cause code processes to hang. For example, iterating over
set
on current Python versions ordict
before Python 3.7 may result in a different ordering on different processes, even with the same insertion order.
Distributed arrays and automatic parallelization#
This tutorial discusses parallelism via jax.Array
, the unified array object model available in JAX v0.4.1 and newer.
import os
import functools
from typing import Optional
import numpy as np
import jax
import jax.numpy as jnp
⚠️ WARNING: The notebook requires 8 devices to run.
if len(jax.local_devices()) < 8:
raise Exception("Notebook requires 8 devices to run")
Intro and a quick example#
By reading this tutorial notebook, you’ll learn about jax.Array
, a unified
datatype for representing arrays, even with physical storage spanning multiple
devices. You’ll also learn about how using jax.Array
s together with jax.jit
can provide automatic compiler-based parallelization.
Before we think step by step, here’s a quick example.
First, we’ll create a jax.Array
sharded across multiple devices:
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
# Create a Sharding object to distribute a value across devices:
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
# Create an array of random values:
x = jax.random.normal(jax.random.key(0), (8192, 8192))
# and use jax.device_put to distribute it across devices:
y = jax.device_put(x, sharding.reshape(4, 2))
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐
│ TPU 0 │ TPU 1 │
├──────────┼──────────┤
│ TPU 2 │ TPU 3 │
├──────────┼──────────┤
│ TPU 6 │ TPU 7 │
├──────────┼──────────┤
│ TPU 4 │ TPU 5 │
└──────────┴──────────┘
Next, we’ll apply a computation to it and visualize how the result values are stored across multiple devices too:
z = jnp.sin(y)
jax.debug.visualize_array_sharding(z)
┌──────────┬──────────┐
│ TPU 0 │ TPU 1 │
├──────────┼──────────┤
│ TPU 2 │ TPU 3 │
├──────────┼──────────┤
│ TPU 6 │ TPU 7 │
├──────────┼──────────┤
│ TPU 4 │ TPU 5 │
└──────────┴──────────┘
The evaluation of the jnp.sin
application was automatically parallelized
across the devices on which the input values (and output values) are stored:
# `x` is present on a single device
%timeit -n 5 -r 5 jnp.sin(x).block_until_ready()
The slowest run took 13.32 times longer than the fastest. This could mean that an intermediate result is being cached
5 loops, best of 5: 9.69 ms per loop
# `y` is sharded across 8 devices.
%timeit -n 5 -r 5 jnp.sin(y).block_until_ready()
5 loops, best of 5: 1.86 ms per loop
Now let’s look at each of these pieces in more detail!
Computation follows data sharding and is automatically parallelized#
With sharded input data, the compiler can give us parallel computation. In particular, functions decorated with jax.jit
can operate over sharded arrays without copying data onto a single device. Instead, computation follows sharding: based on the sharding of the input data, the compiler decides shardings for intermediates and output values, and parallelizes their evaluation, even inserting communication operations as necessary.
For example, the simplest computation is an elementwise one:
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
x = jax.device_put(x, sharding.reshape(4, 2))
print('input sharding:')
jax.debug.visualize_array_sharding(x)
y = jnp.sin(x)
print('output sharding:')
jax.debug.visualize_array_sharding(y)
input sharding:
┌──────────┬──────────┐
│ TPU 0 │ TPU 1 │
├──────────┼──────────┤
│ TPU 2 │ TPU 3 │
├──────────┼──────────┤
│ TPU 6 │ TPU 7 │
├──────────┼──────────┤
│ TPU 4 │ TPU 5 │
└──────────┴──────────┘
output sharding:
┌──────────┬──────────┐
│ TPU 0 │ TPU 1 │
├──────────┼──────────┤
│ TPU 2 │ TPU 3 │
├──────────┼──────────┤
│ TPU 6 │ TPU 7 │
├──────────┼──────────┤
│ TPU 4 │ TPU 5 │
└──────────┴──────────┘
Here for the elementwise operation jnp.sin
the compiler chose the output sharding to be the same as the input. Moreover, the compiler automatically parallelized the computation, so that each device computed its output shard from its input shard in parallel.
In other words, even though we wrote the jnp.sin
computation as if a single machine were to execute it, the compiler splits up the computation for us and executes it on multiple devices.
We can do the same for more than just elementwise operations too. Consider a matrix multiplication with sharded inputs:
y = jax.device_put(x, sharding.reshape(4, 2).replicate(1))
z = jax.device_put(x, sharding.reshape(4, 2).replicate(0))
print('lhs sharding:')
jax.debug.visualize_array_sharding(y)
print('rhs sharding:')
jax.debug.visualize_array_sharding(z)
w = jnp.dot(y, z)
print('out sharding:')
jax.debug.visualize_array_sharding(w)
lhs sharding:
┌───────────────────────┐
│ TPU 0,1 │
├───────────────────────┤
│ TPU 2,3 │
├───────────────────────┤
│ TPU 6,7 │
├───────────────────────┤
│ TPU 4,5 │
└───────────────────────┘
rhs sharding:
┌───────────┬───────────┐
│ │ │
│ │ │
│ │ │
│ │ │
│TPU 0,2,4,6│TPU 1,3,5,7│
│ │ │
│ │ │
│ │ │
│ │ │
└───────────┴───────────┘
out sharding:
┌──────────┬──────────┐
│ TPU 0 │ TPU 1 │
├──────────┼──────────┤
│ TPU 2 │ TPU 3 │
├──────────┼──────────┤
│ TPU 6 │ TPU 7 │
├──────────┼──────────┤
│ TPU 4 │ TPU 5 │
└──────────┴──────────┘
Here the compiler chose the output sharding so that it could maximally parallelize the computation: without needing communication, each device already has the input shards it needs to compute its output shard.
How can we be sure it’s actually running in parallel? We can do a simple timing experiment:
x_single = jax.device_put(x, jax.devices()[0])
jax.debug.visualize_array_sharding(x_single)
┌───────────────────────┐
│ │
│ │
│ │
│ │
│ TPU 0 │
│ │
│ │
│ │
│ │
└───────────────────────┘
np.allclose(jnp.dot(x_single, x_single),
jnp.dot(y, z))
True
%timeit -n 5 -r 5 jnp.dot(x_single, x_single).block_until_ready()
5 loops, best of 5: 19.3 ms per loop
%timeit -n 5 -r 5 jnp.dot(y, z).block_until_ready()
5 loops, best of 5: 3.25 ms per loop
Even copying a sharded Array
produces a result with the sharding of the input:
w_copy = jnp.copy(w)
jax.debug.visualize_array_sharding(w_copy)
┌──────────┬──────────┐
│ TPU 0 │ TPU 1 │
├──────────┼──────────┤
│ TPU 2 │ TPU 3 │
├──────────┼──────────┤
│ TPU 6 │ TPU 7 │
├──────────┼──────────┤
│ TPU 4 │ TPU 5 │
└──────────┴──────────┘
So computation follows data placement: when we explicitly shard data with jax.device_put
, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of JAX’s policy of following explicit device placement.
When explicit shardings disagree, JAX errors#
But what if two arguments to a computation are explicitly placed on different sets of devices, or with incompatible device orders? In these ambiguous cases, an error is raised:
import textwrap
from termcolor import colored
def print_exception(e):
name = colored(f'{type(e).__name__}', 'red')
print(textwrap.fill(f'{name}: {str(e)}'))
sharding1 = PositionalSharding(jax.devices()[:4])
sharding2 = PositionalSharding(jax.devices()[4:])
y = jax.device_put(x, sharding1.reshape(2, 2))
z = jax.device_put(x, sharding2.reshape(2, 2))
try: y + z
except ValueError as e: print_exception(e)
ValueError: Devices of all `Array` inputs and outputs should
be the same. Got array device ids [0, 1, 2, 3] on platform TPU and
another array's device ids [4, 5, 6, 7] on platform TPU
devices = jax.devices()
permuted_devices = [devices[i] for i in [0, 1, 2, 3, 6, 7, 4, 5]]
sharding1 = PositionalSharding(devices)
sharding2 = PositionalSharding(permuted_devices)
y = jax.device_put(x, sharding1.reshape(4, 2))
z = jax.device_put(x, sharding2.reshape(4, 2))
try: y + z
except ValueError as e: print_exception(e)
ValueError: Devices of all `Array` inputs and outputs should
be the same. Got array device ids [0, 1, 2, 3, 4, 5, 6, 7] on platform
TPU and another array's device ids [0, 1, 2, 3, 6, 7, 4, 5] on
platform TPU
We say arrays that have been explicitly placed or sharded with jax.device_put
are committed to their device(s), and so won’t be automatically moved. See the device placement FAQ for more information.
When arrays are not explicitly placed or sharded with jax.device_put
, they are placed uncommitted on the default device.
Unlike committed arrays, uncommitted arrays can be moved and resharded automatically: that is, uncommitted arrays can be arguments to a computation even if other arguments are explicitly placed on different devices.
For example, the output of jnp.zeros
, jnp.arange
, and jnp.array
are uncommitted:
y = jax.device_put(x, sharding1.reshape(4, 2))
y + jnp.ones_like(y)
y + jnp.arange(y.size).reshape(y.shape)
print('no error!')
no error!
Constraining shardings of intermediates in jit
ted code#
While the compiler will attempt to decide how a function’s intermediate values and outputs should be sharded, we can also give it hints using jax.lax.with_sharding_constraint
. Using jax.lax.with_sharding_constraint
is much like jax.device_put
, except we use it inside staged-out (i.e. jit
-decorated) functions:
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
x = jax.random.normal(jax.random.key(0), (8192, 8192))
x = jax.device_put(x, sharding.reshape(4, 2))
@jax.jit
def f(x):
x = x + 1
y = jax.lax.with_sharding_constraint(x, sharding.reshape(2, 4))
return y
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐
│ TPU 0 │ TPU 1 │
├──────────┼──────────┤
│ TPU 2 │ TPU 3 │
├──────────┼──────────┤
│ TPU 6 │ TPU 7 │
├──────────┼──────────┤
│ TPU 4 │ TPU 5 │
└──────────┴──────────┘
┌───────┬───────┬───────┬───────┐
│ │ │ │ │
│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │
│ │ │ │ │
│ │ │ │ │
├───────┼───────┼───────┼───────┤
│ │ │ │ │
│ TPU 6 │ TPU 7 │ TPU 4 │ TPU 5 │
│ │ │ │ │
│ │ │ │ │
└───────┴───────┴───────┴───────┘
@jax.jit
def f(x):
x = x + 1
y = jax.lax.with_sharding_constraint(x, sharding.replicate())
return y
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐
│ TPU 0 │ TPU 1 │
├──────────┼──────────┤
│ TPU 2 │ TPU 3 │
├──────────┼──────────┤
│ TPU 6 │ TPU 7 │
├──────────┼──────────┤
│ TPU 4 │ TPU 5 │
└──────────┴──────────┘
┌───────────────────────┐
│ │
│ │
│ │
│ │
│ TPU 0,1,2,3,4,5,6,7 │
│ │
│ │
│ │
│ │
└───────────────────────┘
By adding with_sharding_constraint
, we’ve constrained the sharding of the output. In addition to respecting the annotation on a particular intermediate, the compiler will use annotations to decide shardings for other values.
It’s often a good practice to annotate the outputs of computations, for example based on how the values are ultimately consumed.
Examples: neural networks#
⚠️ WARNING: The following is meant to be a simple demonstration of automatic sharding propagation with jax.Array
, but it may not reflect best practices for real examples. For instance, real examples may require more use of with_sharding_constraint
.
We can use jax.device_put
and jax.jit
’s computation-follows-sharding features to parallelize computation in neural networks. Here are some simple examples, based on this basic neural network:
import jax
import jax.numpy as jnp
def predict(params, inputs):
for W, b in params:
outputs = jnp.dot(inputs, W) + b
inputs = jnp.maximum(outputs, 0)
return outputs
def loss(params, batch):
inputs, targets = batch
predictions = predict(params, inputs)
return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
loss_jit = jax.jit(loss)
gradfun = jax.jit(jax.grad(loss))
def init_layer(key, n_in, n_out):
k1, k2 = jax.random.split(key)
W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
b = jax.random.normal(k2, (n_out,))
return W, b
def init_model(key, layer_sizes, batch_size):
key, *keys = jax.random.split(key, len(layer_sizes))
params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))
key, *keys = jax.random.split(key, 3)
inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))
return params, (inputs, targets)
layer_sizes = [784, 8192, 8192, 8192, 10]
batch_size = 8192
params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)
8-way batch data parallelism#
sharding = PositionalSharding(jax.devices()).reshape(8, 1)
batch = jax.device_put(batch, sharding)
params = jax.device_put(params, sharding.replicate())
loss_jit(params, batch)
Array(23.469475, dtype=float32)
step_size = 1e-5
for _ in range(30):
grads = gradfun(params, batch)
params = [(W - step_size * dW, b - step_size * db)
for (W, b), (dW, db) in zip(params, grads)]
print(loss_jit(params, batch))
10.760101
%timeit -n 5 -r 5 gradfun(params, batch)[0][0].block_until_ready()
5 loops, best of 5: 26.3 ms per loop
batch_single = jax.device_put(batch, jax.devices()[0])
params_single = jax.device_put(params, jax.devices()[0])
%timeit -n 5 -r 5 gradfun(params_single, batch_single)[0][0].block_until_ready()
5 loops, best of 5: 122 ms per loop
4-way batch data parallelism and 2-way model tensor parallelism#
sharding = sharding.reshape(4, 2)
batch = jax.device_put(batch, sharding.replicate(1))
jax.debug.visualize_array_sharding(batch[0])
jax.debug.visualize_array_sharding(batch[1])
┌───────┐
│TPU 0,1│
├───────┤
│TPU 2,3│
├───────┤
│TPU 4,5│
├───────┤
│TPU 6,7│
└───────┘
┌───────┐
│TPU 0,1│
├───────┤
│TPU 2,3│
├───────┤
│TPU 4,5│
├───────┤
│TPU 6,7│
└───────┘
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
W1 = jax.device_put(W1, sharding.replicate())
b1 = jax.device_put(b1, sharding.replicate())
W2 = jax.device_put(W2, sharding.replicate(0))
b2 = jax.device_put(b2, sharding.replicate(0))
W3 = jax.device_put(W3, sharding.replicate(0).T)
b3 = jax.device_put(b3, sharding.replicate())
W4 = jax.device_put(W4, sharding.replicate())
b4 = jax.device_put(b4, sharding.replicate())
params = (W1, b1), (W2, b2), (W3, b3), (W4, b4)
jax.debug.visualize_array_sharding(W2)
┌───────────┬───────────┐
│ │ │
│ │ │
│ │ │
│ │ │
│TPU 0,2,4,6│TPU 1,3,5,7│
│ │ │
│ │ │
│ │ │
│ │ │
└───────────┴───────────┘
jax.debug.visualize_array_sharding(W3)
┌───────────────────────┐
│ │
│ TPU 0,2,4,6 │
│ │
│ │
├───────────────────────┤
│ │
│ TPU 1,3,5,7 │
│ │
│ │
└───────────────────────┘
print(loss_jit(params, batch))
10.760103
step_size = 1e-5
for _ in range(30):
grads = gradfun(params, batch)
params = [(W - step_size * dW, b - step_size * db)
for (W, b), (dW, db) in zip(params, grads)]
print(loss_jit(params, batch))
10.752466
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
jax.debug.visualize_array_sharding(W2)
jax.debug.visualize_array_sharding(W3)
┌───────────┬───────────┐
│ │ │
│ │ │
│ │ │
│ │ │
│TPU 0,2,4,6│TPU 1,3,5,7│
│ │ │
│ │ │
│ │ │
│ │ │
└───────────┴───────────┘
┌───────────────────────┐
│ │
│ TPU 0,2,4,6 │
│ │
│ │
├───────────────────────┤
│ │
│ TPU 1,3,5,7 │
│ │
│ │
└───────────────────────┘
%timeit -n 10 -r 10 gradfun(params, batch)[0][0].block_until_ready()
10 loops, best of 10: 30.5 ms per loop
SPMD multi-device parallelism with shard_map
#
shard_map
is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or instances, communicate with each other via explicit collective communication operations.
shard_map
is complementary to, and composable with, the automatic compiler-based parallelization built into jit
. With jit
you write code as if for a single device, and the compiler can automatically partition computation over multiple devices, generating per-device code and communication collectives behind the scenes. With shard_map
you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.
If you’re familiar with pmap
, think of shard_map
as an evolution. It’s more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see a detailed comparison to pmap
.)
By reading this tutorial, you’ll learn how to use shard_map
to get full control over your multi-device code. You’ll see in detail how it composes with jax.jit
’s automatic parallelization and jax.grad
’s automatic differentiation. We’ll also give some basic examples of neural network parallelization strategies.
We’ll assume this tutorial is being run in an environment with eight devices:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices
So, let’s see a shard_map
!#
Without further ado, here’s a toy example:
from functools import partial
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('x', 'y'))
a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 * 4.).reshape(16, 4)
@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),
out_specs=P('x', None))
def matmul_basic(a_block, b_block):
# a_block: f32[2, 8]
# b_block: f32[8, 4]
c_partialsum = jnp.dot(a_block, b_block)
c_block = jax.lax.psum(c_partialsum, 'y')
# c_block: f32[2, 4]
return c_block
c = matmul_basic(a, b) # c: f32[8, 4]
This function computes a matrix multiply in parallel by performing local block matrix multiplies followed by a collective sum operation. We can check the result is correct:
from jax.tree_util import tree_map, tree_all
def allclose(a, b):
return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))
allclose(c, jnp.dot(a, b))
True
The result is sharded along its rows:
jax.debug.visualize_array_sharding(c)
CPU 0,1 CPU 2,3 CPU 4,5 CPU 6,7
At a high level, shard_map
is kind of like vmap
or pmap
, in that we’re
mapping a function over pieces of array data, but notice that
shard_map
slices up inputs into blocks (and the output is formed by concatenating result blocks), keeping the rank the same, whereasvmap
would reduce the rank by mapping away an axis;the
mesh
argument lets us control precise device placement of computation and results;we’re mapping over multiple data axes at once, and setting up multiple axis names for collectives (both
'x'
and'y'
here);since we’re not using
jax.jit
yet, everything is eagerly evaluated, and we can evenprint
intermediate values for debugging.
The above code is performing the same computation as this jax.jit
automatic parallelization code:
from jax.sharding import NamedSharding
a = jax.device_put(a, NamedSharding(mesh, P('x', 'y')))
b = jax.device_put(b, NamedSharding(mesh, P('y', None)))
@jax.jit
def matmul_reference(a, b):
c = jnp.dot(a, b)
return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None)))
c_ref = matmul_reference(a, b)
allclose(c_ref, jnp.dot(a, b))
True
We can think of shard_map
as performing a device_put
or
with_sharding_constraint
on its inputs according to its mesh
and in_specs
arguments, so the blocks over which matmul_basic
operates are the same as in
matmul_reference
:
print('a blocks:'); jax.debug.visualize_array_sharding(a)
print('b blocks:'); jax.debug.visualize_array_sharding(b)
print('c blocks:'); jax.debug.visualize_array_sharding(c)
a blocks:
b blocks:
c blocks:
CPU 0 CPU 1 CPU 2 CPU 3 CPU 4 CPU 5 CPU 6 CPU 7
CPU 0,2,4,6 CPU 1,3,5,7
CPU 0,1 CPU 2,3 CPU 4,5 CPU 6,7
Slow down, start with the basics!#
Rank-reducing vs rank-preserving maps#
We can think of vmap
and pmap
as unstacking each array input along an axis
(e.g. unpacking a 2D matrix into its 1D rows), applying its body function to
each piece, and stacking the results back together, at least when collectives
aren’t involved:
def check_vmap(f, xs):
ans = jax.vmap(f, in_axes=(0,), out_axes=0)(xs)
expected = jnp.stack([f(x) for x in xs]) # vmap reference semantics
print(allclose(ans, expected))
check_vmap(lambda x: x @ x, jnp.arange(12).reshape(4, 3))
True
For example, if xs
had shape f32[8,5]
then each x
would have shape
f32[5]
, and if each f(x)
had shape f32[3,7]
then the final stacked result
vmap(f)(xs)
would have shape f32[8,3,7]
. That is, each application of the
body function f
takes as argument inputs with one fewer axis than the
corresponding argument to vmap(f)
. We can say these are rank-reducing maps
with unstacking/stacking of inputs/outputs.
The number of logical applications of f
, or instances of f
, is determined
by the size of the input axis being mapped over: for example, if we map over an
input axis of size 8, semantically we get 8 logical applications of the
function.
In contrast, shard_map
does not have this rank-reducing behavior. Instead, we
can think of it as slicing (or “unconcatenating”) along input axes into blocks,
applying the body function, and concatenating the results back together (again
when collectives aren’t involved):
import numpy as np
devices = np.array(jax.devices()[:4])
mesh = Mesh(devices, ('i',)) # mesh.shape['i'] = 4
def check_shmap(f, y):
ans = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))(y)
expected = jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, mesh.shape['i'])])
print(allclose(ans, expected))
check_shmap(lambda x: x.T @ x, jnp.arange(32).reshape(8, 4))
True
Recall that jnp.split
slices its input into equally-sized blocks with the same
rank, so that if in the above example y
had shape f32[8,5]
then each
y_blk
would have shape f32[2,5]
, and if each f(y_blk)
had shape
f32[3,7]
then the final concatenated result shard_map(f, ...)(y)
would have
shape f32[12,7]
. So shard_map
maps over shards, or blocks, of its inputs.
We can say it’s a rank-preserving map with unconcatenating/concatenating of
its inputs/outputs.
The number of logical applications of f
is determined by the mesh size, not
by any input axis size: for example, if we have a mesh of total size 4 (i.e.
over 4 devices) then semantically we get 4 logical applications of the
function, corresponding to the 4 devices physically computing them.
Controlling how each input is split (unconcatenated) and tiled with in_specs
#
Each of the in_specs
identifies some of the corresponding input array’s axes
with mesh axes by name using PartitionSpec
s, representing how to split (or
unconcatenate) that input into the blocks to which the body function is
applied. That identification determines the shard sizes; when an input axis is
identified with a mesh axis, the input is split (unconcatenated) along that
logical axis into a number of pieces equal to the corresponding mesh axis size.
(It’s an error if the corresponding mesh axis size does not evenly divide the
input array axis size.) If an input’s pspec does not mention a mesh axis name,
then there’s no splitting over that mesh axis. For example:
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, ('i', 'j'))
@partial(shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))
def f1(x_block):
print(x_block.shape) # prints (3, 12)
return x_block
x1 = jnp.arange(12 * 12).reshape(12, 12)
y = f1(x1)
(3, 12)
Here, because the input pspec did not mention the mesh axis name 'j'
, no
input array axis is split over that mesh axis; similarly, because the second
axis of the input array is not identified with (and hence split over) any mesh
axis, application of f1
gets a full view of the input along that axis.
When a mesh axis is not mentioned in an input pspec, we can always rewrite to a
less efficient program where all mesh axes are mentioned but the caller
performs a jnp.tile
, for example:
@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', 'j'))
def f2(x_block):
print(x_block.shape)
return x_block
x = jnp.arange(12 * 12).reshape(12, 12)
x_ = jnp.tile(x, (1, mesh.shape['j'])) # x_ has shape (12, 24)
y = f2(x_) # prints (3,12), and f1(x) == f2(x_)
(3, 12)
In other words, because each input pspec can mention each mesh axis name zero
or one times, rather than having to mention each name exactly once, we can say
that in addition to the jnp.split
built into its input, shard_map
also has
a jnp.tile
built into its input, at least logically (though the tiling may
not need to be carried out physically, depending on the arguments’ physical
sharding layout). The tiling to use is not unique; we could also have tiled
along the first axis, and used the pspec P(('j', 'i'), None)
.
Physical data movement is possible on inputs, as each device needs to have a copy of the appropriate data.
Controlling how each output assembled by concatenation, block transposition, and untiling using out_specs
#
Analogously to the input side, each of the out_specs
identifies some of the
corresponding output array’s axes with mesh axes by name, representing how the
output blocks (one for each application of the body function, or equivalently
one for each physical device) should be assembled back together to form the
final output value. For example, in both the f1
and f2
examples above the
out_specs
indicate we should form the final output by concatenating together
the block results along both axes, resulting in both cases an array y
of
shape (12, 24)
. (It’s an error if an output shape of the body function, i.e.
an output block shape, has a rank too small for the concatenation described by
the corresponding output pspec.)
When a mesh axis name is not mentioned in an output pspec, it represents an un-tiling: when the user writes an output pspec which does not mention one of the mesh axis names, they promise that the output blocks are equal along that mesh axis, and so only one block along that axis is used in the output (rather than concatenating all the blocks together along that mesh axis). For example, using the same mesh as above:
x = jnp.array([[3.]])
z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))()
print(z) # prints the same as jnp.tile(x, (4, 2))
z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', None))()
print(z) # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))
z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P(None, None))()
print(z) # prints the same as jnp.tile(x, (1, 1)), or just x
[[3. 3.]
[3. 3.]
[3. 3.]
[3. 3.]]
[[3.]
[3.]
[3.]
[3.]]
[[3.]]
The body function closing over an array value is equivalent to passing it as an augment with a corresponding input pspec of P(None, None). As another example, following more closely to the other examples above:
@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None))
def f3(x_block):
return jax.lax.psum(x_block, 'j')
x = jnp.arange(12 * 12).reshape(12, 12)
y3 = f3(x)
print(y3.shape)
(12, 6)
The result has a second axis size of 6, half the size of the input’s second
axis. In this case, the un-tile expressed by not mentioning the mesh axis name
'j'
in the output pspec was safe because of the collective psum
, which
ensures each output block is equal along the corresponding mesh axis. Here are
two more examples where we vary which mesh axes are mentioned in the output
pspec:
@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f4(x_block):
return jax.lax.psum(x_block, 'i')
x = jnp.arange(12 * 12).reshape(12, 12)
y4 = f4(x)
print(y4.shape) # (3,12)
@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, None))
def f5(x_block):
return jax.lax.psum(x_block, ('i', 'j'))
y5 = f5(x)
print(y5.shape) # (3,6)
(3, 12)
(3, 6)
On the physical side, not mentioning a mesh axis name in an output pspec
assembles an Array
from the output device buffers with replicated layout
along that mesh axis.
There is no runtime check that the output blocks are actually equal along a mesh axis to be un-tiled along, or equivalently that the corresponding physical buffers have equal values and thus can be interpreted as a replicated layout for a single logical array. But we can provide a static check mechanism which raises an error on all potentially-incorrect programs.
Because the out_specs
can mention mesh axis names zero or one times, and
because they can be mentioned in any order, we can say that in addition to the
jnp.concatenate
built into its output, shard_map
also has both an untile
and a block transpose built into its output.
Physical data movement is not possible on outputs, no matter the output pspec.
Instead, out_specs
just encodes how to assemble the block outputs into
Array
s, or physically how to interpret the buffers across devices as the
physical layout of a single logical Array
.
API Specification#
from jax.sharding import Mesh
Specs = PyTree[PartitionSpec]
def shard_map(
f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs,
auto: collections.abc.Set[AxisName] = frozenset([]),
check_rep: bool = True,
) -> Callable:
...
where:
communication collectives like
psum
in the body off
can mention the axis names ofmesh
;mesh
encodes devices arranged in an array and with associated axis names, just like it does forsharding.NamedSharding
;in_specs
andout_specs
arePartitionSpec
s which can affinely mention axis names frommesh
to express slicing/unconcatenation and concatenation of inputs and outputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively;auto
is an optional set of axis names corresponding to the subset of names ofmesh
to treat automatically in the body, as in the caller, rather than manually;check_rep
is an optional boolean indicating whether to check statically for any replication errors inout_specs
, and also whether to enable a related automatic differentiation optimization (see JEP).
The shapes of the arguments passed to f
have the same ranks as the arguments
passed to shard_map
-of-f
, and the shape of an argument to f
is computed
from the shape shape
of the corresponding argument to shard_map
-of-f
and
the corresponding PartitionSpec
spec
as roughly
tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))
.
Collectives tutorial#
A shard_map
need not be a pure map: function applications can communicate
with each other via collectives, using axis names defined in the mesh
argument.
Recall that shard_map
maps a function over shards, or blocks, of input data,
so that this:
mesh = Mesh(jax.devices(), ('i',))
x = jnp.arange(16.)
f_shmapped = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))
y = f_shmapped(x)
Computes the same values, evaluating applications of f
to the same argument
values, as this reference function:
def f_shmapped_ref(x):
x_blocks = jnp.array_split(x, mesh.shape[0])
y_blocks = [f(x_blk) for x_blk in x_blocks]
return jnp.concatenate(y_blocks)
We call these applications of f
to different argument shards function
instances. Each function instance is executed on a different device (or subset
of devices).
These reference semantics work when f
has no communication collectives in
it. But what if we want the function instances to communicate, corresponding
to having cross-device communication? That is, what are the reference
semantics when f
contains a collective? Say f
has just one collective, and
is of the form
def f(x_blk):
z_blk = f_part1(x_blk)
u_blk = collective(z_blk, axis_name)
v_blk = f_part2(x_blk, z_blk, u_blk)
return v_blk
where we’re assuming there’s only one mesh axis we’re mapping over, and
axis_name
is the corresponding name for it. Then the reference semantics
would look more like:
def f_shmapped_ref(x):
x_blocks = jnp.array_split(x, mesh.shape[0])
z_blocks = [f_part1(x_blk) for x_blk in x_blocks]
u_blocks = [collective_ref(i, z_blocks) for i in range(len(z_blocks))]
v_blocks = [f_part2(x_blk, z_blk, u_blk) for x_blk, z_blk, u_blk
in zip(x_blocks, z_blocks, u_blocks)]
return jnp.concatenate(v_blocks)
Notice that collective_ref
might depend on all the z_blocks
. That is,
while f_part1
and f_part2
are mapped over blocks independently, a
collective introduces some amount of cross-block dependence. Physically, that
means communication across devices. Exactly what communication happens, and
what values are computed, depend on the collective.
psum
#
The simplest collective may be jax.lax.psum
, which computes an
all-reduce-sum along a device mesh axis (or multiple axes).
Here’s a toy example:
import jax
import jax.numpy as jnp
from jax import lax
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental.shard_map import shard_map
mesh1d = Mesh(jax.devices()[:4], ('i',))
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P(None))
def f1(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.psum(x_block, 'i')
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f1(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22 20 12 17]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[22 20 12 17]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[22 20 12 17]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[22 20 12 17]
FINAL RESULT:
[22 20 12 17]
The prints show that each function application starts with its own chunk of
the argument value x_block
. After the psum
, each function application has
the same value of y_block
, computed by summing the applications’ x_block
values together.
In the case where there’s a single axis name in the computation, we could say
that the collective_ref
reference implementation for psum
is
def psum_ref(_, x_blocks):
tot = sum(x_blocks)
return [tot] * len(x_blocks)
Notice also that because f1
returns y_block
, the result of a psum
over
'i'
, we can use out_specs=P()
so the caller gets a single logical copy of
the result value, rather than a tiled result.
When there is more than one mesh axis, we can perform a psum
over
each one separately, or over multiple axes at once:
mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j'))
@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f2(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.psum(x_block, 'i')
print('AFTER:\n', y_block)
return y_block
y = f2(jnp.arange(16).reshape(4, 4))
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[0 1]
[4 5]]
On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[2 3]
[6 7]]
On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8 9]
[12 13]]
On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[10 11]
[14 15]]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[ 8 10]
[16 18]]
On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[12 14]
[20 22]]
On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8 10]
[16 18]]
On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[12 14]
[20 22]]
FINAL RESULT:
[[ 8 10 12 14]
[16 18 20 22]]
By applying a psum
over mesh axis 'i'
, we get values of y_block
which
are equal along axis ‘i'
, but not axis 'j'
. (So we can use
out_specs=P(None, 'j')
to get a single logical result along that axis.)
If we apply the psum
over both axes, the y_block
value is equal along both
axes:
@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None))
def f3(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.psum(x_block, ('i', 'j'))
print('AFTER:\n', y_block)
return y_block
y = f3(jnp.arange(16).reshape(4, 4))
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[0 1]
[4 5]]
On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[2 3]
[6 7]]
On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8 9]
[12 13]]
On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[10 11]
[14 15]]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[20 24]
[36 40]]
On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[20 24]
[36 40]]
On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[20 24]
[36 40]]
On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[20 24]
[36 40]]
FINAL RESULT:
[[20 24]
[36 40]]
In machine learning, we often use psum
to compute total losses or, when we
have a grad
inside the shard_map
ped function body, total gradients.
In the sequel, we’ll see how psum
can be implemented in terms of other
primitives, which gives some intuition about its communication cost.
all_gather
#
Another fundamental operation is gathering array shards along an axis, so that each function application has a full copy of the data along that axis:
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f4(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.all_gather(x_block, 'i', tiled=True)
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 9, 5, 2])
y = f4(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[9]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 9 5 2]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[3 9 5 2]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[3 9 5 2]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[3 9 5 2]
FINAL RESULT:
[3 9 5 2 3 9 5 2 3 9 5 2 3 9 5 2]
The prints show that each function application again starts with its own chunk
of the argument value x_block
. After the all_gather
, they have a common
value, computed by concatenating the values of x_block
.
(Notice that we actually can’t set out_specs=P()
here. For technical
reasons related to automatic differentiation, we consider the output of
all_gather
not to be guaranteed invariant across devices. If we wanted it to
be guaranteed invariant, we could use jax.lax.all_gather_invariant
, or in
this case we could just avoid doing the all_gather
in the function body and
instead just use out_specs=P('i')
to perform the concatenation.)
When tiled=False
(the default), results are stacked along a new axis instead
of concatenated:
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f5(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.all_gather(x_block, 'i', tiled=False)
print('AFTER:\n', y_block)
return y_block
y = f5(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[9]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[[3]
[9]
[5]
[2]]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[[3]
[9]
[5]
[2]]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[[3]
[9]
[5]
[2]]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[[3]
[9]
[5]
[2]]
FINAL RESULT:
[[3]
[9]
[5]
[2]
[3]
[9]
[5]
[2]
[3]
[9]
[5]
[2]
[3]
[9]
[5]
[2]]
We could write the collective_ref
reference semantics function for
all_gather
as
def all_gather_ref(_, x_blocks, *, tiled=False):
combine = jnp.concatenate if tiled else jnp.stack
return [combine(x_blocks)] * len(x_blocks)
In deep learning, we might use all_gather
s on parameters in fully sharded
data parallelism (FSDP).
psum_scatter
#
The jax.lax.psum_scatter
collective is a bit less intuitive. It’s like
psum
except each function instance gets only one shard of the result:
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f6(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True)
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f6(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[20]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[12]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[17]
FINAL RESULT:
[22 20 12 17]
As shown by the prints, each resulting y_block
has a smaller size than the
argument x_block
, unlike with psum
. Moreover, compared to psum
, here
each y_block
only represents a slice of the sum of the x_block
s across
function instances. (Even though each function instance gets only one shard of
the sum, the final output y
is the same as in the psum
example because
here we use out_specs=P('i')
to concatenate each function instance’s
output.)
In terms of what values are computed, a collective_ref
reference
implementation might look like:
def psum_scatter_ref(i, x_blocks, *, tiled=False):
axis_size = len(x_blocks)
tot = sum(x_blocks)
if tiled:
tot = tot.reshape(axis_size, -1, *tot.shape[1:]) # split leading axis
return [tot[i] for i in range(tot.shape[0])]
It’s not captured in the semantics reference implementation, but
psum_scatter
is useful because these results can be computed more
efficiently, with less communication, than a full psum
. In fact, one way to
think of psum_scatter
is as “the first half of a psum
, before an
all_gather
”. That is, one way to implement psum
is:
def psum(x, axis_name):
summed_chunk = jax.lax.psum_scatter(x, axis_name)
return jax.lax.all_gather(summed_chunk, axis_name)
Indeed, this implementation is often used on both TPU and GPU!
The reason psum_scatter
can require about half the communication as a full
psum
is illustrated the ppermute
section.
Another intuition is that we can use psum_scatter
to implement a distributed
matrix multiplication with inputs and outputs sharded over the same axis. In
machine learning, psum_scatter
can be used in tensor-parallel matrix
multiplies or fully-sharded data parallel gradient accumulation, as shown in
the examples to follow.
ppermute
#
The jax.lax.ppermute
collective provides the most direct way for
function instances to send data to one another. Given a mesh axis and a
list of (source_index, destination_index)
pairs representing indices along
that mesh axis, ppermute
sends its argument value from each source function
instance to each destination:
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f7(x_block):
sz = jax.lax.psum(1, 'i')
print('BEFORE:\n', x_block)
y_block = jax.lax.ppermute(x_block, 'i', [(i, (i + 1) % sz) for i in range(sz)])
print('AFTER:\n', y_block)
return y_block
y = f7(jnp.arange(8))
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[0 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[2 3]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[4 5]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[6 7]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[6 7]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[0 1]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[2 3]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[4 5]
FINAL RESULT:
[6 7 0 1 2 3 4 5]
In this case, with just two function instances, each instance’s value of
y_block
is the other’s value of x_block
.
Source indices and destination indices can’t be repeated. If an index does not appear as a destination, then the value of the corresponding function instance’s result is an array of zeros.
A collective_ref
reference implementation could look like
def ppermute_ref(i, x_blocks, perm):
results = [jnp.zeros_like(x_blocks[0])] * len(x_blocks)
for src, dst in perm:
results[dst] = x_blocks[src]
return results
Other collectives can be implemented efficiently, in terms of total
communication, using ppermute
s where each function passes data only to its
neighbors. For example, we could implement psum_scatter
using a sequence of
ppermute
s and local additions this way:
Or, with a numerical example:
Intuitively, on each iteration each function instance sends ‘up’ the value it received on the previous iteration, and reduces (adds) the value it receives this iteration. In code, it might look like this:
def psum_scatter(x, axis_name, *, tiled=False):
size = jax.lax.psum(1, axis_name)
idx = jax.lax.axis_index(axis_name) # function instance index along axis_name
if tiled:
x = x.reshape(size, -1, *x.shape[1:]) # split leading axis
shift = partial(jax.lax.ppermute, axis_name=axis_name,
perm=[(i, (i - 1) % size) for i in range(size)])
for i in range(1, size):
update = shift(x[(idx + i) % size])
x = x.at[(idx + i + 1) % size].add(update)
return x[idx]
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f8(x_block):
print('BEFORE:\n', x_block)
y_block = psum_scatter(x_block, 'i', tiled=True)
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f8(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[20]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[12]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[17]
FINAL RESULT:
[22 20 12 17]
On TPU, there are higher-dimensional variants of this algorithm to exploit multiple bidirectional physical mesh axes.
Notice that psum_scatter
is the transpose of all_gather
. Indeed, a way to
implement all_gather
in terms of ppermute
looks like the reverse of the
above process:
In deep learning, we might use ppermute
when implementing SPMD pipeline
parallelism, where we divide our network along its depth into stages and
evaluate the applications of stages in parallel. Or we might use ppermute
in
parallelizing the evaluation of convolutional layers, where we shard over
spatial axes and thus devices must communicate “halos” to each other. Or it
may be used under-the-hood in tensor-parallel matrix multiplies.
all_to_all
#
A final collective is all_to_all
, which is essentially a block matrix
transpose operating along one positional axis and one cross-device axis:
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f9(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.all_to_all(x_block, 'i', split_axis=0, concat_axis=0,
tiled=True)
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f9(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 5 5 9]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[1 9 3 7]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[4 2 5 1]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[1 6 8 2]
FINAL RESULT:
[3 5 5 9 1 9 3 7 4 2 5 1 1 6 8 2]
The split_axis
argument indicates which positional axis should be sharded
and partitioned across the mesh axis. The concat_axis
argument indicates the
axis along which the communicated results should be concatenated or stacked.
When tiled=False
(the default), the split_axis
axis size must equal the
size of the mesh axis named axis_name
, and a new axis of that size is
created at position concat_axis
for the stacked results. When tiled=True
,
the split_axis
axis size need only be evenly divisible by the size of the
mesh axis, and results are concatenated along the existing axis concat_axis
.
The collective_ref
reference semantics when split_axis=0
and
concat_axis=0
might look like:
def all_to_all_ref(_, x_blocks, *, tiled=False):
axis_size = len(x_blocks)
if tiled:
splits = [jnp.array_split(x, axis_size) for x in x_blocks]
return [jnp.concatenate(s) for s in zip(*splits)]
else:
splits = [list(x) for x in x_blocks]
return [jnp.stack(s) for s in zip(*splits)]
In deep learning, we might use all_to_all
in mixture-of-expert routing,
where we first sort our local batch of examples according to which expert they
should go to, then apply an all_to_all
to redistribute examples to experts.
Toy examples#
How might we use shard_map
and collective communication in practice? These
examples, while simple, give some idea.
Matrix multiplies#
Parallelizing matrix multiplication is central in scaling up deep learning
models, both for training and for inference. When jax.jit
automatically
parallelizes matrix multiplication, it can use one of several different
strategies, depending on matrix sizes, hardware details, and other factors. How
might we write some of those parallelized routines more explicitly using
shard_map
? And how can we optimize them to get better compute/communication
overlap and thus improve FLOP utilization?
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental.shard_map import shard_map
mesh = Mesh(jax.devices()[:4], ('i',))
def device_put(x, pspec):
return jax.device_put(x, NamedSharding(mesh, pspec))
Example 1: all-gather
on one side#
Consider performing a matrix multiplication where we shard the left-hand side argument (can think: parameters) on its leading (non-contracting) dimension:
lhs_spec = P('i', None)
lhs = device_put(jax.random.normal(jax.random.key(0), (8, 8)), lhs_spec)
And wee shard the right-hand side argument (can think: activations) on its contracting dimension, with a similar sharding for the output:
rhs_spec = P('i', None)
rhs = device_put(jax.random.normal(jax.random.key(1), (8, 4)), rhs_spec)
To perform this matrix multiplication, we can first all-gather the right-hand side and then perform local matrix multiplies against the sharded left-hand side:
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_allgather(lhs_block, rhs_block):
rhs = jax.lax.all_gather(rhs_block, 'i', tiled=True)
return lhs_block @ rhs
out = matmul_allgather(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
That’s great, but we’re not getting any compute/communication overlap
here: before we can start the matmul, we need the all_gather
to complete.
Here’s a profile using the same code, but on larger example shapes ((8192, 8192)
for lhs
and (8192, 1024)
for rhs
):
We can get compute/communication overlap if instead of calling all_gather
we
basically inline our above implementation of all_gather
in terms of
ppermute
, then interleave steps of the gather permutation with local matrix
multiplies:
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_allgather_overlapped(lhs_block, rhs_block):
size = jax.lax.psum(1, 'i')
idx = jax.lax.axis_index('i')
shift = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i + 1) % size) for i in range(size)])
B = lhs_block.shape[1] // size
lhs_blocks = lambda i: lax.dynamic_slice_in_dim(lhs_block, i * B, B, 1)
out_block = lhs_blocks(idx) @ rhs_block
for i in range(1, size):
rhs_block = shift(rhs_block)
out_block += lhs_blocks((idx - i) % size) @ rhs_block
return out_block
out = matmul_allgather_overlapped(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
This implementation allows overlap between communication and computation, and also avoids gathering a large intermediate onto each device. But on TPU it uses only half the interconnect bandwidth by permuting in only one direction along the ring. To permute bidirectionally, we just split the blocks in half and send each half in each direction:
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_allgather_overlapped_bidi(lhs_block, rhs_block):
size = jax.lax.psum(1, 'i')
idx = jax.lax.axis_index('i')
shift_up = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i + 1) % size) for i in range(size)])
shift_dn = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i - 1) % size) for i in range(size)])
B = lhs_block.shape[1] // size // 2 # half-size blocks
lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 1)
rhs_block_lo, rhs_block_hi = jnp.split(rhs_block, 2, axis=0)
out_block = lhs_blocks(idx, 0) @ rhs_block_lo
out_block += lhs_blocks(idx, 1) @ rhs_block_hi
for i in range(1, size):
rhs_block_lo = shift_up(rhs_block_lo)
rhs_block_hi = shift_dn(rhs_block_hi)
out_block += lhs_blocks((idx - i) % size, 0) @ rhs_block_lo
out_block += lhs_blocks((idx + i) % size, 1) @ rhs_block_hi
return out_block
out = matmul_allgather_overlapped_bidi(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
In practice, to reduce compile times we would probably roll this into a
jax.lax.fori_loop
. We might also have additional axes of parallelism
involved.
Example 2: psum_scatter
the result#
Another sharding we might start with has both lhs
and rhs
sharded along
their contracting dimensions, with the output sharded like rhs
again:
lhs_spec = P(None, 'i')
lhs = device_put(lhs, lhs_spec)
rhs_spec = P('i', None)
rhs = device_put(rhs, rhs_spec)
Here we can use a reduce_scatter
to perform the contraction sum over shards:
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_psumscatter(lhs_block, rhs_block):
out_summand = lhs_block @ rhs_block
return jax.lax.psum_scatter(out_summand, 'i', tiled=True)
out = matmul_psumscatter(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
But the scattering communication must wait for the entire local matrix multiply
to finish before it can start. To get communication/computation overlap, we can
inline an implementation of psum_scatter
in terms of ppermute
, then
interleave the communication steps with local matrix multiplies:
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_psumscatter_overlapped(lhs_block, rhs_block):
size = jax.lax.psum(1, 'i')
idx = jax.lax.axis_index('i')
shift = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i - 1) % size) for i in range(size)])
lhs_block = lhs_block.reshape(size, -1, lhs_block.shape[1]) # split 1st axis
out_summand = lhs_block[(idx + 1) % size] @ rhs_block
for i in range(1, size):
out_summand = shift(out_summand)
out_summand += lhs_block[(idx + i + 1) % size] @ rhs_block
return out_summand
out = matmul_psumscatter_overlapped(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
As in the previous example, to fully utilize interconnects on TPU, we’d run a bidirectional version:
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_psumscatter_overlapped_bidi(lhs_block, rhs_block):
size = jax.lax.psum(1, 'i')
idx = jax.lax.axis_index('i')
shift_up = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i + 1) % size) for i in range(size)])
shift_dn = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i - 1) % size) for i in range(size)])
B = lhs_block.shape[0] // size // 2 # half-size blocks
lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 0)
out_summand_lo = lhs_blocks((idx - 1) % size, 0) @ rhs_block
out_summand_hi = lhs_blocks((idx + 1) % size, 1) @ rhs_block
for i in range(1, size):
out_summand_lo = shift_up(out_summand_lo)
out_summand_hi = shift_dn(out_summand_hi)
out_summand_lo += lhs_blocks((idx - i - 1) % size, 0) @ rhs_block
out_summand_hi += lhs_blocks((idx + i + 1) % size, 1) @ rhs_block
return jnp.concatenate([out_summand_lo, out_summand_hi])
out = matmul_psumscatter_overlapped_bidi(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
Neural networks#
We can use shard_map
to parallelize computation in neural networks, either by
itself or in combination with the automatic partitioning in jax.jit
. This
section has a few examples based on this toy neural network and random data:
import jax
import jax.numpy as jnp
def predict(params, inputs):
for W, b in params:
outputs = jnp.dot(inputs, W) + b
inputs = jax.nn.relu(outputs)
return outputs
def loss(params, batch):
inputs, targets = batch
predictions = predict(params, inputs)
return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
def init_layer(key, n_in, n_out):
k1, k2 = jax.random.split(key)
W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
b = jax.random.normal(k2, (n_out,))
return W, b
def init(key, layer_sizes, batch_size):
key, *keys = jax.random.split(key, len(layer_sizes))
params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))
key, *keys = jax.random.split(key, 3)
inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))
return params, (inputs, targets)
layer_sizes = [784, 128, 128, 128, 128, 128, 8]
batch_size = 32
params, batch = init(jax.random.PRNGKey(0), layer_sizes, batch_size)
Compare these examples with the purely automatic partitioning examples in the
“Distributed arrays and automatic partitioning”
doc.
While in those automatic partitioning examples we don’t need to edit the model
functions to use different parallelization strategies, with shard_map
we
often do.
8-way batch data parallelism#
The simplest multi-device parallelism strategy is to shard the batch of inputs and targets over multiple devices, replicate the parameters over those devices, and apply the model in parallel to those shards of data. To evaluate the total loss, the devices need only communicate with a scalar-sized all-reduce-sum at the end. (To evaluate the gradient of the loss, the devices must perform all-reduce-sums of parameter gradients in the backward pass.)
from functools import partial
from jax.sharding import NamedSharding, Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map
from jax.experimental import mesh_utils
devices = mesh_utils.create_device_mesh((8,))
# replicate initial params on all devices, shard data batch over devices
mesh = Mesh(devices, ('batch',))
batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
params = jax.device_put(params, NamedSharding(mesh, P()))
# adapt the loss function to sum the losses across devices
def loss_dp(params, batch):
@partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P())
def loss_spmd(local_batch):
inputs, targets = local_batch
predictions = predict(params, inputs) # use reference 'predict`
local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
return jax.lax.pmean(local_loss, 'batch')
return loss_spmd(batch)
We can check that the loss and its gradients match the reference (base) model:
print(jax.jit(loss)(params, batch))
print(jax.jit(loss_dp)(params, batch))
22.779888
22.779888
def allclose(a, b):
return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))
print(allclose(jax.jit(jax.grad(loss))(params, batch),
jax.jit(jax.grad(loss_dp))(params, batch)))
True
We can print the compiler IR to inspect the gradient computation and verify that the collective all-reduce-sum operations happen where we’d expect: at the end of the forward pass to compute the loss value, and in the backward pass to compute the total parameter gradients.
8-way fully sharded data parallelism (FSDP)#
Another strategy is to additionally shard the parameters over the devices,
all-gathering each one when the full value is needed for the jnp.dot
or bias
addition. Since we only have one full parameter in local device memory at a
time, rather than keeping all parameters in all device memories as in the
preceding DP example, we free up significant memory that we can use for larger
models or larger batch sizes. And because XLA will overlap computation and
inter-device communication, the wall-clock time doesn’t suffer.
So now we need collectives in two places: the model prediction function
predict
needs to all-gather the parameters before they’re used, and as in the
DP case the loss function needs to sum the local losses to compute the total
loss.
There’s one other ingredient we need: we don’t want to store the fully gathered
parameters from the forward pass for use on the backward pass. Instead, we want
to gather them again on the backward pass. We can express that by using
jax.remat
with a custom
policy
(or a custom_vjp
), though XLA typically does that rematerialization
automatically.
This general FSDP approach is similar to weight update sharding (WUS) and ZeRO-3.
# shard data batch *and params* over devices
mesh = Mesh(devices, ('batch',))
batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
params = jax.device_put(params, NamedSharding(mesh, P('batch')))
# adapt the prediction function to gather weights just before their use,
# and to re-gather them on the backward pass (rather than saving them)
@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')
def predict_fsdp(params_frag, inputs):
for W_frag, b_frag in params_frag:
W = jax.lax.all_gather(W_frag, 'batch', tiled=True)
b = jax.lax.all_gather(b_frag, 'batch', tiled=True)
outputs = jnp.dot(inputs, W) + b
inputs = jax.nn.relu(outputs)
return outputs
def loss_fsdp(params, batch):
@partial(shard_map, mesh=mesh, in_specs=P('batch'), out_specs=P())
def loss_spmd(local_params, local_batch):
inputs, targets = local_batch
predictions = predict_fsdp(local_params, inputs)
local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
return jax.lax.pmean(local_loss, 'batch')
return loss_spmd(params, batch)
Again we can check that the loss and its gradients match the reference model:
print(jax.jit(loss)(params, batch))
print(jax.jit(loss_fsdp)(params, batch))
print(allclose(jax.jit(jax.grad(loss))(params, batch),
jax.jit(jax.grad(loss_fsdp))(params, batch)))
22.779888
22.779888
True
8-way tensor parallelism (TP)#
Usually we don’t use tensor model parallelism by itself, but seeing it in
isolation is a good warmup on parallel matrix multiplication. It’s also a good
example of using shard_map
in a library function, called in a larger
jit
-based computation.
The parallelization idea is that we’ll keep the data/activations sharded over
its feature axis (rather than its batch axis), and we’ll similarly shard weight
matrices over their input-feature axis (and biases over their feature axis).
Then to perform the parallel matrix multiplication, we’ll perform local matrix
multiplications followed by a psum_scatter
to sum the local results and
efficiently scatter the result’s shards.
devices = mesh_utils.create_device_mesh((8,))
mesh = Mesh(devices, ('feats',))
batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats')))
params = jax.device_put(params, NamedSharding(mesh, P('feats')))
def predict_tp(params, inputs):
for W, b in params:
outputs = gemm_tp(inputs, W, b)
inputs = jax.nn.relu(outputs)
return outputs
@partial(shard_map, mesh=mesh,
in_specs=(P(None, 'feats'), P('feats', None), P('feats')),
out_specs=P(None, 'feats'))
def gemm_tp(inputs, W, b):
block_result = jnp.dot(inputs, W)
return jax.lax.psum_scatter(block_result, 'feats',
scatter_dimension=1, tiled=True) + b
def loss_tp(params, batch):
inputs, targets = batch
predictions = predict_tp(params, inputs)
return jnp.mean(jnp.sum((predictions - targets) ** 2, axis=-1)) # NOTE psum!
FSDP + TP, with shard_map
at the top level#
We can compose these strategies together, using multiple axes of parallelism.
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, ('batch', 'feats'))
batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats')))
params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats'))))
# mostly same as previous predict_fsdp definition, except we call gemm_tp
@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')
def predict_fsdp_tp(params_frag, inputs):
for W_frag, b_frag in params_frag:
W = jax.lax.all_gather(W_frag, 'batch', tiled=True)
b = jax.lax.all_gather(b_frag, 'batch', tiled=True)
block_result = jnp.dot(inputs, W)
outputs = jax.lax.psum_scatter(block_result, 'feats',
scatter_dimension=1, tiled=True) + b
inputs = jax.nn.relu(outputs)
return outputs
@partial(shard_map, mesh=mesh,
in_specs=(P(('feats', 'batch')), P('batch', 'feats')),
out_specs=P())
def loss_fsdp_tp(local_params, local_batch):
inputs, targets = local_batch
predictions = predict_fsdp_tp(local_params, inputs)
sq_err = jax.lax.psum(jnp.sum((predictions - targets)**2, axis=-1), 'feats')
return jax.lax.pmean(jnp.mean(sq_err), 'batch')
Notice how we have to do two collective reductions: one over 'feats'
and
one over 'batch'
. In the pure TP example, we didn’t write the 'feats'
reduction explicitly because we only used shard_map
within gemm_tp
; in the
caller loss_tp
, the compiler automatically translated our use of jnp.sum
to
perform a psum
as needed given the sharded result returned by predict_tp
.
print(jax.jit(loss)(params, batch))
print(jax.jit(loss_fsdp_tp)(params_, batch_))
print(allclose(jax.jit(jax.grad(loss))(params, batch),
jax.jit(jax.grad(loss_fsdp_tp))(params, batch)))
22.779886
22.779886
True
SPMD pipeline parallelism (PP)#
With pipeline parallelism we aim to parallelize the evaluation of layers at different depths in our network. For example, one device might compute the application of the first layer while another device computes the application of the second; when they finish, the first device passes its results to the second while the second passes its results to the device responsible for the third layer, and the process repeats. In general the number of pipeline stages may be different from the number of layers, as each stage may be responsible for multiple layers.
With SPMD pipelining, we exploit the fact that most layers in the network apply
the computation, just with different parameter values. In particular, we can
stack together all the parameters except for those for the first and last
layers, then use a shard_map
to map over blocks of those layer parameters,
where each block of parameters corresponds to a pipeline stage. We then use the
jax.lax.ppermute
collective to shift data down the parallel pipeline.
This particular pipelining strategy is essentially the GPipe strategy. There are several variants, as well as quite different strategies, and which is appropriate can depend on the speed of the networking between stages and batch sizes. But for this tutorial we’ll focus on just one strategy.
First, we choose some pipeline parameters:
L = len(params) - 2 # num layers, excluding first and last
N = batch_size # batch size
F = params[0][0].shape[1] # num features
# choose some pipeline parameters
S = 2 # number of stages
B = 8 # size of each microbatch
assert L % S == 0, "S (number of stages) must divide L (number of inner layers)"
# compute some useful quantities
M, ragged = divmod(N, B) # M is number of microbatches
assert not ragged, "B (size of each microbatch) must divide total batch size"
K, ragged = divmod(M, S) # K is microbatches per stage
assert not ragged, "S (number of stages) must divide number of microbatches"
print(f'{S} stages, {L // S} layer(s) per stage, {L} pipelined layers total')
print(f'{B} examples per microbatch, {M} microbatches total')
2 stages, 2 layer(s) per stage, 4 pipelined layers total
8 examples per microbatch, 4 microbatches total
mesh = Mesh(jax.devices()[:S], ('stages',))
def predict_pp(params, inputs):
(W_first, b_first), inner_params, (W_last, b_last) = params
inputs = jax.nn.relu(jnp.dot(inputs, W_first) + b_first)
inputs = spmd_pipeline(lambda Wb, x: jax.nn.relu(x @ Wb[0] + Wb[1]),
inner_params, inputs)
outputs = jnp.dot(inputs, W_last) + b_last
return outputs
@partial(shard_map, mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')),
out_specs=P())
def loss_pp(params, batch):
inputs, targets = batch
predictions = predict_pp(params, inputs.reshape(K, B, -1)).reshape(K * B, -1)
local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
return jax.lax.pmean(local_loss, 'stages')
def spmd_pipeline(fn, stage_params, inputs):
stage = jax.lax.axis_index('stages')
outputs = jnp.zeros_like(inputs) * jnp.nan
state = jnp.zeros((L // S, B, F)) * jnp.nan
for i in range(M+L-1):
state = state.at[0].set(jnp.where(stage == 0, inputs[i % K], state[0]))
state = jax.vmap(fn)(stage_params, state)
outputs = outputs.at[(i-L+1) % K].set(jnp.where(stage == S-1, state[-1], outputs[(i-L+1) % K]))
state, inputs, outputs = shift(i, state, inputs, outputs)
outputs = jax.lax.ppermute(outputs, 'stages', [(i, (i+1) % S) for i in range(S)])
return outputs
def shift(i, state, inputs, outputs):
sh = lambda x, d: jax.lax.ppermute(x, 'stages', [(i, (i+d) % S) for i in range(S)])
state = jnp.roll(state, +1, axis=0).at[0].set(sh(state[-1], +1))
if (i % K) == (-1 % K):
inputs = sh(inputs, +1)
if ((i-L+1) % K) == (-1 % K):
outputs = sh(outputs, +1)
return state, inputs, outputs
first_params, *inner_params, last_params = params
Ws, bs = zip(*inner_params)
params_stacked = jnp.stack(Ws), jnp.stack(bs)
first_params = jax.device_put(first_params, NamedSharding(mesh, P()))
params_stacked = jax.device_put(params_stacked, NamedSharding(mesh, P('stages')))
last_params = jax.device_put(last_params, NamedSharding(mesh, P()))
params_ = first_params, params_stacked, last_params
batch_ = jax.device_put(batch, NamedSharding(mesh, P('stages')))
print(jax.jit(loss)(params, batch))
print(jax.jit(loss_pp)(params_, batch_))
22.779886
22.779884
_ = jax.jit(jax.grad(loss_pp))(params_, batch_) # don't crash
Named axes and easy-to-revise parallelism with xmap
#
UPDATE: xmap
is deprecated and will be removed in a future release. The recommended ways to do multi-device programming in JAX are using: 1) jit
(automatic partitioning of computation and jax.Array
sharding); and/or 2) shard_map
(manual data sharding). Learn more in Why don’t pmap
or xmap
already solve this? in the shard_map
JEP document.
This tutorial introduces JAX xmap
(jax.experimental.maps.xmap
) and the named-axis programming model that comes with it. By reading this, you’ll learn how to write error-avoiding, self-documenting functions using named axes, then control how they’re executed on hardware at any scale, from your laptop CPU to the largest TPU supercomputer.
We start with a toy neural network example.
From positions to names in a toy neural network#
Presentations on JAX often start with a simple neural network prediction function and loss, written in pure NumPy. Here’s a simple network with one hidden layer:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices
import jax.numpy as jnp
from jax import lax
from jax.nn import one_hot, relu
from jax.scipy.special import logsumexp
def predict(w1, w2, images):
hiddens = relu(jnp.dot(images, w1))
logits = jnp.dot(hiddens, w2)
return logits - logsumexp(logits, axis=1, keepdims=True)
def loss(w1, w2, images, labels):
predictions = predict(w1, w2, images)
targets = one_hot(labels, predictions.shape[-1])
losses = jnp.sum(targets * predictions, axis=1)
return -jnp.mean(losses, axis=0)
We can then initialize inputs with the right shapes and compute the loss value:
w1 = jnp.zeros((784, 512))
w2 = jnp.zeros((512, 10))
images = jnp.zeros((128, 784))
labels = jnp.zeros(128, dtype=jnp.int32)
print(loss(w1, w2, images, labels))
Here’s how we might write the same function using named axes. Don’t worry if you can’t follow the API details. They are not important now and we will explain everything step-by-step afterwards. This is just to show you what you can do with xmap before you learn them!
def named_predict(w1, w2, image):
hidden = relu(lax.pdot(image, w1, 'inputs'))
logits = lax.pdot(hidden, w2, 'hidden')
return logits - logsumexp(logits, 'classes')
def named_loss(w1, w2, images, labels):
predictions = named_predict(w1, w2, images)
num_classes = lax.psum(1, 'classes')
targets = one_hot(labels, num_classes, axis='classes')
losses = lax.psum(targets * predictions, 'classes')
return -lax.pmean(losses, 'batch')
This code is simpler: we don’t need to worry about axis order when calling functions like jnp.dot
, or remember which axis position to reduce over with logsumexp
, jnp.sum
, or jnp.mean
.
But the real win is that names let us use xmap
to control our function’s execution. At its simplest, xmap
will just vectorize over all named axes, so that the function is executed just like its positional-axis counterpart:
from jax.experimental.maps import xmap
in_axes = [['inputs', 'hidden', ...],
['hidden', 'classes', ...],
['batch', 'inputs', ...],
['batch', ...]]
loss = xmap(named_loss, in_axes=in_axes, out_axes=[...])
print(loss(w1, w2, images, labels))
But on a whim we can decide to parallelize over the batch axis:
import jax
import numpy as np
from jax.sharding import Mesh
loss = xmap(named_loss, in_axes=in_axes, out_axes=[...],
axis_resources={'batch': 'x'})
devices = np.array(jax.local_devices())
with Mesh(devices, ('x',)):
print(loss(w1, w2, images, labels))
Or we might want to perform model parallelism over the hidden axis:
loss = xmap(named_loss, in_axes=in_axes, out_axes=[...],
axis_resources={'hidden': 'x'})
devices = np.array(jax.local_devices())
with Mesh(devices, ('x',)):
print(loss(w1, w2, images, labels))
Or we might want to do both model and batch data parallelism at once:
loss = xmap(named_loss, in_axes=in_axes, out_axes=[...],
axis_resources={'batch': 'x', 'hidden': 'y'})
devices = np.array(jax.local_devices()).reshape((4, 2))
with Mesh(devices, ('x', 'y')):
print(loss(w1, w2, images, labels))
With xmap
, we can revise our parallelism strategy on a dime, without needing to rewrite our neural network function.
Preliminaries#
import jax.numpy as jnp
from jax import lax
from functools import partial
import jax
import numpy as np
To better illustrate the new programming model, we make extensive use of custom type annotations in this notebook. The annotations have no effect on how the code evaluates and will be unchecked for now.
from typing import Any, Callable
class ArrayType:
def __getitem__(self, idx):
return Any
f32 = ArrayType()
i32 = ArrayType()
Tensors with named axes#
The NumPy programming model is based around nd-arrays. Each nd-array can be associated with a two-component type:
the element type (accessible via the
.dtype
attribute)shape (a tuple of integers given by
.shape
).
Using our little type annotation language, we will write these types as dtype[shape_tuple]
.
For example, a 5x7x4 array of 32-bit floating point numbers will be denoted as
f32[(5, 7, 4)]
.
Here is a small example that shows how the annotations can demonstrate the way shapes propagate through a simple NumPy program:
x: f32[(2, 3)] = np.ones((2, 3), dtype=np.float32)
y: f32[(3, 5)] = np.ones((3, 5), dtype=np.float32)
z: f32[(2, 5)] = x.dot(y) # matrix multiplication
w: f32[(7, 1, 5)] = np.ones((7, 1, 5), dtype=np.float32)
q: f32[(7, 2, 5)] = z + w # broadcasting
The extension we propose is to add another component of array type: a named_shape
, mapping axis names (arbitrary hashable objects, with strings being a common choice) to integer sizes. Most importantly, because each axis has a name, their order has no meaning. That is, a named shape of {'a': 2, 'b': 5}
is indistinguishable from a named shape of {'b': 5, 'a': 2}
.
This is not an entirely new idea. Some good examples of where using named axes has been proposed in the past are: Mesh TensorFlow, Tensor Considered Harmful manifesto as well as the xarray and einops packages. Keep in mind that many of those are slightly different in that they do assign an order to the named axes, but they are unordered in JAX.
From now on we will allow the type annotations to have two components, the first one still being the value’s .shape
, while the second one will be the .named_shape
.
e: f32[(5, 7), {'batch': 20, 'sequence': 30}]
# e.shape == (5, 7)
# e.named_shape == {'batch': 20, 'sequence': 30} == {'sequence': 30, 'batch': 20}
While we don’t modify the meaning of .ndim
(which is always equal to len(shape)
) and .size
(equal to the product of shape
), we do so solely for backward-compatibility reasons. The true rank of an array that has non-empty named axes is len(shape) + len(named_shape)
. The true number of elements stored in such an array is equal to the product of sizes of all dimensions, both positional and named.
Introducing and eliminating named axes#
But how does one create such arrays, if all top-level JAX operations work in the NumPy model with purely positional axes? While this constraint could be lifted at some point, for the time being the only way to introduce named axes is to use xmap
.
xmap
can be thought of as an adapter that takes in arrays with positional axes, makes some of them named (as specified by in_axes
), and calls the function that it wraps. Once the wrapped function returns arrays, all named axes appearing in those are converted back to positional axes (as specified by out_axes
).
in_axes
should have a structure that matches the signature of the xmap
ped function arguments, except with all places where array arguments would be replaced by an axis mapping. There are two ways in which axis mappings can be specified:
as dictionaries mapping positional axes to axis names (e.g.
{0: 'x', 2: 'y'}
); andas lists of axis names terminated by the ellipsis object (e.g.
['a', 'b', ...]
), indicating that a prefix of positional dimensions are to be mapped to given names.
out_axes
are similar, except that their structure has to match the return signature of the xmap
ped function (but again, with all arrays replaced by axes mappings).
For each array argument, all positional axes mentioned in its respective in_axes
axis mapping are converted to named axes. For each array result, all named axes are inserted in the positions indicated by its respective out_axes
.
from jax.experimental.maps import xmap
def my_func(x: f32[(5,), {'batch': 20}]) -> f32[(5,), {'batch': 20}]:
assert x.shape == (5,)
# assert x.named_shape == {'batch': 20} # TODO: Implement named_shape
return x
x: f32[(20, 5)] = jnp.zeros((20, 5), dtype=np.float32)
f = xmap(my_func,
in_axes={0: 'batch'}, # Name the first axis of the only argument 'batch'
out_axes={1: 'batch'}) # Place the 'batch' named axis of the output as the second positional axis
y: f32[(5, 20)] = f(x)
assert (y == x.T).all() # The first dimension was removed from x and then re-inserted as the last dim
While this might seem like a handful at first, if you’ve seen code that uses jnp.einsum
you are already familiar with this approach. The einsum
function interprets an expression such as nk,km->nm
assigning names (each letter is considered a separate name) to positional axes, performing necessary broadcasts and reductions, and finally putting back the results in positional axes, according to the order given by the right-hand side of the ->
separator. While einsum
never lets you interact with named axes directly, they do appear naturally in its implementation. xmap
is a generalized einsum because named axes are now first-class and you get to implement the function that can manipulate them.
Continuing this analogy, xmap(my_func, ...)
from the above example is equivalent to jnp.einsum('bx->xb')
. But of course not every xmap
ped function will have an equivalent einsum
.
One more similarity with einsum
is that whenever a name is reused for multiple axes, they do have to have the same size:
x = jnp.arange(5)
y = jnp.arange(7)
try:
jnp.einsum('i,i->i', x, y)
except Exception as e:
print('einsum:', e)
try:
xmap(lambda x, y: x * y,
in_axes=(['i', ...], ['i', ...]),
out_axes=['i', ...])(x, y)
except Exception as e:
print('xmap:', e)
Named axis propagation#
We now know how named axes are introduced and eliminated, but what are they good for? How do they propagate throughout the program? Let’s explore a few examples.
Interactions with positional axes#
First rule: named axes never implicitly interact with positional axes. Any function that’s written without named axes in mind can always be invoked with inputs that have named dimensions. The result is the same as if vmap
was applied on a per-named-axis basis.
from jax.scipy.linalg import expm_frechet
# Any other function that does not assume existence of any named axes would do too,
# at least as long as it matches this type signature:
expm_frechet: Callable[[f32[(3, 3)], f32[(3, 3)]], f32[(3, 3)]]
f = partial(expm_frechet, compute_expm=False)
# Each A with each E
batch_A = jnp.ones((5, 3, 3), dtype=np.float32)
batch_E = jnp.ones((5, 3, 3), dtype=np.float32)
batch_AE = xmap(f,
in_axes=(['b', ...], ['b', ...]), # Map first axes of both inputs to 'b'
out_axes=['b', ...])(batch_A, batch_E) # Place 'b' as the first positional axis in the result
for i in range(5):
np.testing.assert_allclose(batch_AE[i], f(batch_A[i], batch_E[i]))
# All-pairs of As and Es
batch_A = jnp.ones((7, 3, 3), dtype=np.float32)
batch_E = jnp.ones((5, 3, 3), dtype=np.float32)
batch_AE = xmap(f,
in_axes=(['ba', ...], ['be', ...]), # Map first axes of inputs to 'ba' and 'be' respectively
out_axes=['ba', 'be', ...])(batch_A, batch_E) # Prefix all positional dimensions of output with 'ba' and 'be'
for i in range(7):
for j in range(5):
np.testing.assert_allclose(batch_AE[i,j], f(batch_A[i], batch_E[j]))
Broadcasting#
Secondly, named axes are broadcast by name, and every existing NumPy (and almost every JAX) operator implicitly broadcasts the named dimensions. Whenever a standard NumPy function is called with arrays with named axes, the NumPy function determines the positional shape of the result array, while the named shape becomes a union of all named shapes of its inputs. Analyze the following example to understand how the axes propagate:
def named_broadcasting(
x: f32[(2, 1, 1), {'a': 2}],
y: f32[(1, 3, 1), {'b': 3}],
z: f32[(1, 1, 5), {'c': 5}]) \
-> f32[(2, 3, 5), {'a': 2, 'b': 3, 'c': 5}]:
i: f32[(2, 3, 1), {'a': 2, 'b': 3}] = x + y
j: f32[(1, 3, 5), {'b': 3, 'c': 5}] = y + z
k: f32[(2, 3, 5), {'a': 2, 'b': 3, 'c': 5}] = i + j
return k
x = jnp.ones((2, 2, 1, 1), dtype=np.float32)
y = jnp.ones((3, 1, 3, 1), dtype=np.float32)
z = jnp.ones((5, 1, 1, 5), dtype=np.float32)
k = xmap(named_broadcasting,
in_axes=(['a', ...], ['b', ...], ['c', ...]),
out_axes=['a', 'b', 'c', ...])(x, y, z)
assert k.shape == (2, 3, 5, 2, 3, 5)
To recap, the named shape of the result of an expression such as i + j
with i
having a named shape of {'a': 2, 'b': 3}
and j
of {'b': 3, 'c': 5}
is {'a': 2, 'b': 3, 'c': 5}
. The 'b'
axis is present in both inputs, so no broadcasting is necessary, while 'a'
and 'c'
occur in only one of the two inputs, causing the other one to get broadcast along the axis missing in its named shape.
No shape errors can occur when operating over named axes, because xmap
enforces that a single name is associated with a single size inside its body.
While the rule for broadcasting named axes might seem like an arbitrary extension of the NumPy model, it is actually consistent with it.
Broadcasting first looks for pairs of dimensions it considers as equivalent in both operands. For all matched pairs, it asserts that both sizes are equal or one of them is 1. All unpaired dimensions are carried over to the result.
Now, in the positional world the way NumPy broadcasting chooses to form the pairs is by right-aligning the shapes. But our axes are named, so there is a straightforward way of finding equivalent axes: just check their names for equality!
Reductions#
But named axes are not only good for batching! In fact, our goal is that named axes should be equivalent to positional axes. In particular, every NumPy function that takes in positional axes as arguments should also accept named axes.
The paragraph above is aspirational and the set of NumPy functions that do accept named axes is relatively limited. At the moment named axes are only supported in:
jnp.sum
,jnp.max
,jnp.min
Reductions are a good example:
def named_broadcast_and_reduce(
x: f32[(), {'x': 2}],
y: f32[(5,), {'y': 4}]) \
-> f32[()]:
z: f32[(5,), {'x': 2, 'y': 4}] = x + y
w: f32[()] = jnp.sum(z, axis=(0, 'x', 'y'))
# We could also reduce in steps:
# w0 : f32[(), {'x': 2, 'y': 4}] = jnp.sum(z, 0) # eliminate the positional axis
# w0x: f32[(), {'y': 4}] = jnp.sum(w0, 'x') # eliminate the `x` axis
# w : f32[()] = jnp.sum(w0x, 'y') # eliminate the `y` axis
return w
positional_broadcast_and_reduce: Callable[[f32[(2,)], f32[(5, 4)]], f32[()]]
positional_broadcast_and_reduce = \
xmap(named_broadcast_and_reduce,
in_axes=({0: 'x'}, {1: 'y'}),
out_axes={})
positional_broadcast_and_reduce(jnp.arange(2, dtype=np.float32),
jnp.arange(20, dtype=np.float32).reshape((5, 4)))
einsum
#
Similarly to how we have extended reductions with support for named axes, we’ve also made it possible to contract over named axes using jnp.einsum
.
Operands and results still use a convention of one letter per positional axis, but now it is also possible to mention named axes in curly braces. For example, n{b,k}
implies that a value will have a single positional dimension n
and named dimensions b
and k
(their order doesn’t matter). Following the usual einsum semantics, any named axes that appear in inputs, but do not appear in an output will be contracted (summed after all multiplications are performed).
It is acceptable to omit a named dimension from all arguments and the result in which case it will be treated according to the usual broadcasting semantics. However, it is not acceptable to mention a named axis in one argument that has it in its named shape and skip it in another argument that also has it in its named shape. Of course, skipping it in the arguments that don’t have it is required.
NOTE: This invariant is unchecked at the moment (it is still work-in-progress). Such axis skipping will result in undefined behavior.
At the moment
jnp.einsum
with named axes only supports two inputs and a single result.
def named_batch_matrix_single_matrix(
x: f32[(5,), {'b': 20, 'k': 7}],
y: f32[(), {'k': 7, 'm': 11}]) \
-> f32[(5,), {'b': 20, 'm': 11}]:
return jnp.einsum('n{b,k},{k,m}->n{b,m}', x, y)
x = jnp.ones((20, 5, 7))
y = jnp.ones((7, 11))
z = jnp.einsum('bnk,km->bnm', x, y)
zx = xmap(named_batch_matrix_single_matrix,
in_axes=[{0: 'b', 2: 'k'}, ['k', 'm', ...]],
out_axes={0: 'b', 2: 'm'})(x, y)
np.testing.assert_allclose(z, zx)
The example above is admittedly no clearer than using jnp.einsum
directly. But contractions over named axes are a crucial component of larger applications such as Transformer models and this is only meant to be an exercise to show you how the names propagate.
Collectives#
Finally, all collectives that could have been used with pmap
ped functions also work with named axes. As we’ll show later, xmap
can be used as a drop-in replacement for pmap
that makes programming for multi-dimensional hardware meshes much easier.
x = jnp.arange(8)
xmap(lambda x: lax.pshuffle(x, 'i', list(reversed(range(8)))),
in_axes=['i', ...], out_axes=['i', ...])(x)
Parallelism support#
While the new programming paradigm can be nice at times, the killer feature of xmap
is its ability to parallelize code over supercomputer-scale hardware meshes!
Named axes are the secret sauce that makes all this possible, thanks to the carefully tuned rules that describe their propagation. Good support for partitioning in a purely positional programming model is notoriously difficult. Positional axes are usually disposable and it is hard to keep track of the way axis partitioning propagates through the program. As you’ll see below, named axes enable us to define a straightforward correspondence between their names and hardware resources, making it easy to reason about the way different values end up partitioned.
In all the previous examples, we haven’t said a word about parallelism and for a good reason. By default xmap
doesn’t perform any parallelization and vectorizes the computation in the same way vmap
does (i.e. it still executes on a single device). To partition the computation over multiple accelerators we have to introduce one more concept: resource axes.
The basic idea is that logical axes (the ones that appear in named shapes) assume that we have abundant hardware and memory, but before the program is to be executed, they have to be placed somewhere. The default (vmap
-like) evaluation style pays a high memory cost on the default JAX device. By mapping logical axes to (one or more) resource axes through the axis_resources
argument, we can control how xmap
evaluates the computation.
x = jnp.ones((2048, 2048))
local_matmul = xmap(jnp.vdot,
in_axes=({0: 'left'}, {1: 'right'}),
out_axes=['left', 'right', ...])
distr_matmul = xmap(jnp.vdot,
in_axes=({0: 'left'}, {1: 'right'}),
out_axes=['left', 'right', ...],
axis_resources={'left': 'x', 'right': 'y'})
Both local_matmul
and distr_matmul
implement matrix multiplication, but distr_matmul
will additionally partition the left
and right
logical axes over the x
and y
resource axes.
But… where do those resource names come from?#
Well, it depends, but one good choice is… a hardware mesh!
For our purposes a mesh is an nd-array of devices with named axes. But, because NumPy doesn’t support named axes (that’s our extension!), the meshes are represented by a pair of an nd-array of JAX device objects (as obtained from jax.devices()
or jax.local_devices()
) and a tuple of resource axis names of length matching the rank of the array.
axis_names = ('x', 'y')
mesh_devices = np.array(jax.devices()).reshape((2, 4))
assert len(axis_names) == mesh_devices.ndim
mesh_def = (mesh_devices, axis_names)
mesh_def
The mesh axis names are exactly the names of resources that named axes can be mapped to. But just creating a mesh definition won’t make the resource names visible to distr_matmul
:
try:
distr_matmul(x, x)
except Exception as e:
print(e)
To introduce the resources in a scope, use the with Mesh
context manager:
from jax.sharding import Mesh
local = local_matmul(x, x) # The local function doesn't require the mesh definition
with Mesh(*mesh_def): # Makes the mesh axis names available as resources
distr = distr_matmul(x, x)
np.testing.assert_allclose(local, distr)
Anyway, the best part of it is that specifying axis_resources
never changes program semantics. You are free to experiment with different ways of partitioning your computation (just change the assignment of resources to named axes!) and even how the physical devices are organized in the mesh (by changing the construction of the NumPy array of devices). None of those things should have any significant influence on the results you get back (up to, for example, floating point inaccuracy), though of course some of them will achieve significantly better performance than the others.
xmap
doesn’t provide any automatic scheduling options at the moment, because the best schedule often has to be somewhat carefully matched to your program. We’re considering adding support for that in the future, but it will take time.
Once you map a logical axis to a mesh dimension, the size of that logical axis has to be divisible by the mesh dimension size.
Is my data replicated? Or partitioned? Where is it?#
Named axes also give us a neat way of reasoning about partitioning and replication. A value is partitioned over a mesh axis if and only if it has a named axis that has been mapped to that mesh axis in its shape. Otherwise, it will be replicated over all slices along that axis.
For example, assume that we’re in an xmap
that had axis_resources={'a': 'x', 'b': 'y'}
specified (i.e. we are running the computation over a 2D mesh with x
and y
axes with sizes 2 and 3 respectively). Then:
An array of type
f32[(5, 5), {}]
is completely replicated over the whole mesh. All devices store a local copy of the value.An array of type
f32[(6,), {'a': 8}]
is partitioned over mesh axisx
, because it has'a'
in its named shape, and'a'
is mapped tox
. It is replicated over mesh axisy
. To put it differently, all devices in a slice of the mesh with the samex
coordinate will store a local copy of a chunk of this array. But, mesh slices with differentx
coordinates will store different chunks of the data.An array of type
f32[(), {'a': 8, 'c': 7}]
is partitioned just like in the previous case: split over thex
mesh axis and replicated over they
axis. Named dimensions with no resources specified are no different than positional dimensions when considering partitioning, so'c'
has no influence on it.An array of type
f32[(), {'a': 8, 'b': 12}]
is completely partitioned over the whole mesh. Every device holds a distinct chunk of the data.
This also highlights one restriction: xmap
won’t complain if you specify axis_resources={'a': 'x', 'b': 'x'}
, but consider how would an array with type f32[(2, 8), {'a': 4, 'b': 12}]
be partitioned. If the size of the x
mesh axis is 2, then we only have 2 devices, but we have 4 chunks to place (2 along 'a'
and 2 along 'b'
)! Now we can state it in full: named axes mapped to the same resources can never both appear in the named shape of a single array. But they can appear in named shapes of two distinct arrays, such as in this program:
def sum_two_args(x: f32[(), {'a': 4}], y: f32[(), {'b': 12}]) -> f32[()]:
return jnp.sum(x, axis='a') + jnp.sum(y, axis='b')
q = jnp.ones((4,), dtype=np.float32)
u = jnp.ones((12,), dtype=np.float32)
with Mesh(np.array(jax.devices()[:4]), ('x',)):
v = xmap(sum_two_args,
in_axes=(['a', ...], ['b', ...]),
out_axes=[...],
axis_resources={'a': 'x', 'b': 'x'})(q, u)
print(v)
This program is valid, because jnp.sum
eliminates the axes that cannot co-occur before the values are added.
While the final release of
xmap
will ensure that you don’t accidentally end up doing so, the current implementation doesn’t verify it. Violating this restriction will result in undefined behavior.
Why axis_resources
and not a more direct mapping to hardware?#
At this point you might wonder why go through the detour of introducing yet another concept of resource axes in the mix. For as long as you’re interested in partitioning your computations over hardware, there is no good reason, but this mental framework is more flexible than that!
For example, there is one additional resource we all deal with: time! Just like a computation can be partitioned over multiple hardware devices, e.g. to lower its memory usage, the same thing can be achieved with a single accelerator that evaluates a chunk of the computation in multiple steps.
So, while hardware meshes are the only source of resource axes in JAX programs at the moment, we are planning to extend the whole system with other sources.
Porting positional code to named code#
In this section we will go over a few more real examples to show how xmap
can help you implement and distribute various models.
This section is a work in progress
The Autodiff Cookbook#
alexbw@, mattjj@
JAX has a pretty general automatic differentiation system. In this notebook, we’ll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics.
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
key = random.key(0)
Gradients#
Starting with grad
#
You can differentiate a function with grad
:
grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))
0.070650816
grad
takes a function and returns a function. If you have a Python function f
that evaluates the mathematical function \(f\), then grad(f)
is a Python function that evaluates the mathematical function \(\nabla f\). That means grad(f)(x)
represents the value \(\nabla f(x)\).
Since grad
operates on functions, you can apply it to its own output to differentiate as many times as you like:
print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0))
-0.13621868
0.25265405
Let’s look at computing gradients with grad
in a linear logistic regression model. First, the setup:
def sigmoid(x):
return 0.5 * (jnp.tanh(x / 2) + 1)
# Outputs probability of a label being true.
def predict(W, b, inputs):
return sigmoid(jnp.dot(inputs, W) + b)
# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12, 0.77],
[0.88, -1.08, 0.15],
[0.52, 0.06, -1.30],
[0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])
# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
preds = predict(W, b, inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))
# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())
Use the grad
function with its argnums
argument to differentiate a function with respect to positional arguments.
# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print('W_grad', W_grad)
# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print('W_grad', W_grad)
# But we can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print('b_grad', b_grad)
# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print('W_grad', W_grad)
print('b_grad', b_grad)
W_grad [-0.16965583 -0.8774644 -1.4901346 ]
W_grad [-0.16965583 -0.8774644 -1.4901346 ]
b_grad -0.29227245
W_grad [-0.16965583 -0.8774644 -1.4901346 ]
b_grad -0.29227245
This grad
API has a direct correspondence to the excellent notation in Spivak’s classic Calculus on Manifolds (1965), also used in Sussman and Wisdom’s Structure and Interpretation of Classical Mechanics (2015) and their Functional Differential Geometry (2013). Both books are open-access. See in particular the “Prologue” section of Functional Differential Geometry for a defense of this notation.
Essentially, when using the argnums
argument, if f
is a Python function for evaluating the mathematical function \(f\), then the Python expression grad(f, i)
evaluates to a Python function for evaluating \(\partial_i f\).
Differentiating with respect to nested lists, tuples, and dicts#
Differentiating with respect to standard Python containers just works, so use tuples, lists, and dicts (and arbitrary nesting) however you like.
def loss2(params_dict):
preds = predict(params_dict['W'], params_dict['b'], inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))
print(grad(loss2)({'W': W, 'b': b}))
{'W': Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32), 'b': Array(-0.29227245, dtype=float32)}
You can register your own container types to work with not just grad
but all the JAX transformations (jit
, vmap
, etc.).
Evaluate a function and its gradient using value_and_grad
#
Another convenient function is value_and_grad
for efficiently computing both a function’s value as well as its gradient’s value:
from jax import value_and_grad
loss_value, Wb_grad = value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b))
loss value 3.0519385
loss value 3.0519385
Checking against numerical differences#
A great thing about derivatives is that they’re straightforward to check with finite differences:
# Set a step size for finite differences calculations
eps = 1e-4
# Check b_grad with scalar finite differences
b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', grad(loss, 1)(W, b))
# Check W_grad with finite differences in a random direction
key, subkey = random.split(key)
vec = random.normal(subkey, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(W, b), unitvec))
b_grad_numerical -0.29325485
b_grad_autodiff -0.29227245
W_dirderiv_numerical -0.2002716
W_dirderiv_autodiff -0.19909117
JAX provides a simple convenience function that does essentially the same thing, but checks up to any order of differentiation that you like:
from jax.test_util import check_grads
check_grads(loss, (W, b), order=2) # check up to 2nd order derivatives
Hessian-vector products with grad
-of-grad
#
One thing we can do with higher-order grad
is build a Hessian-vector product function. (Later on we’ll write an even more efficient implementation that mixes both forward- and reverse-mode, but this one will use pure reverse-mode.)
A Hessian-vector product function can be useful in a truncated Newton Conjugate-Gradient algorithm for minimizing smooth convex functions, or for studying the curvature of neural network training objectives (e.g. 1, 2, 3, 4).
For a scalar-valued function \(f : \mathbb{R}^n \to \mathbb{R}\) with continuous second derivatives (so that the Hessian matrix is symmetric), the Hessian at a point \(x \in \mathbb{R}^n\) is written as \(\partial^2 f(x)\). A Hessian-vector product function is then able to evaluate
\(\qquad v \mapsto \partial^2 f(x) \cdot v\)
for any \(v \in \mathbb{R}^n\).
The trick is not to instantiate the full Hessian matrix: if \(n\) is large, perhaps in the millions or billions in the context of neural networks, then that might be impossible to store.
Luckily, grad
already gives us a way to write an efficient Hessian-vector product function. We just have to use the identity
\(\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)\),
where \(g(x) = \partial f(x) \cdot v\) is a new scalar-valued function that dots the gradient of \(f\) at \(x\) with the vector \(v\). Notice that we’re only ever differentiating scalar-valued functions of vector-valued arguments, which is exactly where we know grad
is efficient.
In JAX code, we can just write this:
def hvp(f, x, v):
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
This example shows that you can freely use lexical closure, and JAX will never get perturbed or confused.
We’ll check this implementation a few cells down, once we see how to compute dense Hessian matrices. We’ll also write an even better version that uses both forward-mode and reverse-mode.
Jacobians and Hessians using jacfwd
and jacrev
#
You can compute full Jacobian matrices using the jacfwd
and jacrev
functions:
from jax import jacfwd, jacrev
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
J = jacfwd(f)(W)
print("jacfwd result, with shape", J.shape)
print(J)
J = jacrev(f)(W)
print("jacrev result, with shape", J.shape)
print(J)
jacfwd result, with shape (4, 3)
[[ 0.05981758 0.12883787 0.08857603]
[ 0.04015916 -0.04928625 0.00684531]
[ 0.12188288 0.01406341 -0.3047072 ]
[ 0.00140431 -0.00472531 0.00263782]]
jacrev result, with shape (4, 3)
[[ 0.05981757 0.12883787 0.08857603]
[ 0.04015916 -0.04928625 0.00684531]
[ 0.12188289 0.01406341 -0.3047072 ]
[ 0.00140431 -0.00472531 0.00263782]]
These two functions compute the same values (up to machine numerics), but differ in their implementation: jacfwd
uses forward-mode automatic differentiation, which is more efficient for “tall” Jacobian matrices, while jacrev
uses reverse-mode, which is more efficient for “wide” Jacobian matrices. For matrices that are near-square, jacfwd
probably has an edge over jacrev
.
You can also use jacfwd
and jacrev
with container types:
def predict_dict(params, inputs):
return predict(params['W'], params['b'], inputs)
J_dict = jacrev(predict_dict)({'W': W, 'b': b}, inputs)
for k, v in J_dict.items():
print("Jacobian from {} to logits is".format(k))
print(v)
Jacobian from W to logits is
[[ 0.05981757 0.12883787 0.08857603]
[ 0.04015916 -0.04928625 0.00684531]
[ 0.12188289 0.01406341 -0.3047072 ]
[ 0.00140431 -0.00472531 0.00263782]]
Jacobian from b to logits is
[0.11503381 0.04563541 0.23439017 0.00189771]
For more details on forward- and reverse-mode, as well as how to implement jacfwd
and jacrev
as efficiently as possible, read on!
Using a composition of two of these functions gives us a way to compute dense Hessian matrices:
def hessian(f):
return jacfwd(jacrev(f))
H = hessian(f)(W)
print("hessian, with shape", H.shape)
print(H)
hessian, with shape (4, 3, 3)
[[[ 0.02285465 0.04922541 0.03384247]
[ 0.04922541 0.10602397 0.07289147]
[ 0.03384247 0.07289147 0.05011288]]
[[-0.03195215 0.03921401 -0.00544639]
[ 0.03921401 -0.04812629 0.00668421]
[-0.00544639 0.00668421 -0.00092836]]
[[-0.01583708 -0.00182736 0.03959271]
[-0.00182736 -0.00021085 0.00456839]
[ 0.03959271 0.00456839 -0.09898177]]
[[-0.00103524 0.00348343 -0.00194457]
[ 0.00348343 -0.01172127 0.0065432 ]
[-0.00194457 0.0065432 -0.00365263]]]
This shape makes sense: if we start with a function \(f : \mathbb{R}^n \to \mathbb{R}^m\), then at a point \(x \in \mathbb{R}^n\) we expect to get the shapes
\(f(x) \in \mathbb{R}^m\), the value of \(f\) at \(x\),
\(\partial f(x) \in \mathbb{R}^{m \times n}\), the Jacobian matrix at \(x\),
\(\partial^2 f(x) \in \mathbb{R}^{m \times n \times n}\), the Hessian at \(x\),
and so on.
To implement hessian
, we could have used jacfwd(jacrev(f))
or jacrev(jacfwd(f))
or any other composition of the two. But forward-over-reverse is typically the most efficient. That’s because in the inner Jacobian computation we’re often differentiating a function wide Jacobian (maybe like a loss function \(f : \mathbb{R}^n \to \mathbb{R}\)), while in the outer Jacobian computation we’re differentiating a function with a square Jacobian (since \(\nabla f : \mathbb{R}^n \to \mathbb{R}^n\)), which is where forward-mode wins out.
How it’s made: two foundational autodiff functions#
Jacobian-Vector products (JVPs, aka forward-mode autodiff)#
JAX includes efficient and general implementations of both forward- and reverse-mode automatic differentiation. The familiar grad
function is built on reverse-mode, but to explain the difference in the two modes, and when each can be useful, we need a bit of math background.
JVPs in math#
Mathematically, given a function \(f : \mathbb{R}^n \to \mathbb{R}^m\), the Jacobian of \(f\) evaluated at an input point \(x \in \mathbb{R}^n\), denoted \(\partial f(x)\), is often thought of as a matrix in \(\mathbb{R}^m \times \mathbb{R}^n\):
\(\qquad \partial f(x) \in \mathbb{R}^{m \times n}\).
But we can also think of \(\partial f(x)\) as a linear map, which maps the tangent space of the domain of \(f\) at the point \(x\) (which is just another copy of \(\mathbb{R}^n\)) to the tangent space of the codomain of \(f\) at the point \(f(x)\) (a copy of \(\mathbb{R}^m\)):
\(\qquad \partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\).
This map is called the pushforward map of \(f\) at \(x\). The Jacobian matrix is just the matrix for this linear map in a standard basis.
If we don’t commit to one specific input point \(x\), then we can think of the function \(\partial f\) as first taking an input point and returning the Jacobian linear map at that input point:
\(\qquad \partial f : \mathbb{R}^n \to \mathbb{R}^n \to \mathbb{R}^m\).
In particular, we can uncurry things so that given input point \(x \in \mathbb{R}^n\) and a tangent vector \(v \in \mathbb{R}^n\), we get back an output tangent vector in \(\mathbb{R}^m\). We call that mapping, from \((x, v)\) pairs to output tangent vectors, the Jacobian-vector product, and write it as
\(\qquad (x, v) \mapsto \partial f(x) v\)
JVPs in JAX code#
Back in Python code, JAX’s jvp
function models this transformation. Given a Python function that evaluates \(f\), JAX’s jvp
is a way to get a Python function for evaluating \((x, v) \mapsto (f(x), \partial f(x) v)\).
from jax import jvp
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
key, subkey = random.split(key)
v = random.normal(subkey, W.shape)
# Push forward the vector `v` along `f` evaluated at `W`
y, u = jvp(f, (W,), (v,))
In terms of Haskell-like type signatures, we could write
jvp :: (a -> b) -> a -> T a -> (b, T b)
where we use T a
to denote the type of the tangent space for a
. In words, jvp
takes as arguments a function of type a -> b
, a value of type a
, and a tangent vector value of type T a
. It gives back a pair consisting of a value of type b
and an output tangent vector of type T b
.
The jvp
-transformed function is evaluated much like the original function, but paired up with each primal value of type a
it pushes along tangent values of type T a
. For each primitive numerical operation that the original function would have applied, the jvp
-transformed function executes a “JVP rule” for that primitive that both evaluates the primitive on the primals and applies the primitive’s JVP at those primal values.
That evaluation strategy has some immediate implications about computational complexity: since we evaluate JVPs as we go, we don’t need to store anything for later, and so the memory cost is independent of the depth of the computation. In addition, the FLOP cost of the jvp
-transformed function is about 3x the cost of just evaluating the function (one unit of work for evaluating the original function, for example sin(x)
; one unit for linearizing, like cos(x)
; and one unit for applying the linearized function to a vector, like cos_x * v
). Put another way, for a fixed primal point \(x\), we can evaluate \(v \mapsto \partial f(x) \cdot v\) for about the same marginal cost as evaluating \(f\).
That memory complexity sounds pretty compelling! So why don’t we see forward-mode very often in machine learning?
To answer that, first think about how you could use a JVP to build a full Jacobian matrix. If we apply a JVP to a one-hot tangent vector, it reveals one column of the Jacobian matrix, corresponding to the nonzero entry we fed in. So we can build a full Jacobian one column at a time, and to get each column costs about the same as one function evaluation. That will be efficient for functions with “tall” Jacobians, but inefficient for “wide” Jacobians.
If you’re doing gradient-based optimization in machine learning, you probably want to minimize a loss function from parameters in \(\mathbb{R}^n\) to a scalar loss value in \(\mathbb{R}\). That means the Jacobian of this function is a very wide matrix: \(\partial f(x) \in \mathbb{R}^{1 \times n}\), which we often identify with the Gradient vector \(\nabla f(x) \in \mathbb{R}^n\). Building that matrix one column at a time, with each call taking a similar number of FLOPs to evaluate the original function, sure seems inefficient! In particular, for training neural networks, where \(f\) is a training loss function and \(n\) can be in the millions or billions, this approach just won’t scale.
To do better for functions like this, we just need to use reverse-mode.
Vector-Jacobian products (VJPs, aka reverse-mode autodiff)#
Where forward-mode gives us back a function for evaluating Jacobian-vector products, which we can then use to build Jacobian matrices one column at a time, reverse-mode is a way to get back a function for evaluating vector-Jacobian products (equivalently Jacobian-transpose-vector products), which we can use to build Jacobian matrices one row at a time.
VJPs in math#
Let’s again consider a function \(f : \mathbb{R}^n \to \mathbb{R}^m\). Starting from our notation for JVPs, the notation for VJPs is pretty simple:
\(\qquad (x, v) \mapsto v \partial f(x)\),
where \(v\) is an element of the cotangent space of \(f\) at \(x\) (isomorphic to another copy of \(\mathbb{R}^m\)). When being rigorous, we should think of \(v\) as a linear map \(v : \mathbb{R}^m \to \mathbb{R}\), and when we write \(v \partial f(x)\) we mean function composition \(v \circ \partial f(x)\), where the types work out because \(\partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\). But in the common case we can identify \(v\) with a vector in \(\mathbb{R}^m\) and use the two almost interchangeably, just like we might sometimes flip between “column vectors” and “row vectors” without much comment.
With that identification, we can alternatively think of the linear part of a VJP as the transpose (or adjoint conjugate) of the linear part of a JVP:
\(\qquad (x, v) \mapsto \partial f(x)^\mathsf{T} v\).
For a given point \(x\), we can write the signature as
\(\qquad \partial f(x)^\mathsf{T} : \mathbb{R}^m \to \mathbb{R}^n\).
The corresponding map on cotangent spaces is often called the pullback of \(f\) at \(x\). The key for our purposes is that it goes from something that looks like the output of \(f\) to something that looks like the input of \(f\), just like we might expect from a transposed linear function.
VJPs in JAX code#
Switching from math back to Python, the JAX function vjp
can take a Python function for evaluating \(f\) and give us back a Python function for evaluating the VJP \((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\).
from jax import vjp
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
y, vjp_fun = vjp(f, W)
key, subkey = random.split(key)
u = random.normal(subkey, y.shape)
# Pull back the covector `u` along `f` evaluated at `W`
v = vjp_fun(u)
In terms of Haskell-like type signatures, we could write
vjp :: (a -> b) -> a -> (b, CT b -> CT a)
where we use CT a
to denote the type for the cotangent space for a
. In words, vjp
takes as arguments a function of type a -> b
and a point of type a
, and gives back a pair consisting of a value of type b
and a linear map of type CT b -> CT a
.
This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating \((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\) is only about three times the cost of evaluating \(f\). In particular, if we want the gradient of a function \(f : \mathbb{R}^n \to \mathbb{R}\), we can do it in just one call. That’s how grad
is efficient for gradient-based optimization, even for objectives like neural network training loss functions on millions or billions of parameters.
There’s a cost, though: though the FLOPs are friendly, memory scales with the depth of the computation. Also, the implementation is traditionally more complex than that of forward-mode, though JAX has some tricks up its sleeve (that’s a story for a future notebook!).
For more on how reverse-mode works, see this tutorial video from the Deep Learning Summer School in 2017.
Vector-valued gradients with VJPs#
If you’re interested in taking vector-valued gradients (like tf.gradients
):
from jax import vjp
def vgrad(f, x):
y, vjp_fn = vjp(f, x)
return vjp_fn(jnp.ones(y.shape))[0]
print(vgrad(lambda x: 3*x**2, jnp.ones((2, 2))))
[[6. 6.]
[6. 6.]]
Hessian-vector products using both forward- and reverse-mode#
In a previous section, we implemented a Hessian-vector product function just using reverse-mode (assuming continuous second derivatives):
def hvp(f, x, v):
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
That’s efficient, but we can do even better and save some memory by using forward-mode together with reverse-mode.
Mathematically, given a function \(f : \mathbb{R}^n \to \mathbb{R}\) to differentiate, a point \(x \in \mathbb{R}^n\) at which to linearize the function, and a vector \(v \in \mathbb{R}^n\), the Hessian-vector product function we want is
\((x, v) \mapsto \partial^2 f(x) v\)
Consider the helper function \(g : \mathbb{R}^n \to \mathbb{R}^n\) defined to be the derivative (or gradient) of \(f\), namely \(g(x) = \partial f(x)\). All we need is its JVP, since that will give us
\((x, v) \mapsto \partial g(x) v = \partial^2 f(x) v\).
We can translate that almost directly into code:
from jax import jvp, grad
# forward-over-reverse
def hvp(f, primals, tangents):
return jvp(grad(f), primals, tangents)[1]
Even better, since we didn’t have to call jnp.dot
directly, this hvp
function works with arrays of any shape and with arbitrary container types (like vectors stored as nested lists/dicts/tuples), and doesn’t even have a dependence on jax.numpy
.
Here’s an example of how to use it:
def f(X):
return jnp.sum(jnp.tanh(X)**2)
key, subkey1, subkey2 = random.split(key, 3)
X = random.normal(subkey1, (30, 40))
V = random.normal(subkey2, (30, 40))
ans1 = hvp(f, (X,), (V,))
ans2 = jnp.tensordot(hessian(f)(X), V, 2)
print(jnp.allclose(ans1, ans2, 1e-4, 1e-4))
True
Another way you might consider writing this is using reverse-over-forward:
# reverse-over-forward
def hvp_revfwd(f, primals, tangents):
g = lambda primals: jvp(f, primals, tangents)[1]
return grad(g)(primals)
That’s not quite as good, though, because forward-mode has less overhead than reverse-mode, and since the outer differentiation operator here has to differentiate a larger computation than the inner one, keeping forward-mode on the outside works best:
# reverse-over-reverse, only works for single arguments
def hvp_revrev(f, primals, tangents):
x, = primals
v, = tangents
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
print("Forward over reverse")
%timeit -n10 -r3 hvp(f, (X,), (V,))
print("Reverse over forward")
%timeit -n10 -r3 hvp_revfwd(f, (X,), (V,))
print("Reverse over reverse")
%timeit -n10 -r3 hvp_revrev(f, (X,), (V,))
print("Naive full Hessian materialization")
%timeit -n10 -r3 jnp.tensordot(hessian(f)(X), V, 2)
Forward over reverse
4.78 ms ± 202 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over forward
9.14 ms ± 5.06 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over reverse
13.7 ms ± 8.1 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Naive full Hessian materialization
55.8 ms ± 977 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Composing VJPs, JVPs, and vmap
#
Jacobian-Matrix and Matrix-Jacobian products#
Now that we have jvp
and vjp
transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX’s vmap
transformation to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products.
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
# Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`.
# First, use a list comprehension to loop over rows in the matrix M.
def loop_mjp(f, x, M):
y, vjp_fun = vjp(f, x)
return jnp.vstack([vjp_fun(mi) for mi in M])
# Now, use vmap to build a computation that does a single fast matrix-matrix
# multiply, rather than an outer loop over vector-matrix multiplies.
def vmap_mjp(f, x, M):
y, vjp_fun = vjp(f, x)
outs, = vmap(vjp_fun)(M)
return outs
key = random.key(0)
num_covecs = 128
U = random.normal(key, (num_covecs,) + y.shape)
loop_vs = loop_mjp(f, W, M=U)
print('Non-vmapped Matrix-Jacobian product')
%timeit -n10 -r3 loop_mjp(f, W, M=U)
print('\nVmapped Matrix-Jacobian product')
vmap_vs = vmap_mjp(f, W, M=U)
%timeit -n10 -r3 vmap_mjp(f, W, M=U)
assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical'
Non-vmapped Matrix-Jacobian product
146 ms ± 1.14 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Vmapped Matrix-Jacobian product
5.92 ms ± 44 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)
/tmp/ipykernel_4161/3769736790.py:8: DeprecationWarning: vstack requires ndarray or scalar arguments, got <class 'tuple'> at position 0. In a future JAX release this will be an error.
return jnp.vstack([vjp_fun(mi) for mi in M])
def loop_jmp(f, W, M):
# jvp immediately returns the primal and tangent values as a tuple,
# so we'll compute and select the tangents in a list comprehension
return jnp.vstack([jvp(f, (W,), (mi,))[1] for mi in M])
def vmap_jmp(f, W, M):
_jvp = lambda s: jvp(f, (W,), (s,))[1]
return vmap(_jvp)(M)
num_vecs = 128
S = random.normal(key, (num_vecs,) + W.shape)
loop_vs = loop_jmp(f, W, M=S)
print('Non-vmapped Jacobian-Matrix product')
%timeit -n10 -r3 loop_jmp(f, W, M=S)
vmap_vs = vmap_jmp(f, W, M=S)
print('\nVmapped Jacobian-Matrix product')
%timeit -n10 -r3 vmap_jmp(f, W, M=S)
assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical'
Non-vmapped Jacobian-Matrix product
289 ms ± 403 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Vmapped Jacobian-Matrix product
3.38 ms ± 123 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)
The implementation of jacfwd
and jacrev
#
Now that we’ve seen fast Jacobian-matrix and matrix-Jacobian products, it’s not hard to guess how to write jacfwd
and jacrev
. We just use the same technique to push-forward or pull-back an entire standard basis (isomorphic to an identity matrix) at once.
from jax import jacrev as builtin_jacrev
def our_jacrev(f):
def jacfun(x):
y, vjp_fun = vjp(f, x)
# Use vmap to do a matrix-Jacobian product.
# Here, the matrix is the Euclidean basis, so we get all
# entries in the Jacobian at once.
J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))
return J
return jacfun
assert jnp.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!'
from jax import jacfwd as builtin_jacfwd
def our_jacfwd(f):
def jacfun(x):
_jvp = lambda s: jvp(f, (x,), (s,))[1]
Jt =vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
return jnp.transpose(Jt)
return jacfun
assert jnp.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!'
Interestingly, Autograd couldn’t do this. Our implementation of reverse-mode jacobian
in Autograd had to pull back one vector at a time with an outer-loop map
. Pushing one vector at a time through the computation is much less efficient than batching it all together with vmap
.
Another thing that Autograd couldn’t do is jit
. Interestingly, no matter how much Python dynamism you use in your function to be differentiated, we could always use jit
on the linear part of the computation. For example:
def f(x):
try:
if x < 3:
return 2 * x ** 3
else:
raise ValueError
except ValueError:
return jnp.pi * x
y, f_vjp = vjp(f, 4.)
print(jit(f_vjp)(1.))
(Array(3.1415927, dtype=float32, weak_type=True),)
Complex numbers and differentiation#
JAX is great at complex numbers and differentiation. To support both holomorphic and non-holomorphic differentiation, it helps to think in terms of JVPs and VJPs.
Consider a complex-to-complex function \(f: \mathbb{C} \to \mathbb{C}\) and identify it with a corresponding function \(g: \mathbb{R}^2 \to \mathbb{R}^2\),
def f(z):
x, y = jnp.real(z), jnp.imag(z)
return u(x, y) + v(x, y) * 1j
def g(x, y):
return (u(x, y), v(x, y))
That is, we’ve decomposed \(f(z) = u(x, y) + v(x, y) i\) where \(z = x + y i\), and identified \(\mathbb{C}\) with \(\mathbb{R}^2\) to get \(g\).
Since \(g\) only involves real inputs and outputs, we already know how to write a Jacobian-vector product for it, say given a tangent vector \((c, d) \in \mathbb{R}^2\), namely
\(\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}\).
To get a JVP for the original function \(f\) applied to a tangent vector \(c + di \in \mathbb{C}\), we just use the same definition and identify the result as another complex number,
\(\partial f(x + y i)(c + d i) = \begin{matrix} \begin{bmatrix} 1 & i \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}\).
That’s our definition of the JVP of a \(\mathbb{C} \to \mathbb{C}\) function! Notice it doesn’t matter whether or not \(f\) is holomorphic: the JVP is unambiguous.
Here’s a check:
def check(seed):
key = random.key(seed)
# random coeffs for u and v
key, subkey = random.split(key)
a, b, c, d = random.uniform(subkey, (4,))
def fun(z):
x, y = jnp.real(z), jnp.imag(z)
return u(x, y) + v(x, y) * 1j
def u(x, y):
return a * x + b * y
def v(x, y):
return c * x + d * y
# primal point
key, subkey = random.split(key)
x, y = random.uniform(subkey, (2,))
z = x + y * 1j
# tangent vector
key, subkey = random.split(key)
c, d = random.uniform(subkey, (2,))
z_dot = c + d * 1j
# check jvp
_, ans = jvp(fun, (z,), (z_dot,))
expected = (grad(u, 0)(x, y) * c +
grad(u, 1)(x, y) * d +
grad(v, 0)(x, y) * c * 1j+
grad(v, 1)(x, y) * d * 1j)
print(jnp.allclose(ans, expected))
check(0)
check(1)
check(2)
True
True
True
What about VJPs? We do something pretty similar: for a cotangent vector \(c + di \in \mathbb{C}\) we define the VJP of \(f\) as
\((c + di)^* \; \partial f(x + y i) = \begin{matrix} \begin{bmatrix} c & -d \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} 1 \\ -i \end{bmatrix}\).
What’s with the negatives? They’re just to take care of complex conjugation, and the fact that we’re working with covectors.
Here’s a check of the VJP rules:
def check(seed):
key = random.key(seed)
# random coeffs for u and v
key, subkey = random.split(key)
a, b, c, d = random.uniform(subkey, (4,))
def fun(z):
x, y = jnp.real(z), jnp.imag(z)
return u(x, y) + v(x, y) * 1j
def u(x, y):
return a * x + b * y
def v(x, y):
return c * x + d * y
# primal point
key, subkey = random.split(key)
x, y = random.uniform(subkey, (2,))
z = x + y * 1j
# cotangent vector
key, subkey = random.split(key)
c, d = random.uniform(subkey, (2,))
z_bar = jnp.array(c + d * 1j) # for dtype control
# check vjp
_, fun_vjp = vjp(fun, z)
ans, = fun_vjp(z_bar)
expected = (grad(u, 0)(x, y) * c +
grad(v, 0)(x, y) * (-d) +
grad(u, 1)(x, y) * c * (-1j) +
grad(v, 1)(x, y) * (-d) * (-1j))
assert jnp.allclose(ans, expected, atol=1e-5, rtol=1e-5)
check(0)
check(1)
check(2)
What about convenience wrappers like grad
, jacfwd
, and jacrev
?
For \(\mathbb{R} \to \mathbb{R}\) functions, recall we defined grad(f)(x)
as being vjp(f, x)[1](1.0)
, which works because applying a VJP to a 1.0
value reveals the gradient (i.e. Jacobian, or derivative). We can do the same thing for \(\mathbb{C} \to \mathbb{R}\) functions: we can still use 1.0
as the cotangent vector, and we just get out a complex number result summarizing the full Jacobian:
def f(z):
x, y = jnp.real(z), jnp.imag(z)
return x**2 + y**2
z = 3. + 4j
grad(f)(z)
Array(6.-8.j, dtype=complex64)
For general \(\mathbb{C} \to \mathbb{C}\) functions, the Jacobian has 4 real-valued degrees of freedom (as in the 2x2 Jacobian matrices above), so we can’t hope to represent all of them within a complex number. But we can for holomorphic functions! A holomorphic function is precisely a \(\mathbb{C} \to \mathbb{C}\) function with the special property that its derivative can be represented as a single complex number. (The Cauchy-Riemann equations ensure that the above 2x2 Jacobians have the special form of a scale-and-rotate matrix in the complex plane, i.e. the action of a single complex number under multiplication.) And we can reveal that one complex number using a single call to vjp
with a covector of 1.0
.
Because this only works for holomorphic functions, to use this trick we need to promise JAX that our function is holomorphic; otherwise, JAX will raise an error when grad
is used for a complex-output function:
def f(z):
return jnp.sin(z)
z = 3. + 4j
grad(f, holomorphic=True)(z)
Array(-27.034945-3.8511531j, dtype=complex64, weak_type=True)
All the holomorphic=True
promise does is disable the error when the output is complex-valued. We can still write holomorphic=True
when the function isn’t holomorphic, but the answer we get out won’t represent the full Jacobian. Instead, it’ll be the Jacobian of the function where we just discard the imaginary part of the output:
def f(z):
return jnp.conjugate(z)
z = 3. + 4j
grad(f, holomorphic=True)(z) # f is not actually holomorphic!
Array(1.-0.j, dtype=complex64, weak_type=True)
There are some useful upshots for how grad
works here:
We can use
grad
on holomorphic \(\mathbb{C} \to \mathbb{C}\) functions.We can use
grad
to optimize \(f : \mathbb{C} \to \mathbb{R}\) functions, like real-valued loss functions of complex parametersx
, by taking steps in the direction of the conjugate ofgrad(f)(x)
.If we have an \(\mathbb{R} \to \mathbb{R}\) function that just happens to use some complex-valued operations internally (some of which must be non-holomorphic, e.g. FFTs used in convolutions) then
grad
still works and we get the same result that an implementation using only real values would have given.
In any case, JVPs and VJPs are always unambiguous. And if we wanted to compute the full Jacobian matrix of a non-holomorphic \(\mathbb{C} \to \mathbb{C}\) function, we can do it with JVPs or VJPs!
You should expect complex numbers to work everywhere in JAX. Here’s differentiating through a Cholesky decomposition of a complex matrix:
A = jnp.array([[5., 2.+3j, 5j],
[2.-3j, 7., 1.+7j],
[-5j, 1.-7j, 12.]])
def f(X):
L = jnp.linalg.cholesky(X)
return jnp.sum((L - jnp.sin(L))**2)
grad(f, holomorphic=True)(A)
Array([[-0.7534186 +0.j , -3.0509028 -10.940544j ,
5.9896846 +3.5423026j],
[-3.0509028 +10.940544j , -8.904491 +0.j ,
-5.1351523 -6.559373j ],
[ 5.9896846 -3.5423026j, -5.1351523 +6.559373j ,
0.01320427 +0.j ]], dtype=complex64)
More advanced autodiff#
In this notebook, we worked through some easy, and then progressively more complicated, applications of automatic differentiation in JAX. We hope you now feel that taking derivatives in JAX is easy and powerful.
There’s a whole world of other autodiff tricks and functionality out there. Topics we didn’t cover, but hope to in an “Advanced Autodiff Cookbook” include:
Gauss-Newton Vector Products, linearizing once
Custom VJPs and JVPs
Efficient derivatives at fixed-points
Estimating the trace of a Hessian using random Hessian-vector products.
Forward-mode autodiff using only reverse-mode autodiff.
Taking derivatives with respect to custom data types.
Checkpointing (binomial checkpointing for efficient reverse-mode, not model snapshotting).
Optimizing VJPs with Jacobian pre-accumulation.
Custom derivative rules for JAX-transformable Python functions#
mattjj@ Mar 19 2020, last updated Oct 14 2020
There are two ways to define differentiation rules in JAX:
using
jax.custom_jvp
andjax.custom_vjp
to define custom differentiation rules for Python functions that are already JAX-transformable; anddefining new
core.Primitive
instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems.
This notebook is about #1. To read instead about #2, see the notebook on adding primitives.
For an introduction to JAX’s automatic differentiation API, see The Autodiff Cookbook. This notebook assumes some familiarity with jax.jvp and jax.grad, and the mathematical meaning of JVPs and VJPs.
TL;DR#
Custom JVPs with jax.custom_jvp
#
import jax.numpy as jnp
from jax import custom_jvp
@custom_jvp
def f(x, y):
return jnp.sin(x) * y
@f.defjvp
def f_jvp(primals, tangents):
x, y = primals
x_dot, y_dot = tangents
primal_out = f(x, y)
tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
return primal_out, tangent_out
from jax import jvp, grad
print(f(2., 3.))
y, y_dot = jvp(f, (2., 3.), (1., 0.))
print(y)
print(y_dot)
print(grad(f)(2., 3.))
2.7278922
2.7278922
-1.2484405
-1.2484405
# Equivalent alternative using the defjvps convenience wrapper
@custom_jvp
def f(x, y):
return jnp.sin(x) * y
f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y,
lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot)
print(f(2., 3.))
y, y_dot = jvp(f, (2., 3.), (1., 0.))
print(y)
print(y_dot)
print(grad(f)(2., 3.))
2.7278922
2.7278922
-1.2484405
-1.2484405
Custom VJPs with jax.custom_vjp
#
from jax import custom_vjp
@custom_vjp
def f(x, y):
return jnp.sin(x) * y
def f_fwd(x, y):
# Returns primal output and residuals to be used in backward pass by f_bwd.
return f(x, y), (jnp.cos(x), jnp.sin(x), y)
def f_bwd(res, g):
cos_x, sin_x, y = res # Gets residuals computed in f_fwd
return (cos_x * g * y, sin_x * g)
f.defvjp(f_fwd, f_bwd)
print(grad(f)(2., 3.))
-1.2484405
Example problems#
To get an idea of what problems jax.custom_jvp
and jax.custom_vjp
are meant to solve, let’s go over a few examples. A more thorough introduction to the jax.custom_jvp
and jax.custom_vjp
APIs is in the next section.
Numerical stability#
One application of jax.custom_jvp
is to improve the numerical stability of differentiation.
Say we want to write a function called log1pexp
, which computes \(x \mapsto \log ( 1 + e^x )\). We can write that using jax.numpy
:
import jax.numpy as jnp
def log1pexp(x):
return jnp.log(1. + jnp.exp(x))
log1pexp(3.)
Array(3.0485873, dtype=float32, weak_type=True)
Since it’s written in terms of jax.numpy
, it’s JAX-transformable:
from jax import jit, grad, vmap
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
3.0485873
0.95257413
[0.5 0.7310586 0.8807971]
But there’s a numerical stability problem lurking here:
print(grad(log1pexp)(100.))
nan
That doesn’t seem right! After all, the derivative of \(x \mapsto \log (1 + e^x)\) is \(x \mapsto \frac{e^x}{1 + e^x}\), and so for large values of \(x\) we’d expect the value to be about 1.
We can get a bit more insight into what’s going on by looking at the jaxpr for the gradient computation:
from jax import make_jaxpr
make_jaxpr(grad(log1pexp))(100.)
{ lambda ; a:f32[]. let
b:f32[] = exp a
c:f32[] = add 1.0 b
_:f32[] = log c
d:f32[] = div 1.0 c
e:f32[] = mul d b
in (e,) }
Stepping through how the jaxpr would be evaluated, we can see that the last line would involve multiplying values that floating point math will round to 0 and \(\infty\), respectively, which is never a good idea. That is, we’re effectively evaluating lambda x: (1 / (1 + jnp.exp(x))) * jnp.exp(x)
for large x
, which effectively turns into 0. * jnp.inf
.
Instead of generating such large and small values, hoping for a cancellation that floats can’t always provide, we’d rather just express the derivative function as a more numerically stable program. In particular, we can write a program that more closely evaluates the equal mathematical expression \(1 - \frac{1}{1 + e^x}\), with no cancellation in sight.
This problem is interesting because even though our definition of log1pexp
could already be JAX-differentiated (and transformed with jit
, vmap
, …), we’re not happy with the result of applying standard autodiff rules to the primitives comprising log1pexp
and composing the result. Instead, we’d like to specify how the whole function log1pexp
should be differentiated, as a unit, and thus arrange those exponentials better.
This is one application of custom derivative rules for Python functions that are already JAX transformable: specifying how a composite function should be differentiated, while still using its original Python definition for other transformations (like jit
, vmap
, …).
Here’s a solution using jax.custom_jvp
:
from jax import custom_jvp
@custom_jvp
def log1pexp(x):
return jnp.log(1. + jnp.exp(x))
@log1pexp.defjvp
def log1pexp_jvp(primals, tangents):
x, = primals
x_dot, = tangents
ans = log1pexp(x)
ans_dot = (1 - 1/(1 + jnp.exp(x))) * x_dot
return ans, ans_dot
print(grad(log1pexp)(100.))
1.0
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
3.0485873
0.95257413
[0.5 0.7310586 0.8807971]
Here’s a defjvps
convenience wrapper to express the same thing:
@custom_jvp
def log1pexp(x):
return jnp.log(1. + jnp.exp(x))
log1pexp.defjvps(lambda t, ans, x: (1 - 1/(1 + jnp.exp(x))) * t)
print(grad(log1pexp)(100.))
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
1.0
3.0485873
0.95257413
[0.5 0.7310586 0.8807971]
Enforcing a differentiation convention#
A related application is to enforce a differentiation convention, perhaps at a boundary.
Consider the function \(f : \mathbb{R}_+ \to \mathbb{R}_+\) with \(f(x) = \frac{x}{1 + \sqrt{x}}\), where we take \(\mathbb{R}_+ = [0, \infty)\). We might implement \(f\) as a program like this:
def f(x):
return x / (1 + jnp.sqrt(x))
As a mathematical function on \(\mathbb{R}\) (the full real line), \(f\) is not differentiable at zero (because the limit defining the derivative doesn’t exist from the left). Correspondingly, autodiff produces a nan
value:
print(grad(f)(0.))
nan
But mathematically if we think of \(f\) as a function on \(\mathbb{R}_+\) then it is differentiable at 0 [Rudin’s Principles of Mathematical Analysis Definition 5.1, or Tao’s Analysis I 3rd ed. Definition 10.1.1 and Example 10.1.6]. Alternatively, we might say as a convention we want to consider the directional derivative from the right. So there is a sensible value for the Python function grad(f)
to return at 0.0
, namely 1.0
. By default, JAX’s machinery for differentiation assumes all functions are defined over \(\mathbb{R}\) and thus doesn’t produce 1.0
here.
We can use a custom JVP rule! In particular, we can define the JVP rule in terms of the derivative function \(x \mapsto \frac{\sqrt{x} + 2}{2(\sqrt{x} + 1)^2}\) on \(\mathbb{R}_+\),
@custom_jvp
def f(x):
return x / (1 + jnp.sqrt(x))
@f.defjvp
def f_jvp(primals, tangents):
x, = primals
x_dot, = tangents
ans = f(x)
ans_dot = ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * x_dot
return ans, ans_dot
print(grad(f)(0.))
1.0
Here’s the convenience wrapper version:
@custom_jvp
def f(x):
return x / (1 + jnp.sqrt(x))
f.defjvps(lambda t, ans, x: ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * t)
print(grad(f)(0.))
1.0
Gradient clipping#
While in some cases we want to express a mathematical differentiation computation, in other cases we may even want to take a step away from mathematics to adjust the computation autodiff performs. One canonical example is reverse-mode gradient clipping.
For gradient clipping, we can use jnp.clip
together with a jax.custom_vjp
reverse-mode-only rule:
from functools import partial
from jax import custom_vjp
@custom_vjp
def clip_gradient(lo, hi, x):
return x # identity function
def clip_gradient_fwd(lo, hi, x):
return x, (lo, hi) # save bounds as residuals
def clip_gradient_bwd(res, g):
lo, hi = res
return (None, None, jnp.clip(g, lo, hi)) # use None to indicate zero cotangents for lo and hi
clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
import matplotlib.pyplot as plt
from jax import vmap
t = jnp.linspace(0, 10, 1000)
plt.plot(jnp.sin(t))
plt.plot(vmap(grad(jnp.sin))(t))
[<matplotlib.lines.Line2D at 0x7fc8cf18fc10>]

def clip_sin(x):
x = clip_gradient(-0.75, 0.75, x)
return jnp.sin(x)
plt.plot(clip_sin(t))
plt.plot(vmap(grad(clip_sin))(t))
[<matplotlib.lines.Line2D at 0x7fc8cd05bca0>]

Python debugging#
Another application that is motivated by development workflow rather than numerics is to set a pdb
debugger trace in the backward pass of reverse-mode autodiff.
When trying to track down the source of a nan
runtime error, or just examine carefully the cotangent (gradient) values being propagated, it can be useful to insert a debugger at a point in the backward pass that corresponds to a specific point in the primal computation. You can do that with jax.custom_vjp
.
We’ll defer an example until the next section.
Implicit function differentiation of iterative implementations#
This example gets pretty deep in the mathematical weeds!
Another application for jax.custom_vjp
is reverse-mode differentiation of functions that are JAX-transformable (by jit
, vmap
, …) but not efficiently JAX-differentiable for some reason, perhaps because they involve lax.while_loop
. (It’s not possible to produce an XLA HLO program that efficiently computes the reverse-mode derivative of an XLA HLO While loop because that would require a program with unbounded memory use, which isn’t possible to express in XLA HLO, at least without side-effecting interactions through infeed/outfeed.)
For example, consider this fixed_point
routine which computes a fixed point by iteratively applying a function in a while_loop
:
from jax.lax import while_loop
def fixed_point(f, a, x_guess):
def cond_fun(carry):
x_prev, x = carry
return jnp.abs(x_prev - x) > 1e-6
def body_fun(carry):
_, x = carry
return x, f(a, x)
_, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess)))
return x_star
This is an iterative procedure for numerically solving the equation \(x = f(a, x)\) for \(x\), by iterating \(x_{t+1} = f(a, x_t)\) until \(x_{t+1}\) is sufficiently close to \(x_t\). The result \(x^*\) depends on the parameters \(a\), and so we can think of there being a function \(a \mapsto x^*(a)\) that is implicitly defined by equation \(x = f(a, x)\).
We can use fixed_point
to run iterative procedures to convergence, for example running Newton’s method to calculate square roots while only executing adds, multiplies, and divides:
def newton_sqrt(a):
update = lambda a, x: 0.5 * (x + a / x)
return fixed_point(update, a, a)
print(newton_sqrt(2.))
1.4142135
We can vmap
or jit
the function as well:
print(jit(vmap(newton_sqrt))(jnp.array([1., 2., 3., 4.])))
[1. 1.4142135 1.7320509 2. ]
We can’t apply reverse-mode automatic differentiation because of the while_loop
, but it turns out we wouldn’t want to anyway: instead of differentiating through the implementation of fixed_point
and all its iterations, we can exploit the mathematical structure to do something that is much more memory-efficient (and FLOP-efficient in this case, too!). We can instead use the implicit function theorem [Prop A.25 of Bertsekas’s Nonlinear Programming, 2nd ed.], which guarantees (under some conditions) the existence of the mathematical objects we’re about to use. In essence, we linearize at the solution and solve those linear equations iteratively to compute the derivatives we want.
Consider again the equation \(x = f(a, x)\) and the function \(x^*\). We want to evaluate vector-Jacobian products like \(v^\mathsf{T} \mapsto v^\mathsf{T} \partial x^*(a_0)\).
At least in an open neighborhood around the point \(a_0\) at which we want to differentiate, let’s assume that the equation \(x^*(a) = f(a, x^*(a))\) holds for all \(a\). Since the two sides are equal as functions of \(a\), their derivatives must be equal as well, so let’s differentiate both sides:
\(\qquad \partial x^*(a) = \partial_0 f(a, x^*(a)) + \partial_1 f(a, x^*(a)) \partial x^*(a)\).
Setting \(A = \partial_1 f(a_0, x^*(a_0))\) and \(B = \partial_0 f(a_0, x^*(a_0))\), we can write the quantity we’re after more simply as
\(\qquad \partial x^*(a_0) = B + A \partial x^*(a_0)\),
or, by rearranging,
\(\qquad \partial x^*(a_0) = (I - A)^{-1} B\).
That means we can evaluate vector-Jacobian products like
\(\qquad v^\mathsf{T} \partial x^*(a_0) = v^\mathsf{T} (I - A)^{-1} B = w^\mathsf{T} B\),
where \(w^\mathsf{T} = v^\mathsf{T} (I - A)^{-1}\), or equivalently \(w^\mathsf{T} = v^\mathsf{T} + w^\mathsf{T} A\), or equivalently \(w^\mathsf{T}\) is the fixed point of the map \(u^\mathsf{T} \mapsto v^\mathsf{T} + u^\mathsf{T} A\). That last characterization gives us a way to write the VJP for fixed_point
in terms of a call to fixed_point
! Moreover, after expanding \(A\) and \(B\) back out, we can see we need only to evaluate VJPs of \(f\) at \((a_0, x^*(a_0))\).
Here’s the upshot:
from jax import vjp
@partial(custom_vjp, nondiff_argnums=(0,))
def fixed_point(f, a, x_guess):
def cond_fun(carry):
x_prev, x = carry
return jnp.abs(x_prev - x) > 1e-6
def body_fun(carry):
_, x = carry
return x, f(a, x)
_, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess)))
return x_star
def fixed_point_fwd(f, a, x_init):
x_star = fixed_point(f, a, x_init)
return x_star, (a, x_star)
def fixed_point_rev(f, res, x_star_bar):
a, x_star = res
_, vjp_a = vjp(lambda a: f(a, x_star), a)
a_bar, = vjp_a(fixed_point(partial(rev_iter, f),
(a, x_star, x_star_bar),
x_star_bar))
return a_bar, jnp.zeros_like(x_star)
def rev_iter(f, packed, u):
a, x_star, x_star_bar = packed
_, vjp_x = vjp(lambda x: f(a, x), x_star)
return x_star_bar + vjp_x(u)[0]
fixed_point.defvjp(fixed_point_fwd, fixed_point_rev)
print(newton_sqrt(2.))
1.4142135
print(grad(newton_sqrt)(2.))
print(grad(grad(newton_sqrt))(2.))
0.35355338
-0.088388346
We can check our answers by differentiating jnp.sqrt
, which uses a totally different implementation:
print(grad(jnp.sqrt)(2.))
print(grad(grad(jnp.sqrt))(2.))
0.35355338
-0.08838835
A limitation to this approach is that the argument f
can’t close over any values involved in differentiation. That is, you might notice that we kept the parameter a
explicit in the argument list of fixed_point
. For this use case, consider using the low-level primitive lax.custom_root
, which allows for deriviatives in closed-over variables with custom root-finding functions.
Basic usage of jax.custom_jvp
and jax.custom_vjp
APIs#
Use jax.custom_jvp
to define forward-mode (and, indirectly, reverse-mode) rules#
Here’s a canonical basic example of using jax.custom_jvp
, where the comments use
Haskell-like type signatures:
from jax import custom_jvp
import jax.numpy as jnp
# f :: a -> b
@custom_jvp
def f(x):
return jnp.sin(x)
# f_jvp :: (a, T a) -> (b, T b)
def f_jvp(primals, tangents):
x, = primals
t, = tangents
return f(x), jnp.cos(x) * t
f.defjvp(f_jvp)
<function __main__.f_jvp(primals, tangents)>
from jax import jvp
print(f(3.))
y, y_dot = jvp(f, (3.,), (1.,))
print(y)
print(y_dot)
0.14112
0.14112
-0.9899925
In words, we start with a primal function f
that takes inputs of type a
and produces outputs of type b
. We associate with it a JVP rule function f_jvp
that takes a pair of inputs representing the primal inputs of type a
and the corresponding tangent inputs of type T a
, and produces a pair of outputs representing the primal outputs of type b
and tangent outputs of type T b
. The tangent outputs should be a linear function of the tangent inputs.
You can also use f.defjvp
as a decorator, as in
@custom_jvp
def f(x):
...
@f.defjvp
def f_jvp(primals, tangents):
...
Even though we defined only a JVP rule and no VJP rule, we can use both forward- and reverse-mode differentiation on f
. JAX will automatically transpose the linear computation on tangent values from our custom JVP rule, computing the VJP as efficiently as if we had written the rule by hand:
from jax import grad
print(grad(f)(3.))
print(grad(grad(f))(3.))
-0.9899925
-0.14112
For automatic transposition to work, the JVP rule’s output tangents must be linear as a function of the input tangents. Otherwise a transposition error is raised.
Multiple arguments work like this:
@custom_jvp
def f(x, y):
return x ** 2 * y
@f.defjvp
def f_jvp(primals, tangents):
x, y = primals
x_dot, y_dot = tangents
primal_out = f(x, y)
tangent_out = 2 * x * y * x_dot + x ** 2 * y_dot
return primal_out, tangent_out
print(grad(f)(2., 3.))
12.0
The defjvps
convenience wrapper lets us define a JVP for each argument separately, and the results are computed separately then summed:
@custom_jvp
def f(x):
return jnp.sin(x)
f.defjvps(lambda t, ans, x: jnp.cos(x) * t)
print(grad(f)(3.))
-0.9899925
Here’s a defjvps
example with multiple arguments:
@custom_jvp
def f(x, y):
return x ** 2 * y
f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot,
lambda y_dot, primal_out, x, y: x ** 2 * y_dot)
print(grad(f)(2., 3.))
print(grad(f, 0)(2., 3.)) # same as above
print(grad(f, 1)(2., 3.))
12.0
12.0
4.0
As a shorthand, with defjvps
you can pass a None
value to indicate that the JVP for a particular argument is zero:
@custom_jvp
def f(x, y):
return x ** 2 * y
f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot,
None)
print(grad(f)(2., 3.))
print(grad(f, 0)(2., 3.)) # same as above
print(grad(f, 1)(2., 3.))
12.0
12.0
0.0
Calling a jax.custom_jvp
function with keyword arguments, or writing a jax.custom_jvp
function definition with default arguments, are both allowed so long as they can be unambiguously mapped to positional arguments based on the function signature retrieved by the standard library inspect.signature
mechanism.
When you’re not performing differentiation, the function f
is called just as if it weren’t decorated by jax.custom_jvp
:
@custom_jvp
def f(x):
print('called f!') # a harmless side-effect
return jnp.sin(x)
@f.defjvp
def f_jvp(primals, tangents):
print('called f_jvp!') # a harmless side-effect
x, = primals
t, = tangents
return f(x), jnp.cos(x) * t
from jax import vmap, jit
print(f(3.))
called f!
0.14112
print(vmap(f)(jnp.arange(3.)))
print(jit(f)(3.))
called f!
[0. 0.84147096 0.9092974 ]
called f!
0.14112
The custom JVP rule is invoked during differentiation, whether forward or reverse:
y, y_dot = jvp(f, (3.,), (1.,))
print(y_dot)
called f_jvp!
called f!
-0.9899925
print(grad(f)(3.))
called f_jvp!
called f!
-0.9899925
Notice that f_jvp
calls f
to compute the primal outputs. In the context of higher-order differentiation, each application of a differentiation transform will use the custom JVP rule if and only if the rule calls the original f
to compute the primal outputs. (This represents a kind of fundamental tradeoff, where we can’t make use of intermediate values from the evaluation of f
in our rule and also have the rule apply in all orders of higher-order differentiation.)
grad(grad(f))(3.)
called f_jvp!
called f_jvp!
called f!
Array(-0.14112, dtype=float32, weak_type=True)
You can use Python control flow with jax.custom_jvp
:
@custom_jvp
def f(x):
if x > 0:
return jnp.sin(x)
else:
return jnp.cos(x)
@f.defjvp
def f_jvp(primals, tangents):
x, = primals
x_dot, = tangents
ans = f(x)
if x > 0:
return ans, 2 * x_dot
else:
return ans, 3 * x_dot
print(grad(f)(1.))
print(grad(f)(-1.))
2.0
3.0
Use jax.custom_vjp
to define custom reverse-mode-only rules#
While jax.custom_jvp
suffices for controlling both forward- and, via JAX’s automatic transposition, reverse-mode differentiation behavior, in some cases we may want to directly control a VJP rule, for example in the latter two example problems presented above. We can do that with jax.custom_vjp
:
from jax import custom_vjp
import jax.numpy as jnp
# f :: a -> b
@custom_vjp
def f(x):
return jnp.sin(x)
# f_fwd :: a -> (b, c)
def f_fwd(x):
return f(x), jnp.cos(x)
# f_bwd :: (c, CT b) -> CT a
def f_bwd(cos_x, y_bar):
return (cos_x * y_bar,)
f.defvjp(f_fwd, f_bwd)
from jax import grad
print(f(3.))
print(grad(f)(3.))
0.14112
-0.9899925
In words, we again start with a primal function f
that takes inputs of type a
and produces outputs of type b
. We associate with it two functions, f_fwd
and f_bwd
, which describe how to perform the forward- and backward-passes of reverse-mode autodiff, respectively.
The function f_fwd
describes the forward pass, not only the primal computation but also what values to save for use on the backward pass. Its input signature is just like that of the primal function f
, in that it takes a primal input of type a
. But as output it produces a pair, where the first element is the primal output b
and the second element is any “residual” data of type c
to be stored for use by the backward pass. (This second output is analogous to PyTorch’s save_for_backward mechanism.)
The function f_bwd
describes the backward pass. It takes two inputs, where the first is the residual data of type c
produced by f_fwd
and the second is the output cotangents of type CT b
corresponding to the output of the primal function. It produces an output of type CT a
representing the cotangents corresponding to the input of the primal function. In particular, the output of f_bwd
must be a sequence (e.g. a tuple) of length equal to the number of arguments to the primal function.
So multiple arguments work like this:
from jax import custom_vjp
@custom_vjp
def f(x, y):
return jnp.sin(x) * y
def f_fwd(x, y):
return f(x, y), (jnp.cos(x), jnp.sin(x), y)
def f_bwd(res, g):
cos_x, sin_x, y = res
return (cos_x * g * y, sin_x * g)
f.defvjp(f_fwd, f_bwd)
print(grad(f)(2., 3.))
-1.2484405
Calling a jax.custom_vjp
function with keyword arguments, or writing a jax.custom_vjp
function definition with default arguments, are both allowed so long as they can be unambiguously mapped to positional arguments based on the function signature retrieved by the standard library inspect.signature
mechanism.
As with jax.custom_jvp
, the custom VJP rule comprised by f_fwd
and f_bwd
is not invoked if differentiation is not applied. If function is evaluated, or transformed with jit
, vmap
, or other non-differentiation transformations, then only f
is called.
@custom_vjp
def f(x):
print("called f!")
return jnp.sin(x)
def f_fwd(x):
print("called f_fwd!")
return f(x), jnp.cos(x)
def f_bwd(cos_x, y_bar):
print("called f_bwd!")
return (cos_x * y_bar,)
f.defvjp(f_fwd, f_bwd)
print(f(3.))
called f!
0.14112
print(grad(f)(3.))
called f_fwd!
called f!
called f_bwd!
-0.9899925
from jax import vjp
y, f_vjp = vjp(f, 3.)
print(y)
called f_fwd!
called f!
0.14112
print(f_vjp(1.))
called f_bwd!
(Array(-0.9899925, dtype=float32, weak_type=True),)
Forward-mode autodiff cannot be used on the jax.custom_vjp
function and will raise an error:
from jax import jvp
try:
jvp(f, (3.,), (1.,))
except TypeError as e:
print('ERROR! {}'.format(e))
called f_fwd!
called f!
ERROR! can't apply forward-mode autodiff (jvp) to a custom_vjp function.
If you want to use both forward- and reverse-mode, use jax.custom_jvp
instead.
We can use jax.custom_vjp
together with pdb
to insert a debugger trace in the backward pass:
import pdb
@custom_vjp
def debug(x):
return x # acts like identity
def debug_fwd(x):
return x, x
def debug_bwd(x, g):
import pdb; pdb.set_trace()
return g
debug.defvjp(debug_fwd, debug_bwd)
def foo(x):
y = x ** 2
y = debug(y) # insert pdb in corresponding backward pass step
return jnp.sin(y)
jax.grad(foo)(3.)
> <ipython-input-113-b19a2dc1abf7>(12)debug_bwd()
-> return g
(Pdb) p x
Array(9., dtype=float32)
(Pdb) p g
Array(-0.91113025, dtype=float32)
(Pdb) q
More features and details#
Working with list
/ tuple
/ dict
containers (and other pytrees)#
You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any pytrees are permissible, so long as their structures are consistent according to the type constraints.
Here’s a contrived example with jax.custom_jvp
:
from collections import namedtuple
Point = namedtuple("Point", ["x", "y"])
@custom_jvp
def f(pt):
x, y = pt.x, pt.y
return {'a': x ** 2,
'b': (jnp.sin(x), jnp.cos(y))}
@f.defjvp
def f_jvp(primals, tangents):
pt, = primals
pt_dot, = tangents
ans = f(pt)
ans_dot = {'a': 2 * pt.x * pt_dot.x,
'b': (jnp.cos(pt.x) * pt_dot.x, -jnp.sin(pt.y) * pt_dot.y)}
return ans, ans_dot
def fun(pt):
dct = f(pt)
return dct['a'] + dct['b'][0]
pt = Point(1., 2.)
print(f(pt))
{'a': 1.0, 'b': (Array(0.84147096, dtype=float32, weak_type=True), Array(-0.41614684, dtype=float32, weak_type=True))}
print(grad(fun)(pt))
Point(x=Array(2.5403023, dtype=float32, weak_type=True), y=Array(0., dtype=float32, weak_type=True))
And an analogous contrived example with jax.custom_vjp
:
@custom_vjp
def f(pt):
x, y = pt.x, pt.y
return {'a': x ** 2,
'b': (jnp.sin(x), jnp.cos(y))}
def f_fwd(pt):
return f(pt), pt
def f_bwd(pt, g):
a_bar, (b0_bar, b1_bar) = g['a'], g['b']
x_bar = 2 * pt.x * a_bar + jnp.cos(pt.x) * b0_bar
y_bar = -jnp.sin(pt.y) * b1_bar
return (Point(x_bar, y_bar),)
f.defvjp(f_fwd, f_bwd)
def fun(pt):
dct = f(pt)
return dct['a'] + dct['b'][0]
pt = Point(1., 2.)
print(f(pt))
{'a': 1.0, 'b': (Array(0.84147096, dtype=float32, weak_type=True), Array(-0.41614684, dtype=float32, weak_type=True))}
print(grad(fun)(pt))
Point(x=Array(2.5403023, dtype=float32, weak_type=True), y=Array(-0., dtype=float32, weak_type=True))
Handling non-differentiable arguments#
Some use cases, like the final example problem, call for non-differentiable arguments like function-valued arguments to be passed to functions with custom differentiation rules, and for those arguments to also be passed to the rules themselves. In the case of fixed_point
, the function argument f
was such a non-differentiable argument. A similar situation arises with jax.experimental.odeint
.
jax.custom_jvp
with nondiff_argnums
#
Use the optional nondiff_argnums
parameter to jax.custom_jvp
to indicate arguments like these. Here’s an example with jax.custom_jvp
:
from functools import partial
@partial(custom_jvp, nondiff_argnums=(0,))
def app(f, x):
return f(x)
@app.defjvp
def app_jvp(f, primals, tangents):
x, = primals
x_dot, = tangents
return f(x), 2. * x_dot
print(app(lambda x: x ** 3, 3.))
27.0
print(grad(app, 1)(lambda x: x ** 3, 3.))
2.0
Notice the gotcha here: no matter where in the argument list these parameters appear, they’re placed at the start of the signature of the corresponding JVP rule. Here’s another example:
@partial(custom_jvp, nondiff_argnums=(0, 2))
def app2(f, x, g):
return f(g((x)))
@app2.defjvp
def app2_jvp(f, g, primals, tangents):
x, = primals
x_dot, = tangents
return f(g(x)), 3. * x_dot
print(app2(lambda x: x ** 3, 3., lambda y: 5 * y))
3375.0
print(grad(app2, 1)(lambda x: x ** 3, 3., lambda y: 5 * y))
3.0
jax.custom_vjp
with nondiff_argnums
#
A similar option exists for jax.custom_vjp
, and, similarly, the convention is that the non-differentiable arguments are passed as the first arguments to the _bwd
rule, no matter where they appear in the signature of the original function. The signature of the _fwd
rule remains unchanged - it is the same as the signature of the primal function. Here’s an example:
@partial(custom_vjp, nondiff_argnums=(0,))
def app(f, x):
return f(x)
def app_fwd(f, x):
return f(x), x
def app_bwd(f, x, g):
return (5 * g,)
app.defvjp(app_fwd, app_bwd)
print(app(lambda x: x ** 2, 4.))
16.0
print(grad(app, 1)(lambda x: x ** 2, 4.))
5.0
See fixed_point
above for another usage example.
You don’t need to use nondiff_argnums
with array-valued arguments, for example ones with integer dtype. Instead, nondiff_argnums
should only be used for argument values that don’t correspond to JAX types (essentially don’t correspond to array types), like Python callables or strings. If JAX detects that an argument indicated by nondiff_argnums
contains a JAX Tracer, then an error is raised. The clip_gradient
function above is a good example of not using nondiff_argnums
for integer-dtype array arguments.
Control autodiff’s saved values with jax.checkpoint
(aka jax.remat
)#
import jax
import jax.numpy as jnp
TL;DR#
Use the jax.checkpoint
decorator (aliased as jax.remat
) with jax.grad
to control which intermediates are saved on the forward pass versus recomputed on the backward pass, trading off memory and FLOPs.
Don’t miss the practical notes for a discussion about how jax.checkpoint
interacts with jax.jit
.
Without using jax.checkpoint
, the forward pass of jax.grad(f)(x)
saves, for use on the backward pass, the values of Jacobian coefficients and other intermediates. We call these saved values residuals:
def g(W, x):
y = jnp.dot(W, x)
return jnp.sin(y)
def f(W1, W2, W3, x):
x = g(W1, x)
x = g(W2, x)
x = g(W3, x)
return x
W1 = jnp.ones((5, 4))
W2 = jnp.ones((6, 5))
W3 = jnp.ones((7, 6))
x = jnp.ones(4)
# Inspect the 'residual' values to be saved on the forward pass
# if we were to evaluate `jax.grad(f)(W1, W2, W3, x)`
from jax.ad_checkpoint import print_saved_residuals
jax.ad_checkpoint.print_saved_residuals(f, W1, W2, W3, x)
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
f32[5] output of cos from <ipython-input-4-f510dde58e22>:3 (g)
f32[6] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
f32[6] output of cos from <ipython-input-4-f510dde58e22>:3 (g)
f32[7] output of cos from <ipython-input-4-f510dde58e22>:3 (g)
By applying jax.checkpoint
to sub-functions, as a decorator or at specific application sites, we force JAX not to save any of that sub-function’s residuals. Instead, only the inputs of a jax.checkpoint
-decorated function might be saved, and any residuals consumed on the backward pass are re-computed from those inputs as needed:
def f2(W1, W2, W3, x):
x = jax.checkpoint(g)(W1, x)
x = jax.checkpoint(g)(W2, x)
x = jax.checkpoint(g)(W3, x)
return x
jax.ad_checkpoint.print_saved_residuals(f2, W1, W2, W3, x)
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
f32[6] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
Here the values of two sin
applications are saved because they are arguments
in subsequent applications of the jax.checkpoint
-decorated g
function, and
inputs to a jax.checkpoint
-decorated function may be saved. But no values of
cos
applications are saved.
To control which values are saveable without having to edit the definition of the function to be differentiated, you can use a rematerialization policy. Here is an example that saves only the results of dot
operations with no batch dimensions (since they are often FLOP-bound, and hence worth saving rather than recomputing):
f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
jax.ad_checkpoint.print_saved_residuals(f3, W1, W2, W3, x)
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)
f32[6] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)
f32[7] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)
You can also use policies to refer to intermediate values you name using jax.ad_checkpoint.checkpoint_name
:
from jax.ad_checkpoint import checkpoint_name
def f4(W1, W2, W3, x):
x = checkpoint_name(g(W1, x), name='a')
x = checkpoint_name(g(W2, x), name='b')
x = checkpoint_name(g(W3, x), name='c')
return x
f4 = jax.checkpoint(f4, policy=jax.checkpoint_policies.save_only_these_names('a'))
jax.ad_checkpoint.print_saved_residuals(f4, W1, W2, W3, x)
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] named 'a' from <ipython-input-7-fc0ed1c14b8d>:4 (f4)
When playing around with these toy examples, we can get a closer look at what’s going on using the print_fwd_bwd
utility defined in this notebook:
from jax.tree_util import tree_flatten, tree_unflatten
from rich.console import Console
from rich.table import Table
import rich.text
def print_fwd_bwd(f, *args, **kwargs) -> None:
args, in_tree = tree_flatten((args, kwargs))
def f_(*args):
args, kwargs = tree_unflatten(in_tree, args)
return f(*args, **kwargs)
fwd = jax.make_jaxpr(lambda *args: jax.vjp(f_, *args))(*args).jaxpr
y, f_vjp = jax.vjp(f_, *args)
res, in_tree = tree_flatten(f_vjp)
def g_(*args):
*res, y = args
f_vjp = tree_unflatten(in_tree, res)
return f_vjp(y)
bwd = jax.make_jaxpr(g_)(*res, y).jaxpr
table = Table(show_header=False, show_lines=True, padding=(1, 2, 0, 2), box=None)
table.add_row("[bold green]forward computation:",
"[bold green]backward computation:")
table.add_row(rich.text.Text.from_ansi(str(fwd)),
rich.text.Text.from_ansi(str(bwd)))
console = Console(width=240, force_jupyter=True)
console.print(table)
def _renderable_repr(self):
return self.html
rich.jupyter.JupyterRenderable._repr_html_ = _renderable_repr
# no use of jax.checkpoint:
print_fwd_bwd(f, W1, W2, W3, x)
forward computation: backward computation: { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let { lambda ; a:f32[7] b:f32[6] c:f32[7,6] d:f32[6] e:f32[5] f:f32[6,5] g:f32[5] h:f32[4] e:f32[5] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a d i:f32[5,4] j:f32[7]. let f:f32[5] = sin e k:f32[7] = mul j a g:f32[5] = cos e l:f32[6] = dot_general[dimension_numbers=(([0], [0]), ([], []))] k c h:f32[6] = dot_general[dimension_numbers=(([1], [0]), ([], []))] b f m:f32[7,6] = dot_general[dimension_numbers=(([], []), ([], []))] k b i:f32[6] = sin h n:f32[6] = mul l d j:f32[6] = cos h o:f32[5] = dot_general[dimension_numbers=(([0], [0]), ([], []))] n f k:f32[7] = dot_general[dimension_numbers=(([1], [0]), ([], []))] c i p:f32[6,5] = dot_general[dimension_numbers=(([], []), ([], []))] n e l:f32[7] = sin k q:f32[5] = mul o g m:f32[7] = cos k r:f32[4] = dot_general[dimension_numbers=(([0], [0]), ([], []))] q i in (l, m, i, c, j, f, b, g, d, a) } s:f32[5,4] = dot_general[dimension_numbers=(([], []), ([], []))] q h in (s, p, m, r) }
# using jax.checkpoint with policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable:
print_fwd_bwd(f3, W1, W2, W3, x)
forward computation: backward computation: { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let { lambda ; a:f32[5] b:f32[6] c:f32[7] d:f32[5,4] e:f32[6,5] f:f32[7,6] g:f32[4] h:f32[7]. let e:f32[5] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a d i:f32[5,4] j:f32[6,5] k:f32[7,6] l:f32[4] = remat2[ f:f32[5] = sin e differentiated=True g:f32[6] = dot_general[dimension_numbers=(([1], [0]), ([], []))] b f jaxpr={ lambda ; m:f32[5] n:f32[6] o:f32[7] p:f32[5,4] q:f32[6,5] r:f32[7,6] h:f32[6] = sin g s:f32[4] t:f32[7]. let i:f32[7] = dot_general[dimension_numbers=(([1], [0]), ([], []))] c h u:f32[5] = sin m j:f32[7] = sin i v:f32[5] = cos m in (j, e, g, i, a, b, c, d) } w:f32[6] = sin n x:f32[6] = cos n y:f32[7] = cos o z:f32[7] = mul t y ba:f32[6] = dot_general[dimension_numbers=(([0], [0]), ([], []))] z r bb:f32[6] = mul ba x bc:f32[5] = dot_general[dimension_numbers=(([0], [0]), ([], []))] bb q bd:f32[5] = mul bc v be:f32[4] = dot_general[dimension_numbers=(([0], [0]), ([], []))] bd p bf:f32[5,4] = dot_general[dimension_numbers=(([], []), ([], []))] bd s bg:f32[6,5] = dot_general[dimension_numbers=(([], []), ([], []))] bb u bh:f32[7,6] = dot_general[dimension_numbers=(([], []), ([], []))] z w in (bf, bg, bh, be) } policy=<function dot_with_no_batch_dims at 0x7f5e469b1700> prevent_cse=True ] a b c d e f g h in (i, j, k, l) }
Let’s think step by step#
You might want to first (re)read the Autodiff Cookbook Part 1.
Fundamentals of jax.checkpoint
#
In both jax.linearize
and jax.vjp
there is flexibility in how and when some values are computed. Different choices can trade off memory use against FLOPs. JAX provides control over these choices with jax.checkpoint
.
One such choice is whether to perform Jacobian coefficient computations on the forward pass, as soon as the inputs are available, or on the backward pass, just before the coefficients are needed. Consider the example of sin_vjp
:
def sin_vjp(x):
y = jnp.sin(x)
cos_x = jnp.cos(x)
return y, lambda y_bar: cos_x * y_bar
Another valid implementation would compute the value of jnp.cos(x)
on the backward pass rather than on the forward pass:
def sin_vjp2(x):
y = jnp.sin(x)
return y, lambda y_bar: jnp.cos(x) * y_bar
For this particular function, the amount of memory used by the two versions is the same, though we’ve reduced the FLOPs for the primal computation (i.e. the forward pass) and increased the FLOPs for the cotangent computation (i.e. the backward pass).
There’s another choice when it comes to function composition. Recall our VJP rule for a composition of two functions:
def f(x):
y = g(x)
z = h(y)
return z
def f_vjp(x):
y, g_vjp = jax.vjp(g, x)
z, h_vjp = jax.vjp(h, y)
def f_bwd(z_bar):
y_bar, = h_vjp(z_bar)
x_bar, = g_vjp(y_bar)
return x_bar
return z, f_bwd
An alternative is:
def f_vjp_checkpoint(x):
y = g(x)
z, h_vjp = jax.vjp(h, y)
def f_bwd2(z_bar):
y_bar, = h_vjp(z_bar)
_, g_vjp = jax.vjp(g, x)
x_bar, = g_vjp(y_bar)
return x_bar
return z, f_bwd2
In words, this alternative implementation doesn’t compute g_vjp
, or the residual values in its closure, on the forward pass. Instead it only computes them in the backward pass f_bwd2
. That means f_vjp_checkpoint
requires less memory: if g
and h
each required similar amounts of memory for their residuals, each much larger than x
, then the function produced by f_vjp_checkpoint(x)
requires half the memory as that of f_vjp(x)
!
The cost we pay is redundant work: in f_bwd2
we must re-evaluate g(x)
as part of jax.vjp(g, x)
just to discard its value (in the underscore variable on the line _, g_vjp = jax.vjp(g, x)
).
We can get this VJP behavior in autodiff � without having to write VJP functions directly � by instead using jax.checkpoint
in an alternative definition of the original function f
:
def f_checkpoint(x):
y = jax.checkpoint(g)(x)
z = h(y)
return z
In other words, we apply jax.checkpoint
to g
, the first stage of f
, rather than to f
itself. This way, when we evaluate jax.grad(f_checkpoint)(x)
, we’d get a computation like:
run the forward pass of
g
, discarding residual values;run the forward pass of
h
, saving residuals;run the backward pass of
h
, consuming residuals from step 2;re-run the forward pass of
g
, saving residuals;run the backward pass of
g
, consuming residuals from step 4.
That is, by evaluating jax.grad(f_checkpoint)(x)
we’d get the same computation as:
def f_checkpoint_grad(x):
y = g(x) # step 1
_, h_vjp = jax.vjp(h)(y) # step 2
y_bar, = h_vjp(1.0) # step 3
_, g_vjp = jax.vjp(g, x) # step 4
x_bar, = g_vjp(y_bar) # step 5
return x_bar
In general, jax.checkpoint(foo)
is a new function which has the same input-output behavior as foo
, but behaves differently under autodiff, particularly under jax.linearize
and jax.vjp
(and their wrappers, like jax.grad
) but not jax.jvp
. When differentiated, only the input to a jax.checkpoint
-differentiated function is stored on the forward pass; on the backward pass, residuals (i.e. intermediates from foo
and its Jacobian coefficient values needed for the backward pass) are recomputed.
Notice that if f = lambda x: h(g(x))
is the function we want to differentiate, i.e. if we want to apply jax.grad(f)
, we don’t get any memory savings by applying jax.checkpoint
to f
itself. That’s because evaluating jax.grad(jax.checkpoint(f))(x)
would lead to a computation like:
run the forward pass, discarding all residuals;
immediately re-run the forward pass, saving residuals;
run the backward pass, consuming residuals from step 2.
That is, in code we’d have something like:
def f_grad_bad(x):
_ = f(x) # step 1
_, f_vjp = jax.vjp(f, x) # step 2
x_bar, = f_vjp(1.0) # step 3
return x_bar
We also wouldn’t get any memory savings by applying jax.checkpoint
to h
, the second stage of f
. That’s because evaluating jax.grad(lambda x: jax.checkpoint(h)(g(x)))
would lead to a computation like:
run the forward pass of
g
, saving residuals;run the forward pass of
h
, discarding residuals;immediately re-run the forward pass of
h
, saving residuals;run the backward pass of
h
, consuming residuals from step 3;run the backward pass of
g
, consuming residuals from step 1.
That is, in code we’d have something like:
def f_grad_bad2(x):
y, g_vjp = jax.vjp(g, x) # step 1
z = h(y) # step 2
_, h_vjp = jax.vjp(h, y) # step 3
y_bar, = h_vjp(1.0) # step 3
x_bar, = g_vjp(y_bar) # step 5
return x_bar
Slightly more generally, if we had a chain composition of functions, like f = lambda x: f3(f2(f1(x)))
, and we were interested in evaluating jax.grad(f)
, we could say that:
we shouldn’t apply
jax.checkpoint
to the whole functionf
, since that wouldn’t save any memory (and will perform wasteful recomputation);we shouldn’t apply
jax.checkpoint
to the last sub-functionf3
, since that wouldn’t save any memory (and will perform wasteful recomputation);we could apply
jax.checkpoint
tof1
,f2
, or their compositionlambda x: f2(f1(x))
, since any of those might save memory and would express different memory/recompute tradeoffs.
Custom policies for what’s saveable#
As shown so far, using jax.checkpoint
switches from one extreme to another:
without
jax.checkpoint
, JAX’s autodiff tends to compute everything possible on the forward pass and store it for the backward pass;with a
jax.checkpoint
decorator, we instead compute as little as possible on the forward pass and recompute values as needed on the backward pass.
To operate between these two extremes, saving some things and not others, we can carefully place jax.checkpoint
decorators on sub-functions. But that requires editing the function to be differentiated, e.g. model code, which may be inconvenient. It can also be hard to experiment with variations.
So an alternative is to use the policy
argument to jax.checkpoint
. A policy is a callable (i.e. a function) which takes as input a type-level specification of a first order primitive application and returns a boolean indicating whether the corresponding output value(s) are allowed to be saved as residuals (or instead must be recomputed in the (co)tangent computation as needed). To write robust code, a policy should be selected from the attributes on jax.checkpoint_policies
, like jax.checkpoint_policies.dots_with_no_batch_dims_saveable
, since the API for writing custom policy callables is considered internal.
For example, consider this function to be differentiated:
def loss(params, x, y):
return jnp.sum((predict(params, x) - y)**2)
def predict(params, x):
*Ws, Wlast = params
for W in Ws:
x = layer(W, x)
x = jnp.dot(Wlast, x)
return x
def layer(W, x):
return jnp.sin(jnp.dot(W, x))
W1 = W2 = W3 = jnp.ones((4, 4))
params = [W1, W2, W3]
x = jnp.ones(4)
y = jnp.ones(4)
print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] output of sin from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of sin from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of mul from <ipython-input-18-3808b5023c3d>:2 (loss)
Instead of saving so many values on the forward pass, perhaps we only want to save the results of matrix multiplications with no batch dimension (since they may be FLOP- rather than memory-bound). We can do that using the policy jax.checkpoint_policies.dots_with_no_batch_dims_saveable
:
loss_checkpoint = jax.checkpoint(loss, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
print_saved_residuals(loss_checkpoint, params, x, y)
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] from the argument 'y'
f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:8 (predict)
Notice also that by providing a policy, we didn’t need to edit the code defining loss
, predict
, or layer
. That is particularly convenient if we want to experiment with policies in calling code (e.g. a training script) without changing library code (e.g. the neural network library).
Some policies can refer to values named with jax.ad_checkpoint.checkpoint_name
:
from jax.ad_checkpoint import checkpoint_name
def predict(params, x):
*Ws, Wlast = params
for i, W in enumerate(Ws):
x = layer(W, x)
x = checkpoint_name(x, name=f'layer{i}_output')
x = jnp.dot(Wlast, x)
return x
By itself, checkpoint_name
is just an identity function. But because some policy functions know to look for them, we can use the names to control whether certain values output by checkpoint_name
are considered saveable:
print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] named 'layer0_output' from <ipython-input-22-e48aedf368ad>:7 (predict)
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] named 'layer1_output' from <ipython-input-22-e48aedf368ad>:7 (predict)
f32[4] output of mul from <ipython-input-18-3808b5023c3d>:2 (loss)
loss_checkpoint2 = jax.checkpoint(loss, policy=jax.checkpoint_policies.save_any_names_but_these('layer1_output'))
print_saved_residuals(loss_checkpoint2, params, x, y)
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] from the argument 'y'
Another policy which refers to names is jax.checkpoint_policies.save_only_these_names
.
Some of the policies are:
everything_saveable
(the default strategy, as ifjax.checkpoint
were not being used at all)nothing_saveable
(i.e. rematerialize everything, as if a custom policy were not being used at all)dots_saveable
or its aliascheckpoint_dots
dots_with_no_batch_dims_saveable
or its aliascheckpoint_dots_with_no_batch_dims
save_anything_but_these_names
(save any values except for the output ofcheckpoint_name
with any of the names given)save_any_names_but_these
(save only named values, i.e. any outputs ofcheckpoint_name
, except for those with the names given)save_only_these_names
(save only named values, and only among the names given)
Policies only indicate what is saveable; a value is only saved if it’s actually needed by the backward pass.
Advanced: recursive jax.checkpoint
#
By applying jax.checkpoint
in the right way, there are many tradeoffs between memory usage and (re)computation that can be expressed. One surprising example is recursive checkpointing, where we apply jax.checkpoint
to a function which itself calls jax.checkpoint
-decorated functions in a way so that memory usage from the chain composition of \(D\) functions scales like \(\mathcal{O}(\log_2 D)\) rather than \(\mathcal{O}(D)\).
As a toy example, consider the chain composition of multiple jnp.sin
functions:
def chain_compose(funs):
def f(x):
for fun in funs:
x = fun(x)
return x
return f
f = chain_compose([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
In general, the number of stored residuals scales linearly with the length of the chain:
f = chain_compose([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
But we can apply jax.checkpoint
recursively to improve the scaling:
def recursive_checkpoint(funs):
if len(funs) == 1:
return funs[0]
elif len(funs) == 2:
f1, f2 = funs
return lambda x: f1(f2(x))
else:
f1 = recursive_checkpoint(funs[:len(funs)//2])
f2 = recursive_checkpoint(funs[len(funs)//2:])
return lambda x: f1(jax.checkpoint(f2)(x))
f = recursive_checkpoint([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] from the argument 'x'
f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f = recursive_checkpoint([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] from the argument 'x'
f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
The cost here, as usual, is recomputation: in particular, we end up performing \(\mathcal{O}(\log_2 D)\) times as many FLOPs:
f = chain_compose([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
forward computation: backward computation: { lambda ; a:f32[]. let { lambda ; a:f32[] b:f32[] c:f32[] d:f32[] e:f32[] f:f32[] g:f32[] h:f32[] i:f32[]. let b:f32[] = sin a j:f32[] = mul i a c:f32[] = cos a k:f32[] = mul j b d:f32[] = sin b l:f32[] = mul k c e:f32[] = cos b m:f32[] = mul l d f:f32[] = sin d n:f32[] = mul m e g:f32[] = cos d o:f32[] = mul n f h:f32[] = sin f p:f32[] = mul o g i:f32[] = cos f q:f32[] = mul p h j:f32[] = sin h in (q,) } k:f32[] = cos h l:f32[] = sin j m:f32[] = cos j n:f32[] = sin l o:f32[] = cos l p:f32[] = sin n q:f32[] = cos n in (p, q, o, m, k, i, g, e, c) }
f = recursive_checkpoint([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
forward computation: backward computation: { lambda ; a:f32[]. let { lambda ; a:f32[] b:f32[] c:f32[] d:f32[]. let b:f32[] = remat2[ e:f32[] = mul d a differentiated=False f:f32[] = mul e b jaxpr={ lambda ; c:f32[]. let d:f32[] = sin c; e:f32[] = sin d in (e,) } g:f32[] = remat2[ policy=None differentiated=True prevent_cse=True jaxpr={ lambda ; h:f32[] i:f32[]. let ] a j:f32[] = sin h f:f32[] = sin b k:f32[] = cos h g:f32[] = sin f l:f32[] = cos j h:f32[] = sin g m:f32[] = mul i l i:f32[] = sin h n:f32[] = mul m k j:f32[] = sin i in (n,) } k:f32[] = cos i policy=None l:f32[] = sin j prevent_cse=True m:f32[] = cos j ] c f in (l, m, k, g, a) } o:f32[] = remat2[ differentiated=True jaxpr={ lambda ; p:f32[] q:f32[]. let r:f32[] = sin p s:f32[] = sin r t:f32[] = sin s u:f32[] = cos s v:f32[] = cos t w:f32[] = mul q v x:f32[] = mul w u y:f32[] = remat2[ differentiated=True jaxpr={ lambda ; z:f32[] ba:f32[]. let bb:f32[] = sin z bc:f32[] = cos z bd:f32[] = cos bb be:f32[] = mul ba bd bf:f32[] = mul be bc in (bf,) } policy=None prevent_cse=True ] p x in (y,) } policy=None prevent_cse=True ] 3.0 g in (o,) }
Practical notes#
When differentiated functions are staged out to XLA for compilation, for example by applying jax.jit
to a function which contains a jax.grad
call, XLA will automatically optimize the computation, including decisions about when to compute or rematerialize values. As a result, jax.checkpoint
often isn’t needed for differentiated functions under a jax.jit
. XLA will optimize things for you.
One exception is when using staged-out control flow, like jax.lax.scan
. Automatic compiler optimizations across multiple control flow primitives, e.g. across a forward-pass scan
and the corresponding backward-pass scan
, typically aren’t aren’t as thorough. As a result, it’s often a good idea to use jax.checkpoint
on the body function passed to jax.lax.scan
.
For example, one common pattern in large Transformer models is to express the architecture as a jax.lax.scan
over layers so as to reduce compilation times. That is, using a simple fully-connected network as an analogy, instead of writing something like this:
LayerParam = tuple[jnp.ndarray, jnp.ndarray] # weights, bias pair for a layer
ParamsList = list[LayerParam]
def net(params: ParamsList, x: jnp.ndarray):
for W, b in params:
x = jnp.maximum(jnp.dot(x, W) + b, 0.)
return x
We would instead iterate over the layer application with jax.lax.scan
:
StackedWeights = jnp.ndarray # all weight matrices stacked together
StackedBiases = jnp.ndarray # all bias vectors stacked together
all_weights = jnp.stack([W for W, _ in params])
all_biases = jnp.stack([b for _, b in params])
def layer(x, W_b_pair):
W, b = W_b_pair
out = jnp.maximum(jnp.dot(x, W) + b, 0.)
return out, None
def net(all_weights, all_biases, x):
x, _ = jax.lax.scan(layer, x, (all_weights, all_biases))
return x
This scan-over-layers version reduces compile times, but by foiling some compiler optimizations it can lead to inefficient computation of gradients. To mitigate the issue, we would use jax.checkpoint
on the scanned function:
from functools import partial
@partial(jax.checkpoint,
policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def layer(x, W_b_pair):
W, b = W_b_pair
out = jnp.maximum(jnp.dot(x, W) + b, 0.)
return out, None
By using jax.checkpoint
this way, we’re manually controlling which values JAX’s autodiff saves between the forward and backward passes, and hence not relying on XLA optimizations to choose for us.
How JAX primitives work#
necula@google.com, October 2019.
JAX implements certain transformations of Python functions, e.g., jit
, grad
,
vmap
, or pmap
. The Python functions to be transformed must be JAX-traceable,
which means that as the Python function executes
the only operations it applies to the data are either inspections of data
attributes such as shape or type, or special operations called JAX primitives.
In particular, a JAX-traceable function is sometimes invoked by JAX with
abstract arguments. An example of a JAX abstract value is ShapedArray(float32[2,2])
,
which captures the type and the shape of values, but not the concrete data values.
JAX primitives know how to operate on both concrete data
values and on the JAX abstract values.
The JAX-transformed functions must themselves be JAX-traceable functions,
to ensure that these transformations
can be composed, e.g., jit(jacfwd(grad(f)))
.
There are pre-defined JAX primitives corresponding to most XLA operations, e.g., add, matmul, sin, cos, indexing. JAX comes with an implementation of numpy functions in terms of JAX primitives, which means that Python programs using JAX’s implementation of numpy are JAX-traceable and therefore transformable. Other libraries can be made JAX-traceable by implementing them in terms of JAX primitives.
The set of JAX primitives is extensible. Instead of reimplementing a function in terms of pre-defined JAX primitives, one can define a new primitive that encapsulates the behavior of the function.
The goal of this document is to explain the interface that a JAX primitive must support in order to allow JAX to perform all its transformations.
Consider that we want to add to JAX support for a multiply-add function with three arguments, defined mathematically as “multiply_add(x, y, z) = x * y + z”. This function operates on 3 identically-shaped tensors of floating point values and performs the operations pointwise.
Using existing primitives#
The easiest way to define new functions is to write them in terms of JAX primitives, or in terms of other
functions that are themselves written using JAX primitives, e.g., those
defined in the jax.lax
module:
from jax import lax
from jax._src import api
def multiply_add_lax(x, y, z):
"""Implementation of multiply-add using the jax.lax primitives."""
return lax.add(lax.mul(x, y), z)
def square_add_lax(a, b):
"""A square-add function using the newly defined multiply-add."""
return multiply_add_lax(a, a, b)
print("square_add_lax = ", square_add_lax(2., 10.))
# Differentiate w.r.t. the first argument
print("grad(square_add_lax) = ", api.grad(square_add_lax, argnums=0)(2.0, 10.))
square_add_lax = 14.0
grad(square_add_lax) = 4.0
In order to understand how JAX is internally using the primitives, we add some helpers for tracing function calls.
#@title Helper functions (execute this cell)
import functools
import traceback
_indentation = 0
def _trace(msg=None):
"""Print a message at current indentation."""
if msg is not None:
print(" " * _indentation + msg)
def _trace_indent(msg=None):
"""Print a message and then indent the rest."""
global _indentation
_trace(msg)
_indentation = 1 + _indentation
def _trace_unindent(msg=None):
"""Unindent then print a message."""
global _indentation
_indentation = _indentation - 1
_trace(msg)
def trace(name):
"""A decorator for functions to trace arguments and results."""
def trace_func(func): # pylint: disable=missing-docstring
def pp(v):
"""Print certain values more succinctly"""
vtype = str(type(v))
if "jax._src.xla_bridge._JaxComputationBuilder" in vtype:
return "<JaxComputationBuilder>"
elif "jaxlib.xla_extension.XlaOp" in vtype:
return "<XlaOp at 0x{:x}>".format(id(v))
elif ("partial_eval.JaxprTracer" in vtype or
"batching.BatchTracer" in vtype or
"ad.JVPTracer" in vtype):
return "Traced<{}>".format(v.aval)
elif isinstance(v, tuple):
return "({})".format(pp_values(v))
else:
return str(v)
def pp_values(args):
return ", ".join([pp(arg) for arg in args])
@functools.wraps(func)
def func_wrapper(*args):
_trace_indent("call {}({})".format(name, pp_values(args)))
res = func(*args)
_trace_unindent("|<- {} = {}".format(name, pp(res)))
return res
return func_wrapper
return trace_func
class expectNotImplementedError(object):
"""Context manager to check for NotImplementedError."""
def __enter__(self): pass
def __exit__(self, type, value, tb):
global _indentation
_indentation = 0
if type is NotImplementedError:
print("\nFound expected exception:")
traceback.print_exc(limit=3)
return True
elif type is None: # No exception
assert False, "Expected NotImplementedError"
else:
return False
Instead of using jax.lax
primitives directly, we can use other functions
that are already written in terms of those primitives, such as those in jax.numpy
:
import jax.numpy as jnp
import numpy as np
@trace("multiply_add_numpy")
def multiply_add_numpy(x, y, z):
return jnp.add(jnp.multiply(x, y), z)
@trace("square_add_numpy")
def square_add_numpy(a, b):
return multiply_add_numpy(a, a, b)
print("\nNormal evaluation:")
print("square_add_numpy = ", square_add_numpy(2., 10.))
print("\nGradient evaluation:")
print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.))
Normal evaluation:
call square_add_numpy(2.0, 10.0)
call multiply_add_numpy(2.0, 2.0, 10.0)
|<- multiply_add_numpy = 14.0
|<- square_add_numpy = 14.0
square_add_numpy = 14.0
Gradient evaluation:
call square_add_numpy(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
call multiply_add_numpy(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
|<- multiply_add_numpy = Traced<ConcreteArray(14.0, dtype=float32, weak_type=True)>
|<- square_add_numpy = Traced<ConcreteArray(14.0, dtype=float32, weak_type=True)>
grad(square_add_numpy) = 4.0
Notice that in the process of computing grad
, JAX invokes square_add_numpy
and
multiply_add_numpy
with special arguments ConcreteArray(...)
(described further
below in this colab).
It is important to remember that a JAX-traceable function must be able to
operate not only on concrete arguments but also on special abstract arguments
that JAX may use to abstract the function execution.
The JAX traceability property is satisfied as long as the function is written in terms of JAX primitives.
Defining new JAX primitives#
The right way to add support for multiply-add is in terms of existing JAX primitives, as shown above. However, in order to demonstrate how JAX primitives work let us pretend that we want to add a new primitive to JAX for the multiply-add functionality.
from jax import core
multiply_add_p = core.Primitive("multiply_add") # Create the primitive
@trace("multiply_add_prim")
def multiply_add_prim(x, y, z):
"""The JAX-traceable way to use the JAX primitive.
Note that the traced arguments must be passed as positional arguments
to `bind`.
"""
return multiply_add_p.bind(x, y, z)
@trace("square_add_prim")
def square_add_prim(a, b):
"""A square-add function implemented using the new JAX-primitive."""
return multiply_add_prim(a, a, b)
If we try to call the newly defined functions we get an error, because we have not yet told JAX anything about the semantics of the new primitive.
with expectNotImplementedError():
square_add_prim(2., 10.)
call square_add_prim(2.0, 10.0)
call multiply_add_prim(2.0, 2.0, 10.0)
Found expected exception:
Traceback (most recent call last):
File "/tmp/ipykernel_4031/2844449444.py", line 2, in <module>
square_add_prim(2., 10.)
File "/tmp/ipykernel_4031/1393342955.py", line 48, in func_wrapper
res = func(*args)
File "/tmp/ipykernel_4031/1308506715.py", line 16, in square_add_prim
return multiply_add_prim(a, a, b)
NotImplementedError: Evaluation rule for 'multiply_add' not implemented
Primal evaluation rules#
@trace("multiply_add_impl")
def multiply_add_impl(x, y, z):
"""Concrete implementation of the primitive.
This function does not need to be JAX traceable.
Args:
x, y, z: the concrete arguments of the primitive. Will only be called with
concrete values.
Returns:
the concrete result of the primitive.
"""
# Note that we can use the original numpy, which is not JAX traceable
return np.add(np.multiply(x, y), z)
# Now we register the primal implementation with JAX
multiply_add_p.def_impl(multiply_add_impl)
<function __main__.multiply_add_impl(x, y, z)>
assert square_add_prim(2., 10.) == 14.
call square_add_prim(2.0, 10.0)
call multiply_add_prim(2.0, 2.0, 10.0)
call multiply_add_impl(2.0, 2.0, 10.0)
|<- multiply_add_impl = 14.0
|<- multiply_add_prim = 14.0
|<- square_add_prim = 14.0
JIT#
If we now try to use jit
we get a NotImplementedError
:
with expectNotImplementedError():
api.jit(square_add_prim)(2., 10.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
Found expected exception:
Traceback (most recent call last):
File "/tmp/ipykernel_4031/1813425700.py", line 2, in <module>
api.jit(square_add_prim)(2., 10.)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py", line 304, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
NotImplementedError: Abstract evaluation for 'multiply_add' not implemented
Abstract evaluation rules#
In order to JIT the function, and for other transformations as well, JAX first evaluates it abstractly using only the shape and type of the arguments. This abstract evaluation serves multiple purposes:
Gets the sequence of JAX primitives that are used in the computation. This sequence will be compiled.
Computes the shape and type of all vectors and operations used in the computation.
For example, the abstraction of a vector with 3 elements may be ShapedArray(float32[3])
, or ConcreteArray([1., 2., 3.])
.
In the latter case, JAX uses the actual concrete value wrapped as an abstract value.
from jax import core
@trace("multiply_add_abstract_eval")
def multiply_add_abstract_eval(xs, ys, zs):
"""Abstract evaluation of the primitive.
This function does not need to be JAX traceable. It will be invoked with
abstractions of the actual arguments.
Args:
xs, ys, zs: abstractions of the arguments.
Result:
a ShapedArray for the result of the primitive.
"""
assert xs.shape == ys.shape
assert xs.shape == zs.shape
return core.ShapedArray(xs.shape, xs.dtype)
# Now we register the abstract evaluation with JAX
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)
<function __main__.multiply_add_abstract_eval(xs, ys, zs)>
If we re-attempt to JIT, we see how the abstract evaluation proceeds, but we get another error, about missing the actual XLA compilation rule:
with expectNotImplementedError():
api.jit(square_add_prim)(2., 10.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
Found expected exception:
Traceback (most recent call last):
File "/home/docs/.asdf/installs/python/3.10.13/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/docs/.asdf/installs/python/3.10.13/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
app.launch_new_instance()
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform cpu
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/tmp/ipykernel_4031/1813425700.py", line 2, in <module>
api.jit(square_add_prim)(2., 10.)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py", line 304, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform cpu
XLA Compilation rules#
JAX compilation works by compiling each primitive into a graph of XLA operations.
This is the biggest hurdle to adding new functionality to JAX, because the
set of XLA operations is limited, and JAX already has pre-defined primitives
for most of them. However, XLA includes a CustomCall
operation that can be used to encapsulate arbitrary functionality defined using C++.
from jax._src.lib.mlir.dialects import hlo
@trace("multiply_add_lowering")
def multiply_add_lowering(ctx, xc, yc, zc):
"""The compilation to XLA of the primitive.
Given an mlir.ir.Value for each argument, return the mlir.ir.Values for
the results of the function.
Does not need to be a JAX-traceable function.
"""
return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result]
# Now we register the lowering rule with JAX
# For GPU see the [Custom operations for GPUs](https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html)
# TODO: TPU?
from jax.interpreters import mlir
mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu')
<function __main__.multiply_add_lowering(ctx, xc, yc, zc)>
Now we succeed to JIT. Notice below that JAX first evaluates the function
abstractly, which triggers the multiply_add_abstract_eval
function, and
then compiles the set of primitives it has encountered, including multiply_add
.
At this point JAX invokes multiply_add_xla_translation
.
assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7fbfd03fa890>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7fbfd0229430>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7fbfd02294b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7fbfd0229470>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7fbfd194aa40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7fbfd0461120>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x55fbb990a290>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_4031/1570919344.py":1:0) at callsite("<module>"("/tmp/ipykernel_4031/1570919344.py":1:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at "_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7fbfd0490870, file "/tmp/ipykernel_4031/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0)), (<code object func_wrapper at 0x7fc000582b80, file "/tmp/ipykernel_4031/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0)), (<code object square_add_prim at 0x7fbfd0490500, file "/tmp/ipykernel_4031/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0)), (<code object <lambda> at 0x7fbfd020a4a0, file "/tmp/ipykernel_4031/1570919344.py", line 1>, 6): loc("<lambda>"("/tmp/ipykernel_4031/1570919344.py":1:0)), (<code object <module> at 0x7fbfd0208920, file "/tmp/ipykernel_4031/1570919344.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_4031/1570919344.py":1:0)), (<code object run_code at 0x7fc00a010920, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)), (<code object run_ast_nodes at 0x7fc00a0107c0, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3418>, 500): loc("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0)), (<code object run_cell_async at 0x7fc00a010450, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3183>, 828): loc("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0)), (<code object _pseudo_sync_runner at 0x7fc009d76fa0, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 120>, 8): loc("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0))}, canonical_name_cache={'/tmp/ipykernel_4031/1308506715.py': '/tmp/ipykernel_4031/1308506715.py', '/tmp/ipykernel_4031/1393342955.py': '/tmp/ipykernel_4031/1393342955.py', '/tmp/ipykernel_4031/1570919344.py': '/tmp/ipykernel_4031/1570919344.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_4031/1308506715.py': True, '/tmp/ipykernel_4031/1393342955.py': True, '/tmp/ipykernel_4031/1570919344.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, platforms=None, global_constant_computation=False, replace_tokens_with_dummy=True)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7fbfd022c670>, tokens_out=None, axis_size_env=None, dim_var_values=[]), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7fbfd04557f0>]
Below is another use of jit
where we compile only
with respect to the first argument. Notice how the second argument to square_add_prim
is concrete, which leads
in the third argument to multiply_add_abstract_eval
being
ConcreteArray
. We see that multiply_add_abstract_eval
may be used with
both ShapedArray
and ConcreteArray
.
assert api.jit(lambda x, y: square_add_prim(x, y),
static_argnums=1)(2., 10.) == 14.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 10.0)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 10.0)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7fbfd03fb470>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7fbfd02311b0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7fbfd0231230>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7fbfd02311f0>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7fbfd194aa40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7fbfd022d690>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x55fbb9c975c0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_4031/4165789807.py":1:0) at callsite("<module>"("/tmp/ipykernel_4031/4165789807.py":1:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at "_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7fbfd0490870, file "/tmp/ipykernel_4031/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0)), (<code object func_wrapper at 0x7fc000582b80, file "/tmp/ipykernel_4031/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0)), (<code object square_add_prim at 0x7fbfd0490500, file "/tmp/ipykernel_4031/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0)), (<code object <lambda> at 0x7fbfd1967b50, file "/tmp/ipykernel_4031/4165789807.py", line 1>, 6): loc("<lambda>"("/tmp/ipykernel_4031/4165789807.py":1:0)), (<code object <module> at 0x7fbfd1965b00, file "/tmp/ipykernel_4031/4165789807.py", line 1>, 20): loc("<module>"("/tmp/ipykernel_4031/4165789807.py":1:0)), (<code object run_code at 0x7fc00a010920, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)), (<code object run_ast_nodes at 0x7fc00a0107c0, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3418>, 500): loc("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0)), (<code object run_cell_async at 0x7fc00a010450, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3183>, 828): loc("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0)), (<code object _pseudo_sync_runner at 0x7fc009d76fa0, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 120>, 8): loc("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0))}, canonical_name_cache={'/tmp/ipykernel_4031/1308506715.py': '/tmp/ipykernel_4031/1308506715.py', '/tmp/ipykernel_4031/1393342955.py': '/tmp/ipykernel_4031/1393342955.py', '/tmp/ipykernel_4031/4165789807.py': '/tmp/ipykernel_4031/4165789807.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_4031/1308506715.py': True, '/tmp/ipykernel_4031/1393342955.py': True, '/tmp/ipykernel_4031/4165789807.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, platforms=None, global_constant_computation=False, replace_tokens_with_dummy=True)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7fbfd022dc60>, tokens_out=None, axis_size_env=None, dim_var_values=[]), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%0 = "stablehlo.constant"() <{value = dense<1.000000e+01> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7fc00a0cd7f0>]
Forward differentiation#
JAX implements forward differentiation in the form of a Jacobian-vector product (see the JAX autodiff cookbook).
If we attempt now to compute the jvp
function we get an
error because we have not yet told JAX how to differentiate
the multiply_add
primitive.
# The second argument `(2., 10.)` are the argument values
# where we evaluate the Jacobian, and the third `(1., 1.)`
# are the values of the tangents for the arguments.
with expectNotImplementedError():
api.jvp(square_add_prim, (2., 10.), (1., 1.))
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
Found expected exception:
Traceback (most recent call last):
File "/tmp/ipykernel_4031/800067577.py", line 5, in <module>
api.jvp(square_add_prim, (2., 10.), (1., 1.))
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 1906, in jvp
return _jvp(lu.wrap_init(fun), primals, tangents, has_aux=has_aux)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 1935, in _jvp
out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat)
NotImplementedError: Differentiation rule for 'multiply_add' not implemented
from jax.interpreters import ad
@trace("multiply_add_value_and_jvp")
def multiply_add_value_and_jvp(arg_values, arg_tangents):
"""Evaluates the primal output and the tangents (Jacobian-vector product).
Given values of the arguments and perturbation of the arguments (tangents),
compute the output of the primitive and the perturbation of the output.
This method must be JAX-traceable. JAX may invoke it with abstract values
for the arguments and tangents.
Args:
arg_values: a tuple of arguments
arg_tangents: a tuple with the tangents of the arguments. The tuple has
the same length as the arg_values. Some of the tangents may also be the
special value ad.Zero to specify a zero tangent.
Returns:
a pair of the primal output and the tangent.
"""
x, y, z = arg_values
xt, yt, zt = arg_tangents
_trace("Primal evaluation:")
# Now we have a JAX-traceable computation of the output.
# Normally, we can use the ma primitive itself to compute the primal output.
primal_out = multiply_add_prim(x, y, z)
_trace("Tangent evaluation:")
# We must use a JAX-traceable way to compute the tangent. It turns out that
# the output tangent can be computed as (xt * y + x * yt + zt),
# which we can implement in a JAX-traceable way using the same "multiply_add_prim" primitive.
# We do need to deal specially with Zero. Here we just turn it into a
# proper tensor of 0s (of the same shape as 'x').
# An alternative would be to check for Zero and perform algebraic
# simplification of the output tangent computation.
def make_zero(tan):
return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan
output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))
return (primal_out, output_tangent)
# Register the forward differentiation rule with JAX
ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp
# Tangent is: xt*y + x*yt + zt = 1.*2. + 2.*1. + 1. = 5.
assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.)
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (1.0, 1.0, 1.0))
Primal evaluation:
call multiply_add_prim(2.0, 2.0, 10.0)
call multiply_add_impl(2.0, 2.0, 10.0)
|<- multiply_add_impl = 14.0
|<- multiply_add_prim = 14.0
Tangent evaluation:
call multiply_add_prim(2.0, 1.0, 1.0)
call multiply_add_impl(2.0, 1.0, 1.0)
|<- multiply_add_impl = 3.0
|<- multiply_add_prim = 3.0
call multiply_add_prim(1.0, 2.0, 3.0)
call multiply_add_impl(1.0, 2.0, 3.0)
|<- multiply_add_impl = 5.0
|<- multiply_add_prim = 5.0
|<- multiply_add_value_and_jvp = (14.0, 5.0)
|<- multiply_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
|<- square_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
TO EXPLAIN:
Why is JAX using ConcreteArray in square_add_prim? There is no abstract evaluation going on here.
Not sure how to explain that multiply_add_prim is invoked with ConcreteValue, yet we do not call the multiply_add_abstract_eval.
I think it would be useful to show the jaxpr here
JIT of forward differentiation#
We can apply JIT to the forward differentiation function:
assert api.jit(lambda arg_values, arg_tangents:
api.jvp(square_add_prim, arg_values, arg_tangents))(
(2., 10.), (1., 1.)) == (14., 5.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
call multiply_add_value_and_jvp((Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>), (Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>))
Primal evaluation:
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
Tangent evaluation:
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- multiply_add_value_and_jvp = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7fbfd027e520>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7fbfd0280230>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7fbfd0280730>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7fbfd02806f0>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7fbfd194aa40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7fbfd022f0d0>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x55fbb9de4690>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":27:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_4031/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_4031/2145028508.py":1:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7fbfd0490870, file "/tmp/ipykernel_4031/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0)), (<code object func_wrapper at 0x7fc000582b80, file "/tmp/ipykernel_4031/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7fbfd02090b0, file "/tmp/ipykernel_4031/3197095916.py", line 4>, 36): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":27:0)), (<code object square_add_prim at 0x7fbfd0490500, file "/tmp/ipykernel_4031/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0)), (<code object <lambda> at 0x7fbfd0209790, file "/tmp/ipykernel_4031/2145028508.py", line 1>, 10): loc("<lambda>"("/tmp/ipykernel_4031/2145028508.py":2:0)), (<code object <module> at 0x7fbfd0209370, file "/tmp/ipykernel_4031/2145028508.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_4031/2145028508.py":1:0))}, canonical_name_cache={'/tmp/ipykernel_4031/1308506715.py': '/tmp/ipykernel_4031/1308506715.py', '/tmp/ipykernel_4031/1393342955.py': '/tmp/ipykernel_4031/1393342955.py', '/tmp/ipykernel_4031/3197095916.py': '/tmp/ipykernel_4031/3197095916.py', '/tmp/ipykernel_4031/2145028508.py': '/tmp/ipykernel_4031/2145028508.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_4031/1308506715.py': True, '/tmp/ipykernel_4031/1393342955.py': True, '/tmp/ipykernel_4031/3197095916.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/tmp/ipykernel_4031/2145028508.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, platforms=None, global_constant_computation=False, replace_tokens_with_dummy=True)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7fbfd022eda0>, tokens_out=None, axis_size_env=None, dim_var_values=[]), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7fbfd1cce630>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7fbfd027e520>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7fbfd0280230>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7fbfd0280730>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7fbfd02806f0>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7fbfd194aa40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7fbfd022f0d0>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x55fbb9de4690>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":27:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_4031/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_4031/2145028508.py":1:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x55fbb9de48c0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_4031/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_4031/2145028508.py":1:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7fbfd0490870, file "/tmp/ipykernel_4031/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0)), (<code object func_wrapper at 0x7fc000582b80, file "/tmp/ipykernel_4031/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7fbfd02090b0, file "/tmp/ipykernel_4031/3197095916.py", line 4>, 36): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":27:0)), (<code object square_add_prim at 0x7fbfd0490500, file "/tmp/ipykernel_4031/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0)), (<code object <lambda> at 0x7fbfd0209790, file "/tmp/ipykernel_4031/2145028508.py", line 1>, 10): loc("<lambda>"("/tmp/ipykernel_4031/2145028508.py":2:0)), (<code object <module> at 0x7fbfd0209370, file "/tmp/ipykernel_4031/2145028508.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_4031/2145028508.py":1:0)), (<code object multiply_add_value_and_jvp at 0x7fbfd02090b0, file "/tmp/ipykernel_4031/3197095916.py", line 4>, 86): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":41:0))}, canonical_name_cache={'/tmp/ipykernel_4031/1308506715.py': '/tmp/ipykernel_4031/1308506715.py', '/tmp/ipykernel_4031/1393342955.py': '/tmp/ipykernel_4031/1393342955.py', '/tmp/ipykernel_4031/3197095916.py': '/tmp/ipykernel_4031/3197095916.py', '/tmp/ipykernel_4031/2145028508.py': '/tmp/ipykernel_4031/2145028508.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_4031/1308506715.py': True, '/tmp/ipykernel_4031/1393342955.py': True, '/tmp/ipykernel_4031/3197095916.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/tmp/ipykernel_4031/2145028508.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, platforms=None, global_constant_computation=False, replace_tokens_with_dummy=True)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7fbfd022dfc0>, tokens_out=None, axis_size_env=None, dim_var_values=[]), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 2), Value(<block argument> of type 'tensor<f32>' at index: 3))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7fbfd04435b0>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7fbfd027e520>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7fbfd0280230>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7fbfd0280730>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7fbfd02806f0>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7fbfd194aa40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7fbfd022f0d0>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x55fbb9de4690>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":27:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_4031/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_4031/2145028508.py":1:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x55fbb9de48c0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_4031/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_4031/2145028508.py":1:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x55fbb9ee9d20>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_4031/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_4031/2145028508.py":1:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7fbfd0490870, file "/tmp/ipykernel_4031/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0)), (<code object func_wrapper at 0x7fc000582b80, file "/tmp/ipykernel_4031/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7fbfd02090b0, file "/tmp/ipykernel_4031/3197095916.py", line 4>, 36): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":27:0)), (<code object square_add_prim at 0x7fbfd0490500, file "/tmp/ipykernel_4031/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0)), (<code object <lambda> at 0x7fbfd0209790, file "/tmp/ipykernel_4031/2145028508.py", line 1>, 10): loc("<lambda>"("/tmp/ipykernel_4031/2145028508.py":2:0)), (<code object <module> at 0x7fbfd0209370, file "/tmp/ipykernel_4031/2145028508.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_4031/2145028508.py":1:0)), (<code object multiply_add_value_and_jvp at 0x7fbfd02090b0, file "/tmp/ipykernel_4031/3197095916.py", line 4>, 86): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":41:0)), (<code object multiply_add_value_and_jvp at 0x7fbfd02090b0, file "/tmp/ipykernel_4031/3197095916.py", line 4>, 88): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":41:0))}, canonical_name_cache={'/tmp/ipykernel_4031/1308506715.py': '/tmp/ipykernel_4031/1308506715.py', '/tmp/ipykernel_4031/1393342955.py': '/tmp/ipykernel_4031/1393342955.py', '/tmp/ipykernel_4031/3197095916.py': '/tmp/ipykernel_4031/3197095916.py', '/tmp/ipykernel_4031/2145028508.py': '/tmp/ipykernel_4031/2145028508.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_4031/1308506715.py': True, '/tmp/ipykernel_4031/1393342955.py': True, '/tmp/ipykernel_4031/3197095916.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/tmp/ipykernel_4031/2145028508.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, platforms=None, global_constant_computation=False, replace_tokens_with_dummy=True)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[])], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7fbfd022ee60>, tokens_out=None, axis_size_env=None, dim_var_values=[]), Value(<block argument> of type 'tensor<f32>' at index: 2), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%3 = "stablehlo.add"(%2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7fbfd02323b0>]
Notice that first we evaluate multiply_add_value_and_jvp
abstractly, which in turn
evaluates abstractly both the primal and the tangent evaluation (a total of
3 invocations of the ma
primitive). Then we compile the 3 occurrences
of the primitive.
Reverse differentiation#
If we attempt now to use reverse differentiation we
see that JAX starts by using the multiply_add_value_and_jvp
to
compute the forward differentiation for abstract values, but then runs
into a NotImplementedError
.
When computing the reverse differentiation JAX first does abstract evaluation
of the forward differentiation code multiply_add_value_and_jvp
to obtain a
trace of primitives that compute the output tangent.
Observe that JAX performs this abstract evaluation with concrete values
for the differentiation point, and abstract values for the tangents.
Observe also that JAX uses the special abstract tangent value Zero
for
the tangent corresponding to the 3rd argument of ma
. This reflects the
fact that we do not differentiate w.r.t. the 2nd argument to square_add_prim
,
which flows to the 3rd argument to multiply_add_prim
.
Observe also that during the abstract evaluation of the tangent we pass the
value 0.0 as the tangent for the 3rd argument. This is due to the use
of the make_zero
function in the definition of multiply_add_value_and_jvp
.
# This is reverse differentiation w.r.t. the first argument of square_add_prim
with expectNotImplementedError():
api.grad(square_add_prim)(2., 10.)
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
Primal evaluation:
call multiply_add_prim(2.0, 2.0, 10.0)
call multiply_add_impl(2.0, 2.0, 10.0)
|<- multiply_add_impl = 14.0
|<- multiply_add_prim = 14.0
Tangent evaluation:
call multiply_add_prim(2.0, Traced<ShapedArray(float32[], weak_type=True)>, 0.0)
call multiply_add_abstract_eval(ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[], weak_type=True), ConcreteArray(0.0, dtype=float32, weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, 2.0, Traced<ShapedArray(float32[])>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[]))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- multiply_add_value_and_jvp = (14.0, Traced<ShapedArray(float32[])>)
|<- multiply_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
|<- square_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
Found expected exception:
Traceback (most recent call last):
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py", line 283, in get_primitive_transpose
return primitive_transposes[p]
KeyError: multiply_add
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/docs/.asdf/installs/python/3.10.13/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/docs/.asdf/installs/python/3.10.13/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
app.launch_new_instance()
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'multiply_add' not implemented
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/tmp/ipykernel_4031/339076514.py", line 3, in <module>
api.grad(square_add_prim)(2., 10.)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 621, in grad_f
_, g = value_and_grad_f(*args, **kwargs)
NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'multiply_add' not implemented
The above error is because there is a missing piece for JAX to be able to use the forward differentiation code to compute reverse differentiation.
Transposition#
As explained above, when computing reverse differentiation JAX obtains a trace of primitives that compute the tangent using forward differentiation. Then, JAX interprets this trace abstractly backwards and for each primitive it applies a transposition rule.
To understand what is going on, consider for now a simpler example of the function “f(x, y) = x * y + y”. Assume we need to differentiate at the point (2., 4.)
. JAX will produce the following JVP tangent calculation of ft
from the tangents of the input xt
and yt
:
a = xt * 4.
b = 2. * yt
c = a + b
ft = c + yt
By construction, the tangent calculation is always linear in the input tangents. The only non-linear operator that may arise in the tangent calculation is multiplication, but then one of the operands is constant.
JAX will produce the reverse differentiation computation by processing the JVP computation backwards. For each operation in the tangent computation, it accumulates the cotangents of the variables used by the operation, using the cotangent of the result of the operation:
# Initialize cotangents of inputs and intermediate vars
xct = yct = act = bct = cct = 0.
# Initialize cotangent of the output
fct = 1.
# Process "ft = c + yt"
cct += fct
yct += fct
# Process "c = a + b"
act += cct
bct += cct
# Process "b = 2. * yt"
yct += 2. * bct
# Process "a = xt * 4."
xct += act * 4.
One can verify that this computation produces xct = 4.
and yct = 3.
, which
are the partial derivatives of the function f
.
JAX knows for each primitive that may appear in a JVP calculation how to transpose it. Conceptually, if the primitive p(x, y, z)
is linear in the arguments y
and z
for a constant value of x
, e.g., p(x, y, z) = y*cy + z*cz
, then the transposition of the primitive is:
p_transpose(out_ct, x, _, _) = (None, out_ct*cy, out_ct*cz)
Notice that p_transpose
takes the cotangent of the output of the primitive and a value corresponding to each argument of the primitive. For the linear arguments, the transposition gets an undefined _
value, and for the other
arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value None
returned
for the constant arguments.
In particular,
add_transpose(out_ct, _, _) = (out_ct, out_ct)
mult_transpose(out_ct, x, _) = (None, x * out_ct)
mult_transpose(out_ct, _, y) = (out_ct * y, None)
@trace("multiply_add_transpose")
def multiply_add_transpose(ct, x, y, z):
"""Evaluates the transpose of a linear primitive.
This method is only used when computing the backward gradient following
value_and_jvp, and is only needed for primitives that are used in the JVP
calculation for some other primitive. We need transposition for multiply_add_prim,
because we have used multiply_add_prim in the computation of the output_tangent in
multiply_add_value_and_jvp.
In our case, multiply_add is not a linear primitive. However, it is used linearly
w.r.t. tangents in multiply_add_value_and_jvp:
output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))
Always one of the first two multiplicative arguments is a constant.
Args:
ct: the cotangent of the output of the primitive.
x, y, z: values of the arguments. The arguments that are used linearly
get an ad.UndefinedPrimal value. The other arguments get a constant
value.
Returns:
a tuple with the cotangent of the inputs, with the value None
corresponding to the constant arguments.
"""
if not ad.is_undefined_primal(x):
# This use of multiply_add is with a constant "x"
assert ad.is_undefined_primal(y)
ct_y = ad.Zero(y.aval) if type(ct) is ad.Zero else multiply_add_prim(x, ct, lax.zeros_like_array(x))
res = None, ct_y, ct
else:
# This use of multiply_add is with a constant "y"
assert ad.is_undefined_primal(x)
ct_x = ad.Zero(x.aval) if type(ct) is ad.Zero else multiply_add_prim(ct, y, lax.zeros_like_array(y))
res = ct_x, None, ct
return res
ad.primitive_transposes[multiply_add_p] = multiply_add_transpose
Now we can complete the run of the grad
:
assert api.grad(square_add_prim)(2., 10.) == 4.
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
Primal evaluation:
call multiply_add_prim(2.0, 2.0, 10.0)
call multiply_add_impl(2.0, 2.0, 10.0)
|<- multiply_add_impl = 14.0
|<- multiply_add_prim = 14.0
Tangent evaluation:
call multiply_add_prim(2.0, Traced<ShapedArray(float32[], weak_type=True)>, 0.0)
call multiply_add_abstract_eval(ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[], weak_type=True), ConcreteArray(0.0, dtype=float32, weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, 2.0, Traced<ShapedArray(float32[])>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[]))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- multiply_add_value_and_jvp = (14.0, Traced<ShapedArray(float32[])>)
|<- multiply_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
|<- square_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
call multiply_add_transpose(1.0, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), 2.0, UndefinedPrimal(ShapedArray(float32[])))
call multiply_add_prim(1.0, 2.0, 0.0)
call multiply_add_impl(1.0, 2.0, 0.0)
|<- multiply_add_impl = 2.0
|<- multiply_add_prim = 2.0
|<- multiply_add_transpose = (2.0, None, 1.0)
call multiply_add_transpose(1.0, 2.0, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), 0.0)
call multiply_add_prim(2.0, 1.0, 0.0)
call multiply_add_impl(2.0, 1.0, 0.0)
|<- multiply_add_impl = 2.0
|<- multiply_add_prim = 2.0
|<- multiply_add_transpose = (None, 2.0, 1.0)
Notice the two calls to multiply_add_transpose
. They correspond to the two
uses of multiply_add_prim
in the computation of the output_tangent
in multiply_add_value_and_jvp
. The first call to transpose corresponds to the
last use of multiply_add_prim
: multiply_add_prim(xt, y, ...)
where y
is the constant 2.0.
JIT of reverse differentiation#
Notice that the abstract evaluation of the multiply_add_value_and_jvp
is using only
abstract values, while in the absence of JIT we used ConcreteArray
.
assert api.jit(api.grad(square_add_prim))(2., 10.) == 4.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_value_and_jvp((Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
Primal evaluation:
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
Tangent evaluation:
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- multiply_add_value_and_jvp = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>)
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_transpose(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, UndefinedPrimal(ShapedArray(float32[])))
call multiply_add_prim(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- multiply_add_transpose = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, None, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_transpose(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[]), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- multiply_add_transpose = (None, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7fbfd027e250>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7fbfd02812b0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7fbfd02813b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7fbfd0282b30>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7fbfd194aa40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7fbfd022fd30>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x55fbb9ee9d20>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("<module>"("/tmp/ipykernel_4031/3085343041.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7fbfd0490870, file "/tmp/ipykernel_4031/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0)), (<code object func_wrapper at 0x7fc000582b80, file "/tmp/ipykernel_4031/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7fbfd02090b0, file "/tmp/ipykernel_4031/3197095916.py", line 4>, 88): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":41:0)), (<code object square_add_prim at 0x7fbfd0490500, file "/tmp/ipykernel_4031/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0)), (<code object <module> at 0x7fbfd02199a0, file "/tmp/ipykernel_4031/3085343041.py", line 1>, 18): loc("<module>"("/tmp/ipykernel_4031/3085343041.py":1:0)), (<code object run_code at 0x7fc00a010920, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0))}, canonical_name_cache={'/tmp/ipykernel_4031/1308506715.py': '/tmp/ipykernel_4031/1308506715.py', '/tmp/ipykernel_4031/1393342955.py': '/tmp/ipykernel_4031/1393342955.py', '/tmp/ipykernel_4031/3197095916.py': '/tmp/ipykernel_4031/3197095916.py', '/tmp/ipykernel_4031/3085343041.py': '/tmp/ipykernel_4031/3085343041.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_4031/1308506715.py': True, '/tmp/ipykernel_4031/1393342955.py': True, '/tmp/ipykernel_4031/3197095916.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/tmp/ipykernel_4031/3085343041.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, platforms=None, global_constant_computation=False, replace_tokens_with_dummy=True)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='transpose'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7fbfd0460ee0>, tokens_out=None, axis_size_env=None, dim_var_values=[]), Value(%0 = "stablehlo.constant"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%1 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7fbfd02812f0>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7fbfd027e250>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7fbfd02812b0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7fbfd02813b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7fbfd0282b30>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7fbfd194aa40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7fbfd022fd30>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x55fbb9ee9d20>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("<module>"("/tmp/ipykernel_4031/3085343041.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x55fbb9e7b400>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("<module>"("/tmp/ipykernel_4031/3085343041.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7fbfd0490870, file "/tmp/ipykernel_4031/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0)), (<code object func_wrapper at 0x7fc000582b80, file "/tmp/ipykernel_4031/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7fbfd02090b0, file "/tmp/ipykernel_4031/3197095916.py", line 4>, 88): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":41:0)), (<code object square_add_prim at 0x7fbfd0490500, file "/tmp/ipykernel_4031/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0)), (<code object <module> at 0x7fbfd02199a0, file "/tmp/ipykernel_4031/3085343041.py", line 1>, 18): loc("<module>"("/tmp/ipykernel_4031/3085343041.py":1:0)), (<code object run_code at 0x7fc00a010920, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)), (<code object multiply_add_value_and_jvp at 0x7fbfd02090b0, file "/tmp/ipykernel_4031/3197095916.py", line 4>, 86): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":41:0))}, canonical_name_cache={'/tmp/ipykernel_4031/1308506715.py': '/tmp/ipykernel_4031/1308506715.py', '/tmp/ipykernel_4031/1393342955.py': '/tmp/ipykernel_4031/1393342955.py', '/tmp/ipykernel_4031/3197095916.py': '/tmp/ipykernel_4031/3197095916.py', '/tmp/ipykernel_4031/3085343041.py': '/tmp/ipykernel_4031/3085343041.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_4031/1308506715.py': True, '/tmp/ipykernel_4031/1393342955.py': True, '/tmp/ipykernel_4031/3197095916.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/tmp/ipykernel_4031/3085343041.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, platforms=None, global_constant_computation=False, replace_tokens_with_dummy=True)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='transpose'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[]), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7fbfd02bc580>, tokens_out=None, axis_size_env=None, dim_var_values=[]), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%4 = "stablehlo.constant"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>), Value(%5 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7fbfd2119c30>]
Batching#
The batching transformation takes a point-wise computation and turns it
into a computation on vectors. If we try it right now, we get a NotImplementedError
:
# The arguments are two vectors instead of two scalars
with expectNotImplementedError():
api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),
np.array([10., 20.]))
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
Found expected exception:
Traceback (most recent call last):
File "/tmp/ipykernel_4031/2641678767.py", line 3, in <module>
api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 1214, in vmap_f
out_flat = batching.batch(
NotImplementedError: Batching rule for 'multiply_add' not implemented
We need to tell JAX how to evaluate the batched version of the primitive. In this particular case, the multiply_add_prim
already operates pointwise for any dimension of input vectors. So the batched version can use the same multiply_add_prim
implementation.
from jax.interpreters import batching
@trace("multiply_add_batch")
def multiply_add_batch(vector_arg_values, batch_axes):
"""Computes the batched version of the primitive.
This must be a JAX-traceable function.
Since the multiply_add primitive already operates pointwise on arbitrary
dimension tensors, to batch it we can use the primitive itself. This works as
long as both the inputs have the same dimensions and are batched along the
same axes. The result is batched along the axis that the inputs are batched.
Args:
vector_arg_values: a tuple of two arguments, each being a tensor of matching
shape.
batch_axes: the axes that are being batched. See vmap documentation.
Returns:
a tuple of the result, and the result axis that was batched.
"""
assert batch_axes[0] == batch_axes[1]
assert batch_axes[0] == batch_axes[2]
_trace("Using multiply_add to compute the batch:")
res = multiply_add_prim(*vector_arg_values)
return res, batch_axes[0]
batching.primitive_batchers[multiply_add_p] = multiply_add_batch
assert np.allclose(api.vmap(square_add_prim, in_axes=0, out_axes=0)(
np.array([2., 3.]),
np.array([10., 20.])),
[14., 29.])
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_batch(([2. 3.], [2. 3.], [10. 20.]), (0, 0, 0))
Using multiply_add to compute the batch:
call multiply_add_prim([2. 3.], [2. 3.], [10. 20.])
call multiply_add_impl([2. 3.], [2. 3.], [10. 20.])
|<- multiply_add_impl = [14. 29.]
|<- multiply_add_prim = [14. 29.]
|<- multiply_add_batch = ([14. 29.], 0)
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
JIT of batching#
assert np.allclose(api.jit(api.vmap(square_add_prim, in_axes=0, out_axes=0))
(np.array([2., 3.]),
np.array([10., 20.])),
[14., 29.])
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_batch((Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>), (0, 0, 0))
Using multiply_add to compute the batch:
call multiply_add_prim(Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[2]), ShapedArray(float32[2]), ShapedArray(float32[2]))
|<- multiply_add_abstract_eval = ShapedArray(float32[2])
|<- multiply_add_prim = Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>
|<- multiply_add_batch = (Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, 0)
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7fbfd02a61b0>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7fbfd0477f70>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7fbfd02c00b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7fbfd02c0770>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7fbfd194aa40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7fbfd02bc3a0>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x55fbb9ceef90>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_batch"("/tmp/ipykernel_4031/184469370.py":25:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("<module>"("/tmp/ipykernel_4031/1392464762.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7fbfd0490870, file "/tmp/ipykernel_4031/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0)), (<code object func_wrapper at 0x7fc000582b80, file "/tmp/ipykernel_4031/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0)), (<code object multiply_add_batch at 0x7fbfd021b9f0, file "/tmp/ipykernel_4031/184469370.py", line 4>, 52): loc("multiply_add_batch"("/tmp/ipykernel_4031/184469370.py":25:0)), (<code object square_add_prim at 0x7fbfd0490500, file "/tmp/ipykernel_4031/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0)), (<code object <module> at 0x7fbfd02192c0, file "/tmp/ipykernel_4031/1392464762.py", line 1>, 48): loc("<module>"("/tmp/ipykernel_4031/1392464762.py":1:0)), (<code object run_code at 0x7fc00a010920, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0))}, canonical_name_cache={'/tmp/ipykernel_4031/1308506715.py': '/tmp/ipykernel_4031/1308506715.py', '/tmp/ipykernel_4031/1393342955.py': '/tmp/ipykernel_4031/1393342955.py', '/tmp/ipykernel_4031/184469370.py': '/tmp/ipykernel_4031/184469370.py', '/tmp/ipykernel_4031/1392464762.py': '/tmp/ipykernel_4031/1392464762.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_4031/1308506715.py': True, '/tmp/ipykernel_4031/1393342955.py': True, '/tmp/ipykernel_4031/184469370.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/batching.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/tmp/ipykernel_4031/1392464762.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, platforms=None, global_constant_computation=False, replace_tokens_with_dummy=True)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='vmap'))), primitive=multiply_add, avals_in=[ShapedArray(float32[2]), ShapedArray(float32[2]), ShapedArray(float32[2])], avals_out=[ShapedArray(float32[2])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7fbfd02bd240>, tokens_out=None, axis_size_env=None, dim_var_values=[]), Value(<block argument> of type 'tensor<2xf32>' at index: 0), Value(<block argument> of type 'tensor<2xf32>' at index: 0), Value(<block argument> of type 'tensor<2xf32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7fbfd02970b0>]
Writing custom Jaxpr interpreters in JAX#
JAX offers several composable function transformations (jit
, grad
, vmap
,
etc.) that enable writing concise, accelerated code.
Here we show how to add your own function transformations to the system, by writing a custom Jaxpr interpreter. And we’ll get composability with all the other transformations for free.
This example uses internal JAX APIs, which may break at any time. Anything not in the API Documentation should be assumed internal.
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
from jax import random
What is JAX doing?#
JAX provides a NumPy-like API for numerical computing which can be used as is, but JAX’s true power comes from composable function transformations. Take the jit
function transformation, which takes in a function and returns a semantically identical function but is lazily compiled by XLA for accelerators.
x = random.normal(random.key(0), (5000, 5000))
def f(w, b, x):
return jnp.tanh(jnp.dot(x, w) + b)
fast_f = jit(f)
When we call fast_f
, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax’s tracing machinery, you can refer to the “How it works” section in the README.
Jaxpr tracer#
A tracer of special importance in Jax is the Jaxpr tracer, which records ops into a Jaxpr (Jax expression). A Jaxpr is a data structure that can be evaluated like a mini functional programming language and thus Jaxprs are a useful intermediate representation for function transformation.
To get a first look at Jaxprs, consider the make_jaxpr
transformation. make_jaxpr
is essentially a “pretty-printing” transformation:
it transforms a function into one that, given example arguments, produces a Jaxpr representation of its computation.
make_jaxpr
is useful for debugging and introspection.
Let’s use it to look at how some example Jaxprs are structured.
def examine_jaxpr(closed_jaxpr):
jaxpr = closed_jaxpr.jaxpr
print("invars:", jaxpr.invars)
print("outvars:", jaxpr.outvars)
print("constvars:", jaxpr.constvars)
for eqn in jaxpr.eqns:
print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
print()
print("jaxpr:", jaxpr)
def foo(x):
return x + 1
print("foo")
print("=====")
examine_jaxpr(jax.make_jaxpr(foo)(5))
print()
def bar(w, b, x):
return jnp.dot(w, x) + b + jnp.ones(5), x
print("bar")
print("=====")
examine_jaxpr(jax.make_jaxpr(bar)(jnp.ones((5, 10)), jnp.ones(5), jnp.ones(10)))
foo
=====
invars: [Var(id=140015952089984):int32[]]
outvars: [Var(id=140015952089792):int32[]]
constvars: []
equation: [Var(id=140015952089984):int32[], 1] add [Var(id=140015952089792):int32[]] {}
jaxpr: { lambda ; a:i32[]. let b:i32[] = add a 1 in (b,) }
bar
=====
invars: [Var(id=140015952459328):float32[5,10], Var(id=140015952459392):float32[5], Var(id=140015952459456):float32[10]]
outvars: [Var(id=140015952459712):float32[5], Var(id=140015952459456):float32[10]]
constvars: []
equation: [Var(id=140015952459328):float32[5,10], Var(id=140015952459456):float32[10]] dot_general [Var(id=140015952459520):float32[5]] {'dimension_numbers': (((1,), (0,)), ((), ())), 'precision': None, 'preferred_element_type': dtype('float32')}
equation: [Var(id=140015952459520):float32[5], Var(id=140015952459392):float32[5]] add [Var(id=140015952459584):float32[5]] {}
equation: [1.0] broadcast_in_dim [Var(id=140015952459648):float32[5]] {'shape': (5,), 'broadcast_dimensions': ()}
equation: [Var(id=140015952459584):float32[5], Var(id=140015952459648):float32[5]] add [Var(id=140015952459712):float32[5]] {}
jaxpr: { lambda ; a:f32[5,10] b:f32[5] c:f32[10]. let
d:f32[5] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] a c
e:f32[5] = add d b
f:f32[5] = broadcast_in_dim[broadcast_dimensions=() shape=(5,)] 1.0
g:f32[5] = add e f
in (g, c) }
jaxpr.invars
- theinvars
of a Jaxpr are a list of the input variables to Jaxpr, analogous to arguments in Python functions.jaxpr.outvars
- theoutvars
of a Jaxpr are the variables that are returned by the Jaxpr. Every Jaxpr has multiple outputs.jaxpr.constvars
- theconstvars
are a list of variables that are also inputs to the Jaxpr, but correspond to constants from the trace (we’ll go over these in more detail later).jaxpr.eqns
- a list of equations, which are essentially let-bindings. Each equation is a list of input variables, a list of output variables, and a primitive, which is used to evaluate inputs to produce outputs. Each equation also has aparams
, a dictionary of parameters.
Altogether, a Jaxpr encapsulates a simple program that can be evaluated with inputs to produce an output. We’ll go over how exactly to do this later. The important thing to note now is that a Jaxpr is a data structure that can be manipulated and evaluated in whatever way we want.
Why are Jaxprs useful?#
Jaxprs are simple program representations that are easy to transform. And because Jax lets us stage out Jaxprs from Python functions, it gives us a way to transform numerical programs written in Python.
Your first interpreter: invert
#
Let’s try to implement a simple function “inverter”, which takes in the output of the original function and returns the inputs that produced those outputs. For now, let’s focus on simple, unary functions which are composed of other invertible unary functions.
Goal:
def f(x):
return jnp.exp(jnp.tanh(x))
f_inv = inverse(f)
assert jnp.allclose(f_inv(f(1.0)), 1.0)
The way we’ll implement this is by (1) tracing f
into a Jaxpr, then (2) interpreting the Jaxpr backwards. While interpreting the Jaxpr backwards, for each equation we’ll look up the primitive’s inverse in a table and apply it.
1. Tracing a function#
Let’s use make_jaxpr
to trace a function into a Jaxpr.
# Importing Jax functions useful for tracing/interpreting.
import numpy as np
from functools import wraps
from jax import core
from jax import lax
from jax._src.util import safe_map
jax.make_jaxpr
returns a closed Jaxpr, which is a Jaxpr that has been bundled with
the constants (literals
) from the trace.
def f(x):
return jnp.exp(jnp.tanh(x))
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
print(closed_jaxpr.jaxpr)
print(closed_jaxpr.literals)
{ lambda ; a:f32[5]. let b:f32[5] = tanh a; c:f32[5] = exp b in (c,) }
[]
2. Evaluating a Jaxpr#
Before we write a custom Jaxpr interpreter, let’s first implement the “default” interpreter, eval_jaxpr
, which evaluates the Jaxpr as-is, computing the same values that the original, un-transformed Python function would.
To do this, we first create an environment to store the values for each of the variables, and update the environment with each equation we evaluate in the Jaxpr.
def eval_jaxpr(jaxpr, consts, *args):
# Mapping from variable -> value
env = {}
def read(var):
# Literals are values baked into the Jaxpr
if type(var) is core.Literal:
return var.val
return env[var]
def write(var, val):
env[var] = val
# Bind args and consts to environment
safe_map(write, jaxpr.invars, args)
safe_map(write, jaxpr.constvars, consts)
# Loop through equations and evaluate primitives using `bind`
for eqn in jaxpr.eqns:
# Read inputs to equation from environment
invals = safe_map(read, eqn.invars)
# `bind` is how a primitive is called
outvals = eqn.primitive.bind(*invals, **eqn.params)
# Primitives may return multiple outputs or not
if not eqn.primitive.multiple_results:
outvals = [outvals]
# Write the results of the primitive into the environment
safe_map(write, eqn.outvars, outvals)
# Read the final result of the Jaxpr from the environment
return safe_map(read, jaxpr.outvars)
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, jnp.ones(5))
[Array([2.1416876, 2.1416876, 2.1416876, 2.1416876, 2.1416876], dtype=float32)]
Notice that eval_jaxpr
will always return a flat list even if the original function does not.
Furthermore, this interpreter does not handle higher-order primitives (like jit
and pmap
), which we will not cover in this guide. You can refer to core.eval_jaxpr
(link) to see the edge cases that this interpreter does not cover.
Custom inverse
Jaxpr interpreter#
An inverse
interpreter doesn’t look too different from eval_jaxpr
. We’ll first set up the registry which will map primitives to their inverses. We’ll then write a custom interpreter that looks up primitives in the registry.
It turns out that this interpreter will also look similar to the “transpose” interpreter used in reverse-mode autodifferentiation found here.
inverse_registry = {}
We’ll now register inverses for some of the primitives. By convention, primitives in Jax end in _p
and a lot of the popular ones live in lax
.
inverse_registry[lax.exp_p] = jnp.log
inverse_registry[lax.tanh_p] = jnp.arctanh
inverse
will first trace the function, then custom-interpret the Jaxpr. Let’s set up a simple skeleton.
def inverse(fun):
@wraps(fun)
def wrapped(*args, **kwargs):
# Since we assume unary functions, we won't worry about flattening and
# unflattening arguments.
closed_jaxpr = jax.make_jaxpr(fun)(*args, **kwargs)
out = inverse_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
return out[0]
return wrapped
Now we just need to define inverse_jaxpr
, which will walk through the Jaxpr backward and invert primitives when it can.
def inverse_jaxpr(jaxpr, consts, *args):
env = {}
def read(var):
if type(var) is core.Literal:
return var.val
return env[var]
def write(var, val):
env[var] = val
# Args now correspond to Jaxpr outvars
safe_map(write, jaxpr.outvars, args)
safe_map(write, jaxpr.constvars, consts)
# Looping backward
for eqn in jaxpr.eqns[::-1]:
# outvars are now invars
invals = safe_map(read, eqn.outvars)
if eqn.primitive not in inverse_registry:
raise NotImplementedError(
f"{eqn.primitive} does not have registered inverse.")
# Assuming a unary function
outval = inverse_registry[eqn.primitive](*invals)
safe_map(write, eqn.invars, [outval])
return safe_map(read, jaxpr.invars)
That’s it!
def f(x):
return jnp.exp(jnp.tanh(x))
f_inv = inverse(f)
assert jnp.allclose(f_inv(f(1.0)), 1.0)
Importantly, you can trace through a Jaxpr interpreter.
jax.make_jaxpr(inverse(f))(f(1.))
{ lambda ; a:f32[]. let b:f32[] = log a; c:f32[] = atanh b in (c,) }
That’s all it takes to add a new transformation to a system, and you get composition with all the others for free! For example, we can use jit
, vmap
, and grad
with inverse
!
jit(vmap(grad(inverse(f))))((jnp.arange(5) + 1.) / 5.)
Array([-3.1440797, 15.584931 , 2.2551253, 1.3155028, 1. ], dtype=float32, weak_type=True)
Exercises for the reader#
Handle primitives with multiple arguments where inputs are partially known, for example
lax.add_p
,lax.mul_p
.Handle
xla_call
andxla_pmap
primitives, which will not work with botheval_jaxpr
andinverse_jaxpr
as written.
Custom operations for GPUs with C++ and CUDA#
JAX ships with a large number of built-in operations, but users occasionally run into a situation where they need a new operation that is not supported by JAX.
To accommodate such scenarios, JAX allows users to define custom operations and this tutorial is to explain how we can define one for GPUs and use it in single-GPU and multi-GPU environments.
This tutorial contains information from Extending JAX with custom C++ and CUDA code and supposes that you are familiar with JAX primitive.
RMS normalization#
For this tutorial, we are going to add the RMS normalization as a custom operation in JAX.
Note that the RMS normalization can be expressed with jax.numpy
directly. However, we are using it as an example to show the process of creating a custom operation for GPUs.
The CUDA code in gpu_ops/rms_norm_kernels.cu
for this operation has been borrowed from Apex and adapted to eliminate any dependency on PyTorch.
High-level steps#
This tutorial shows how to write both a custom operation and its gradient.
In C: You need to follow these steps in C for each new JAX primitive:
Have CUDA kernel(s).
Create a C function that dispatches the CUDA kernel that will be called by XLA.
Create a descriptor to convey information needed for the computation.
The types, the shapes and other attributes.
Bind C functions to Python
To create the descriptor and to call the primitive during execution.
In Python: You need to follow these steps in Python:
Define a new JAX primitive (instruction/operation)
Write Python functions to build the graph nodes with the primitive.
Define its abstract evaluation.
Define its lowering to MLIR.
[Optional] Define the gradient.
[Optional] Use custom_partitioning or shard_map functions for fast multi-GPU.
C code#
See gpu_ops
code listing for a complete code listing of C++ and CUDA files.
gpu_ops/rms_norm_kernels.cu
defines the following functions, which are declared with the XLA custom function signature.
These functions are responsible for launching RMS normalization kernels with the given buffers
on the specified stream
.
namespace gpu_ops {
void rms_forward_affine_mixed_dtypes(cudaStream_t stream, void **buffers,
const char *opaque,
std::size_t opaque_len);
void rms_backward_affine(cudaStream_t stream, void **buffers,
const char *opaque, std::size_t opaque_len);
} // namespace gpu_ops
stream
is the CUDA stream to be used to execute any kernel on the GPU.buffers
has all pointers to input buffers followed by all pointers to output buffers.opaque
is a buffer for any extra information that is being passed to the custom functions andopaque_len
is the length ofopaque
.
For this tutorial, an RMSNormDescriptor
object will be passed to these functions as opaque
.
namespace gpu_ops {
enum ElementType { BF16, F16, F32, F64 };
struct RMSNormDescriptor {
int n1;
int n2;
double eps;
ElementType x_type;
ElementType w_type;
int part_grad_size;
};
} // namespace gpu_ops
Now, we need to expose these functions as well as ElementType
and RMSNormDescriptor
as a Python module, gpu_ops
, through pybind11
.
pybind11::dict RMSNormRegistrations() {
pybind11::dict dict;
dict["rms_forward_affine_mixed_dtype"] =
gpu_ops::EncapsulateFunction(gpu_ops::rms_forward_affine_mixed_dtypes);
dict["rms_backward_affine"] =
gpu_ops::EncapsulateFunction(gpu_ops::rms_backward_affine);
return dict;
}
PYBIND11_MODULE(gpu_ops, m) {
m.def("get_rms_norm_registrations", &RMSNormRegistrations);
m.def("create_rms_norm_descriptor",
[](int n1, int n2, double eps, gpu_ops::ElementType x_type,
gpu_ops::ElementType w_type, int part_grad_size) {
return gpu_ops::PackDescriptor(gpu_ops::RMSNormDescriptor{
n1, n2, eps, x_type, w_type, part_grad_size});
});
pybind11::enum_<gpu_ops::ElementType>(m, "ElementType")
.value("BF16", gpu_ops::ElementType::BF16)
.value("F16", gpu_ops::ElementType::F16)
.value("F32", gpu_ops::ElementType::F32)
.value("F64", gpu_ops::ElementType::F64);
}
Build gpu_ops
extension module#
We build the gpu_ops
Python extension module with the aforementioned code.
(See gpu_ops
code listing for a complete code listing of C++ and CUDA files.)
python -m pip install pybind11==2.10.1
mkdir -p build
pybind_include_path=$(python -c "import pybind11; print(pybind11.get_include())")
python_executable=$(python -c 'import sys; print(sys.executable)')
nvcc --threads 4 -Xcompiler -Wall -ldl --expt-relaxed-constexpr -O3 -DNDEBUG -Xcompiler -O3 --generate-code=arch=compute_70,code=[compute_70,sm_70] --generate-code=arch=compute_75,code=[compute_75,sm_75] --generate-code=arch=compute_80,code=[compute_80,sm_80] --generate-code=arch=compute_86,code=[compute_86,sm_86] -Xcompiler=-fPIC -Xcompiler=-fvisibility=hidden -x cu -c gpu_ops/rms_norm_kernels.cu -o build/rms_norm_kernels.cu.o
c++ -I/usr/local/cuda/include -I$pybind_include_path $(${python_executable}-config --cflags) -O3 -DNDEBUG -O3 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o build/gpu_ops.cpp.o -c gpu_ops/gpu_ops.cpp
c++ -fPIC -O3 -DNDEBUG -O3 -flto -shared -o build/gpu_ops$(${python_executable}-config --extension-suffix) build/gpu_ops.cpp.o build/rms_norm_kernels.cu.o -L/usr/local/cuda/lib64 -lcudadevrt -lcudart_static -lrt -lpthread -ldl
strip build/gpu_ops$(${python_executable}-config --extension-suffix)
Add RMS normalization to JAX as custom call#
gpu_ops
is just a Python extension module and we need more work to plug it into JAX.
Create primitives#
We first create primitives, _rms_norm_fwd_p
and _rms_norm_bwd_p
, which the custom functions can be mapped to.
We set the multiple_results
attribute to True
for these operations, which means that the operation produces multiple outputs as a tuple.
When it is set to False
, the operation produces a single output without a tuple.
For more details, see How JAX primitives work.
from functools import partial
import jax
import jax.numpy as jnp
import jax._src.test_util as jtu
from build import gpu_ops
from jax import core, dtypes
from jax.interpreters import xla
from jax.lib import xla_client
# Create _rms_norm_fwd_p for forward operation.
_rms_norm_fwd_p = core.Primitive("rms_norm_fwd")
_rms_norm_fwd_p.multiple_results = True
_rms_norm_fwd_p.def_impl(partial(xla.apply_primitive, _rms_norm_fwd_p))
def rms_norm_fwd(x, weight, eps=1e-05):
output, invvar = _rms_norm_fwd_p.bind(x, weight, eps=eps)
return output
# Create _rms_norm_bwd_p for backward operation.
_rms_norm_bwd_p = core.Primitive("rms_norm_bwd")
_rms_norm_bwd_p.multiple_results = True
_rms_norm_bwd_p.def_impl(partial(xla.apply_primitive, _rms_norm_bwd_p))
def rms_norm_bwd(g, invvar, x, weight, eps):
grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind(
g, invvar, x, weight, eps=eps
)
return grad_input, grad_weight
Lowering to MLIR custom call#
To map the custom functions to the new primitives, _rms_norm_fwd_p
and _rms_norm_bwd_p
, we need to:
Register custom functions as custom call targets with
xla_client.register_custom_call_target
, andRegister lowering functions that lower the primitives to MLIR custom calls with the registered custom call targets.
The functions _rms_norm_fwd_cuda_lowering
and _rms_norm_bwd_cuda_lowering
below lower the primitives to MLIR custom call operations with the custom targets from gpu_ops
. These functions are registered with jax.interpreters.mlir.register_lowering
.
Note that an RMSNormDescriptor
object is created in the lowering function, and passed to the custom call as opaque
.
from functools import reduce
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jaxlib.hlo_helpers import custom_call
# Register functions defined in gpu_ops as custom call target for GPUs
for _name, _value in gpu_ops.get_rms_norm_registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="gpu")
def element_type_to_descriptor_type_mapping(element_type):
_element_type_to_descriptor_type_mapping = {
ir.BF16Type.get(): gpu_ops.ElementType.BF16,
ir.F16Type.get(): gpu_ops.ElementType.F16,
ir.F32Type.get(): gpu_ops.ElementType.F32,
ir.F64Type.get(): gpu_ops.ElementType.F64,
}
return _element_type_to_descriptor_type_mapping.get(element_type)
def default_layouts(*shapes):
return [range(len(shape) - 1, -1, -1) for shape in shapes]
def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps):
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
w_type = ir.RankedTensorType(weight.type)
w_shape = w_type.shape
iv_element_type = (
ir.F32Type.get()
if x_type.element_type in [ir.F16Type.get(), ir.BF16Type.get()]
else x_type.element_type
)
n2 = reduce(lambda x, y: x * y, w_shape)
n1 = reduce(lambda x, y: x * y, x_shape) // n2
opaque = gpu_ops.create_rms_norm_descriptor(
n1,
n2,
eps,
element_type_to_descriptor_type_mapping(x_type.element_type),
element_type_to_descriptor_type_mapping(w_type.element_type),
0, # unused
)
out = custom_call(
b"rms_forward_affine_mixed_dtype",
result_types=[
ir.RankedTensorType.get(x_shape, w_type.element_type),
ir.RankedTensorType.get((n1,), iv_element_type),
],
operands=[x, weight],
backend_config=opaque,
operand_layouts=default_layouts(x_shape, w_shape),
result_layouts=default_layouts(x_shape, (n1,)),
).results
return out
mlir.register_lowering(
_rms_norm_fwd_p,
_rms_norm_fwd_cuda_lowering,
platform="gpu",
)
def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps):
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
w_type = ir.RankedTensorType(weight.type)
w_shape = w_type.shape
iv_type = ir.RankedTensorType(invvar.type)
n2 = reduce(lambda x, y: x * y, w_shape)
n1 = reduce(lambda x, y: x * y, x_shape) // n2
part_grad_shape = ctx.avals_out[-1].shape
opaque = gpu_ops.create_rms_norm_descriptor(
n1,
n2,
eps,
element_type_to_descriptor_type_mapping(x_type.element_type),
element_type_to_descriptor_type_mapping(w_type.element_type),
part_grad_shape[0],
)
out = custom_call(
b"rms_backward_affine",
result_types=[
ir.RankedTensorType.get(x_shape, x_type.element_type),
ir.RankedTensorType.get(w_shape, w_type.element_type),
ir.RankedTensorType.get(part_grad_shape, iv_type.element_type),
],
operands=[grad_output, invvar, x, weight],
backend_config=opaque,
operand_layouts=default_layouts(x_shape, (n1,), x_shape, w_shape),
result_layouts=default_layouts(x_shape, w_shape, part_grad_shape),
).results
return out
mlir.register_lowering(
_rms_norm_bwd_p,
_rms_norm_bwd_cuda_lowering,
platform="gpu",
)
Let’s test it#
per_core_batch_size=4
seq_len=512
emb_dim=512
x = jax.random.normal(
jax.random.PRNGKey(0),
shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim),
dtype=jnp.bfloat16,
)
norm_shape = x.shape[-2:]
weight = jnp.ones(norm_shape, dtype=jnp.bfloat16)
Test forward function#
out = rms_norm_fwd(x, weight)
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
Cell In [5], line 1
----> 1 out = rms_norm_fwd(x, weight)
...
NotImplementedError: Abstract evaluation for 'rms_norm_fwd' not implemented
Abstract evaluation#
The test above failed with NotImplementedError: Abstract evaluation for 'rms_norm_fwd' not implemented
. Why did the test fail? What does it mean?
As part of the execution, JAX performs abstract evaluation. As JAX has no knowledge about the new primitives, it doesn’t know how to compute the output shapes and output data types, thus can’t evaluate these operations abstractly.
We need to provide a function for abstract evaluation of each primitive. These abstract evaluation functions compute the shape and the data type of the outputs, but don’t compute actual values for the operations.
These functions are passed to .def_abstract_eval
method to be registered with the corresponding primitives.
See How JAX primitives work for more information on abstract evaluation.
from functools import reduce
from operator import mul
from jax.core import ShapedArray
def _rms_norm_fwd_abstract(x, weight, eps):
w_dtype = dtypes.canonicalize_dtype(weight.dtype)
iv_dtype = dtypes.canonicalize_dtype(x.dtype)
if iv_dtype in [jnp.float16, jnp.bfloat16]:
iv_dtype = jnp.float32
n2 = reduce(mul, weight.shape)
n1 = reduce(mul, x.shape) // n2
return (
ShapedArray(x.shape, w_dtype, named_shape=x.named_shape), # output
ShapedArray((n1,), iv_dtype, named_shape=x.named_shape), # invvar
)
_rms_norm_fwd_p.def_abstract_eval(_rms_norm_fwd_abstract)
def _rms_norm_bwd_abstract(grad_output, invvar, x, weight, eps):
iv_dtype = dtypes.canonicalize_dtype(invvar.dtype)
w_dtype = dtypes.canonicalize_dtype(weight.dtype)
x_dtype = dtypes.canonicalize_dtype(x.dtype)
n2 = reduce(lambda x, y: x * y, weight.shape)
n1 = reduce(lambda x, y: x * y, x.shape) // n2
part_grad_shape = (16, n2)
assert dtypes.canonicalize_dtype(grad_output.dtype) == w_dtype
assert grad_output.shape == x.shape
assert invvar.shape == (n1,)
assert (
iv_dtype == jnp.float32 if x_dtype in [jnp.float16, jnp.bfloat16] else x_dtype
)
assert grad_output.named_shape == x.named_shape
weight_named_shape = (
weight_named_shape if weight.named_shape else x.named_shape
)
return (
ShapedArray(
x.shape, x_dtype, named_shape=x.named_shape
), # grad input
ShapedArray(
weight.shape, w_dtype, named_shape=weight_named_shape
), # grad weight
ShapedArray(
part_grad_shape, iv_dtype, named_shape=weight_named_shape
), # part grad
)
_rms_norm_bwd_p.def_abstract_eval(_rms_norm_bwd_abstract)
Let’s test it again#
Test the forward function#
out = rms_norm_fwd(x, weight)
Test the backward function#
Now let’s test the backward operation using jax.grad
and jtu.check_grads
.
def loss(x, weight):
predictions = rms_norm_fwd(x, weight)
return -jnp.mean(predictions**2)
loss_grad = jax.grad(loss)
out = loss_grad(x, weight)
jtu.check_grads(loss, (x, weight), modes=["rev"], order=1)
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
Cell In [8], line 7
3 return -jnp.mean(predictions**2)
6 loss_grad = jax.grad(loss)
----> 7 out = loss_grad(x, weight)
...
NotImplementedError: Differentiation rule for 'rms_norm_fwd' not implemented
Differentiation rule#
The backward operation failed with the error NotImplementedError: Differentiation rule for 'rms_norm_fwd' not implemented
. It means that, although we have defined rms_norm_fwd
and rms_norm_bwd
, JAX doesn’t know the relationship between them.
We can teach JAX that rms_norm_bwd
is the backward operation for rms_norm_fwd
, using jax.custom_vjp
and its convention. As the first step, we need to refine the definition of rms_norm_fwd
and rms_norm_bwd
.
# rms_norm_fwd was previously defined as
#
# def rms_norm_fwd(x, weight, eps=1e-05):
# output, invvar = _rms_norm_fwd_p.bind(x, weight, eps=eps)
# return output
#
def rms_norm_fwd(x, weight, eps=1e-05):
output, invvar = _rms_norm_fwd_p.bind(x, weight, eps=eps)
return output, (invvar, x, weight)
# rms_norm_bwd was previously defined as
#
# def rms_norm_bwd(g, invvar, x, weight, eps):
# grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind(
# g, invvar, x, weight, eps=eps
# )
# return grad_input, grad_weight
#
def rms_norm_bwd(eps, res, g):
invvar, x, weight = res
grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind(
g, invvar, x, weight, eps=eps
)
return grad_input, grad_weight
rms_norm_fwd
now returns an extra output (invvar, x, weight)
for the residual data and rms_norm_bwd
takes eps
, res
, and g
as the parameters.
Once the relationship between rms_norm_fwd
and rms_norm_bwd
is established through jax.custom_vjp
, JAX will ensure that the residual data from rms_norm_fwd
is passed to rms_norm_bwd
as res
for backward operation.
For non-differentiable parameters such as eps
, JAX ensures that they are passed to the backward operation before the residual data. That’s why eps
precedes res
in the parameter list of rms_norm_bwd
.
Now that rms_norm_fwd
returns the residual data, which is not needed for simple RMS normalization operation, we define a wrapper around it, rms_norm
. It simply calls rms_norm_fwd
and returns only output
. Note that rms_norm
is annotated with @partial(jax.custom_vjp, nondiff_argnums=(2,))
and we are passing rms_norm_fwd
and rms_norm_bwd
to rms_norm.defvjp
. It teaches JAX that, when rms_norm
is differentiated, rms_norm_fwd
is to be used for forward operation, and rms_norm_bwd
is to be used for backward operation.
See Custom derivative rules for JAX-transformable Python functions for more information on jax.custom_vjp
.
@partial(jax.custom_vjp, nondiff_argnums=(2,))
def rms_norm(x, weight, eps=1e-05):
output, _ = rms_norm_fwd(x, weight, eps=eps)
return output
rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd)
With the refinement we have made, the backward operation test works with a modification: loss
now calls rms_norm
instead of rms_norm_fwd
.
def loss(x, weight):
predictions = rms_norm(x, weight)
return -jnp.mean(predictions**2)
loss_grad = jax.grad(loss)
out = loss_grad(x, weight)
jtu.check_grads(loss, (x, weight), modes=["rev"], order=1)
Let’s test it on multiple devices#
We are using jax.experimental.pjit.pjit
for parallel execution on multiple devices, and we produce reference values with sequential execution on a single device.
Test the forward function#
Let’s first test the forward operation on multiple devices. We are creating a simple 1D mesh and sharding x
on all devices.
from jax.sharding import Mesh, PartitionSpec
from jax.experimental.pjit import pjit
mesh = Mesh(jax.local_devices(), ("x",))
ref = rms_norm(x, weight)
pjitted = pjit(
rms_norm,
# Shard x by batch dimension and replicate weight on all devices.
in_shardings=(PartitionSpec("x", None, None), PartitionSpec(None, None)),
# Shard the output by batch dimension.
out_shardings=PartitionSpec("x", None, None),
)
with mesh:
print(pjitted.lower(x, weight).compile().runtime_executable().hlo_modules()[0].to_string())
out = pjitted(x, weight)
jnp.allclose(ref, out, atol=1e-5, rtol=1e-5)
HloModule pjit_rms_norm, entry_computation_layout={(bf16[4,512,512]{2,1,0},bf16[512,512]{1,0})->bf16[4,512,512]{2,1,0}}
%fused_computation (param_1: bf16[32,512,512], param_1.3: u32[]) -> bf16[4,512,512] {
%param_1 = bf16[32,512,512]{2,1,0} parameter(0)
%param_1.3 = u32[] parameter(1)
%convert.2 = s32[] convert(u32[] %param_1.3), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
%constant_9 = s32[] constant(4), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
%multiply.3 = s32[] multiply(s32[] %convert.2, s32[] %constant_9), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
%constant_8 = s32[] constant(0), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
ROOT %dynamic-slice.2 = bf16[4,512,512]{2,1,0} dynamic-slice(bf16[32,512,512]{2,1,0} %param_1, s32[] %multiply.3, s32[] %constant_8, s32[] %constant_8), dynamic_slice_sizes={4,512,512}, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
}
ENTRY %main.7_spmd (param: bf16[4,512,512], param.1: bf16[512,512]) -> bf16[4,512,512] {
%param = bf16[4,512,512]{2,1,0} parameter(0), sharding={devices=[8,1,1]0,1,2,3,4,5,6,7}
%all-gather = bf16[32,512,512]{2,1,0} all-gather(bf16[4,512,512]{2,1,0} %param), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={0}, use_global_device_ids=true, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
%param.1 = bf16[512,512]{1,0} parameter(1), sharding={replicated}
%custom-call.0 = (bf16[32,512,512]{2,1,0}, f32[32]{0}) custom-call(bf16[32,512,512]{2,1,0} %all-gather, bf16[512,512]{1,0} %param.1), custom_call_target="rms_forward_affine_mixed_dtype", operand_layout_constraints={bf16[32,512,512]{2,1,0}, bf16[512,512]{1,0}}, api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}, backend_config=" \000\000\000\000\000\004\000\361h\343\210\265\370\344>\000\000\000\000\000\000\000\000\000\000\000\000\255\177\000\000"
%get-tuple-element = bf16[32,512,512]{2,1,0} get-tuple-element((bf16[32,512,512]{2,1,0}, f32[32]{0}) %custom-call.0), index=0, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
%partition-id = u32[] partition-id(), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
ROOT %fusion = bf16[4,512,512]{2,1,0} fusion(bf16[32,512,512]{2,1,0} %get-tuple-element, u32[] %partition-id), kind=kLoop, calls=%fused_computation, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
}
True
The values have been computed correctly for forward operation, however, the generated HLO modules show an all-gather
operation to replicate x
on all devices, incurring large communication overhead.
As XLA does not have enough knowledge about the custom functions to shard input tensors, it decides to replicate them to produce correct values before making the custom call.
To avoid this duplication, we can:
custom_partitioning: to make it behave like all native JAX operations (but more complicated)
Use manual sharding
This example demonstrates the use of custom_partitioning.
Check for correctness#
with Mesh(jax.local_devices(), ("x",)):
def run_and_verify(loss):
pjitted = pjit(
jax.grad(loss, argnums=(0, 1)),
# Shard x by batch dimension and replicate weight on all devices.
in_shardings=(
PartitionSpec("x", None, None),
PartitionSpec(None, None),
),
# Shard the output by batch dimension and replicate weight grad on all devices.
out_shardings=(
PartitionSpec("x", None, None),
PartitionSpec(None, None),
),
)
hlo = pjitted.lower(x, weight).compile().as_text()
out = pjitted(x, weight)
print(hlo)
assert "all-reduce-done" in hlo, "The gradient will produce wrong value!"
if "all-gather-start" in hlo:
print("NOT OPTIMIZED, ALL_GATHER in the graph!")
return out
custom_p_out = run_and_verify(custom_p_loss)
for r, o in zip(ref_out, custom_p_out):
print(jnp.allclose(r, o, atol=1e-6, rtol=1e-6))
HloModule pjit_custom_p_loss, is_scheduled=true, entry_computation_layout={(f16[4,512,512]{2,1,0}, f16[512,512]{1,0})->(f16[4,512,512]{2,1,0}, f16[512,512]{1,0})}, allow_spmd_sharding_propagation_to_parameters={false,false}, allow_spmd_sharding_propagation_to_output={false,false}, num_partitions=4, frontend_attributes={fingerprint_before_lhs="d7b9bc40de002332dd665ff2ab537b76"}
%fused_multiply (param_0: f16[4,512,512]) -> f16[4,512,512] {
%param_0 = f16[4,512,512]{2,1,0} parameter(0)
%constant_4_1 = f16[] constant(-4.7684e-07)
%broadcast.8.1 = f16[4,512,512]{2,1,0} broadcast(f16[] %constant_4_1), dimensions={}, metadata={op_name="pjit(custom_p_loss)/jit(main)/mul" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=484}
ROOT %multiply.5.1 = f16[4,512,512]{2,1,0} multiply(f16[4,512,512]{2,1,0} %param_0, f16[4,512,512]{2,1,0} %broadcast.8.1), metadata={op_name="pjit(custom_p_loss)/jit(main)/mul" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=484}
}
%region_0.9._custom_call_lowering_rule (Arg_0.10.0: f16[], Arg_1.11.0: f16[]) -> f16[] {
%Arg_1.11.0 = f16[] parameter(1)
%Arg_0.10.0 = f16[] parameter(0)
ROOT %add.2.0 = f16[] add(f16[] %Arg_0.10.0, f16[] %Arg_1.11.0), metadata={op_name="jit(main)/add" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=433}
}
ENTRY %main.23_spmd (param.2: f16[4,512,512], param.1.0: f16[512,512]) -> (f16[4,512,512], f16[512,512]) {
%param.1.0 = f16[512,512]{1,0} parameter(1), sharding={replicated}
%param.2 = f16[4,512,512]{2,1,0} parameter(0), sharding={devices=[4,1,1]<=[4]}
%custom-call.3.0 = (f16[4,512,512]{2,1,0}, f32[4]{0}) custom-call(f16[4,512,512]{2,1,0} %param.2, f16[512,512]{1,0} %param.1.0), custom_call_target="rms_forward_affine_mixed_dtype", operand_layout_constraints={f16[4,512,512]{2,1,0}, f16[512,512]{1,0}}, api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormFwdClass.partition at 0x7ff99e3980d0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormFwdClass.infer_sharding_from_operands at 0x7ff99e398040> decode_shardings=True in_tree=PyTreeDef((*, *)) out_tree=PyTreeDef((*, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=440}, backend_config="\004\000\000\000\000\000\004\000\361h\343\210\265\370\344>\001\000\000\000\001\000\000\000\000\000\000\000$V\000\000"
%get-tuple-element.14 = f16[4,512,512]{2,1,0} get-tuple-element((f16[4,512,512]{2,1,0}, f32[4]{0}) %custom-call.3.0), index=0, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormFwdClass.partition at 0x7ff99e3980d0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormFwdClass.infer_sharding_from_operands at 0x7ff99e398040> decode_shardings=True in_tree=PyTreeDef((*, *)) out_tree=PyTreeDef((*, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=440}
%loop_multiply_fusion = f16[4,512,512]{2,1,0} fusion(f16[4,512,512]{2,1,0} %get-tuple-element.14), kind=kLoop, calls=%fused_multiply, metadata={op_name="pjit(custom_p_loss)/jit(main)/mul" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=484}
%get-tuple-element.1.0 = f32[4]{0} get-tuple-element((f16[4,512,512]{2,1,0}, f32[4]{0}) %custom-call.3.0), index=1, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormFwdClass.partition at 0x7ff99e3980d0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormFwdClass.infer_sharding_from_operands at 0x7ff99e398040> decode_shardings=True in_tree=PyTreeDef((*, *)) out_tree=PyTreeDef((*, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=440}
%custom-call.5.0 = (f16[4,512,512]{2,1,0}, f16[512,512]{1,0}, f32[16,262144]{1,0}) custom-call(f16[4,512,512]{2,1,0} %loop_multiply_fusion, f32[4]{0} %get-tuple-element.1.0, f16[4,512,512]{2,1,0} %param.2, f16[512,512]{1,0} %param.1.0), custom_call_target="rms_backward_affine", operand_layout_constraints={f16[4,512,512]{2,1,0}, f32[4]{0}, f16[4,512,512]{2,1,0}, f16[512,512]{1,0}}, api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}, backend_config="\004\000\000\000\000\000\004\000\361h\343\210\265\370\344>\001\000\000\000\001\000\000\000\020\000\000\000$V\000\000"
%get-tuple-element.7.0 = f16[512,512]{1,0} get-tuple-element((f16[4,512,512]{2,1,0}, f16[512,512]{1,0}, f32[16,262144]{1,0}) %custom-call.5.0), index=1, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}
%all-reduce-start = f16[512,512]{1,0} all-reduce-start(f16[512,512]{1,0} %get-tuple-element.7.0), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%region_0.9._custom_call_lowering_rule, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"collective_backend_config":{"is_sync":true,"no_parallel_custom_call":false}}
%all-reduce-done = f16[512,512]{1,0} all-reduce-done(f16[512,512]{1,0} %all-reduce-start), metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}
%get-tuple-element.12.0 = f16[4,512,512]{2,1,0} get-tuple-element((f16[4,512,512]{2,1,0}, f16[512,512]{1,0}, f32[16,262144]{1,0}) %custom-call.5.0), index=0, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}
ROOT %tuple.1.0 = (f16[4,512,512]{2,1,0}, f16[512,512]{1,0}) tuple(f16[4,512,512]{2,1,0} %get-tuple-element.12.0, f16[512,512]{1,0} %all-reduce-done)
}
True
True
Now there are no all-gathers in the HLO, sharding is respected and only gradients are accumulated via an all-reduce.
Let’s put it together#
The complete definition of the primitives using custom_partitioning can be found in Custom_Operation_for_GPUs.py and the corresponding C++ code the defines python bindings in addition to the kernel implementations can be found below:
gpu_ops
code listing#
gpu_ops/kernel_helpers.h
gpu_ops/kernels.h
gpu_ops/pybind11_kernel_helpers.h
gpu_ops/gpu_ops.cpp
gpu_ops/rms_norm_kernels.cu
Generalized Convolutions in JAX#
JAX provides a number of interfaces to compute convolutions across data, including:
For basic convolution operations, the jax.numpy
and jax.scipy
operations are usually sufficient. If you want to do more general batched multi-dimensional convolution, the jax.lax
function is where you should start.
Basic One-dimensional Convolution#
Basic one-dimensional convolution is implemented by jax.numpy.convolve()
, which provides a JAX interface for numpy.convolve()
. Here is a simple example of 1D smoothing implemented via a convolution:
import matplotlib.pyplot as plt
from jax import random
import jax.numpy as jnp
import numpy as np
key = random.key(1701)
x = jnp.linspace(0, 10, 500)
y = jnp.sin(x) + 0.2 * random.normal(key, shape=(500,))
window = jnp.ones(10) / 10
y_smooth = jnp.convolve(y, window, mode='same')
plt.plot(x, y, 'lightgray')
plt.plot(x, y_smooth, 'black');

The mode
parameter controls how boundary conditions are treated; here we use mode='same'
to ensure that the output is the same size as the input.
For more information, see the jax.numpy.convolve()
documentation, or the documentation associated with the original numpy.convolve()
function.
Basic N-dimensional Convolution#
For N-dimensional convolution, jax.scipy.signal.convolve()
provides a similar interface to that of jax.numpy.convolve()
, generalized to N dimensions.
For example, here is a simple approach to de-noising an image based on convolution with a Gaussian filter:
from scipy import misc
import jax.scipy as jsp
fig, ax = plt.subplots(1, 3, figsize=(12, 5))
# Load a sample image; compute mean() to convert from RGB to grayscale.
image = jnp.array(misc.face().mean(-1))
ax[0].imshow(image, cmap='binary_r')
ax[0].set_title('original')
# Create a noisy version by adding random Gaussian noise
key = random.key(1701)
noisy_image = image + 50 * random.normal(key, image.shape)
ax[1].imshow(noisy_image, cmap='binary_r')
ax[1].set_title('noisy')
# Smooth the noisy image with a 2D Gaussian smoothing kernel.
x = jnp.linspace(-3, 3, 7)
window = jsp.stats.norm.pdf(x) * jsp.stats.norm.pdf(x[:, None])
smooth_image = jsp.signal.convolve(noisy_image, window, mode='same')
ax[2].imshow(smooth_image, cmap='binary_r')
ax[2].set_title('smoothed');
/tmp/ipykernel_5100/4118182506.py:7: DeprecationWarning: scipy.misc.face has been deprecated in SciPy v1.10.0; and will be completely removed in SciPy v1.12.0. Dataset methods have moved into the scipy.datasets module. Use scipy.datasets.face instead.
image = jnp.array(misc.face().mean(-1))

Like in the one-dimensional case, we use mode='same'
to specify how we would like edges to be handled. For more information on available options in N-dimensional convolutions, see the jax.scipy.signal.convolve()
documentation.
General Convolutions#
For the more general types of batched convolutions often useful in the context of building deep neural networks, JAX and XLA offer the very general N-dimensional conv_general_dilated function, but it’s not very obvious how to use it. We’ll give some examples of the common use-cases.
A survey of the family of convolutional operators, a guide to convolutional arithmetic, is highly recommended reading!
Let’s define a simple diagonal edge kernel:
# 2D kernel - HWIO layout
kernel = jnp.zeros((3, 3, 3, 3), dtype=jnp.float32)
kernel += jnp.array([[1, 1, 0],
[1, 0,-1],
[0,-1,-1]])[:, :, jnp.newaxis, jnp.newaxis]
print("Edge Conv kernel:")
plt.imshow(kernel[:, :, 0, 0]);
Edge Conv kernel:

And we’ll make a simple synthetic image:
# NHWC layout
img = jnp.zeros((1, 200, 198, 3), dtype=jnp.float32)
for k in range(3):
x = 30 + 60*k
y = 20 + 60*k
img = img.at[0, x:x+10, y:y+10, k].set(1.0)
print("Original Image:")
plt.imshow(img[0]);
Original Image:

lax.conv and lax.conv_with_general_padding#
These are the simple convenience functions for convolutions
️⚠️ The convenience lax.conv
, lax.conv_with_general_padding
helper function assume NCHW images and OIHW kernels.
from jax import lax
out = lax.conv(jnp.transpose(img,[0,3,1,2]), # lhs = NCHW image tensor
jnp.transpose(kernel,[3,2,0,1]), # rhs = OIHW conv kernel tensor
(1, 1), # window strides
'SAME') # padding mode
print("out shape: ", out.shape)
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(np.array(out)[0,0,:,:]);
out shape: (1, 3, 200, 198)
First output channel:

out = lax.conv_with_general_padding(
jnp.transpose(img,[0,3,1,2]), # lhs = NCHW image tensor
jnp.transpose(kernel,[2,3,0,1]), # rhs = IOHW conv kernel tensor
(1, 1), # window strides
((2,2),(2,2)), # general padding 2x2
(1,1), # lhs/image dilation
(1,1)) # rhs/kernel dilation
print("out shape: ", out.shape)
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(np.array(out)[0,0,:,:]);
out shape: (1, 3, 202, 200)
First output channel:

Dimension Numbers define dimensional layout for conv_general_dilated#
The important argument is the 3-tuple of axis layout arguments: (Input Layout, Kernel Layout, Output Layout)
N - batch dimension
H - spatial height
W - spatial width
C - channel dimension
I - kernel input channel dimension
O - kernel output channel dimension
⚠️ To demonstrate the flexibility of dimension numbers we choose a NHWC image and HWIO kernel convention for lax.conv_general_dilated
below.
dn = lax.conv_dimension_numbers(img.shape, # only ndim matters, not shape
kernel.shape, # only ndim matters, not shape
('NHWC', 'HWIO', 'NHWC')) # the important bit
print(dn)
ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))
SAME padding, no stride, no dilation#
out = lax.conv_general_dilated(img, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(1,1), # window strides
'SAME', # padding mode
(1,1), # lhs/image dilation
(1,1), # rhs/kernel dilation
dn) # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape)
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(np.array(out)[0,:,:,0]);
out shape: (1, 200, 198, 3)
First output channel:

VALID padding, no stride, no dilation#
out = lax.conv_general_dilated(img, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(1,1), # window strides
'VALID', # padding mode
(1,1), # lhs/image dilation
(1,1), # rhs/kernel dilation
dn) # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, "DIFFERENT from above!")
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(np.array(out)[0,:,:,0]);
out shape: (1, 198, 196, 3) DIFFERENT from above!
First output channel:

SAME padding, 2,2 stride, no dilation#
out = lax.conv_general_dilated(img, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(2,2), # window strides
'SAME', # padding mode
(1,1), # lhs/image dilation
(1,1), # rhs/kernel dilation
dn) # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, " <-- half the size of above")
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(np.array(out)[0,:,:,0]);
out shape: (1, 100, 99, 3) <-- half the size of above
First output channel:

VALID padding, no stride, rhs kernel dilation ~ Atrous convolution (excessive to illustrate)#
out = lax.conv_general_dilated(img, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(1,1), # window strides
'VALID', # padding mode
(1,1), # lhs/image dilation
(12,12), # rhs/kernel dilation
dn) # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape)
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(np.array(out)[0,:,:,0]);
out shape: (1, 176, 174, 3)
First output channel:

VALID padding, no stride, lhs=input dilation ~ Transposed Convolution#
out = lax.conv_general_dilated(img, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(1,1), # window strides
((0, 0), (0, 0)), # padding mode
(2,2), # lhs/image dilation
(1,1), # rhs/kernel dilation
dn) # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, "<-- larger than original!")
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(np.array(out)[0,:,:,0]);
out shape: (1, 397, 393, 3) <-- larger than original!
First output channel:

We can use the last to, for instance, implement transposed convolutions:
# The following is equivalent to tensorflow:
# N,H,W,C = img.shape
# out = tf.nn.conv2d_transpose(img, kernel, (N,2*H,2*W,C), (1,2,2,1))
# transposed conv = 180deg kernel rotation plus LHS dilation
# rotate kernel 180deg:
kernel_rot = jnp.rot90(jnp.rot90(kernel, axes=(0,1)), axes=(0,1))
# need a custom output padding:
padding = ((2, 1), (2, 1))
out = lax.conv_general_dilated(img, # lhs = image tensor
kernel_rot, # rhs = conv kernel tensor
(1,1), # window strides
padding, # padding mode
(2,2), # lhs/image dilation
(1,1), # rhs/kernel dilation
dn) # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, "<-- transposed_conv")
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(np.array(out)[0,:,:,0]);
out shape: (1, 400, 396, 3) <-- transposed_conv
First output channel:

1D Convolutions#
You aren’t limited to 2D convolutions, a simple 1D demo is below:
# 1D kernel - WIO layout
kernel = jnp.array([[[1, 0, -1], [-1, 0, 1]],
[[1, 1, 1], [-1, -1, -1]]],
dtype=jnp.float32).transpose([2,1,0])
# 1D data - NWC layout
data = np.zeros((1, 200, 2), dtype=jnp.float32)
for i in range(2):
for k in range(2):
x = 35*i + 30 + 60*k
data[0, x:x+30, k] = 1.0
print("in shapes:", data.shape, kernel.shape)
plt.figure(figsize=(10,5))
plt.plot(data[0]);
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
('NWC', 'WIO', 'NWC'))
print(dn)
out = lax.conv_general_dilated(data, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(1,), # window strides
'SAME', # padding mode
(1,), # lhs/image dilation
(1,), # rhs/kernel dilation
dn) # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape)
plt.figure(figsize=(10,5))
plt.plot(out[0]);
in shapes: (1, 200, 2) (3, 2, 2)
ConvDimensionNumbers(lhs_spec=(0, 2, 1), rhs_spec=(2, 1, 0), out_spec=(0, 2, 1))
out shape: (1, 200, 2)


3D Convolutions#
import matplotlib as mpl
# Random 3D kernel - HWDIO layout
kernel = jnp.array([
[[0, 0, 0], [0, 1, 0], [0, 0, 0]],
[[0, -1, 0], [-1, 0, -1], [0, -1, 0]],
[[0, 0, 0], [0, 1, 0], [0, 0, 0]]],
dtype=jnp.float32)[:, :, :, jnp.newaxis, jnp.newaxis]
# 3D data - NHWDC layout
data = jnp.zeros((1, 30, 30, 30, 1), dtype=jnp.float32)
x, y, z = np.mgrid[0:1:30j, 0:1:30j, 0:1:30j]
data += (jnp.sin(2*x*jnp.pi)*jnp.cos(2*y*jnp.pi)*jnp.cos(2*z*jnp.pi))[None,:,:,:,None]
print("in shapes:", data.shape, kernel.shape)
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
('NHWDC', 'HWDIO', 'NHWDC'))
print(dn)
out = lax.conv_general_dilated(data, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(1,1,1), # window strides
'SAME', # padding mode
(1,1,1), # lhs/image dilation
(1,1,1), # rhs/kernel dilation
dn) # dimension_numbers
print("out shape: ", out.shape)
# Make some simple 3d density plots:
from mpl_toolkits.mplot3d import Axes3D
def make_alpha(cmap):
my_cmap = cmap(jnp.arange(cmap.N))
my_cmap[:,-1] = jnp.linspace(0, 1, cmap.N)**3
return mpl.colors.ListedColormap(my_cmap)
my_cmap = make_alpha(plt.cm.viridis)
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=data.ravel(), cmap=my_cmap)
ax.axis('off')
ax.set_title('input')
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=out.ravel(), cmap=my_cmap)
ax.axis('off')
ax.set_title('3D conv output');
in shapes: (1, 30, 30, 30, 1) (3, 3, 3, 1, 1)
ConvDimensionNumbers(lhs_spec=(0, 4, 1, 2, 3), rhs_spec=(4, 3, 0, 1, 2), out_spec=(0, 4, 1, 2, 3))
out shape: (1, 30, 30, 30, 1)


Developer Documentation#
JAX welcomes contributions from the community. See below for various install guides to get setup as a developer as well as developer-focused resources such as Jax Enhancement Proposals.
Contributing to JAX#
Everyone can contribute to JAX, and we value everyone’s contributions. There are several ways to contribute, including:
Answering questions on JAX’s discussions page
Improving or expanding JAX’s documentation
Contributing to JAX’s code-base
Contributing in any of the above ways to the broader ecosystem of libraries built on JAX
The JAX project follows Google’s Open Source Community Guidelines.
Ways to contribute#
We welcome pull requests, in particular for those issues marked with contributions welcome or good first issue.
For other proposals, we ask that you first open a GitHub Issue or Discussion to seek feedback on your planned contribution.
Contributing code using pull requests#
We do all of our development using git, so basic knowledge is assumed.
Follow these steps to contribute code:
Sign the Google Contributor License Agreement (CLA). For more information, see the Pull Request Checklist below.
Fork the JAX repository by clicking the Fork button on the repository page. This creates a copy of the JAX repository in your own account.
Install Python >= 3.9 locally in order to run tests.
pip
installing your fork from source. This allows you to modify the code and immediately test it out:git clone https://github.com/YOUR_USERNAME/jax cd jax pip install -r build/test-requirements.txt # Installs all testing requirements. pip install -e ".[cpu]" # Installs JAX from the current directory in editable mode.
Add the JAX repo as an upstream remote, so you can use it to sync your changes.
git remote add upstream https://www.github.com/google/jax
Create a branch where you will develop from:
git checkout -b name-of-change
And implement your changes using your favorite editor (we recommend Visual Studio Code).
Make sure your code passes JAX’s lint and type checks, by running the following from the top of the repository:
pip install pre-commit pre-commit run --all
See Linting and Type-checking for more details.
Make sure the tests pass by running the following command from the top of the repository:
pytest -n auto tests/
JAX’s test suite is quite large, so if you know the specific test file that covers your changes, you can limit the tests to that; for example:
pytest -n auto tests/lax_scipy_test.py
You can narrow the tests further by using the
pytest -k
flag to match particular test names:pytest -n auto tests/lax_scipy_test.py -k testLogSumExp
JAX also offers more fine-grained control over which particular tests are run; see Running the tests for more information.
Once you are satisfied with your change, create a commit as follows ( how to write a commit message):
git add file1.py file2.py ... git commit -m "Your commit message"
Then sync your code with the main repo:
git fetch upstream git rebase upstream/main
Finally, push your commit on your development branch and create a remote branch in your fork that you can use to create a pull request from:
git push --set-upstream origin name-of-change
Please ensure your contribution is a single commit (see Single-change commits and pull requests)
Create a pull request from the JAX repository and send it for review. Check the JAX pull request checklist for considerations when preparing your PR, and consult GitHub Help if you need more information on using pull requests.
JAX pull request checklist#
As you prepare a JAX pull request, here are a few things to keep in mind:
Google contributor license agreement#
Contributions to this project must be accompanied by a Google Contributor License Agreement (CLA). You (or your employer) retain the copyright to your contribution; this simply gives us permission to use and redistribute your contributions as part of the project. Head over to https://cla.developers.google.com/ to see your current agreements on file or to sign a new one.
You generally only need to submit a CLA once, so if you’ve already submitted one (even if it was for a different project), you probably don’t need to do it again. If you’re not certain whether you’ve signed a CLA, you can open your PR and our friendly CI bot will check for you.
Single-change commits and pull requests#
A git commit ought to be a self-contained, single change with a descriptive message. This helps with review and with identifying or reverting changes if issues are uncovered later on.
Pull requests typically comprise a single git commit. (In some cases, for
instance for large refactors or internal rewrites, they may contain several.)
In preparing a pull request for review, you may need to squash together
multiple commits. We ask that you do this prior to sending the PR for review if
possible. The git rebase -i
command might be useful to this end.
Linting and Type-checking#
JAX uses mypy and ruff to statically test code quality; the easiest way to run these checks locally is via the pre-commit framework:
pip install pre-commit
pre-commit run --all
If your pull request touches documentation notebooks, this will also run some checks on those (See Update notebooks for more details).
Full GitHub test suite#
Your PR will automatically be run through a full test suite on GitHub CI, which covers a range of Python versions, dependency versions, and configuration options. It’s normal for these tests to turn up failures that you didn’t catch locally; to fix the issues you can push new commits to your branch.
Restricted test suite#
Once your PR has been reviewed, a JAX maintainer will mark it as Pull Ready
. This
will trigger a larger set of tests, including tests on GPU and TPU backends that are
not available via standard GitHub CI. Detailed results of these tests are not publicly
viewable, but the JAX maintainer assigned to your PR will communicate with you regarding
any failures these might uncover; it’s not uncommon, for example, that numerical tests
need different tolerances on TPU than on CPU.
Building from source#
First, obtain the JAX source code:
git clone https://github.com/google/jax
cd jax
Building JAX involves two steps:
Building or installing
jaxlib
, the C++ support library forjax
.Installing the
jax
Python package.
Building or installing jaxlib
#
Installing jaxlib
with pip#
If you’re only modifying Python portions of JAX, we recommend installing
jaxlib
from a prebuilt wheel using pip:
pip install jaxlib
See the JAX readme for full guidance on pip installation (e.g., for GPU and TPU support).
Building jaxlib
from source#
To build jaxlib
from source, you must also install some prerequisites:
a C++ compiler (g++, clang, or MSVC)
On Ubuntu or Debian you can install the necessary prerequisites with:
sudo apt install g++ python python3-dev
If you are building on a Mac, make sure XCode and the XCode command line tools are installed.
See below for Windows build instructions.
Python packages:
numpy
,wheel
,build
.
You can install the necessary Python dependencies using pip
:
pip install numpy wheel build
To build jaxlib
for CPU or TPU, you can run:
python build/build.py
pip install dist/*.whl # installs jaxlib (includes XLA)
There are two ways to build jaxlib
with CUDA support: (1) use
python build/build.py --enable_cuda
to generate a jaxlib wheel with cuda
support, or (2) use
python build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12
to generate three wheels (jaxlib without cuda, jax-cuda-plugin,
and jax-cuda-pjrt). You can set gpu_plugin_cuda_version
to 11 or 12.
See python build/build.py --help
for configuration options, including ways to
specify the paths to CUDA and CUDNN, which you must have installed. Here
python
should be the name of your Python 3 interpreter; on some systems, you
may need to use python3
instead. By default, the wheel is written to the
dist/
subdirectory of the current directory.
Building jaxlib from source with a modified XLA repository.#
JAX depends on XLA, whose source code is in the XLA GitHub repository. By default JAX uses a pinned copy of the XLA repository, but we often want to use a locally-modified copy of XLA when working on JAX. There are two ways to do this:
use Bazel’s
override_repository
feature, which you can pass as a command line flag tobuild.py
as follows:python build/build.py --bazel_options=--override_repository=xla=/path/to/xla
modify the
WORKSPACE
file in the root of the JAX source tree to point to a different XLA tree.
To contribute changes back to XLA, send PRs to the XLA repository.
The version of XLA pinned by JAX is regularly updated, but is updated in
particular before each jaxlib
release.
Additional Notes for Building jaxlib
from source on Windows#
On Windows, follow Install Visual Studio to set up a C++ toolchain. Visual Studio 2019 version 16.5 or newer is required. If you need to build with CUDA enabled, follow the CUDA Installation Guide to set up a CUDA environment.
JAX builds use symbolic links, which require that you activate Developer Mode.
You can either install Python using its Windows installer, or if you prefer, you can use Anaconda or Miniconda to set up a Python environment.
Some targets of Bazel use bash utilities to do scripting, so MSYS2 is needed. See Installing Bazel on Windows for more details. Install the following packages:
pacman -S patch coreutils
Once coreutils is installed, the realpath command should be present in your shell’s path.
Once everything is installed. Open PowerShell, and make sure MSYS2 is in the
path of the current session. Ensure bazel
, patch
and realpath
are
accessible. Activate the conda environment. The following command builds with
CUDA enabled, adjust it to whatever suitable for you:
python .\build\build.py `
--enable_cuda `
--cuda_path='C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1' `
--cudnn_path='C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1' `
--cuda_version='10.1' `
--cudnn_version='7.6.5'
To build with debug information, add the flag --bazel_options='--copt=/Z7'
.
Additional notes for building a ROCM jaxlib
for AMD GPUs#
You need several ROCM/HIP libraries installed to build for ROCM. For
example, on a Ubuntu machine with
AMD’s apt
repositories available,
you need a number of packages installed:
sudo apt install miopen-hip hipfft-dev rocrand-dev hipsparse-dev hipsolver-dev \
rccl-dev rccl hip-dev rocfft-dev roctracer-dev hipblas-dev rocm-device-libs
To build jaxlib with ROCM support, you can run the following build command, suitably adjusted for your paths and ROCM version.
python build/build.py --enable_rocm --rocm_path=/opt/rocm-5.7.0
AMD’s fork of the XLA repository may include fixes not present in the upstream XLA repository. If you experience problems with the upstream repository, you can try AMD’s fork, by cloning their repository:
git clone https://github.com/ROCmSoftwarePlatform/xla.git
and override the XLA repository with which JAX is built:
python build/build.py --enable_rocm --rocm_path=/opt/rocm-5.7.0 \
--bazel_options=--override_repository=xla=/path/to/xla-rocm
Installing jax
#
Once jaxlib
has been installed, you can install jax
by running:
pip install -e . # installs jax
To upgrade to the latest version from GitHub, just run git pull
from the JAX
repository root, and rebuild by running build.py
or upgrading jaxlib
if
necessary. You shouldn’t have to reinstall jax
because pip install -e
sets up symbolic links from site-packages into the repository.
Running the tests#
First, install the dependencies by running pip install -r build/test-requirements.txt
.
There are two supported mechanisms for running the JAX tests, either using Bazel or using pytest.
Using Bazel#
First, configure the JAX build by running:
python build/build.py --configure_only
You may pass additional options to build.py
to configure the build; see the
jaxlib
build documentation for details.
By default the Bazel build runs the JAX tests using jaxlib
built from source.
To run JAX tests, run:
bazel test //tests:cpu_tests //tests:backend_independent_tests
//tests:gpu_tests
and //tests:tpu_tests
are also available, if you have the necessary hardware.
To use a preinstalled jaxlib
instead of building jaxlib
from source, run
bazel test --//jax:build_jaxlib=false //tests:cpu_tests //tests:backend_independent_tests
A number of test behaviors can be controlled using environment variables (see
below). Environment variables may be passed to JAX tests using the
--test_env=FLAG=value
flag to Bazel.
Some of JAX tests are for multiple accelerators (i.e. GPUs, TPUs). When JAX is already installed, you can run GPUs tests like this:
bazel test //tests:gpu_tests --local_test_jobs=4 --test_tag_filters=multiaccelerator --//jax:build_jaxlib=false --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform
You can speed up single accelerator tests by running them in parallel on multiple accelerators. This also triggers multiple concurrent tests per accelerator. For GPUs, you can do it like this:
NB_GPUS=2
JOBS_PER_ACC=4
J=$((NB_GPUS * JOBS_PER_ACC))
MULTI_GPU="--run_under $PWD/build/parallel_accelerator_execute.sh --test_env=JAX_ACCELERATOR_COUNT=${NB_GPUS} --test_env=JAX_TESTS_PER_ACCELERATOR=${JOBS_PER_ACC} --local_test_jobs=$J"
bazel test //tests:gpu_tests //tests:backend_independent_tests --test_env=XLA_PYTHON_CLIENT_PREALLOCATE=false --test_tag_filters=-multiaccelerator $MULTI_GPU
Some test targets, like a //tests:logpcg_tests
optionally use matplotlib, so you may need to pip install matplotlib
to run tests via bazel.
Using pytest
#
To run all the JAX tests using pytest
, we recommend using pytest-xdist
,
which can run tests in parallel. It is installed as a part of
pip install -r build/test-requirements.txt
command.
From the repository root directory run:
pytest -n auto tests
Controlling test behavior#
JAX generates test cases combinatorially, and you can control the number of
cases that are generated and checked for each test (default is 10) using the
JAX_NUM_GENERATED_CASES
environment variable. The automated tests
currently use 25 by default.
For example, one might write
# Bazel
bazel test //tests/... --test_env=JAX_NUM_GENERATED_CASES=25`
or
# pytest
JAX_NUM_GENERATED_CASES=25 pytest -n auto tests
The automated tests also run the tests with default 64-bit floats and ints
(JAX_ENABLE_X64
):
JAX_ENABLE_X64=1 JAX_NUM_GENERATED_CASES=25 pytest -n auto tests
You can run a more specific set of tests using pytest’s built-in selection mechanisms, or alternatively you can run a specific test file directly to see more detailed information about the cases being run:
JAX_NUM_GENERATED_CASES=5 python tests/lax_numpy_test.py
You can skip a few tests known to be slow, by passing environment variable JAX_SKIP_SLOW_TESTS=1.
To specify a particular set of tests to run from a test file, you can pass a string
or regular expression via the --test_targets
flag. For example, you can run all
the tests of jax.numpy.pad
using:
python tests/lax_numpy_test.py --test_targets="testPad"
The Colab notebooks are tested for errors as part of the documentation build.
Doctests#
JAX uses pytest in doctest mode to test the code examples within the documentation. You can run this using
pytest docs
Additionally, JAX runs pytest in doctest-modules
mode to ensure code examples in
function docstrings will run correctly. You can run this locally using, for example:
pytest --doctest-modules jax/_src/numpy/lax_numpy.py
Keep in mind that there are several files that are marked to be skipped when the
doctest command is run on the full package; you can see the details in
ci-build.yaml
Type checking#
We use mypy
to check the type hints. To check types locally the same way
as the CI checks them:
pip install mypy
mypy --config=pyproject.toml --show-error-codes jax
Alternatively, you can use the pre-commit framework to run this on all staged files in your git repository, automatically using the same mypy version as in the GitHub CI:
pre-commit run mypy
Linting#
JAX uses the ruff linter to ensure code quality. You can check your local changes by running:
pip install ruff
ruff jax
Alternatively, you can use the pre-commit framework to run this on all staged files in your git repository, automatically using the same ruff version as the GitHub tests:
pre-commit run ruff
Update documentation#
To rebuild the documentation, install several packages:
pip install -r docs/requirements.txt
And then run:
sphinx-build -b html docs docs/build/html -j auto
This can take a long time because it executes many of the notebooks in the documentation source; if you’d prefer to build the docs without executing the notebooks, you can run:
sphinx-build -b html -D nb_execution_mode=off docs docs/build/html -j auto
You can then see the generated documentation in docs/build/html/index.html
.
The -j auto
option controls the parallelism of the build. You can use a number
in place of auto
to control how many CPU cores to use.
Update notebooks#
We use jupytext to maintain two synced copies of the notebooks
in docs/notebooks
: one in ipynb
format, and one in md
format. The advantage of the former
is that it can be opened and executed directly in Colab; the advantage of the latter is that
it makes it much easier to track diffs within version control.
Editing ipynb
#
For making large changes that substantially modify code and outputs, it is easiest to
edit the notebooks in Jupyter or in Colab. To edit notebooks in the Colab interface,
open http://colab.research.google.com and Upload
from your local repo.
Update it as needed, Run all cells
then Download ipynb
.
You may want to test that it executes properly, using sphinx-build
as explained above.
Editing md
#
For making smaller changes to the text content of the notebooks, it is easiest to edit the
.md
versions using a text editor.
Syncing notebooks#
After editing either the ipynb or md versions of the notebooks, you can sync the two versions
using jupytext by running jupytext --sync
on the updated
notebooks; for example:
pip install jupytext==1.16.0
jupytext --sync docs/notebooks/thinking_in_jax.ipynb
The jupytext version should match that specified in .pre-commit-config.yaml.
To check that the markdown and ipynb files are properly synced, you may use the pre-commit framework to perform the same check used by the github CI:
git add docs -u # pre-commit runs on files in git staging.
pre-commit run jupytext
Creating new notebooks#
If you are adding a new notebook to the documentation and would like to use the jupytext --sync
command discussed here, you can set up your notebook for jupytext by using the following command:
jupytext --set-formats ipynb,md:myst path/to/the/notebook.ipynb
This works by adding a "jupytext"
metadata field to the notebook file which specifies the
desired formats, and which the jupytext --sync
command recognizes when invoked.
Notebooks within the Sphinx build#
Some of the notebooks are built automatically as part of the pre-submit checks and
as part of the Read the docs build.
The build will fail if cells raise errors. If the errors are intentional, you can either catch them,
or tag the cell with raises-exceptions
metadata (example PR).
You have to add this metadata by hand in the .ipynb
file. It will be preserved when somebody else
re-saves the notebook.
We exclude some notebooks from the build, e.g., because they contain long computations.
See exclude_patterns
in conf.py.
Documentation building on readthedocs.io
#
JAX’s auto-generated documentation is at https://jax.readthedocs.io/.
The documentation building is controlled for the entire project by the
readthedocs JAX settings. The current settings
trigger a documentation build as soon as code is pushed to the GitHub main
branch.
For each code version, the building process is driven by the
.readthedocs.yml
and the docs/conf.py
configuration files.
For each automated documentation build you can see the documentation build logs.
If you want to test the documentation generation on Readthedocs, you can push code to the test-docs
branch. That branch is also built automatically, and you can
see the generated documentation here. If the documentation build
fails you may want to wipe the build environment for test-docs.
For a local test, I was able to do it in a fresh directory by replaying the commands I saw in the Readthedocs logs:
mkvirtualenv jax-docs # A new virtualenv
mkdir jax-docs # A new directory
cd jax-docs
git clone --no-single-branch --depth 50 https://github.com/google/jax
cd jax
git checkout --force origin/test-docs
git clean -d -f -f
workon jax-docs
python -m pip install --upgrade --no-cache-dir pip
python -m pip install --upgrade --no-cache-dir -I Pygments==2.3.1 setuptools==41.0.1 docutils==0.14 mock==1.0.1 pillow==5.4.1 alabaster>=0.7,<0.8,!=0.7.5 commonmark==0.8.1 recommonmark==0.5.0 'sphinx<2' 'sphinx-rtd-theme<0.5' 'readthedocs-sphinx-ext<1.1'
python -m pip install --exists-action=w --no-cache-dir -r docs/requirements.txt
cd docs
python `which sphinx-build` -T -E -b html -d _build/doctrees-readthedocs -D language=en . _build/html
Internal APIs#
core#
|
|
|
Autodidax: JAX core from scratch#
Ever want to learn how JAX works, but the implementation seemed impenetrable? Well, you’re in luck! By reading this tutorial, you’ll learn every big idea in JAX’s core system. You’ll even get clued into our weird jargon!
This is a work-in-progress draft. There are some important ingredients missing, still to come in parts 5 and 6 (and more?). There are also some simplifications here that we haven’t yet applied to the main system, but we will.
Part 1: Transformations as interpreters: standard evaluation, jvp
, and vmap
#
We want to transform functions that look like this:
def f(x):
y = sin(x) * 2.
z = - y + x
return z
Think of functions like sin
and the arithmetic operations underlying the
infix operators (mul
, add
, and neg
) as primitive operations, meaning
atomic units of processing rather than compositions.
“Transform” means “interpret differently.” Instead of standard interpretation where we apply primitive operations to numerical inputs to produce numerical outputs, we want to override primitive application and let different values flow through our program. For example, we might want to replace the application of every primitive with an application of its JVP rule, and let primal-tangent pairs flow through our program. Moreover, we want to be able to compose multiple transformations, leading to stacks of interpreters.
JAX core machinery#
We can implement stacks of interpreters and even have them all discharge on the fly as we execute the Python function to be transformed. To start, let’s define these primitives so that we can intercept their application:
from typing import NamedTuple
class Primitive(NamedTuple):
name: str
add_p = Primitive('add')
mul_p = Primitive('mul')
neg_p = Primitive("neg")
sin_p = Primitive("sin")
cos_p = Primitive("cos")
reduce_sum_p = Primitive("reduce_sum")
greater_p = Primitive("greater")
less_p = Primitive("less")
transpose_p = Primitive("transpose")
broadcast_p = Primitive("broadcast")
def add(x, y): return bind1(add_p, x, y)
def mul(x, y): return bind1(mul_p, x, y)
def neg(x): return bind1(neg_p, x)
def sin(x): return bind1(sin_p, x)
def cos(x): return bind1(cos_p, x)
def greater(x, y): return bind1(greater_p, x, y)
def less(x, y): return bind1(less_p, x, y)
def transpose(x, perm): return bind1(transpose_p, x, perm=perm)
def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes)
def reduce_sum(x, axis=None):
if axis is None:
axis = tuple(range(np.ndim(x)))
if type(axis) is int:
axis = (axis,)
return bind1(reduce_sum_p, x, axis=axis)
def bind1(prim, *args, **params):
out, = bind(prim, *args, **params)
return out
We’ll set up array data types and infix operator methods in a moment.
A Primitive
is just an object with a name, to which we attach our
interpretation rules (one for each transformation). The bind
function is our
interception point: it’ll figure out which transformation rule to apply, based
on how the arguments are boxed in tracers and what interpreters are active.
The functions that user code calls, like add
and sin
, are just wrappers
around calls to bind
. These wrappers let us control how arguments are passed
to bind
, and in particular we follow a handy internal convention: when we
call bind
, we pass values representing array data as positional arguments,
and we pass metadata like the axis
argument to sum_p
via keyword. This
calling convention simplifies some core logic (since e.g. instances of the
Tracer
class to be defined below can only occur in positional arguments to
bind
). The wrappers can also provide docstrings!
We represent active interpreters as a stack. The stack is just a simple
list
, and each element is a container with an integer level (corresponding
to the element’s height in the stack), an interpreter type (which we’ll call a
trace_type
), and an optional field for any global data the interpreter
needs. We call each element a MainTrace
, though maybe “Interpreter” would be
more descriptive.
from collections.abc import Sequence
from contextlib import contextmanager
from typing import Optional, Any
class MainTrace(NamedTuple):
level: int
trace_type: type['Trace']
global_data: Optional[Any]
trace_stack: list[MainTrace] = []
dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3
@contextmanager
def new_main(trace_type: type['Trace'], global_data=None):
level = len(trace_stack)
main = MainTrace(level, trace_type, global_data)
trace_stack.append(main)
try:
yield main
finally:
trace_stack.pop()
When we’re about to apply a transformation, we’ll push another interpreter
onto the stack using new_main
. Then, as we apply primitives in the function,
we can think of the bind
first being interpreted by the trace at the top of
the stack (i.e. with the highest level). If that first interpreter itself
binds other primitives in its interpretation rule for the primitive, like how
the JVP rule of sin_p
might bind cos_p
and mul_p
, then those bind
calls will be handled by the interpreter at the next level down.
What goes at the bottom of the interpreter stack? At the bottom, we know all the transformation interpreters are finished, and we just want to do standard evaluation. So at the bottom we’ll put an evaluation interpreter.
Let’s sketch out the interface for interpreters, which is based on the Trace
and Tracer
base classes. A Tracer
represents a boxed-up value, perhaps
carrying some extra context data used by the interpreter. A Trace
handles
boxing up values into Tracers
and also handles primitive application.
class Trace:
main: MainTrace
def __init__(self, main: MainTrace) -> None:
self.main = main
def pure(self, val): assert False # must override
def lift(self, val): assert False # must override
def process_primitive(self, primitive, tracers, params):
assert False # must override
The first two methods are about boxing up values in Tracer
s, which are the
objects that flow through the Python programs we transform. The last method is
the callback we’ll use to interpret primitive application.
The Trace
itself doesn’t contain any data, other than a reference to its
corresponding MainTrace
instance. In fact, multiple instances of a Trace
might be created and discarded during an application of a transformation,
whereas only a single MainTrace
instance is created per application of a
transformation.
As for Tracer
s themselves, each one carries an abstract value (and forwards
infix operators to it), and the rest is up to the transformation. (The
relationship between Tracer
s and AbstractValue
s is that there’s one
Tracer
per transformation, and at least one AbstractValue
per base type,
like arrays.)
import numpy as np
class Tracer:
_trace: Trace
__array_priority__ = 1000
@property
def aval(self):
assert False # must override
def full_lower(self):
return self # default implementation
def __neg__(self): return self.aval._neg(self)
def __add__(self, other): return self.aval._add(self, other)
def __radd__(self, other): return self.aval._radd(self, other)
def __mul__(self, other): return self.aval._mul(self, other)
def __rmul__(self, other): return self.aval._rmul(self, other)
def __gt__(self, other): return self.aval._gt(self, other)
def __lt__(self, other): return self.aval._lt(self, other)
def __bool__(self): return self.aval._bool(self)
def __nonzero__(self): return self.aval._nonzero(self)
def __getattr__(self, name):
try:
return getattr(self.aval, name)
except AttributeError:
raise AttributeError(f"{self.__class__.__name__} has no attribute {name}")
def swap(f): return lambda x, y: f(y, x)
class ShapedArray:
array_abstraction_level = 1
shape: tuple[int, ...]
dtype: np.dtype
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype
@property
def ndim(self):
return len(self.shape)
_neg = staticmethod(neg)
_add = staticmethod(add)
_radd = staticmethod(swap(add))
_mul = staticmethod(mul)
_rmul = staticmethod(swap(mul))
_gt = staticmethod(greater)
_lt = staticmethod(less)
@staticmethod
def _bool(tracer):
raise Exception("ShapedArray can't be unambiguously converted to bool")
@staticmethod
def _nonzero(tracer):
raise Exception("ShapedArray can't be unambiguously converted to bool")
def str_short(self):
return f'{self.dtype.name}[{",".join(str(d) for d in self.shape)}]'
def __hash__(self):
return hash((self.shape, self.dtype))
def __eq__(self, other):
return (type(self) is type(other) and
self.shape == other.shape and self.dtype == other.dtype)
def __repr__(self):
return f"ShapedArray(shape={self.shape}, dtype={self.dtype})"
class ConcreteArray(ShapedArray):
array_abstraction_level = 2
val: np.ndarray
def __init__(self, val):
self.val = val
self.shape = val.shape
self.dtype = val.dtype
@staticmethod
def _bool(tracer):
return bool(tracer.aval.val)
@staticmethod
def _nonzero(tracer):
return bool(tracer.aval.val)
def get_aval(x):
if isinstance(x, Tracer):
return x.aval
elif type(x) in jax_types:
return ConcreteArray(np.asarray(x))
else:
raise TypeError(x)
jax_types = {bool, int, float,
np.bool_, np.int32, np.int64, np.float32, np.float64, np.ndarray}
Notice that we actually have two AbstractValue
s for arrays, representing
different levels of abstraction. A ShapedArray
represents the set of all
possible arrays with a given shape and dtype. A ConcreteArray
represents a
singleton set consisting of a single array value.
Now that we’ve set up the interpreter stack, the Trace/Tracer API for
interpreters, and abstract values, we can come back to implement bind
:
def bind(prim, *args, **params):
top_trace = find_top_trace(args)
tracers = [full_raise(top_trace, arg) for arg in args]
outs = top_trace.process_primitive(prim, tracers, params)
return [full_lower(out) for out in outs]
The main action is that we call find_top_trace
to figure out which
interpreter should handle this primitive application. We then call that top
trace’s process_primitive
so that the trace can apply its interpretation
rule. The calls to full_raise
just ensure that the inputs are boxed in the
top trace’s Tracer
instances, and the call to full_lower
is an optional
optimization so that we unbox values out of Tracer
s as much as possible.
import operator as op
def find_top_trace(xs) -> Trace:
top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),
default=trace_stack[0], key=op.attrgetter('level'))
if dynamic_trace and dynamic_trace.level > top_main.level:
top_main = dynamic_trace
return top_main.trace_type(top_main)
In words, ignoring the dynamic_trace
step until Part 3, find_top_trace
returns the highest-level interpreter associated with the Tracer
s on its
inputs, and otherwise returns the interpreter at the bottom of the stack
(which is always an evaluation trace, at least for now). This is a deviation
from the description above, where we always start by running the interpreter
at the top of the stack and then work our way down, applying every interpreter
in the stack. Instead, we’re only applying an interpreter when the input
arguments to a primitive bind are boxed in a Tracer
corresponding to that
interpreter. This optimization lets us skip irrelevant transformations, but
bakes in an assumption that transformations mostly follow data dependence
(except for the special bottom-of-the-stack interpreter, which interprets
everything).
An alternative would be to have every interpreter in the stack interpret every operation. That’s worth exploring! JAX is designed around data dependence in large part because that’s so natural for automatic differentiation, and JAX’s roots are in autodiff. But it may be over-fit.
def full_lower(val: Any):
if isinstance(val, Tracer):
return val.full_lower()
else:
return val
def full_raise(trace: Trace, val: Any) -> Tracer:
if not isinstance(val, Tracer):
assert type(val) in jax_types
return trace.pure(val)
level = trace.main.level
if val._trace.main is trace.main:
return val
elif val._trace.main.level < level:
return trace.lift(val)
elif val._trace.main.level > level:
raise Exception(f"Can't lift level {val._trace.main.level} to {level}.")
else: # val._trace.level == level
raise Exception(f"Different traces at same level: {val._trace}, {trace}.")
The logic in full_raise
serves to box values into Tracer
s for a particular
Trace
, calling different methods on the Trace
based on context:
Trace.pure
is called on non-Tracer
constants, and Trace.lift
is called
for values that are already Tracer
s from a lower-level interpreter. These
two methods could share the same implementation, but by distinguishing them in
the core logic we can provide more information to the Trace
subclass.
That’s it for the JAX core! Now we can start adding interpreters.
Evaluation interpreter#
We’ll start with the simplest interpreter: the evaluation interpreter that will sit at the bottom of the interpreter stack.
class EvalTrace(Trace):
pure = lift = lambda self, x: x # no boxing in Tracers needed
def process_primitive(self, primitive, tracers, params):
return impl_rules[primitive](*tracers, **params)
trace_stack.append(MainTrace(0, EvalTrace, None)) # special bottom of the stack
# NB: in JAX, instead of a dict we attach impl rules to the Primitive instance
impl_rules = {}
impl_rules[add_p] = lambda x, y: [np.add(x, y)]
impl_rules[mul_p] = lambda x, y: [np.multiply(x, y)]
impl_rules[neg_p] = lambda x: [np.negative(x)]
impl_rules[sin_p] = lambda x: [np.sin(x)]
impl_rules[cos_p] = lambda x: [np.cos(x)]
impl_rules[reduce_sum_p] = lambda x, *, axis: [np.sum(x, axis)]
impl_rules[greater_p] = lambda x, y: [np.greater(x, y)]
impl_rules[less_p] = lambda x, y: [np.less(x, y)]
impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)]
def broadcast_impl(x, *, shape, axes):
for axis in sorted(axes):
x = np.expand_dims(x, axis)
return [np.broadcast_to(x, shape)]
impl_rules[broadcast_p] = broadcast_impl
With this interpreter, we can evaluate user functions:
def f(x):
y = sin(x) * 2.
z = - y + x
return z
print(f(3.0))
2.7177599838802657
Woo! Like going around in a big circle. But the point of this indirection is that now we can add some real transformations.
Forward-mode autodiff with jvp
#
First, a few helper functions:
import builtins
def zeros_like(val):
aval = get_aval(val)
return np.zeros(aval.shape, aval.dtype)
def unzip2(pairs):
lst1, lst2 = [], []
for x1, x2 in pairs:
lst1.append(x1)
lst2.append(x2)
return lst1, lst2
def map(f, *xs):
return list(builtins.map(f, *xs))
def zip(*args):
fst, *rest = args = map(list, args)
n = len(fst)
for arg in rest:
assert len(arg) == n
return list(builtins.zip(*args))
The Tracer
for forward-mode autodiff carries a primal-tangent pair. The
Trace
applies JVP rules.
class JVPTracer(Tracer):
def __init__(self, trace, primal, tangent):
self._trace = trace
self.primal = primal
self.tangent = tangent
@property
def aval(self):
return get_aval(self.primal)
class JVPTrace(Trace):
pure = lift = lambda self, val: JVPTracer(self, val, zeros_like(val))
def process_primitive(self, primitive, tracers, params):
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
jvp_rule = jvp_rules[primitive]
primal_outs, tangent_outs = jvp_rule(primals_in, tangents_in, **params)
return [JVPTracer(self, x, t) for x, t in zip(primal_outs, tangent_outs)]
jvp_rules = {}
Notice both pure
and lift
package a value into a JVPTracer
with the
minimal amount of context, which is a zero tangent value.
Let’s add some JVP rules for primitives:
def add_jvp(primals, tangents):
(x, y), (x_dot, y_dot) = primals, tangents
return [x + y], [x_dot + y_dot]
jvp_rules[add_p] = add_jvp
def mul_jvp(primals, tangents):
(x, y), (x_dot, y_dot) = primals, tangents
return [x * y], [x_dot * y + x * y_dot]
jvp_rules[mul_p] = mul_jvp
def sin_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
return [sin(x)], [cos(x) * x_dot]
jvp_rules[sin_p] = sin_jvp
def cos_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
return [cos(x)], [-sin(x) * x_dot]
jvp_rules[cos_p] = cos_jvp
def neg_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
return [neg(x)], [neg(x_dot)]
jvp_rules[neg_p] = neg_jvp
def reduce_sum_jvp(primals, tangents, *, axis):
(x,), (x_dot,) = primals, tangents
return [reduce_sum(x, axis)], [reduce_sum(x_dot, axis)]
jvp_rules[reduce_sum_p] = reduce_sum_jvp
def greater_jvp(primals, tangents):
(x, y), _ = primals, tangents
out_primal = greater(x, y)
return [out_primal], [zeros_like(out_primal)]
jvp_rules[greater_p] = greater_jvp
def less_jvp(primals, tangents):
(x, y), _ = primals, tangents
out_primal = less(x, y)
return [out_primal], [zeros_like(out_primal)]
jvp_rules[less_p] = less_jvp
Finally, we add a transformation API to kick off the trace:
def jvp_v1(f, primals, tangents):
with new_main(JVPTrace) as main:
trace = JVPTrace(main)
tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]
out = f(*tracers_in)
tracer_out = full_raise(trace, out)
primal_out, tangent_out = tracer_out.primal, tracer_out.tangent
return primal_out, tangent_out
And with that, we can differentiate!
x = 3.0
y, sin_deriv_at_3 = jvp_v1(sin, (x,), (1.0,))
print(sin_deriv_at_3)
print(cos(3.0))
-0.9899924966004454
-0.9899924966004454
def f(x):
y = sin(x) * 2.
z = - y + x
return z
x, xdot = 3., 1.
y, ydot = jvp_v1(f, (x,), (xdot,))
print(y)
print(ydot)
2.7177599838802657
2.979984993200891
def deriv(f):
return lambda x: jvp_v1(f, (x,), (1.,))[1]
print(deriv(sin)(3.))
print(deriv(deriv(sin))(3.))
print(deriv(deriv(deriv(sin)))(3.))
print(deriv(deriv(deriv(deriv(sin))))(3.))
-0.9899924966004454
-0.1411200080598672
0.9899924966004454
0.1411200080598672
def f(x):
if x > 0.: # Python control flow
return 2. * x
else:
return x
print(deriv(f)(3.))
print(deriv(f)(-3.))
2.0
1.0
Pytrees and flattening user functions’ inputs and outputs#
A limitation with jvp_v1
is that it assumes the user function accepts arrays
as positional arguments and produces a single array as output. What if it
produced a list as output? Or accepted nested containers as inputs? It would
be a pain to deal with all the possible containers in inputs and outputs at
every layer of the stack. Instead, we can wrap the user function so that the
wrapped version accepts arrays as inputs and returns a flat list of arrays as
output. The wrapper just needs to unflatten its input, call the user function,
and flatten the output.
Here’s how we’d like to write jvp
, assuming the user always gives us
functions that take arrays as inputs and produces a flat list of arrays as
outputs:
def jvp_flat(f, primals, tangents):
with new_main(JVPTrace) as main:
trace = JVPTrace(main)
tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
primals_out, tangents_out = unzip2((t.primal, t.tangent) for t in tracers_out)
return primals_out, tangents_out
To support user functions that have arbitrary containers in the inputs and
outputs, here’s how we’d write the user-facing jvp
wrapper:
def jvp(f, primals, tangents):
primals_flat, in_tree = tree_flatten(primals)
tangents_flat, in_tree2 = tree_flatten(tangents)
if in_tree != in_tree2: raise TypeError
f, out_tree = flatten_fun(f, in_tree)
primals_out_flat, tangents_out_flat = jvp_flat(f, primals_flat, tangents_flat)
primals_out = tree_unflatten(out_tree(), primals_out_flat)
tangents_out = tree_unflatten(out_tree(), tangents_out_flat)
return primals_out, tangents_out
Notice that we had to plumb the tree structure of the user function output
back to the caller of flatten_fun
. That information isn’t available until we
actually run the user function, so flatten_fun
just returns a reference to a
mutable cell, represented as a thunk. These side-effects are safe because we
always run the user function exactly once. (This safe regime is the reason for
the “linear” name in linear_util.py
, in the sense of linear
types.)
All that remains is to write tree_flatten
, tree_unflatten
, and
flatten_fun
.
Show code cell source
def flatten_fun(f, in_tree):
store = Store()
def flat_fun(*args_flat):
pytree_args = tree_unflatten(in_tree, args_flat)
out = f(*pytree_args)
out_flat, out_tree = tree_flatten(out)
store.set_value(out_tree)
return out_flat
return flat_fun, store
class Empty: pass
empty = Empty()
class Store:
val = empty
def set_value(self, val):
assert self.val is empty
self.val = val
def __call__(self):
return self.val
Show code cell source
from collections.abc import Hashable, Iterable, Iterator
import itertools as it
from typing import Callable
class NodeType(NamedTuple):
name: str
to_iterable: Callable
from_iterable: Callable
def register_pytree_node(ty: type, to_iter: Callable, from_iter: Callable
) -> None:
node_types[ty] = NodeType(str(ty), to_iter, from_iter)
node_types: dict[type, NodeType] = {}
register_pytree_node(tuple, lambda t: (None, t), lambda _, xs: tuple(xs))
register_pytree_node(list, lambda l: (None, l), lambda _, xs: list(xs))
register_pytree_node(dict,
lambda d: map(tuple, unzip2(sorted(d.items()))),
lambda keys, vals: dict(zip(keys, vals)))
class PyTreeDef(NamedTuple):
node_type: NodeType
node_metadata: Hashable
child_treedefs: tuple['PyTreeDef', ...]
class Leaf: pass
leaf = Leaf()
def tree_flatten(x: Any) -> tuple[list[Any], PyTreeDef]:
children_iter, treedef = _tree_flatten(x)
return list(children_iter), treedef
def _tree_flatten(x: Any) -> tuple[Iterable, PyTreeDef]:
node_type = node_types.get(type(x))
if node_type:
node_metadata, children = node_type.to_iterable(x)
children_flat, child_trees = unzip2(map(_tree_flatten, children))
flattened = it.chain.from_iterable(children_flat)
return flattened, PyTreeDef(node_type, node_metadata, tuple(child_trees))
else:
return [x], leaf
def tree_unflatten(treedef: PyTreeDef, xs: list[Any]) -> Any:
return _tree_unflatten(treedef, iter(xs))
def _tree_unflatten(treedef: PyTreeDef, xs: Iterator) -> Any:
if treedef is leaf:
return next(xs)
else:
children = (_tree_unflatten(t, xs) for t in treedef.child_treedefs)
return treedef.node_type.from_iterable(treedef.node_metadata, children)
With this pytree-handling jvp
implementation, we can now handle arbitrary
input and output containers. That’ll come in handy with future transformations
too!
def f(x):
y = sin(x) * 2.
z = - y + x
return {'hi': z, 'there': [x, y]}
x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
{'hi': 2.7177599838802657, 'there': [3.0, 0.2822400161197344]}
{'hi': 2.979984993200891, 'there': [1.0, -1.9799849932008908]}
Vectorized batching with vmap
#
First, a couple helper functions, one for producing mapped abstract values from unmapped ones (by removing an axis), and one for moving batch dimensions around:
def mapped_aval(batch_dim, aval):
shape = list(aval.shape)
del shape[batch_dim]
return ShapedArray(tuple(shape), aval.dtype)
def move_batch_axis(axis_size, src, dst, x):
if src is not_mapped:
target_shape = list(np.shape(x))
target_shape.insert(dst, axis_size)
return broadcast(x, target_shape, [dst])
elif src == dst:
return x
else:
return moveaxis(x, src, dst)
def moveaxis(x, src: int, dst: int):
perm = [i for i in range(np.ndim(x)) if i != src]
perm.insert(dst, src)
return transpose(x, perm)
The Tracer
for vectorized batching carries a batched value and an optional
integer indicating which axis (if any) is the batch axis.
from typing import Union
class NotMapped: pass
not_mapped = NotMapped()
BatchAxis = Union[NotMapped, int]
class BatchTracer(Tracer):
def __init__(self, trace, val, batch_dim: BatchAxis):
self._trace = trace
self.val = val
self.batch_dim = batch_dim
@property
def aval(self):
if self.batch_dim is not_mapped:
return get_aval(self.val)
else:
return mapped_aval(self.batch_dim, get_aval(self.val))
def full_lower(self):
if self.batch_dim is not_mapped:
return full_lower(self.val)
else:
return self
class BatchTrace(Trace):
pure = lift = lambda self, val: BatchTracer(self, val, not_mapped)
def process_primitive(self, primitive, tracers, params):
vals_in, bdims_in = unzip2((t.val, t.batch_dim) for t in tracers)
vmap_rule = vmap_rules[primitive]
val_outs, bdim_outs = vmap_rule(self.axis_size, vals_in, bdims_in, **params)
return [BatchTracer(self, x, bd) for x, bd in zip(val_outs, bdim_outs)]
@property
def axis_size(self):
return self.main.global_data
vmap_rules = {}
Here we’ve implemented the optional Tracer.full_lower
method, which lets us
peel off a batching tracer if it’s not needed because it doesn’t represent a
batched value.
For BatchTrace
, analogous to JVPTrace
, the methods pure
and lift
just
box a value in a BatchTracer
with the minimal amount of context, which in
this case is a batch_dim
taking the sentinel value not_mapped
. Notice we
use the MainTrace
’s interpreter-global data field to store the batch axis
size.
Next we can define batching interpreter rules for each primitive:
from functools import partial
def binop_batching_rule(op, axis_size, vals_in, dims_in):
(x, y), (x_bdim, y_bdim) = vals_in, dims_in
if x_bdim != y_bdim:
if x_bdim is not_mapped:
x = move_batch_axis(axis_size, x_bdim, y_bdim, x)
x_bdim = y_bdim
else:
y = move_batch_axis(axis_size, y_bdim, x_bdim, y)
return [op(x, y)], [x_bdim]
vmap_rules[add_p] = partial(binop_batching_rule, add)
vmap_rules[mul_p] = partial(binop_batching_rule, mul)
def vectorized_unop_batching_rule(op, axis_size, vals_in, dims_in):
(x,), (x_bdim,) = vals_in, dims_in
return [op(x)], [x_bdim]
vmap_rules[sin_p] = partial(vectorized_unop_batching_rule, sin)
vmap_rules[cos_p] = partial(vectorized_unop_batching_rule, cos)
vmap_rules[neg_p] = partial(vectorized_unop_batching_rule, neg)
def reduce_sum_batching_rule(axis_size, vals_in, dims_in, *, axis):
(x,), (x_bdim,) = vals_in, dims_in
new_axis = tuple(ax + (x_bdim <= ax) for ax in axis)
out_bdim = x_bdim - sum(ax < x_bdim for ax in axis)
return [reduce_sum(x, new_axis)], [out_bdim]
vmap_rules[reduce_sum_p] = reduce_sum_batching_rule
Finally, we add a transformation API to kick off the trace:
def vmap_flat(f, in_axes, *args):
axis_size, = {x.shape[ax] for x, ax in zip(args, in_axes)
if ax is not not_mapped}
with new_main(BatchTrace, axis_size) as main:
trace = BatchTrace(main)
tracers_in = [BatchTracer(trace, x, ax) if ax is not None else x
for x, ax in zip(args, in_axes)]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
vals_out, bdims_out = unzip2((t.val, t.batch_dim) for t in tracers_out)
outs_transposed = [move_batch_axis(axis_size, bdim, 0, val_out)
for val_out, bdim in zip(vals_out, bdims_out)]
return outs_transposed
def vmap(f, in_axes):
def batched_f(*args):
args_flat, in_tree = tree_flatten(args)
in_axes_flat, in_tree2 = tree_flatten(in_axes)
if in_tree != in_tree2: raise TypeError
f_flat, out_tree = flatten_fun(f, in_tree)
outs_flat = vmap_flat(f_flat, in_axes_flat, *args_flat)
return tree_unflatten(out_tree(), outs_flat)
return batched_f
def add_one_to_a_scalar(scalar):
assert np.ndim(scalar) == 0
return 1 + scalar
vector_in = np.arange(3.)
vector_out = vmap(add_one_to_a_scalar, (0,))(vector_in)
print(vector_in)
print(vector_out)
[0. 1. 2.]
[1. 2. 3.]
def jacfwd(f, x):
pushfwd = lambda v: jvp(f, (x,), (v,))[1]
vecs_in = np.eye(np.size(x)).reshape(np.shape(x) * 2)
return vmap(pushfwd, (0,))(vecs_in)
def f(x):
return sin(x)
jacfwd(f, np.arange(3.))
array([[ 1. , 0. , -0. ],
[ 0. , 0.54030231, -0. ],
[ 0. , 0. , -0.41614684]])
That’s it for jvp
and vmap
!
Part 2: Jaxprs#
The next transformations on the horizon are jit
for just-in-time
compilation and vjp
for reverse-mode autodiff. (grad
is just a small
wrapper around vjp
.) Whereas jvp
and vmap
only needed each Tracer
to
carry a little bit of extra context, for both jit
and vjp
we need much
richer context: we need to represent programs. That is, we need jaxprs!
Jaxprs are JAX’s internal intermediate representation of programs. They are
explicitly typed, functional, first-order, and in ANF form. We need a
program representation for jit
because the purpose of jit
is to stage
computation out of Python. For any computation we want to stage out, we need
to be able to represent it as data, and build it up as we trace a Python
function. Similarly, vjp
needs a way to represent the computation for the
backward pass of reverse-mode autodiff. We use the same jaxpr program
representation for both needs.
(Building a program representation is the most free kind of trace-transformation, and so except for issues around handling native Python control flow, any transformation could be implemented by first tracing to a jaxpr and then interpreting the jaxpr.)
Jaxpr data structures#
The jaxpr term syntax is roughly:
jaxpr ::=
{ lambda <binder> , ... .
let <eqn>
...
in ( <atom> , ... ) }
binder ::= <var>:<array_type>
var ::= a | b | c | ...
atom ::= <var> | <literal>
literal ::= <int32> | <int64> | <float32> | <float64>
eqn ::= <binder> , ... = <primitive> [ <params> ] <atom> , ...
The syntax of types is:
jaxpr_type ::= [ <array_type> , ... ] -> [ <array_type> , ... ]
array_type ::= <dtype>[<shape>]
dtype ::= f32 | f64 | i32 | i64
shape ::= <int> , ...
How do we represent these as Python data structures? We reuse ShapedArrays to represent types, and we can represent the term syntax with a few Python structs:
class Var:
aval: ShapedArray
def __init__(self, aval): self.aval = aval
class Lit:
val: Any
aval: ShapedArray
def __init__(self, val):
self.aval = aval = raise_to_shaped(get_aval(val))
self.val = np.array(val, aval.dtype)
Atom = Union[Var, Lit]
class JaxprEqn(NamedTuple):
primitive: Primitive
inputs: list[Atom]
params: dict[str, Any]
out_binders: list[Var]
class Jaxpr(NamedTuple):
in_binders: list[Var]
eqns: list[JaxprEqn]
outs: list[Atom]
def __hash__(self): return id(self)
__eq__ = op.is_
def raise_to_shaped(aval):
return ShapedArray(aval.shape, aval.dtype)
Type-checking a jaxpr involves checking that there are no unbound variables, that variables are only bound once, and that for each equation the type of the primitive application matches the type of the output binders.
class JaxprType(NamedTuple):
in_types: list[ShapedArray]
out_types: list[ShapedArray]
def __repr__(self):
in_types = ', '.join(aval.str_short() for aval in self.in_types)
out_types = ', '.join(aval.str_short() for aval in self.out_types)
return f'({in_types}) -> ({out_types})'
def typecheck_jaxpr(jaxpr: Jaxpr) -> JaxprType:
env: set[Var] = set()
for v in jaxpr.in_binders:
if v in env: raise TypeError
env.add(v)
for eqn in jaxpr.eqns:
in_types = [typecheck_atom(env, x) for x in eqn.inputs]
out_types = abstract_eval_rules[eqn.primitive](*in_types, **eqn.params)
for out_binder, out_type in zip(eqn.out_binders, out_types):
if not out_type == out_binder.aval: raise TypeError
for out_binder in eqn.out_binders:
if out_binder in env: raise TypeError
env.add(out_binder)
in_types = [v.aval for v in jaxpr.in_binders]
out_types = [typecheck_atom(env, x) for x in jaxpr.outs]
return JaxprType(in_types, out_types)
def typecheck_atom(env: set[Var], x: Atom) -> ShapedArray:
if isinstance(x, Var):
if x not in env: raise TypeError("unbound variable")
return x.aval
elif isinstance(x, Lit):
return raise_to_shaped(get_aval(x.val))
else:
assert False
We can apply the function represented by a jaxpr to arguments with a simple interpreter.
def eval_jaxpr(jaxpr: Jaxpr, args: list[Any]) -> list[Any]:
env: dict[Var, Any] = {}
def read(x: Atom) -> Any:
return env[x] if type(x) is Var else x.val
def write(v: Var, val: Any) -> None:
assert v not in env # single-assignment
env[v] = val
map(write, jaxpr.in_binders, args)
for eqn in jaxpr.eqns:
in_vals = map(read, eqn.inputs)
outs = bind(eqn.primitive, *in_vals, **eqn.params)
map(write, eqn.out_binders, outs)
return map(read, jaxpr.outs)
def jaxpr_as_fun(jaxpr: Jaxpr):
return lambda *args: eval_jaxpr(jaxpr, args)
By using bind
in the interpreter, this interpreter itself is traceable.
Building jaxprs with tracing#
Now that we have jaxprs as a data structure, we need ways to produce these
from tracing Python code. In general there are two variants of how we trace to
a jaxpr; jit
uses one and vjp
uses the other. We’ll start with the one
used by jit
, which is also used by control flow primitives like lax.cond
,
lax.while_loop
, and lax.scan
.
def split_list(lst: list[Any], n: int) -> tuple[list[Any], list[Any]]:
assert 0 <= n <= len(lst)
return lst[:n], lst[n:]
def partition_list(bs: list[bool], l: list[Any]) -> tuple[list[Any], list[Any]]:
assert len(bs) == len(l)
lists = lst1, lst2 = [], []
for b, x in zip(bs, l):
lists[b].append(x)
return lst1, lst2
# NB: the analogous class in JAX is called 'DynamicJaxprTracer'
class JaxprTracer(Tracer):
__slots__ = ['aval']
aval: ShapedArray
def __init__(self, trace, aval):
self._trace = trace
self.aval = aval
# NB: the analogous class in JAX is called 'DynamicJaxprTrace'
class JaxprTrace(Trace):
def new_arg(self, aval: ShapedArray) -> JaxprTracer:
aval = raise_to_shaped(aval)
tracer = self.builder.new_tracer(self, aval)
self.builder.tracer_to_var[id(tracer)] = Var(aval)
return tracer
def get_or_make_const_tracer(self, val: Any) -> JaxprTracer:
tracer = self.builder.const_tracers.get(id(val))
if tracer is None:
tracer = self.builder.new_tracer(self, raise_to_shaped(get_aval(val)))
self.builder.add_const(tracer, val)
return tracer
pure = lift = get_or_make_const_tracer
def process_primitive(self, primitive, tracers, params):
avals_in = [t.aval for t in tracers]
avals_out = abstract_eval_rules[primitive](*avals_in, **params)
out_tracers = [self.builder.new_tracer(self, a) for a in avals_out]
inputs = [self.builder.getvar(t) for t in tracers]
outvars = [self.builder.add_var(t) for t in out_tracers]
self.builder.add_eqn(JaxprEqn(primitive, inputs, params, outvars))
return out_tracers
@property
def builder(self):
return self.main.global_data
# NB: in JAX, we instead attach abstract eval rules to Primitive instances
abstract_eval_rules = {}
Notice that we keep as interpreter-global data a builder object, which keeps track of variables, constants, and eqns as we build up the jaxpr.
class JaxprBuilder:
eqns: list[JaxprEqn]
tracer_to_var: dict[int, Var]
const_tracers: dict[int, JaxprTracer]
constvals: dict[Var, Any]
tracers: list[JaxprTracer]
def __init__(self):
self.eqns = []
self.tracer_to_var = {}
self.const_tracers = {}
self.constvals = {}
self.tracers = []
def new_tracer(self, trace: JaxprTrace, aval: ShapedArray) -> JaxprTracer:
tracer = JaxprTracer(trace, aval)
self.tracers.append(tracer)
return tracer
def add_eqn(self, eqn: JaxprEqn) -> None:
self.eqns.append(eqn)
def add_var(self, tracer: JaxprTracer) -> Var:
assert id(tracer) not in self.tracer_to_var
var = self.tracer_to_var[id(tracer)] = Var(tracer.aval)
return var
def getvar(self, tracer: JaxprTracer) -> Var:
var = self.tracer_to_var.get(id(tracer))
assert var is not None
return var
def add_const(self, tracer: JaxprTracer, val: Any) -> Var:
var = self.add_var(tracer)
self.const_tracers[id(val)] = tracer
self.constvals[var] = val
return var
def build(self, in_tracers: list[JaxprTracer], out_tracers: list[JaxprTracer]
) -> tuple[Jaxpr, list[Any]]:
constvars, constvals = unzip2(self.constvals.items())
t2v = lambda t: self.tracer_to_var[id(t)]
in_binders = constvars + [t2v(t) for t in in_tracers]
out_vars = [t2v(t) for t in out_tracers]
jaxpr = Jaxpr(in_binders, self.eqns, out_vars)
typecheck_jaxpr(jaxpr)
jaxpr, constvals = _inline_literals(jaxpr, constvals)
return jaxpr, constvals
def _inline_literals(jaxpr: Jaxpr, consts: list[Any]) -> tuple[Jaxpr, list[Any]]:
const_binders, other_binders = split_list(jaxpr.in_binders, len(consts))
scalars = [type(x) in jax_types and not get_aval(x).shape for x in consts]
new_const_binders, lit_binders = partition_list(scalars, const_binders)
new_consts, lit_vals = partition_list(scalars, consts)
literals = dict(zip(lit_binders, map(Lit, lit_vals)))
new_eqns = [JaxprEqn(eqn.primitive, [literals.get(x, x) for x in eqn.inputs],
eqn.params, eqn.out_binders) for eqn in jaxpr.eqns]
new_outs = [literals.get(x, x) for x in jaxpr.outs]
new_jaxpr = Jaxpr(new_const_binders + other_binders, new_eqns, new_outs)
typecheck_jaxpr(new_jaxpr)
return new_jaxpr, new_consts
The rules we need for JaxprTrace.process_primitive
are essentially typing
rules for primitive applications: given the primitive, its parameters, and
types for the inputs, the rule must produce a type for the output, which is
then packaged with the output JaxprTracer
. We can use abstract evaluation
rules for this same purpose, even though they can be more general (since
abstract evaluation rules must accept ConcreteArray inputs, and since they
need only return an upper bound on the set of possible outputs, they can
produce ConcreteArray outputs as well). We’ll reuse these abstract evaluation
rules for the other jaxpr-producing trace machinery, where the potential extra
generality is useful.
def binop_abstract_eval(x: ShapedArray, y: ShapedArray) -> list[ShapedArray]:
if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
raise TypeError
if raise_to_shaped(x) != raise_to_shaped(y): raise TypeError
return [ShapedArray(x.shape, x.dtype)]
abstract_eval_rules[add_p] = binop_abstract_eval
abstract_eval_rules[mul_p] = binop_abstract_eval
def compare_abstract_eval(x: ShapedArray, y: ShapedArray) -> list[ShapedArray]:
if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
raise TypeError
if x.shape != y.shape: raise TypeError
return [ShapedArray(x.shape, np.dtype('bool'))]
abstract_eval_rules[greater_p] = compare_abstract_eval
abstract_eval_rules[less_p] = compare_abstract_eval
def vectorized_unop_abstract_eval(x: ShapedArray) -> list[ShapedArray]:
return [ShapedArray(x.shape, x.dtype)]
abstract_eval_rules[sin_p] = vectorized_unop_abstract_eval
abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval
abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval
def reduce_sum_abstract_eval(x: ShapedArray, *, axis: tuple[int, ...]
) -> list[ShapedArray]:
axis_ = set(axis)
new_shape = [d for i, d in enumerate(x.shape) if i not in axis_]
return [ShapedArray(tuple(new_shape), x.dtype)]
abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval
def broadcast_abstract_eval(x: ShapedArray, *, shape: Sequence[int],
axes: Sequence[int]) -> list[ShapedArray]:
return [ShapedArray(tuple(shape), x.dtype)]
abstract_eval_rules[broadcast_p] = broadcast_abstract_eval
To check our implementation of jaxprs, we can add a make_jaxpr
transformation and a pretty-printer:
from functools import lru_cache
@lru_cache() # ShapedArrays are hashable
def make_jaxpr_v1(f, *avals_in):
avals_in, in_tree = tree_flatten(avals_in)
f, out_tree = flatten_fun(f, in_tree)
builder = JaxprBuilder()
with new_main(JaxprTrace, builder) as main:
trace = JaxprTrace(main)
tracers_in = [trace.new_arg(aval) for aval in avals_in]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
jaxpr, consts = builder.build(tracers_in, tracers_out)
return jaxpr, consts, out_tree()
Show code cell source
from collections import defaultdict
import string
class PPrint:
lines: list[tuple[int, str]]
def __init__(self, lines):
self.lines = lines
def indent(self, indent: int) -> 'PPrint':
return PPrint([(indent + orig_indent, s) for orig_indent, s in self.lines])
def __add__(self, rhs: 'PPrint') -> 'PPrint':
return PPrint(self.lines + rhs.lines)
def __rshift__(self, rhs: 'PPrint') -> 'PPrint':
if not rhs.lines: return self
if not self.lines: return rhs
indent, s = self.lines[-1]
indented_block = rhs.indent(indent + len(s))
common_line = s + ' ' * rhs.lines[0][0] + rhs.lines[0][1]
return PPrint(self.lines[:-1]
+ [(indent, common_line)]
+ indented_block.lines[1:])
def __str__(self) -> str:
return '\n'.join(' ' * indent + s for indent, s in self.lines)
def pp(s: Any) -> PPrint:
return PPrint([(0, line) for line in str(s).splitlines()])
def vcat(ps: list[PPrint]) -> PPrint:
return sum(ps, pp(''))
def pp_jaxpr(jaxpr: Jaxpr) -> PPrint:
namegen = (''.join(s) for r in it.count(1)
for s in it.permutations(string.ascii_lowercase, r))
names = defaultdict(lambda: next(namegen))
in_binders = ', '.join(var_str(names, x) for x in jaxpr.in_binders)
eqns = vcat([pp_eqn(names, e) for e in jaxpr.eqns])
outs = ', '.join(names[v] if isinstance(v, Var) else str(v.val)
for v in jaxpr.outs)
return (pp(f'{{ lambda {in_binders} .') +
((pp('let ') >> eqns) + pp(f'in ( {outs} ) }}')).indent(2))
def var_str(names: defaultdict[Var, str], v: Var) -> str:
return f'{names[v]}:{v.aval.str_short()}'
def pp_eqn(names: defaultdict[Var, str], eqn: JaxprEqn) -> PPrint:
rule = pp_rules.get(eqn.primitive)
if rule:
return rule(names, eqn)
else:
lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >>
pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
for x in eqn.inputs)))
return lhs >> pp(' = ') >> rhs
def pp_params(params: dict[str, Any]) -> PPrint:
items = sorted(params.items())
if items:
return pp(' [ ') >> vcat([pp(f'{k}={v}') for k, v in items]) >> pp(' ] ')
else:
return pp(' ')
Jaxpr.__repr__ = lambda self: str(pp_jaxpr(self))
pp_rules: dict[Primitive, Callable[..., PPrint]] = {}
jaxpr, consts, _ = make_jaxpr_v1(lambda x: 2. * x, raise_to_shaped(get_aval(3.)))
print(jaxpr)
print(typecheck_jaxpr(jaxpr))
{ lambda a:float64[] .
let b:float64[] = mul 2.0 a
in ( b ) }
(float64[]) -> (float64[])
But there’s a limitation here: because of how find_top_trace
operates by
data dependence, make_jaxpr_v1
can’t stage out all the primitive operations
performed by the Python callable it’s given. For example:
jaxpr, consts, _ = make_jaxpr_v1(lambda: mul(2., 2.))
print(jaxpr)
{ lambda .
let
in ( 4.0 ) }
This is precisely the issue that
omnistaging fixed.
We want to ensure that the JaxprTrace
started by make_jaxpr
is always
applied, regardless of whether any inputs to bind
are boxed in corresponding
JaxprTracer
instances. We can achieve this by employing the dynamic_trace
global defined in Part 1:
@contextmanager
def new_dynamic(main: MainTrace):
global dynamic_trace
prev_dynamic_trace, dynamic_trace = dynamic_trace, main
try:
yield
finally:
dynamic_trace = prev_dynamic_trace
@lru_cache()
def make_jaxpr(f: Callable, *avals_in: ShapedArray,
) -> tuple[Jaxpr, list[Any], PyTreeDef]:
avals_in, in_tree = tree_flatten(avals_in)
f, out_tree = flatten_fun(f, in_tree)
builder = JaxprBuilder()
with new_main(JaxprTrace, builder) as main:
with new_dynamic(main):
trace = JaxprTrace(main)
tracers_in = [trace.new_arg(aval) for aval in avals_in]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
jaxpr, consts = builder.build(tracers_in, tracers_out)
return jaxpr, consts, out_tree()
jaxpr, consts, _ = make_jaxpr(lambda: mul(2., 2.))
print(jaxpr)
{ lambda .
let a:float64[] = mul 2.0 2.0
in ( a ) }
Using dynamic_trace
this way is conceptually the same as stashing the
current interpreter stack and starting a new one with the JaxprTrace
at the
bottom. That is, no interpreters lower in the stack than the dynamic_trace
are applied (since JaxprTrace.process_primitive
doesn’t call bind
), though
if the Python callable being traced to a jaxpr itself uses transformations
then those can be pushed onto the interpreter stack above the JaxprTrace
.
But temporarily stashing the interpreter stack would break up the system
state. The dynamic_trace
tag achieves the same goals while keeping the
system state simpler.
That’s it for jaxprs! With jaxprs in hand, we can implement the remaining major JAX features.
Part 3: jit
, simplified#
While jit
has a transformation-like API in that it accepts a Python callable
as an argument, under the hood it’s really a higher-order primitive rather
than a transformation. A primitive is higher-order when it’s parameterized
by a function.
On-the-fly (“final style”) and staged (“initial style”) processing#
There are two options for how to handle higher-order primitives. Each requires a different approach to tracing and engenders different tradeoffs:
On-the-fly processing, where
bind
takes a Python callable as an argument. We defer forming a jaxpr until as late as possible, namely until we’re running the final interpreter at the bottom of the interpreter stack. That way we can swap aJaxprTrace
in at the bottom of the interpreter stack and thus stage out rather than execute all primitive operations. With this approach, transformations in the stack get applied as we execute the Python callable as usual. This approach can be very tricky to implement, but it’s as general as possible because it allows higher-order primitives not to raise the abstraction level of their arguments and thus allows data-dependent Python control flow. We refer to this approach as using a “final-style higher-order primitive” employing the discharge-at-tracing-time “final-style transformations” we’ve used so far.Staged processing, where
bind
takes a jaxpr as an argument. Before we callbind
, in the primitive wrapper we can just usemake_jaxpr
to form a jaxpr up-front and be done with the Python callable entirely. In this case,make_jaxpr
puts itsJaxprTrace
at the top of the interpreter stack, and no transformations lower in the stack, which might enter via closed-over Tracers, are applied to the Python callable as we trace it. (Transformations applied within the Python callable are applied as usual, being added to the stack above the JaxprTrace.) Instead, the transformations lower in the stack are later applied to the call primitive, and the call primitive’s rules must then transform the jaxpr itself. Because we trace to a jaxpr up-front, this approach can’t support data-dependent Python control flow, but it is more straightforward to implement. We refer to this kind of higher-order primitive as an “initial-style higher-order primitive”, and say that its jaxpr-processing transformation rules are “initial-style transformation rules.”
The latter approach fits for jit
because we don’t need to support
data-dependent Python control flow in the user-provided Python callable, as
the whole purpose of jit
is to stage computation out of Python to be
executed by XLA. (In contrast, custom_jvp
is a higher-order primitive in
which we want to support data-dependent Python control flow.)
Historically, we started using the “initial-style” and “final-style” terminology after reading the typed tagless final interpreters paper, and jokingly referring to JAX as an implementation of “untyped tagful final interpreters.” We don’t claim to carry over (or understand) any deep meaning behind these terms; we loosely use “initial style” to mean “build an AST and then transform it”, and we use “final style” to mean “transform as we trace.” But it’s just imprecise yet sticky jargon.
With the initial-style approach, here’s the user-facing jit
wrapper:
def jit(f):
def f_jitted(*args):
avals_in = [raise_to_shaped(get_aval(x)) for x in args]
jaxpr, consts, out_tree = make_jaxpr(f, *avals_in)
outs = bind(xla_call_p, *consts, *args, jaxpr=jaxpr, num_consts=len(consts))
return tree_unflatten(out_tree, outs)
return f_jitted
xla_call_p = Primitive('xla_call')
With any new primitive, we need to give it transformation rules, starting with
its evaluation rule. When we evaluate an application of the xla_call
primitive, we want to stage out the computation to XLA. That involves
translating the jaxpr to an XLA HLO program, transferring the argument values
to the XLA device, executing the XLA program, and transferring back the
results. We’ll cache the XLA HLO compilation so that for each jit
ted
function it only needs to be performed once per argument shape and dtype
signature.
First, some utilities.
class IDHashable:
val: Any
def __init__(self, val):
self.val = val
def __hash__(self) -> int:
return id(self.val)
def __eq__(self, other):
return type(other) is IDHashable and id(self.val) == id(other.val)
Next, we’ll define the evaluation rule for xla_call
:
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
xe = xc._xla
xops = xc._xla.ops
def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int):
consts, args = args[:num_consts], args[num_consts:]
hashable_consts = tuple(map(IDHashable, consts))
execute = xla_callable(IDHashable(jaxpr), hashable_consts)
return execute(*args)
impl_rules[xla_call_p] = xla_call_impl
@lru_cache()
def xla_callable(hashable_jaxpr: IDHashable,
hashable_consts: tuple[IDHashable, ...]):
jaxpr: Jaxpr = hashable_jaxpr.val
typecheck_jaxpr(jaxpr)
consts = [x.val for x in hashable_consts]
in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]]
c = xc.XlaBuilder('xla_call')
xla_consts = _xla_consts(c, consts)
xla_params = _xla_params(c, in_avals)
outs = jaxpr_subcomp(c, jaxpr, xla_consts + xla_params)
out = xops.Tuple(c, outs)
compiled = xb.get_backend(None).compile(
xc._xla.mlir.xla_computation_to_mlir_module(c.build(out)))
return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs])
def _xla_consts(c: xe.XlaBuilder, consts: list[Any]) -> list[xe.XlaOp]:
unique_consts = {id(cnst): cnst for cnst in consts}
xla_consts = {
id_: xops.ConstantLiteral(c, cnst) for id_, cnst in unique_consts.items()}
return [xla_consts[id(cnst)] for cnst in consts]
def _xla_params(c: xe.XlaBuilder, avals_in: list[ShapedArray]) -> list[xe.XlaOp]:
return [xops.Parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)]
def _xla_shape(aval: ShapedArray) -> xe.Shape:
return xc.Shape.array_shape(xc.dtype_to_etype(aval.dtype), aval.shape)
The main action is in xla_callable
, which compiles a jaxpr into an XLA HLO
program using jaxpr_subcomp
, then returns a callable which executes the
compiled program:
def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: list[xe.XlaOp]
) -> list[xe.XlaOp]:
env: dict[Var, xe.XlaOp] = {}
def read(x: Atom) -> xe.XlaOp:
return env[x] if type(x) is Var else xops.Constant(c, np.asarray(x.val))
def write(v: Var, val: xe.XlaOp) -> None:
env[v] = val
map(write, jaxpr.in_binders, args)
for eqn in jaxpr.eqns:
in_avals = [x.aval for x in eqn.inputs]
in_vals = map(read, eqn.inputs)
rule = xla_translations[eqn.primitive]
out_vals = rule(c, in_avals, in_vals, **eqn.params)
map(write, eqn.out_binders, out_vals)
return map(read, jaxpr.outs)
def execute_compiled(compiled, out_avals, *args):
input_bufs = [input_handlers[type(x)](x) for x in args]
out_bufs = compiled.execute(input_bufs)
return [handle_result(aval, buf) for aval, buf in zip(out_avals, out_bufs)]
default_input_handler = xb.get_backend(None).buffer_from_pyval
input_handlers = {ty: default_input_handler for ty in
[bool, int, float, np.ndarray, np.float64, np.float32]}
def handle_result(aval: ShapedArray, buf):
del aval # Unused for now
return np.asarray(buf)
xla_translations = {}
Notice that jaxpr_subcomp
has the structure of a simple interpreter. That’s
a common pattern: the way we process jaxprs is usually with an interpreter.
And as with any interpreter, we need an interpretation rule for each
primitive:
def direct_translation(op, c, in_avals, in_vals):
del c, in_avals
return [op(*in_vals)]
xla_translations[add_p] = partial(direct_translation, xops.Add)
xla_translations[mul_p] = partial(direct_translation, xops.Mul)
xla_translations[neg_p] = partial(direct_translation, xops.Neg)
xla_translations[sin_p] = partial(direct_translation, xops.Sin)
xla_translations[cos_p] = partial(direct_translation, xops.Cos)
xla_translations[greater_p] = partial(direct_translation, xops.Gt)
xla_translations[less_p] = partial(direct_translation, xops.Lt)
def reduce_sum_translation(c, in_avals, in_vals, *, axis):
(x_aval,), (x,) = in_avals, in_vals
zero = xops.ConstantLiteral(c, np.array(0, x_aval.dtype))
subc = xc.XlaBuilder('add')
shape = _xla_shape(ShapedArray((), x_aval.dtype))
xops.Add(xops.Parameter(subc, 0, shape), xops.Parameter(subc, 1, shape))
return [xops.Reduce(c, [x], [zero], subc.build(), axis)]
xla_translations[reduce_sum_p] = reduce_sum_translation
def broadcast_translation(c, in_avals, in_vals, *, shape, axes):
x, = in_vals
dims_complement = [i for i in range(len(shape)) if i not in axes]
return [xops.BroadcastInDim(x, shape, dims_complement)]
xla_translations[broadcast_p] = broadcast_translation
With that, we can now use jit
to stage out, compile, and execute programs
with XLA!
@jit
def f(x, y):
print('tracing!')
return sin(x) * cos(y)
z = f(3., 4.) # 'tracing!' prints the first time
print(z)
tracing!
-0.09224219304455371
z = f(4., 5.) # 'tracing!' doesn't print, compilation cache hit!
print(z)
-0.21467624978306993
@jit
def f(x):
return reduce_sum(x, axis=0)
print(f(np.array([1., 2., 3.])))
6.0
def f(x):
y = sin(x) * 2.
z = - y + x
return z
def deriv(f):
return lambda x: jvp(f, (x,), (1.,))[1]
print( deriv(deriv(f))(3.))
print(jit(deriv(deriv(f)))(3.))
0.2822400161197344
0.2822400161197344
Instead of implementing jit
to first trace to a jaxpr and then to lower the
jaxpr to XLA HLO, it might appear that we could have skipped the jaxpr step
and just lowered to HLO while tracing. That is, perhaps we could have instead
implemented jit
with a Trace
and Tracer
that appended to the XLA HLO
graph incrementally on each primitive bind. That’s correct for now, but won’t
be possible when we introduce compiled SPMD computations because there we must
know the number of replicas needed before compiling the program.
We haven’t yet defined any transformation rules for xla_call_p
other than
its evaluation rule. That is, we can’t yet do vmap
-of-jit
or
jvp
-of-jit
or even jit
-of-jit
. Instead jit
has to be at the “top
level.” Let’s fix that!
def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):
del num_consts # Unused
new_jaxpr, new_consts = jvp_jaxpr(jaxpr)
outs = bind(xla_call_p, *new_consts, *primals, *tangents, jaxpr=new_jaxpr,
num_consts=len(new_consts))
n = len(outs) // 2
primals_out, tangents_out = outs[:n], outs[n:]
return primals_out, tangents_out
jvp_rules[xla_call_p] = xla_call_jvp_rule
@lru_cache()
def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]:
def jvp_traceable(*primals_and_tangents):
n = len(primals_and_tangents) // 2
primals, tangents = primals_and_tangents[:n], primals_and_tangents[n:]
return jvp(jaxpr_as_fun(jaxpr), primals, tangents)
in_avals = [v.aval for v in jaxpr.in_binders]
new_jaxpr, new_consts, _ = make_jaxpr(jvp_traceable, *in_avals, *in_avals)
return new_jaxpr, new_consts
def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):
del num_consts # Unused
new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, tuple(dims_in))
outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr,
num_consts=len(new_consts))
return outs, [0] * len(outs)
vmap_rules[xla_call_p] = xla_call_vmap_rule
@lru_cache()
def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...]
) -> tuple[Jaxpr, list[Any]]:
vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))
in_avals = [unmapped_aval(axis_size, d, v.aval)
for v, d in zip(jaxpr.in_binders, bdims_in)]
new_jaxpr, new_consts, _ = make_jaxpr(vmap_traceable, *in_avals)
return new_jaxpr, new_consts
def unmapped_aval(axis_size: int, batch_dim: BatchAxis, aval: ShapedArray
) -> ShapedArray:
if batch_dim is not_mapped:
return aval
else:
shape = list(aval.shape)
shape.insert(batch_dim, axis_size)
return ShapedArray(tuple(shape), aval.dtype)
def xla_call_abstract_eval_rule(*in_types, jaxpr, num_consts):
del num_consts # Unused
jaxpr_type = typecheck_jaxpr(jaxpr)
if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):
raise TypeError
return jaxpr_type.out_types
abstract_eval_rules[xla_call_p] = xla_call_abstract_eval_rule
def xla_call_translation(c, in_avals, in_vals, *, jaxpr, num_consts):
del num_consts # Only used at top-level.
# Calling jaxpr_subcomp directly would inline. We generate a Call HLO instead.
subc = xc.XlaBuilder('inner xla_call')
xla_params = _xla_params(subc, in_avals)
outs = jaxpr_subcomp(subc, jaxpr, xla_params)
subc = subc.build(xops.Tuple(subc, outs))
return destructure_tuple(c, xops.Call(c, subc, in_vals))
xla_translations[xla_call_p] = xla_call_translation
def destructure_tuple(c, tup):
num_elements = len(c.get_shape(tup).tuple_shapes())
return [xops.GetTupleElement(tup, i) for i in range(num_elements)]
@jit
def f(x):
print('tracing!')
y = sin(x) * 2.
z = - y + x
return z
x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
tracing!
2.7177599838802657
2.979984993200891
y, ydot = jvp(f, (x,), (xdot,)) # 'tracing!' not printed
ys = vmap(f, (0,))(np.arange(3.))
print(ys)
[ 0. -0.68294197 0.18140515]
One piece missing is device memory persistence for arrays. That is, we’ve
defined handle_result
to transfer results back to CPU memory as NumPy
arrays, but it’s often preferable to avoid transferring results just to
transfer them back for the next operation. We can do that by introducing an
Array
class, which can wrap XLA buffers and otherwise duck-type
numpy.ndarray
s:
def handle_result(aval: ShapedArray, buf): # noqa: F811
return Array(aval, buf)
class Array:
buf: Any
aval: ShapedArray
def __init__(self, aval, buf):
self.aval = aval
self.buf = buf
dtype = property(lambda self: self.aval.dtype)
shape = property(lambda self: self.aval.shape)
ndim = property(lambda self: self.aval.ndim)
def __array__(self): return np.asarray(self.buf)
def __repr__(self): return repr(np.asarray(self.buf))
def __str__(self): return str(np.asarray(self.buf))
_neg = staticmethod(neg)
_add = staticmethod(add)
_radd = staticmethod(add)
_mul = staticmethod(mul)
_rmul = staticmethod(mul)
_gt = staticmethod(greater)
_lt = staticmethod(less)
input_handlers[Array] = lambda x: x.buf
jax_types.add(Array)
@jit
def f(x):
y = sin(x) * 2.
z = - y + x
return z
x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
2.7177599838802657
2.979984993200891
Show code cell source
def pprint_xla_call(names: defaultdict[Var, str], eqn: JaxprEqn) -> PPrint:
lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
params_without_jaxpr = {k:v for k, v in eqn.params.items() if k != 'jaxpr'}
rhs = (pp(eqn.primitive.name) >> pp_params(params_without_jaxpr) >>
pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
for x in eqn.inputs)))
return vcat([lhs >> pp(' = ') >> rhs,
pp_jaxpr(eqn.params['jaxpr']).indent(2)])
pp_rules[xla_call_p] = pprint_xla_call
Part 4: linearize
and vjp
(and grad
!)#
The linearize
and vjp
autodiff functions are built on jvp
, but involve
jaxprs as well. That’s because both involve staging out, or delaying,
computation.
linearize
#
In the case of linearize
, we want to stage out the linear part of a jvp
computation. That is, in terms of
Haskell-like type signatures,
if we have jvp : (a -> b) -> (a, T a) -> (b, T b)
,
then we write linearize : (a -> b) -> a -> (b, T a -o T b)
, using T a
to
mean “the tangent type of a
” and using the “lollipop” -o
rather than the
arrow ->
to indicate a linear function. We define the semantics of
linearize
in terms of jvp
too:
y, f_lin = linearize(f, x)
y_dot = f_lin(x_dot)
gives the same result for (y, y_dot)
as
y, y_dot = jvp(f, (x,), (x_dot,))
where the application of f_lin
does not redo any of the linearization work.
We’ll represent the delayed linear part f_lin : T a -o T b
as a jaxpr.
Tangentially, now that we have linear arrows -o
, we can provide a slightly
more informative type for jvp
:
jvp : (a -> b) -> (UnrestrictedUse a, T a) -o (UnrestrictedUse b, T b)
Here we’re writing UnrestrictedUse
just to indicate that we have a special
pair where the first element can be used in an unrestricted (nonlinear) way.
In conjunction with the linear arrow, this notation is just meant to express
that the function jvp f
uses its first input in a nonlinear way but its
second input in a linear way, producing a corresponding nonlinear output
(which can be used in a nonlinear way) paired with a linear output. This more
refined type signature encodes the data dependencies in jvp f
, which are
useful for partial evaluation.
To build the f_lin
jaxpr from a JVP, we need to perform partial evaluation:
we evaluate all the primal values as we trace, but stage the tangent
computations into a jaxpr. This is our second way to build jaxprs. But where
make_jaxpr
and its underlying JaxprTrace
/JaxprTracer
interpreters aim
to stage out every primitive bind, this second approach stages out only those
primitive binds with a data dependence on tangent inputs.
First, some utilities:
def split_half(lst: list[Any]) -> tuple[list[Any], list[Any]]:
assert not len(lst) % 2
return split_list(lst, len(lst) // 2)
def merge_lists(which: list[bool], l1: list[Any], l2: list[Any]) -> list[Any]:
l1, l2 = iter(l1), iter(l2)
out = [next(l2) if b else next(l1) for b in which]
assert next(l1, None) is next(l2, None) is None
return out
Next, we’ll write linearize
by combining jvp
together with a general
partial evaluation transformation, to be added next:
def linearize_flat(f, *primals_in):
pvals_in = ([PartialVal.known(x) for x in primals_in] +
[PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])
def f_jvp(*primals_tangents_in):
primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))
return [*primals_out, *tangents_out]
jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in)
primal_pvals, _ = split_half(pvals_out)
assert all(pval.is_known for pval in primal_pvals)
primals_out = [pval.const for pval in primal_pvals]
f_lin = lambda *tangents: eval_jaxpr(jaxpr, [*consts, *tangents])
return primals_out, f_lin
def linearize(f, *primals_in):
primals_in_flat, in_tree = tree_flatten(primals_in)
f, out_tree = flatten_fun(f, in_tree)
primals_out_flat, f_lin_flat = linearize_flat(f, *primals_in_flat)
primals_out = tree_unflatten(out_tree(), primals_out_flat)
def f_lin(*tangents_in):
tangents_in_flat, in_tree2 = tree_flatten(tangents_in)
if in_tree != in_tree2: raise TypeError
tangents_out_flat = f_lin_flat(*tangents_in_flat)
return tree_unflatten(out_tree(), tangents_out_flat)
return primals_out, f_lin
def vspace(aval: ShapedArray) -> ShapedArray:
return raise_to_shaped(aval) # TODO handle integers?
Now we turn to the general partial evaluation transformation. The goal is to accept a Python callable and a list of inputs, some known and some unknown, and to produce (1) all the outputs which can be computed from the known inputs, together with (2) a jaxpr representing the part of the Python callable’s computation which can only be performed after the remaining inputs are known.
This transformation is tricky to summarize in a type signature. If we
assume the input function’s type signature is (a1, a2) -> (b1, b2)
, where
a1
and a2
represent the known and unknown inputs, respectively, and where
b1
only has a data dependency on a1
while b2
has some data dependency on
a2
, then we might write
partial_eval : ((a1, a2) -> (b1, b2)) -> a1 -> exists r. (b1, r, (r, a2) -> b2)
In words, given values for the inputs of type a1
, partial_eval
produces
the outputs of type b1
along with “residual” values of
existentially-quantified type r
representing the intermediates required to
complete the computation in the second stage. It also produces a function of
type (r, a2) -> b2
which accepts the residual values as well as the
remaining inputs and produces the remaining outputs.
We like to think of partial evaluation as “unzipping” one computation into two. For example, consider this jaxpr:
{ lambda a:float64[] .
let b:float64[] = sin a
c:float64[] = neg b
in ( c ) }
A jaxpr for the JVP would look like:
{ lambda a:float64[] b:float64[] .
let c:float64[] = sin a
d:float64[] = cos a
e:float64[] = mul d b
f:float64[] = neg c
g:float64[] = neg e
in ( f, g ) }
If we imagine applying partial evaluation to this jaxpr with the first input known and the second unknown, we end up ‘unzipping’ the JVP jaxpr into primal and tangent jaxprs:
{ lambda a:float64[] .
let c:float64[] = sin a
d:float64[] = cos a
f:float64[] = neg c
in ( f, d ) }
{ lambda d:float64[] b:float64[] .
let e:float64[] = mul d b
g:float64[] = neg e
in ( g ) }
This second jaxpr represents the linear computation that we want from
linearize
.
However, unlike in this jaxpr example, we want the computation on known values
to occur while evaluating the input Python callable. That is, rather than
forming a jaxpr for the entire function (a1, a2) -> (b1, b2)
, staging all
operations out of Python first before sorting out what can be evaluated now
and what must be delayed, we want only to form a jaxpr for those operations
that must be delayed due to a dependence on unknown inputs. In the context
of automatic differentiation, this is the feature that ultimately enables us
to handle functions like grad(lambda x: x**2 if x > 0 else 0.)
. Python
control flow works because partial evaluation keeps the primal computation in
Python. As a consequence, our Trace
and Tracer
subclasses must on the fly
sort out what can be evaluated and what must be staged out into a jaxpr.
First, we start with a PartialVal
class, which represents a value that can
be either known or unknown:
class PartialVal(NamedTuple):
aval: ShapedArray
const: Optional[Any]
@classmethod
def known(cls, val: Any):
return PartialVal(get_aval(val), val)
@classmethod
def unknown(cls, aval: ShapedArray):
return PartialVal(aval, None)
is_known = property(lambda self: self.const is not None)
is_unknown = property(lambda self: self.const is None)
Partial evaluation will take a list of PartialVal
s representing inputs, and
return a list of PartialVal
outputs along with a jaxpr representing the
delayed computation:
def partial_eval_flat(f: Callable, pvals_in: list[PartialVal]
) -> tuple[Jaxpr, list[PartialVal], list[Any]]:
with new_main(PartialEvalTrace) as main:
trace = PartialEvalTrace(main)
tracers_in = [trace.new_arg(pval) for pval in pvals_in]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
pvals_out = [t.pval for t in tracers_out]
unk_tracers_in = [t for t in tracers_in if t.pval.is_unknown]
unk_tracers_out = [t for t in tracers_out if t.pval.is_unknown]
jaxpr, consts = tracers_to_jaxpr(unk_tracers_in, unk_tracers_out)
return jaxpr, pvals_out, consts
Next we need to implement PartialEvalTrace
and its PartialEvalTracer
. This
interpreter will build a jaxpr on the fly while tracking data dependencies. To
do so, it builds a bipartite directed acyclic graph (DAG) between
PartialEvalTracer
nodes, representing staged-out values, and JaxprRecipe
nodes, representing formulas for how to compute some values from others. One
kind of recipe is a JaxprEqnRecipe
, corresponding to a JaxprEqn
’s
primitive application, but we also have recipe types for constants and lambda
binders:
from weakref import ref, ReferenceType
class LambdaBindingRecipe(NamedTuple):
pass
class ConstRecipe(NamedTuple):
val: Any
class JaxprEqnRecipe(NamedTuple):
prim: Primitive
tracers_in: list['PartialEvalTracer']
params: dict[str, Any]
avals_out: list[ShapedArray]
tracer_refs_out: list['ReferenceType[PartialEvalTracer]']
JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe]
class PartialEvalTracer(Tracer):
pval: PartialVal
recipe: Optional[JaxprRecipe]
def __init__(self, trace, pval, recipe):
self._trace = trace
self.pval = pval
self.recipe = recipe
aval = property(lambda self: self.pval.aval)
def full_lower(self):
if self.pval.is_known:
return full_lower(self.pval.const)
return self
The PartialEvalTrace
contains the logic for constructing the graph of
JaxprRecipe
s and PartialEvalTracer
s. Each argument corresponds to a
LambdaBindingRecipe
leaf node, and each constant is a ConstRecipe
leaf
node holding a reference to the constant. All other tracers and recipes come
from process_primitive
, which forms tracers with JaxprEqnRecipe
s.
For most primitives, the process_primitive
logic is straightforward: if all
inputs are known then we can bind the primitive on the known values
(evaluating it in Python) and avoid forming tracers corresponding to the
output. If instead any input is unknown then we instead stage out into a
JaxprEqnRecipe
representing the primitive application. To build the tracers
representing unknown outputs, we need avals, which we get from the abstract
eval rules. (Notice that tracers reference JaxprEqnRecipe
s, and
JaxprEqnRecipe
s reference tracers; we avoid circular garbage by using
weakrefs.)
That process_primitive
logic applies to most primitives, but xla_call_p
requires recursive treatment. So we special-case its rule in a
partial_eval_rules
dict.
class PartialEvalTrace(Trace):
def new_arg(self, pval: PartialVal) -> Any:
return PartialEvalTracer(self, pval, LambdaBindingRecipe())
def lift(self, val: Any) -> PartialEvalTracer:
return PartialEvalTracer(self, PartialVal.known(val), None)
pure = lift
def instantiate_const(self, tracer: PartialEvalTracer) -> PartialEvalTracer:
if tracer.pval.is_unknown:
return tracer
else:
pval = PartialVal.unknown(raise_to_shaped(tracer.aval))
return PartialEvalTracer(self, pval, ConstRecipe(tracer.pval.const))
def process_primitive(self, primitive, tracers, params):
if all(t.pval.is_known for t in tracers):
return bind(primitive, *map(full_lower, tracers), **params)
rule = partial_eval_rules.get(primitive)
if rule: return rule(self, tracers, **params)
tracers_in = [self.instantiate_const(t) for t in tracers]
avals_in = [t.aval for t in tracers_in]
avals_out = abstract_eval_rules[primitive](*avals_in, **params)
tracers_out = [PartialEvalTracer(self, PartialVal.unknown(aval), None)
for aval in avals_out]
eqn = JaxprEqnRecipe(primitive, tracers_in, params, avals_out,
map(ref, tracers_out))
for t in tracers_out: t.recipe = eqn
return tracers_out
partial_eval_rules = {}
Now that we can build graph representations of jaxprs with PartialEvalTrace
,
we need a mechanism to convert the graph representation to a standard jaxpr.
The jaxpr corresponds to a topological sort of the graph.
def tracers_to_jaxpr(tracers_in: list[PartialEvalTracer],
tracers_out: list[PartialEvalTracer]):
tracer_to_var: dict[int, Var] = {id(t): Var(raise_to_shaped(t.aval))
for t in tracers_in}
constvar_to_val: dict[int, Any] = {}
constid_to_var: dict[int, Var] = {}
processed_eqns: set[int] = set()
eqns: list[JaxprEqn] = []
for t in toposort(tracers_out, tracer_parents):
if isinstance(t.recipe, LambdaBindingRecipe):
assert id(t) in set(map(id, tracers_in))
elif isinstance(t.recipe, ConstRecipe):
val = t.recipe.val
var = constid_to_var.get(id(val))
if var is None:
aval = raise_to_shaped(get_aval(val))
var = constid_to_var[id(val)] = Var(aval)
constvar_to_val[var] = val
tracer_to_var[id(t)] = var
elif isinstance(t.recipe, JaxprEqnRecipe):
if id(t.recipe) not in processed_eqns:
eqns.append(recipe_to_eqn(tracer_to_var, t.recipe))
processed_eqns.add(id(t.recipe))
else:
raise TypeError(t.recipe)
constvars, constvals = unzip2(constvar_to_val.items())
in_binders = constvars + [tracer_to_var[id(t)] for t in tracers_in]
out_vars = [tracer_to_var[id(t)] for t in tracers_out]
jaxpr = Jaxpr(in_binders, eqns, out_vars)
typecheck_jaxpr(jaxpr)
return jaxpr, constvals
def recipe_to_eqn(tracer_to_var: dict[int, Var], recipe: JaxprEqnRecipe
) -> JaxprEqn:
inputs = [tracer_to_var[id(t)] for t in recipe.tracers_in]
out_binders = [Var(aval) for aval in recipe.avals_out]
for t_ref, var in zip(recipe.tracer_refs_out, out_binders):
if t_ref() is not None: tracer_to_var[id(t_ref())] = var
return JaxprEqn(recipe.prim, inputs, recipe.params, out_binders)
def tracer_parents(t: PartialEvalTracer) -> list[PartialEvalTracer]:
return t.recipe.tracers_in if isinstance(t.recipe, JaxprEqnRecipe) else []
Show code cell source
def toposort(out_nodes: list[Any], parents: Callable[[Any], list[Any]]):
if not out_nodes: return []
out_nodes = remove_duplicates(out_nodes)
child_counts = {}
stack = list(out_nodes)
while stack:
node = stack.pop()
if id(node) in child_counts:
child_counts[id(node)] += 1
else:
child_counts[id(node)] = 1
stack.extend(parents(node))
for node in out_nodes:
child_counts[id(node)] -= 1
sorted_nodes = []
childless_nodes = [node for node in out_nodes if not child_counts[id(node)]]
while childless_nodes:
node = childless_nodes.pop()
sorted_nodes.append(node)
for parent in parents(node):
if child_counts[id(parent)] == 1:
childless_nodes.append(parent)
else:
child_counts[id(parent)] -= 1
sorted_nodes = sorted_nodes[::-1]
check_toposort(sorted_nodes, parents)
return sorted_nodes
def remove_duplicates(lst):
seen = set()
return [x for x in lst if id(x) not in seen and not seen.add(id(x))]
def check_toposort(nodes: list[Any], parents: Callable[[Any], list[Any]]):
seen = set()
for node in nodes:
assert all(id(parent) in seen for parent in parents(node))
seen.add(id(node))
Now we can linearize!
y, sin_lin = linearize(sin, 3.)
print(y, sin(3.))
print(sin_lin(1.), cos(3.))
0.1411200080598672 0.1411200080598672
-0.9899924966004454 -0.9899924966004454
To handle linearize
-of-jit
, we still need to write a partial evaluation
rule for xla_call_p
. Other than tracer bookkeeping, the main task is to
perform partial evaluation of a jaxpr, ‘unzipping’ it into two jaxprs.
There are actually two rules to write: one for trace-time partial evaluation,
which we’ll call xla_call_partial_eval
, and one for partial evaluation of
jaxprs, which we’ll call xla_call_peval_eqn
.
def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
del num_consts # Unused
in_unknowns = [not t.pval.is_known for t in tracers]
jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns)
known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)
known_vals = [t.pval.const for t in known_tracers]
outs1_res = bind(xla_call_p, *known_vals, jaxpr=jaxpr1, num_consts=0)
outs1, res = split_list(outs1_res, len(jaxpr1.outs) - num_res)
res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]
outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None)
for v in jaxpr2.outs]
eqn = JaxprEqnRecipe(xla_call_p, res_tracers + unknown_tracers,
dict(jaxpr=jaxpr2, num_consts=0),
[v.aval for v in jaxpr2.outs], map(ref, outs2))
for t in outs2: t.recipe = eqn
return merge_lists(out_unknowns, outs1, outs2)
partial_eval_rules[xla_call_p] = xla_call_partial_eval
def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool],
instantiate: Optional[list[bool]] = None,
) -> tuple[Jaxpr, Jaxpr, list[bool], int]:
env: dict[Var, bool] = {}
residuals: set[Var] = set()
def read(x: Atom) -> bool:
return type(x) is Var and env[x]
def write(unk: bool, v: Var) -> None:
env[v] = unk
def new_res(x: Atom) -> Atom:
if type(x) is Var: residuals.add(x)
return x
eqns1, eqns2 = [], []
map(write, in_unknowns, jaxpr.in_binders)
for eqn in jaxpr.eqns:
unks_in = map(read, eqn.inputs)
rule = partial_eval_jaxpr_rules.get(eqn.primitive)
if rule:
eqn1, eqn2, unks_out, res = rule(unks_in, eqn)
eqns1.append(eqn1); eqns2.append(eqn2); residuals.update(res)
map(write, unks_out, eqn.out_binders)
elif any(unks_in):
inputs = [v if unk else new_res(v) for unk, v in zip(unks_in, eqn.inputs)]
eqns2.append(JaxprEqn(eqn.primitive, inputs, eqn.params, eqn.out_binders))
map(partial(write, True), eqn.out_binders)
else:
eqns1.append(eqn)
map(partial(write, False), eqn.out_binders)
out_unknowns = map(read, jaxpr.outs)
if instantiate is not None:
for v, uk, inst in zip(jaxpr.outs, out_unknowns, instantiate):
if inst and not uk: new_res(v)
out_unknowns = map(op.or_, out_unknowns, instantiate)
residuals, num_res = list(residuals), len(residuals)
assert all(type(v) is Var for v in residuals), residuals
ins1, ins2 = partition_list(in_unknowns, jaxpr.in_binders)
outs1, outs2 = partition_list(out_unknowns, jaxpr.outs)
jaxpr1 = Jaxpr(ins1, eqns1, outs1 + residuals)
jaxpr2 = Jaxpr(residuals + ins2, eqns2, outs2)
typecheck_partial_eval_jaxpr(jaxpr, in_unknowns, out_unknowns, jaxpr1, jaxpr2)
return jaxpr1, jaxpr2, out_unknowns, num_res
def typecheck_partial_eval_jaxpr(jaxpr, unks_in, unks_out, jaxpr1, jaxpr2):
jaxprty = typecheck_jaxpr(jaxpr) # (a1, a2) -> (b1, b2 )
jaxpr1ty = typecheck_jaxpr(jaxpr1) # a1 -> (b1, res)
jaxpr2ty = typecheck_jaxpr(jaxpr2) # (res, a2) -> b2
a1, a2 = partition_list(unks_in, jaxprty.in_types)
b1, b2 = partition_list(unks_out, jaxprty.out_types)
b1_, res = split_list(jaxpr1ty.out_types, len(b1))
res_, a2_ = split_list(jaxpr2ty.in_types, len(res))
b2_ = jaxpr2ty.out_types
if jaxpr1ty.in_types != a1: raise TypeError
if jaxpr2ty.out_types != b2: raise TypeError
if b1 != b1_: raise TypeError
if res != res_: raise TypeError
if a2 != a2_: raise TypeError
if b2 != b2_: raise TypeError
partial_eval_jaxpr_rules = {}
def xla_call_peval_eqn(unks_in: list[bool], eqn: JaxprEqn,
) -> tuple[JaxprEqn, JaxprEqn, list[bool], list[Var]]:
jaxpr = eqn.params['jaxpr']
jaxpr1, jaxpr2, unks_out, num_res = partial_eval_jaxpr(jaxpr, unks_in)
ins1, ins2 = partition_list(unks_in, eqn.inputs)
out_binders1, out_binders2 = partition_list(unks_out, eqn.out_binders)
residuals = [Var(v.aval) for v in jaxpr2.in_binders[:num_res]]
eqn1 = JaxprEqn(xla_call_p, ins1, dict(jaxpr=jaxpr1, num_consts=0),
out_binders1 + residuals)
eqn2 = JaxprEqn(xla_call_p, residuals + ins2,
dict(jaxpr=jaxpr2, num_consts=0), out_binders2)
return eqn1, eqn2, unks_out, residuals
partial_eval_jaxpr_rules[xla_call_p] = xla_call_peval_eqn
With that, we can compose linearize
and jit
however we like:
@jit
def f(x):
y = sin(x) * 2.
z = - y + x
return z
y, f_lin = linearize(f, 3.)
y_dot = f_lin(1.)
print(y, y_dot)
2.7177599838802657 2.979984993200891
@jit
def f(x):
y = sin(x) * 2.
z = g(x, y)
return z
@jit
def g(x, y):
return cos(x) + y
y, f_lin = linearize(f, 3.)
y_dot = f_lin(1.)
print(y, y_dot)
-0.7077524804807109 -2.121105001260758
vjp
and grad
#
The vjp
transformation works a lot like linearize. Its type signature is
analogous:
linearize : (a -> b) -> a -> (b, T a -o T b)
vjp : (a -> b) -> a -> (b, T b -o T a)
The only difference is that we transpose the linear part of the computation
before returning it, so that it goes from type T a -o T b
to type T b -o T a
. That is, we’ll implement vjp
as, essentially,
def vjp(f, x):
y, f_lin = linearize(f, x)
f_vjp = lambda y_bar: transpose(f_lin)(y_bar)
return y, f_vjp
Since we have the linear computation as a jaxpr, not just a Python callable, we can implement the transpose transformation as a jaxpr interpreter.
def vjp_flat(f, *primals_in):
pvals_in = ([PartialVal.known(x) for x in primals_in] +
[PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])
primal_pvals_in, tangent_pvals_in = split_half(pvals_in)
def f_jvp(*primals_tangents_in):
primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))
return [*primals_out, *tangents_out]
jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in) # linearize
primal_pvals, _ = split_half(pvals_out)
assert all(pval.is_known for pval in primal_pvals)
primals_out = [pval.const for pval in primal_pvals]
transpose_inputs = consts + [UndefPrimal(p.aval) for p in tangent_pvals_in]
f_vjp = lambda *cts: eval_jaxpr_transposed(jaxpr, transpose_inputs, cts)
return primals_out, f_vjp
def vjp(f, *primals_in):
primals_in_flat, in_tree = tree_flatten(primals_in)
f, out_tree = flatten_fun(f, in_tree)
primals_out_flat, f_vjp_flat = vjp_flat(f, *primals_in_flat)
primals_out = tree_unflatten(out_tree(), primals_out_flat)
def f_vjp(*cotangents_out):
cotangents_out_flat, _ = tree_flatten(cotangents_out)
cotangents_in_flat = f_vjp_flat(*cotangents_out_flat)
return tree_unflatten(in_tree, cotangents_in_flat)
return primals_out, f_vjp
class UndefPrimal(NamedTuple):
aval: ShapedArray
register_pytree_node(UndefPrimal,
lambda u: (u.aval, ()),
lambda aval, _: UndefPrimal(aval))
We use UndefPrimal
instances to indicate which arguments with respect to
which we want to transpose. These arise because in general, being explicit
about closed-over values, we want to transpose functions of type
a -> b -o c
to functions of type a -> c -o b
. Even more generally, the
inputs with respect to which the function is linear could be scattered through
the argument list. So we indicate the linear positions using UndefPrimal
.
We register UndefPrimal
as a pytree node because the pytree mechanism gives
a handy way to prune these placeholders out of argument lists.
Next, we can write eval_jaxpr_transposed
, along with transpose rules for
all primitives which can be linear in at least one argument:
# NB: the analogous function in JAX is called 'backward_pass'
def eval_jaxpr_transposed(jaxpr: Jaxpr, args: list[Any], cotangents: list[Any]
) -> list[Any]:
primal_env: dict[Var, Any] = {}
ct_env: dict[Var, Any] = {}
def read_primal(x: Atom) -> Any:
return primal_env.get(x, UndefPrimal(x.aval)) if type(x) is Var else x.val
def write_primal(v: Var, val: Any) -> None:
if type(val) is not UndefPrimal:
primal_env[v] = val
def read_cotangent(v: Var) -> Any:
return ct_env.pop(v, np.zeros(v.aval.shape, v.aval.dtype))
def write_cotangent(x: Atom, val: Any):
if type(x) is Var and val is not None:
ct_env[x] = add(ct_env[x], val) if x in ct_env else val
map(write_primal, jaxpr.in_binders, args)
map(write_cotangent, jaxpr.outs, cotangents)
for eqn in jaxpr.eqns[::-1]:
primals_in = map(read_primal, eqn.inputs)
cts_in = map(read_cotangent, eqn.out_binders)
rule = transpose_rules[eqn.primitive]
cts_out = rule(cts_in, *primals_in, **eqn.params)
map(write_cotangent, eqn.inputs, cts_out)
return [read_cotangent(v) for v, x in zip(jaxpr.in_binders, args)
if type(x) is UndefPrimal]
transpose_rules = {}
def mul_transpose_rule(cts, x, y):
z_bar, = cts
assert (type(x) is UndefPrimal) ^ (type(y) is UndefPrimal)
return [mul(z_bar, y), None] if type(x) is UndefPrimal else [None, mul(x, z_bar)]
transpose_rules[mul_p] = mul_transpose_rule
def neg_transpose_rule(cts, x):
ybar, = cts
assert type(x) is UndefPrimal
return [neg(ybar)]
transpose_rules[neg_p] = neg_transpose_rule
def add_transpose_rule(cts, x, y):
z_bar, = cts
return [z_bar, z_bar]
transpose_rules[add_p] = add_transpose_rule
def reduce_sum_transpose_rule(cts, x, *, axis):
y_bar, = cts
return [broadcast(y_bar, x.aval.shape, axis)]
transpose_rules[reduce_sum_p] = reduce_sum_transpose_rule
def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):
del num_consts # Unused
undef_primals = [type(x) is UndefPrimal for x in invals]
transposed_jaxpr, new_consts = transpose_jaxpr(jaxpr, tuple(undef_primals))
residuals, _ = partition_list(undef_primals, invals)
outs = bind(xla_call_p, *new_consts, *residuals, *cts,
jaxpr=transposed_jaxpr, num_consts=len(new_consts))
outs = iter(outs)
return [next(outs) if undef else None for undef in undef_primals]
transpose_rules[xla_call_p] = xla_call_transpose_rule
@lru_cache()
def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...]
) -> tuple[Jaxpr, list[Any]]:
avals_in, avals_out = typecheck_jaxpr(jaxpr)
traceable = partial(eval_jaxpr_transposed, jaxpr)
args = [UndefPrimal(a) if u else a for a, u in zip(avals_in, undef_primals)]
trans_jaxpr, consts, _ = make_jaxpr(traceable, tuple(args), tuple(avals_out))
typecheck_jaxpr(trans_jaxpr)
return trans_jaxpr, consts
Now that we can linearize and transpose, we can finally write grad
:
def grad(f):
def gradfun(x, *xs):
y, f_vjp = vjp(f, x, *xs)
if np.shape(y) != (): raise TypeError
x_bar, *_ = f_vjp(np.ones(np.shape(y), np.result_type(y)))
return x_bar
return gradfun
y, f_vjp = vjp(sin, 3.)
print(f_vjp(1.), cos(3.))
(-0.9899924966004454,) -0.9899924966004454
def f(x):
y = sin(x) * 2.
z = - y + x
return z
print(grad(f)(3.))
2.979984993200891
@jit
def f(x):
y = x * 2.
z = g(y)
return z
@jit
def g(x):
return cos(x) * 2.
print(grad(f)(3.))
1.1176619927957034
Here’s something of a compositionality stress test:
# from core_test.py fun_with_nested_calls_2
def foo(x):
@jit
def bar(y):
def baz(w):
q = jit(lambda x: y)(x)
q = q + jit(lambda: y)()
q = q + jit(lambda y: w + y)(y)
q = jit(lambda w: jit(sin)(x) * y)(1.0) + q
return q
p, t = jvp(baz, (x + 1.0,), (y,))
return t + (x * p)
return bar(x)
def assert_allclose(*vals):
for v1, v2 in zip(vals[:-1], vals[1:]):
np.testing.assert_allclose(v1, v2)
ans1 = f(3.)
ans2 = jit(f)(3.)
ans3, _ = jvp(f, (3.,), (5.,))
ans4, _ = jvp(jit(f), (3.,), (5.,))
assert_allclose(ans1, ans2, ans3, ans4)
deriv1 = grad(f)(3.)
deriv2 = grad(jit(f))(3.)
deriv3 = jit(grad(jit(f)))(3.)
_, deriv4 = jvp(f, (3.,), (1.,))
_, deriv5 = jvp(jit(f), (3.,), (1.,))
assert_allclose(deriv1, deriv2, deriv3, deriv4, deriv5)
hess1 = grad(grad(f))(3.)
hess2 = grad(grad(jit(f)))(3.)
hess3 = grad(jit(grad(f)))(3.)
hess4 = jit(grad(grad(f)))(3.)
_, hess5 = jvp(grad(f), (3.,), (1.,))
_, hess6 = jvp(jit(grad(f)), (3.,), (1.,))
_, hess7 = jvp(jit(grad(f)), (3.,), (1.,))
assert_allclose(hess1, hess2, hess3, hess4, hess5, hess6, hess7)
Part 5: the control flow primitives cond
#
Next we’ll add higher-order primitives for staged-out control flow. These
resemble jit
from Part 3, another higher-order primitive, but differ in that
they are parameterized by multiple callables rather than just one.
Adding cond
#
We introduce a cond
primitive to represent conditional application of one
function or another inside a jaxpr. We write the type of cond
as
Bool -> (a -> b) -> (a -> b) -> a -> b
. In words, cond
takes a boolean
representing the predicate and two functions of equal types. Depending on the
value of the predicate, it applies one function or the other to its final
argument.
In Python, we represent it as a function which itself takes two functions as
arguments. As with jit
, the first step is to call make_jaxpr
on its
callable arguments to turn them into jaxprs:
def cond(pred, true_fn, false_fn, *operands):
avals_in = [raise_to_shaped(get_aval(x)) for x in operands]
true_jaxpr, true_consts, out_tree = make_jaxpr(true_fn, *avals_in)
false_jaxpr, false_consts, out_tree_ = make_jaxpr(false_fn, *avals_in)
if out_tree != out_tree_: raise TypeError
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
if typecheck_jaxpr(true_jaxpr) != typecheck_jaxpr(false_jaxpr):
raise TypeError
outs = bind_cond(pred, *true_consts, *false_consts, *operands,
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
return tree_unflatten(out_tree, outs)
cond_p = Primitive('cond')
def _join_jaxpr_consts(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int
) -> tuple[Jaxpr, Jaxpr]:
jaxpr1_type, jaxpr2_type = typecheck_jaxpr(jaxpr1), typecheck_jaxpr(jaxpr2)
assert jaxpr1_type.in_types[n1:] == jaxpr2_type.in_types[n2:]
consts1, rest1 = split_list(jaxpr1.in_binders, n1)
consts2, rest2 = split_list(jaxpr2.in_binders, n2)
new_jaxpr1 = Jaxpr(consts1 + consts2 + rest1, jaxpr1.eqns, jaxpr1.outs)
new_jaxpr2 = Jaxpr(consts1 + consts2 + rest2, jaxpr2.eqns, jaxpr2.outs)
return new_jaxpr1, new_jaxpr2
def bind_cond(pred, *args, true_jaxpr, false_jaxpr):
assert len(args) == len(true_jaxpr.in_binders) == len(false_jaxpr.in_binders)
return bind(cond_p, pred, *args, true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
We require true_jaxpr
and false_jaxpr
to have the same type, but because
they might close over different constants (and because jaxprs can only
represent closed terms, i.e. can’t have free variables and are instead
closure-converted) we need to use the helper _join_jaxpr_consts
to make
consistent the input binder lists of the two jaxprs. (To be more economical we
could try to identify pairs of constants with the same shapes, but instead we
just concatenate the lists of constants.)
Next we can turn to adding interpreter rules for cond
. Its evaluation rule
is simple:
def cond_impl(pred, *operands, true_jaxpr, false_jaxpr):
if pred:
return eval_jaxpr(true_jaxpr, operands)
else:
return eval_jaxpr(false_jaxpr, operands)
impl_rules[cond_p] = cond_impl
out = cond(True, lambda: 3, lambda: 4)
print(out)
3
For its JVP and vmap rules, we only need to call the same jvp_jaxpr
and
vmap_jaxpr
utilities we created for jit
, followed by another pass of
_join_jaxpr_consts
:
def cond_jvp_rule(primals, tangents, *, true_jaxpr, false_jaxpr):
pred, *primals = primals
_ , *tangents = tangents
true_jaxpr , true_consts = jvp_jaxpr(true_jaxpr)
false_jaxpr, false_consts = jvp_jaxpr(false_jaxpr)
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
assert typecheck_jaxpr(true_jaxpr) == typecheck_jaxpr(false_jaxpr)
outs = bind_cond(pred, *true_consts, *false_consts, *primals, *tangents,
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
primals_out, tangents_out = split_half(outs)
return primals_out, tangents_out
jvp_rules[cond_p] = cond_jvp_rule
out, out_tan = jvp(lambda x: cond(True, lambda: x * x, lambda: 0.), (1.,), (1.,))
print(out_tan)
2.0
def cond_vmap_rule(axis_size, vals_in, dims_in, *, true_jaxpr, false_jaxpr):
pred , *vals_in = vals_in
pred_dim, *dims_in = dims_in
if pred_dim is not not_mapped: raise NotImplementedError # TODO
true_jaxpr, true_consts = vmap_jaxpr(true_jaxpr, axis_size, tuple(dims_in))
false_jaxpr, false_consts = vmap_jaxpr(false_jaxpr, axis_size, tuple(dims_in))
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
assert typecheck_jaxpr(true_jaxpr) == typecheck_jaxpr(false_jaxpr)
outs = bind_cond(pred, *true_consts, *false_consts, *vals_in,
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
return outs, [0] * len(outs)
vmap_rules[cond_p] = cond_vmap_rule
xs = np.array([1., 2., 3])
out = vmap(lambda x: cond(True, lambda: x + 1., lambda: 0.), (0,))(xs)
print(out)
[2. 3. 4.]
Notice that we’re not currently supporting the case where the predicate value
itself is batched. In mainline JAX, we handle this case by transforming the
conditional to a select primitive.
That transformation is semantically correct so long as true_fun
and
false_fun
do not involve any side-effecting primitives.
Another thing not represented here, but present in the mainline JAX, is that
applying transformations to two jaxprs of equal type might result in jaxprs of
different types. For example, applying the mainline JAX version of
vmap_jaxpr
to the identity-function jaxpr
{ lambda a:float32[] .
let
in ( a ) }
would result in a jaxpr with a batched output, of type
[float32[10]] -> [float32[10]]
if the batch size were 10, while applying it
to the zero-function jaxpr
{ lambda a:float32[] .
let
in ( 0. ) }
would result in a jaxpr with an unbatched output, of type
[float32[10]] -> [float32[]]
. This is an optimization, aimed at not batching
values unnecessarily. But it means that in cond
we’d need an extra step of
joining the two transformed jaxprs to have consistent output types. We don’t
need this step here because we chose vmap_jaxpr
always to batch all outputs
over the leading axis.
Next we can turn to abstract evaluation and XLA lowering rules:
def cond_abstract_eval(pred_type, *in_types, true_jaxpr, false_jaxpr):
if pred_type != ShapedArray((), np.dtype('bool')): raise TypeError
jaxpr_type = typecheck_jaxpr(true_jaxpr)
if jaxpr_type != typecheck_jaxpr(false_jaxpr):
raise TypeError
if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):
raise TypeError
return jaxpr_type.out_types
abstract_eval_rules[cond_p] = cond_abstract_eval
def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr):
del in_avals # Unused
pred, *in_vals = in_vals
flat_vals, in_tree = tree_flatten(in_vals)
operand = xops.Tuple(c, flat_vals)
operand_shape = c.get_shape(operand)
def make_comp(name: str, jaxpr: Jaxpr) -> xe.XlaComputation:
c = xc.XlaBuilder(name)
operand = xops.Parameter(c, 0, operand_shape)
operands = tree_unflatten(in_tree, destructure_tuple(c, operand))
outs = jaxpr_subcomp(c, jaxpr, operands)
return c.build(xops.Tuple(c, outs))
true_comp = make_comp('true_fn', true_jaxpr)
false_comp = make_comp('false_fn', false_jaxpr)
int_etype = xc.dtype_to_etype(np.dtype('int32'))
out = xops.Conditional(xops.ConvertElementType(pred, int_etype),
[false_comp, true_comp], [operand] * 2)
return destructure_tuple(c, out)
xla_translations[cond_p] = cond_translation
out = jit(lambda: cond(False, lambda: 1, lambda: 2))()
print(out)
2
Finally, to support reverse-mode automatic differentiation, we need partial
evaluation and transposition rules. For partial evaluation, we need to
introduce another jaxpr-munging utility, _join_jaxpr_res
, to handle the fact
that applying partial evaluation to true_fun
and false_fun
will in general
result in distinct residuals. We use _join_jaxpr_res
to make the output
types of the transformed jaxprs consistent (while _join_jaxpr_consts
dealt
with input types).
def cond_partial_eval(trace, tracers, *, true_jaxpr, false_jaxpr):
pred_tracer, *tracers = tracers
assert pred_tracer.pval.is_known
pred = pred_tracer.pval.const
in_uks = [not t.pval.is_known for t in tracers]
*jaxprs, out_uks, num_res = _cond_partial_eval(true_jaxpr, false_jaxpr, in_uks)
t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2 = jaxprs
known_tracers, unknown_tracers = partition_list(in_uks, tracers)
known_vals = [t.pval.const for t in known_tracers]
outs1_res = bind_cond(pred, *known_vals,
true_jaxpr=t_jaxpr1, false_jaxpr=f_jaxpr1)
outs1, res = split_list(outs1_res, len(outs1_res) - num_res)
pred_tracer_ = trace.instantiate_const(full_raise(trace, pred_tracer))
res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]
outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None)
for v in t_jaxpr2.outs]
eqn = JaxprEqnRecipe(cond_p, [pred_tracer_, *res_tracers, *unknown_tracers],
dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),
[v.aval for v in t_jaxpr2.outs], map(ref, outs2))
for t in outs2: t.recipe = eqn
return merge_lists(out_uks, outs1, outs2)
partial_eval_rules[cond_p] = cond_partial_eval
def _cond_partial_eval(true_jaxpr: Jaxpr, false_jaxpr: Jaxpr, in_uks: list[bool]
) -> tuple[Jaxpr, Jaxpr, Jaxpr, Jaxpr, list[bool], int]:
_, _, t_out_uks, _ = partial_eval_jaxpr(true_jaxpr , in_uks)
_, _, f_out_uks, _ = partial_eval_jaxpr(false_jaxpr, in_uks)
out_uks = map(op.or_, t_out_uks, f_out_uks)
t_jaxpr1, t_jaxpr2, _, t_nres = partial_eval_jaxpr(true_jaxpr , in_uks, out_uks)
f_jaxpr1, f_jaxpr2, _, f_nres = partial_eval_jaxpr(false_jaxpr, in_uks, out_uks)
t_jaxpr1, f_jaxpr1 = _join_jaxpr_res(t_jaxpr1, f_jaxpr1, t_nres, f_nres)
t_jaxpr2, f_jaxpr2 = _join_jaxpr_consts(t_jaxpr2, f_jaxpr2, t_nres, f_nres)
assert typecheck_jaxpr(t_jaxpr1) == typecheck_jaxpr(f_jaxpr1)
assert typecheck_jaxpr(t_jaxpr2) == typecheck_jaxpr(f_jaxpr2)
num_res = t_nres + f_nres
return t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2, out_uks, num_res
def _join_jaxpr_res(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int
) -> tuple[Jaxpr, Jaxpr]:
jaxpr1_type, jaxpr2_type = typecheck_jaxpr(jaxpr1), typecheck_jaxpr(jaxpr2)
out_types1, _ = split_list(jaxpr1_type.out_types, len(jaxpr1.outs) - n1)
out_types2, _ = split_list(jaxpr2_type.out_types, len(jaxpr2.outs) - n2)
assert out_types1 == out_types2
outs1, res1 = split_list(jaxpr1.outs, len(jaxpr1.outs) - n1)
outs2, res2 = split_list(jaxpr2.outs, len(jaxpr2.outs) - n2)
zeros_like1 = [Lit(np.zeros(v.aval.shape, v.aval.dtype)) for v in res1]
zeros_like2 = [Lit(np.zeros(v.aval.shape, v.aval.dtype)) for v in res2]
new_jaxpr1 = Jaxpr(jaxpr1.in_binders, jaxpr1.eqns, outs1 + res1 + zeros_like2)
new_jaxpr2 = Jaxpr(jaxpr2.in_binders, jaxpr2.eqns, outs2 + zeros_like1 + res2)
return new_jaxpr1, new_jaxpr2
_, f_lin = linearize(lambda x: cond(True, lambda: x, lambda: 0.), 1.)
out = f_lin(3.14)
print(out)
3.14
def cond_peval_eqn(unks_in: list[bool], eqn: JaxprEqn,
) -> tuple[JaxprEqn, JaxprEqn, list[bool], list[Atom]]:
pred_unk, *unks_in = unks_in
assert not pred_unk
true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']
*jaxprs, unks_out, num_res = _cond_partial_eval(true_jaxpr, false_jaxpr, unks_in)
t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2 = jaxprs
ins1, ins2 = partition_list(unks_in, eqn.inputs[1:])
outs1, outs2 = partition_list(unks_out, eqn.out_binders)
residuals, _ = split_list(t_jaxpr2.in_binders, num_res)
eqn1 = JaxprEqn(cond_p, [eqn.inputs[0], *ins1],
dict(true_jaxpr=t_jaxpr1, false_jaxpr=f_jaxpr1),
outs1 + residuals)
eqn2 = JaxprEqn(cond_p, [eqn.inputs[0], *residuals, *ins2],
dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),
outs2)
res = [eqn.inputs[0], *residuals] if type(eqn.inputs[0]) is Var else residuals
return eqn1, eqn2, unks_out, res
partial_eval_jaxpr_rules[cond_p] = cond_peval_eqn
_, f_lin = linearize(jit(lambda x: cond(True, lambda: x, lambda: 0.)), 1.)
out = f_lin(3.14)
print(out)
3.14
Transposition is a fairly straightforward application of transpose_jaxpr
:
def cond_transpose_rule(cts, pred, *invals, true_jaxpr, false_jaxpr):
undef_primals = tuple(type(x) is UndefPrimal for x in invals)
true_jaxpr, true_consts = transpose_jaxpr(true_jaxpr, undef_primals)
false_jaxpr, false_consts = transpose_jaxpr(false_jaxpr, undef_primals)
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
res = [x for x in invals if type(x) is not UndefPrimal]
outs = bind_cond(pred, *true_consts, *false_consts, *res, *cts,
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
outs = iter(outs)
return [None] + [next(outs) if type(x) is UndefPrimal else None for x in invals]
transpose_rules[cond_p] = cond_transpose_rule
out = grad(lambda x: cond(True, lambda: x * x, lambda: 0.))(1.)
print(out)
2.0
Show code cell source
def pprint_cond(names: defaultdict[Var, str], eqn: JaxprEqn) -> PPrint:
true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']
new_params = {k:v for k, v in eqn.params.items() if not k.endswith('jaxpr')}
lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
rhs = (pp(eqn.primitive.name) >> pp_params(new_params) >>
pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
for x in eqn.inputs)))
return vcat([lhs >> pp(' = ') >> rhs,
pp_jaxpr(true_jaxpr).indent(2),
pp_jaxpr(false_jaxpr).indent(2)])
pp_rules[cond_p] = pprint_cond
JAX Enhancement Proposals (JEPs)#
Most changes can be discussed with simple issues/discussions and pull requests.
Some changes though are a bit larger in scope or require more discussion, and these should be implemented as JEP. This allows for writing longer documents that can be discussed in a pull request themselves.
The structure of JEPs is kept as lightweight as possible to start and might be extended later on.
When you should use a JEP#
When your change requires a design doc. We prefer collecting the designs as JEPs for better discoverability and further reference.
When your change requires extensive discussion. It’s fine to have relatively short discussions on issues or pull requests, but when the discussion gets longer this becomes unpractical for later digestion. JEPs allow to update the main document with a summary of the discussion and these updates can be discussed themselves in the pull request adding the JEP.
How to start a JEP#
First, create an issue with the JEP label. All pull requests that relate to the JEP (i.e. adding the JEP itself as well as any implementing pull requests) should be linked to this issue.
Then create a pull request that adds a file named %d-{short-title}.md - with the number being the issue number.
JAX PRNG Design#
We want a PRNG design that
is expressive in that it is convenient to use and it doesn’t constrain the user’s ability to write numerical programs with exactly the behavior that they want,
enables reproducible program execution in a backend-independent way,
has semantics that are invariant to
@jit
compilation boundaries and device backends,enables vectorization for generating array values using SIMD hardware,
is parallelizable in that it doesn’t add sequencing constraints between random function calls that otherwise would have no data dependence,
scales to multi-replica, multi-core, and distributed computation,
fits with JAX and XLA semantics and design philosophies (which are ultimately motivated by other practical concerns).
As a corollary of these we believe the design should be functional. Another corollary is that, at least given current hardware constraints, we’re going to do the PRNG in software.
TLDR JAX PRNG = Threefry counter PRNG + a functional array-oriented splitting model
Contents#
Three programming models and toy example programs#
Here’s a toy example of a stateful global PRNG like the one often used in Numpy programs:
def foo(): return bar() + baz()
def bar(): return rand(RNG, (3, 4))
def baz(): return rand(RNG, (3, 4))
def main():
global RNG
RNG = RandomState(0)
return foo()
To achieve reproducibility here we would need to control the order of evaluation for bar() and baz() even though there is no explicit data dependence from one to the other. This kind of sequencing requirement stemming from reproducibility (#2) violates parallelizability (#5) and doesn’t fit with JAX or XLA’s functional semantics (#6) in which subexpressions can be evaluated in any order. Even if we didn’t require reproducibility and thus allowed any evaluation order, parallelization across calls (#5) would still be made difficult by the need to update shared state. Moreover, because the same PRNG state would need to be accessed and maintained in both Python and any compiled code, this model would likely lead to engineering challenges to achieve compilation invariance (#3) and scaling to multiple replicas (#6). Finally, the expressiveness is limited (#1) because there is no way for foo() to call bar() or baz() without affecting its own (implicit) PRNG state.
Whether the model supports vectorization (#4) depends on some additional details. In Numpy, PRNG vectorization is limited by a sequential-equivalent guarantee:
In [1]: rng = np.random.RandomState(0)
In [2]: rng.randn(2)
Out[2]: array([1.76405235, 0.40015721])
In [3]: rng = np.random.RandomState(0)
In [4]: np.stack([rng.randn() for _ in range(2)])
Out[4]: array([1.76405235, 0.40015721])
To allow for vectorization (#4) within primitive PRNG function calls that generate arrays (e.g. to rand() with a shape argument), we drop this sequential-equivalent guarantee. This vectorization can be supported by any of the three programming models discussed in this section, though it motivates the implementation in terms of a counter-based PRNG as described in the next section.
The stateful PRNG user programming model is not promising. Here’s an example of a functional model but lacking a key ingredient that we call splitting:
def foo(rng_1):
y, rng_2 = baz(rng_1)
z, rng_3 = bar(rng_2)
return y + z, rng_3
def bar(x, rng):
val, new_rng = rand(rng, (3, 4))
return val, new_rng
def baz(x, rng):
val, new_rng = rand(rng, (3, 4))
return val, new_rng
def main():
foo(RandomState(0))
This model explicitly threads the PRNG state through all functions (primitive or non-primitive) that generate random values: that is, every random function must both accept and return the state. Now there is an explicit data dependence between the call to baz() and the call to bar() in foo(), so the data flow (and hence sequencing) is made explicit and fits with JAX’s existing semantics (#7), unlike in the previous model. This explicit threading can also make the semantics invariant to compilation boundaries (#3).
Explicit threading is inconvenient for the programmer. But worse, it hasn’t actually improved the expressiveness (#1): there is still no way for foo() to call into bar() or baz() while maintaining its own PRNG state. Without knowledge of their callers or the subroutines they call, functions must defensively pass in and return the rng state everywhere. Moreover, it also doesn’t improve the prospects for parallelization (#5) or scaling to multiple replicas (#6) because everything is still sequential, even if the sequencing is made explicit in the functional programming sense.
In short, making the code functional by explicitly threading state isn’t enough to achieve our expressiveness (#1) and performance (#5, #6) goals.
The key problem in both the previous models is that there’s too much sequencing. To reduce the amount of sequential dependence we use functional splittable PRNGs. Splitting is a mechanism to ‘fork’ a new PRNG state into two PRNG states while maintaining the usual desirable PRNG properties (the two new streams are computationally parallelizable and produce independent random values, i.e. they behave like multistreams).
def foo(rng_1):
rng_2, rng_3 = split(rng_1, 2)
return bar(rng_2) + baz(rng_3)
def bar(x, rng):
return rand(rng, (3, 4))
def baz(x, rng):
return rand(rng, (3, 4))
def main():
foo(RandomState(0))
Some points to notice:
there is no sequential dependence between the calls to bar() and baz() and they can be evaluated in either order without affecting the value of the result, which solves the remaining performance goals (#5, #6),
functions do not need to return updated versions of PRNGs and it is straightforward to call a random subroutine without affecting existing PRNG states, improving the expressiveness (#1) from the other functional model.
The example doesn’t show it, but as a consequence of the choice (2) the only way to advance the PRNG state is to call split(). That is, we have two ways to achieve (1), and they differ in whether they burden the user program with explicit calls to split(), as in the above example, or instead burden the user program with explicit threading. We prefer the former, i.e. the version with explicit splitting, because we can easily implement the explicit-threading version in terms of it.
Design#
We can use the counter-based PRNG design, and in particular the Threefry hash function, as described in Parallel random numbers: as easy as 1, 2, 3. We use the counter to achieve efficient vectorization: for a given key we can generate an array of values in a vectorized fashion by mapping the hash function over a range of integers [k + 1, …, k + sample_size]. We use the key together with the hash function to implement splittable PRNGs: that is, splitting is a way to generate two new keys from an existing one.
type Sample = Int256
type Key = Sample -- important identification for splitting
type Count = Int32
hash :: Key -> Count -> Int256 -- output type equal to Key and Sample
split :: Key -> (Key, Key)
split key = (hash key 0, hash key 1)
draw_samples :: Key -> Int -> [Sample]
draw_samples key n = map (hash key) [1..n]
Surprisingly, drawing a sample is very similar to splitting! The key is the difference in the type of the output (even though the types are identified): in one case the value is to be used in forming random samples of interest (e.g. turning random bits into a Float representing a random normal) while in the other case the value is to be used as a key for further hashing.
The asymmetry in the hash function arguments, of type Key and Count, is that the latter is trivial and computationally cheap to advance by an arbitrary amount, since we just need to increase the integer value, while the former is only advanced by hashing. That’s why we use the count argument for vectorization.
More realistic example user programs#
Here’s what a training loop on the host might look like when the step requires a PRNG (maybe for dropout or for VAE training):
rng = lax.rng.new_rng()
for i in xrange(num_steps):
rng, rng_input = lax.rng.split(rng)
params = compiled_update(rng_input, params, next(batches))
Notice that we’re burdening the user with explicit splitting of the rng, but the rng does not need to be returned from the code at all.
Here’s how we can use this PRNG model with the stax neural net builder library to implement dropout:
def Dropout(rate, mode='train'):
def init_fun(input_shape):
return input_shape, ()
def apply_fun(rng, params, inputs):
if mode == 'train':
keep = lax.random.bernoulli(rng, rate, inputs.shape)
return np.where(keep, inputs / rate, 0)
else:
return inputs
return init_fun, apply_fun
The rng value here is just the key used for the hash, not a special object. The rng argument is passed to every apply_fun, and so it needs to be handled in the serial and parallel combinators with splitting:
def serial(*layers):
init_funs, apply_funs = zip(*layers)
def init_fun(input_shape):
...
def apply_fun(rng, params, inputs):
rngs = split(rng, len(layers))
for rng, param, apply_fun in zip(rngs, params, apply_funs):
inputs = apply_fun(rng, param, inputs)
return inputs
return init_fun, apply_fun
def parallel(*layers):
init_funs, apply_funs = zip(*layers)
def init_fun(input_shape):
...
def apply_fun(rng, params, inputs):
rngs = split(rng, len(layers))
return [f(r, p, x) for f, r, p, x in zip(apply_funs, rngs, params, inputs)]
return init_fun, apply_fun
Here we’re using a simple extended version of split that can produce multiple copies.
Tradeoffs and alternatives#
We’re not exploiting any device hardware PRNG
We don’t currently have enough control over the hardware PRNG’s state for all backends.
Even if we did, it would be backend-dependent and we might have to introduce sequential dependencies between random calls to ensure deterministic ordering and hence reproducibility.
We don’t know of any workloads for which the software PRNG should become a bottleneck.
We could consider providing an additional API that allows access to a hardware PRNG for users who want to give up other desiderata (like strict reproducibility).
We give up the sequential equivalent guarantee, in which creating a random array in one call produces the same values as creating the flattened array one random element at a time.
This property is likely incompatible with vectorization (a high priority).
We don’t know of any users or examples for which this property is important.
Users could write a layer on top of this API to provide this guarantee.
We can’t follow the
numpy.random
API exactly.
Custom JVP/VJP rules for JAX-transformable functions#
This is a design document, explaining some of the thinking behind the design and
implementation of jax.custom_jvp
and jax.custom_vjp
. For user-oriented
documentation, see the tutorial notebook.
There are two ways to define differentiation rules in JAX:
using
jax.custom_jvp
andjax.custom_vjp
to define custom differentiation rules for Python functions that are already JAX-transformable; anddefining new
core.Primitive
instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems.
This document is about #1 only.
Contents#
Goals#
We want users to customize the forward- and/or reverse-mode differentiation behavior of their code. This customization
should have a clear and consistent semantics in how it works and how it composes with other JAX transformations; and
should be flexible in supporting use cases and workflows like in Autograd and PyTorch, including cases involving differentiation of Python control flow and workflows for NaN debugging.
As JAX developers we want to write library functions, like
logit
and
expit
,
that are defined in terms of other primitives, but for the purposes of
differentiation have primitive-like behavior in the sense that we want to define
custom differentiation rules for them, which may be more numerically stable or
performant. In particular, we don’t want to have to specify vmap
or jit
rules for functions like logit
and expit
.
As a stretch goal, we’d like to make JAX a great environment for power users
looking to add custom differentiation rules for higher-order functions like
fixed_point
, odeint
, etc.; this design doc won’t solve that problem, but we
want to be confident we’re not going to preclude good solutions to that problem.
That is, our primary goals are
solve the vmap-removes-custom-jvp semantics problem (#1249), and
allow Python in custom VJPs, e.g. to debug NaNs (#1275).
Secondary goals are
3. clean up and simplify user experience (symbolic zeros, kwargs, etc)
4. make progress towards a world where users can easily add fixed_point
,
odeint
, root
, etc.
Overall, we want to close #116, #1097, #1249, #1275, #1366, #1723, #1670, #1875, #1938, and replace the custom_transforms machinery (from #636, #818, and others).
Non-goals#
Here are objectives we’re not aiming to achieve:
The
custom_transforms
machinery aimed to provide a transformation-generic mechanism for customizing behavior, in principle (though never really used in practice) allowing users to customize rules for any transformation while somehow inheriting the “transparent” behavior for others. We are instead only going to solve the customization problem for differentiation (JVP and VJP, separately). Differentiation is the only case actually requested, and by specializing to differentiation we can reduce complexity and improve flexibility. To control all rules one can just write a primitive.We’re not going to prioritize mathematical aesthetics over flexibility and clarity on the user side, and simplicity on the implementation side. In particular, while the custom VJP signature
a -> (b, CT b --o CT a)
is mathematically pleasing, if it’s hard to implement in a Python mechanism because of the closure in the return type, we’re fine doing something that handles residuals more explicitly.Serialization support, of the form where the staged-out serialized program representation can be loaded and further JAX-transformed as opposed to just evaluated, is currently out of scope for these custom JVP/VJP transformation rules. Serialization may be useful not only for researchers who want to save some representation of their computation (and transform it after loading it), but also for future considerations like having jaxpr transformations implemented outside Python, or having jaxprs as an MLIR dialect. By defining this as a non-goal for the purpose of this design, we have fewer constraints on where we can stash Python callables.
Main problem descriptions#
The vmap-removes-custom-jvp semantics problem#
The vmap-removes-custom-jvp semantics problem is that vmap does not compose
properly with differentiation of functions with custom_transforms
rules:
# old custom_transforms api to be replaced
@jax.custom_transforms
def f(x):
return 2. * x
# f_vjp :: a -> (b, CT b --o CT a)
def f_vjp(x):
return f(x), lambda g: 3. * x # 3 instead of 2
jax.defvjp_all(f, f_vjp)
grad(f)(1.) # 3.
vmap(grad(f))(np.ones(4)) # [3., 3., 3., 3.]
grad(lambda x: vmap(f)(x).sum())(np.ones(4)) # [2., 2., 2., 2.]
The last grad-of-vmap line has an unexpected result! In general, applying
vmap
, or really any non-differentiation transformation, has the effect of
removing the custom differentiation rule. (Applying jvp
causes a failure when
a custom VJP rule is defined.)
The problem exists because transformations are like rewrites, and the vmap
transformation effectively rewrites the function to no longer call the
newly-introduced primitive for which there is a custom rule (and hence grad
then doesn’t produce the custom rule’s result). In more detail, the
custom_transforms
machinery sets things up so that evaluating f(x)
applies
the function
{ lambda ; ; a.
let b = f_primitive a
in [b] }
where f_primitive
is a new primitive (introduced for every custom_transforms
function and in fact for every call of the function) to which the custom VJP
rule is associated. When we evaluate grad(f)(x)
, the differentiation machinery
encounters f_primitive
and processes it with the custom rule.
However, because f_primitive
is transparent to vmap
, in the sense that
vmap
operates on (effectively by inlining) the definition of f_primitive
,
the function vmap(f)
is effectively
{ lambda ; ; a.
let b = mul 2. a
in [b] }
In words, vmap
rewrites the function in terms of its underlying primitives and
their transformation rules, removing f_primitive
entirely.
More generally, because vmap(f)
has semantics defined in terms of calls to
f, it is semantically inconsistent to remove the custom derivative rule. That
is, since we define
vmap(f)(xs) == np.stack([f(x) for x in xs])
we must have
jvp(vmap(f))(xs) == jvp(lambda xs: np.stack([f(x) for x in xs]))
yet this property is not observed when f
has a custom derivative rule defined,
as the custom derivative rule is used in the right-hand version but not the
left-hand one.
This issue isn’t specific to vmap
; it applies to all transformations for which
the semantics of transforming a function f
are defined in terms of calls to
the function f
, rather than rewriting it into another function. The mask
transformation also falls into this class. Differentiation transforms and the
hypothetical all-unary-functions-become-cosine transform are not in this class.
(The interaction between additional custom rules, like custom vmap
rules, is
likely to get even more complex, suggesting the problem framing of
custom_transforms
is too broad.)
The Python flexibility problem#
In JAX, as in Autograd and PyTorch but not TF1, differentiation of a Python function is performed while the function is being executed and traced. This behavior delights users for a few reasons.
First and most importantly, it enables pdb-based workflows, e.g. for
inspecting numerics or catching NaNs. That is, users can employ the standard
Python debugger and other Python-native tools to debug their code, even being
able to inspect runtime values to understand numerical behavior on examples and
to catch fundamentally runtime errors like NaNs. In fact, just while working on
the PR corresponding to this design, especially on the odeint
primitive, I
used runtime value inspection to debug issues many times, increasing my
confidence that this is a key user workflow in Python. One especially handy
trick, which I’ve used in both JAX and Autograd many times, is the ability to
insert a debugger breakpoint in a custom VJP rule to enter a debugger at a
specific point in the backward pass.
Second, it allows differentiation of Python native control flow. We’re not sure how often this is used in practice in finalized software artifacts, but when users first poke around JAX or Autograd they’re often impressed by this freedom. There’s a reason we include it at the top of our JAX and Autograd READMEs, slide decks, and demos. Ceding this capability would be a step backward from Autograd. We want JAX to have the best automatic differentiation.
However, the custom_transforms
machinery does not provide this Python-support
flexibility. That is, because it’s implemented in terms of up-front jaxpr
formation from the Python code for both the user function and custom
differentiation rules, code like this leads to an abstract value tracing error:
# old custom_transforms api to be replaced
@jax.custom_transforms
def f(x):
if x > 0:
return x
else:
return 0.
def f_vjp(x):
return ...
jax.defvjp_all(f, f_vjp)
grad(f)(1.) # Error!
Solution idea#
The main idea is that dougalm@ already solved
these problems with core.call
. That is, we can frame the task of specifying
a custom JVP rule for a user function in terms of a new Python-level call
primitive (not to be added to the jaxpr language; see below). This new call
primitive has a user Python function associated with it just like core.call
,
but additionally has a second Python callable representing the JVP rule. Let’s
refer to this new call primitive as custom_jvp_call
.
Transformations like vmap
interact with custom_jvp_call
as with core.call
:
they effectively pass right through it and are applied to the underlying Python
callables. Schematically, writing in terms of curried versions of the primitives
for convenience, analogously to how vmap
interacts with core.call
by
applying to the function to be called:
vmap(call(f)) == call(vmap(f))
for the new primitive custom_jvp_call
we simply apply vmap
to the two
functions it entails:
vmap(custom_jvp_call(f, f_jvp)) == custom_jvp_call(vmap(f), vmap(f_jvp))
This behavior means we’ve solved the vmap-removes-custom-jvp semantics problem.
The jvp
transformation interacts as one might expect: it just calls f_jvp
,
jvp(call(f)) == call(jvp(f))
jvp(custom_jvp_call(f, f_jvp)) == f_jvp
Because custom_jvp_call
acts like core.call
(and not like xla.xla_call
) in
that it doesn’t raise the abstraction level of its inputs (because it’s not
delaying anything or staging anything out), it means we’ve solved the Python
flexibility problem: there are no constraints
on the user Python function (above the usual functional programming constraints
required by jvp
or vjp
).
What about evaluation and compilation? These are two ways to “exit” the JAX system, in the sense that no additional transformations can be applied after these steps. As a result, their rules are trivial:
eval(call(f)) == eval(f)
jit(call(f)) == hlo_call(jit(f))
eval(custom_jvp_call(f, f_jvp)) == eval(f)
jit(custom_jvp_call(f, f_jvp)) == hlo_call(jit(f))
In words, if a JVP rule hasn’t already rewritten custom_jvp_call(f, f_jvp)
into f_jvp
, when we get to the point of evaluation with eval
or staging out
to XLA with jit
, differentiation is never going to be applied, so we just
ignore f_jvp
and behave just like core.call
. However, due to the wrinkle
discussed next, the partial eval rule for custom_jvp_call
must be a bit more
complex, since partial evaluation isn’t just used to stage out to XLA with
jit
.
The only remaining wrinkle has to do with “initial-style” jaxpr-forming
primitives, like lax.scan
, and their transformation rules. These represent a
different kind of “staging out to a jaxpr” than that for compilation because we
can perform additional transformations on the staged-out jaxpr. That is, when
lax.scan
forms a jaxpr, it does not exit the transformation system, since when
we apply a jvp or vmap to a lax.scan
we need to apply it to the function
represented by the jaxpr.
Another way to state the wrinkle is that initial-style primitives like lax.scan
rely on the ability to round-trip to a jaxpr and back to a Python callable while
preserving semantics. That must mean preserving custom differentiation rule
semantics too.
The solution is to use a bit of dynamic scoping: when we’re staging out to a
jaxpr for an initial-style primitive, like those in lax_control_flow.py, we set
a bit on the global trace state. When that bit is set, instead of using the
final-style custom_jvp_call
primitive, we use an initial-style
custom_jvp_call_jaxpr
primitive, and trace the functions f
and f_jvp
to
jaxprs up-front to make initial-style processing easier. The
custom_jvp_call_jaxpr
primitive is otherwise similar to the final-style
version.
(Footnote: while morally we form jaxprs for both f
and f_jvp
before binding
custom_jvp_call_jaxpr
, we need to delay the formation of the jaxpr of f_jvp
because it may call the custom-JVP function and thus eager processing would lead
to an infinite recursion. We delay that jaxpr formation in a thunk.)
If we gave up on the Python flexibility
problem, we could get away with only having
custom_jvp_call_jaxpr
and not having the separate Python-level primitive
custom_jvp_call
.
API#
The custom JVP for an a -> b
function is specified with an (a, Ta) -> (b, T b)
function:
# f :: a -> b
@jax.custom_jvp
def f(x):
return np.sin(x)
# f_jvp :: (a, T a) -> (b, T b)
def f_jvp(primals, tangents):
x, = primals
t, = tangents
return f(x), np.cos(x) * t
f.defjvp(f_jvp)
(Interesting autodiff aside: for the rule to apply to higher-order
differentiation, one must call f
in the body of f_jvp
; that precludes some
kinds of work sharing between the internals of f
and the tangent calculation.)
The custom VJP for an a -> b
function is specified with an a -> (b, c)
forward
pass function paired with a (c, CT b) -> CT
a backward pass function:
# f :: a -> b
@jax.custom_vjp
def f(x):
return np.sin(x)
# f_fwd :: a -> (b, c)
def f_fwd(x):
return f(x), np.cos(x)
# f_bwd :: (c, CT b) -> CT a
def f_bwd(cos_x, g):
return (cos_x * g,)
f.defvjp(f_fwd, f_bwd)
The signature a -> (b, CT b --o CT a)
is more aesthetically pleasing, but
supporting it would make the implementation more complex and might require
compromising expressibility desiderata. The basic reason that Python callables
are opaque (unless we trace them to a jaxpr eagerly, which places expressiveness
constraints), and in this case we may be returning a callable with vmap
tracers
inside its closure that we need to know about during the forward pass.
We could add convenience wrappers, for example to define the JVP rule for a single argument at a time (like we do internally for primitives). But because this proposal is complicated enough as it is, I decided against convenience layers; let’s keep things minimal for now.
There are some other bells and whistles to the API:
Inputs and output types
a
,b
, andc
can be arbitrary pytrees of jaxtypes.Passing arguments by name (keyword arguments) is supported when they can be resolved to positions using the
inspect
module. This is a bit of an experiment with Python 3’s improved ability to programmatically inspect argument signatures. I believe it is sound but not complete, which is a fine place to be. (See also #2069.)Arguments can be marked non-differentiable using
nondiff_argnums
, and as withjit
’sstatic_argnums
these arguments don’t have to be JAX types. We need to set a convention for how these arguments are passed to the rules. For a primal function with type signature(d, a) -> b
whered
represents the non-differentiable type, the JVP rule’s signature is(a, T a, d) -> T b
and the VJP rule’s reverse component signature is(d, c, CT b) -> CT a
. That is, the non-differentiable arguments are passed in order afterprimals
andtangents
for a custom JVP rule, and passed in order preceding the residuals in a custom VJP rule’s reverse function.
Implementation notes#
Updated
jax.experimental.odeint
Since
odeint
is a pretty complex user of a custom VJP rule, in addition to just updating it to work at all, I wanted to revise it to be a canonical user of the new custom VJP API as a way to test that the API was a good one.Along the way I made other improvements to the
odeint
implementation:remove raveling/unraveling boilerplate
make use of
lax.scan
to remove the index-update logicspeed up by 20+% on the simple pendulum benchmark
Added a custom bind method on each transform for the custom derivative call primitives,
custom_jvp_call
andcustom_vjp_call
. It’s likecore.call_bind
, except we don’t process env traces: those are just errors.Added
custom_lin
primitive, which gets staged out into linear jaxprs to be transposed when using a custom VJP rule.Because our reverse-mode autodiff is decomposed into linearization, partial evaluation, and transposition, our custom VJP rules are processed in two separate steps: one during linearization and one during transposition.
The linearization step, i.e. the JVP rule for
custom_vjp_call
, appliescustom_lin
to the tangent values;custom_lin
carries with it the user’s custom backward-pass function, and as a primitive it only has a transpose rule.This mechanism is described more in #636.
To prevent
custom_vjp
and nondiff_argnums
update guide#
mattjj@ Oct 14 2020
This doc assumes familiarity with jax.custom_vjp
, as described in the Custom
derivative rules for JAX-transformable Python
functions
notebook.
What to update#
After JAX PR #4008, the arguments
passed into a custom_vjp
function’s nondiff_argnums
can’t be Tracer
s (or
containers of Tracer
s), which basically means to allow for
arbitrarily-transformable code nondiff_argnums
shouldn’t be used for
array-valued arguments. Instead, nondiff_argnums
should be used only for
non-array values, like Python callables or shape tuples or strings.
Wherever we used to use nondiff_argnums
for array values, we should just pass
those as regular arguments. In the bwd
rule, we need to produce values for them,
but we can just produce None
values to indicate there’s no corresponding
gradient value.
For example, here’s the old way to write clip_gradient
, which won’t work
when hi
and/or lo
are Tracer
s from some JAX transformation.
from functools import partial
import jax
@partial(jax.custom_vjp, nondiff_argnums=(0, 1))
def clip_gradient(lo, hi, x):
return x # identity function
def clip_gradient_fwd(lo, hi, x):
return x, None # no residual values to save
def clip_gradient_bwd(lo, hi, _, g):
return (jnp.clip(g, lo, hi),)
clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
Here’s the new, awesome way, which supports arbitrary transformations:
import jax
@jax.custom_vjp # no nondiff_argnums!
def clip_gradient(lo, hi, x):
return x # identity function
def clip_gradient_fwd(lo, hi, x):
return x, (lo, hi) # save lo and hi values as residuals
def clip_gradient_bwd(res, g):
lo, hi = res
return (None, None, jnp.clip(g, lo, hi)) # return None for lo and hi
clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
If you use the old way instead of the new way, you’ll get a loud error in any
case where something might go wrong (namely when there’s a Tracer
passed into
a nondiff_argnums
argument).
Here’s a case where we actually need nondiff_argnums
with custom_vjp
:
from functools import partial
import jax
@partial(jax.custom_vjp, nondiff_argnums=(0,))
def skip_app(f, x):
return f(x)
def skip_app_fwd(f, x):
return skip_app(f, x), None
def skip_app_bwd(f, _, g):
return (g,)
skip_app.defvjp(skip_app_fwd, skip_app_bwd)
Explanation#
Passing Tracer
s into nondiff_argnums
arguments was always buggy. While there
were some cases that worked correctly, others would lead to complex and
confusing error messages.
The essence of the bug was that nondiff_argnums
was implemented in a way that
acted very much like lexical closure. But lexical closure over Tracer
s wasn’t
at the time intended to work with custom_jvp
/custom_vjp
. Implementing
nondiff_argnums
that way was a mistake!
PR #4008 fixes all lexical closure
issues with custom_jvp
and custom_vjp
. Woohoo! That is, now custom_jvp
and custom_vjp
functions and rules can close over Tracer
s to our hearts’
content. For all non-autodiff transformations, things will Just Work. For
autodiff transformations, we’ll get a clear error message about why we can’t
differentiate with respect to values over which a custom_jvp
or custom_vjp
closes:
Detected differentiation of a custom_jvp function with respect to a closed-over value. That isn’t supported because the custom JVP rule only specifies how to differentiate the custom_jvp function with respect to explicit input parameters.
Try passing the closed-over value into the custom_jvp function as an argument, and adapting the custom_jvp rule.
In tightening up and robustifying custom_jvp
and custom_vjp
in this way, we
found that allowing custom_vjp
to accept Tracer
s in its nondiff_argnums
would take a significant amount of bookkeeping: we’d need to rewrite the user’s
fwd
function to return the values as residuals, and rewrite the user’s bwd
function to accept them as normal residuals (rather than accepting them as
special leading arguments, as happens with nondiff_argnums
). This seems maybe
manageable, until you think through how we have to handle arbitrary pytrees!
Moreover, that complexity isn’t necessary: if user code treats array-like
non-differentiable arguments just like regular arguments and residuals,
everything already works. (Before
#4039 JAX might’ve complained about
involving integer-valued inputs and outputs in autodiff, but after
#4039 those will just work!)
Unlike custom_vjp
, it was easy to make custom_jvp
work with
nondiff_argnums
arguments that were Tracer
s. So these updates only need to
happen with custom_vjp
.
Omnistaging#
mattjj@ Sept 25 2020
This is more of an upgrade guide than a design doc.
Contents#
tl;dr#
What’s going on?#
A change to JAX’s tracing infrastructure called “omnistaging” (google/jax#3370) was switched on in jax==0.2.0. This change improves memory performance, trace execution time, and simplifies jax internals, but may cause some existing code to break. Breakage is usually a result of buggy code, so long-term it’s best to fix the bugs, but omnistaging can also be disabled as a temporary workaround. And we’re happy to help you with fixes!
How do I know if omnistaging broke my code?#
The easiest way to tell if omnistaging is responsible is to disable omnistaging and see if the issues go away. See the What issues can arise when omnistaging is switched on? section below.
How can I disable omnistaging for now?#
Note: this applies to JAX versions 0.2.0 through 0.2.11; omnistaging cannot be disabled in JAX versions 0.2.12 and higher
It is temporarily possible to disable omnistaging by
setting the shell environment variable
JAX_OMNISTAGING
to something falsey;setting the boolean flag
jax_omnistaging
to something falsey if your code parses flags with absl;using this statement near the top of your main file:
jax.config.disable_omnistaging()
How do I fix bugs exposed by omnistaging?#
By far the most common issue with omnistaging is using jax.numpy
to compute
shape values or other trace-time constants. See the code block below for a quick
example, and for full details along with other issues see the section What
issues can arise when omnistaging is switched
on?.
Instead of this:
@jit
def f(x):
input_size = jnp.prod(x.shape)
if input_size > 100:
...
do this:
import numpy as np
@jit
def f(x):
input_size = np.prod(x.shape)
if input_size > 100:
...
Instead of thinking of jax.numpy
as a drop-in replacement for numpy
, it’s
now better to think of using jax.numpy
operations only when you want to perform a
computation on an accelerator (like your GPU).
What is “omnistaging” and why is it useful?#
Omnistaging is the name for a JAX core upgrade aimed at staging out more
computation from op-by-op Python to XLA, and avoiding any “trace-time constant
folding” in jit
, pmap
, and control flow primitives. As a result, omnistaging
improves JAX’s memory performance (sometimes dramatically) both by reducing
fragmentation during tracing and by producing fewer large compile-time constants
for XLA. It can also improve tracing performance by eliminating op-by-op
execution at tracing time. Further, omnistaging simplifies JAX core internals,
fixing many outstanding bugs and setting the stage for important upcoming
features.
The name “omnistaging” means staging out everything possible.
Toy example#
JAX transformations like jit
and pmap
stage out computations to XLA. That
is, we apply them to functions comprising multiple primitive operations so that
rather being executed one at a time from Python the operations are all part of
one end-to-end optimized XLA computation.
But exactly which operations get staged out? Until omnistaging, JAX staged out computation based on data dependence only. Here’s an example function, followed by the XLA HLO program it stages out before the omnistaging change:
from jax import jit
import jax.numpy as jnp
@jit
def f(x):
y = jnp.add(1, 1)
return x * y
f(3)
ENTRY jit_f.6 {
constant.2 = pred[] constant(false)
parameter.1 = s32[] parameter(0)
constant.3 = s32[] constant(2)
multiply.4 = s32[] multiply(parameter.1, constant.3)
ROOT tuple.5 = (s32[]) tuple(multiply.4)
}
Notice that the add
operation is not staged out. Instead, we only see a
multiply.
Here’s the HLO generated from this function after the omnistaging change:
ENTRY jit_f.8 {
constant.2 = pred[] constant(false)
parameter.1 = s32[] parameter(0)
constant.3 = s32[] constant(1)
constant.4 = s32[] constant(1)
add.5 = s32[] add(constant.3, constant.4)
multiply.6 = s32[] multiply(parameter.1, add.5)
ROOT tuple.7 = (s32[]) tuple(multiply.6)
}
Slightly less toy example#
Here’s a less toy example which can arise in practice when we want to create boolean masks:
import jax.numpy as jnp
from jax import lax
@jit
def select_tril(x):
mask = jnp.arange(x.shape[0])[:, None] > jnp.arange(x.shape[1])
return lax.select(mask, x, jnp.zeros_like(x)) # lax.select is like jnp.where
x = np.arange(12).reshape((3, 4))
select_tril(x)
Before omnistaging:
ENTRY jit_select_tril.8 {
constant.3 = pred[] constant(false)
constant.1 = pred[3,4]{1,0} constant({...})
parameter.2 = s32[3,4]{1,0} parameter(0)
constant.4 = s32[] constant(0)
broadcast.5 = s32[3,4]{1,0} broadcast(constant.4), dimensions={}
select.6 = s32[3,4]{1,0} select(constant.1, parameter.2, broadcast.5)
ROOT tuple.7 = (s32[3,4]{1,0}) tuple(select.6)
}
The select
operation is staged out, but the operations for constructing the
constant mask
are not. Rather than being staged out, the operations that
construct mask
are executed op-by-op at Python tracing time, and XLA only sees
a compile time constant constant.1
representing the value of mask
. That’s
unfortunate, because if we had staged out the operations for constructing
mask
, XLA could have fused them into the select
and avoided materializing
the result at all. As a result we end up wasting memory with a potentially-large
constant, wasting time dispatching multiple un-fused op-by-op XLA computations,
and potentially even fragmenting memory.
(The broadcast
that corresponds to the construction of the zeros array for
jnp.zeros_like(x)
is staged out because JAX is lazy about very simple
expressions from google/jax#1668. After
omnistaging, we can remove that lazy sublanguage and simplify JAX internals.)
The reason the creation of mask
is not staged out is that, before omnistaging,
jit
operates based on data dependence. That is, jit
stages out only those
operations in a function that have a data dependence on an argument. Control
flow primitives and pmap
behave similarly. In the case of select_tril
, the
operations to construct the constant mask
do not have a data dependence on the
argument x, so they are not staged out; only the lax.select
call has a data
dependence.
With omnistaging all jax.numpy
calls in the dynamic context of a
jit
-transformed function are staged out to XLA. That is, after omnistaging the
computation XLA sees for select_tril
is
ENTRY jit_select_tril.16 {
constant.4 = pred[] constant(false)
iota.1 = s32[3]{0} iota(), iota_dimension=0
broadcast.5 = s32[3,1]{1,0} broadcast(iota.1), dimensions={0}
reshape.7 = s32[3]{0} reshape(broadcast.5)
broadcast.8 = s32[3,4]{1,0} broadcast(reshape.7), dimensions={0}
iota.2 = s32[4]{0} iota(), iota_dimension=0
broadcast.6 = s32[1,4]{1,0} broadcast(iota.2), dimensions={1}
reshape.9 = s32[4]{0} reshape(broadcast.6)
broadcast.10 = s32[3,4]{1,0} broadcast(reshape.9), dimensions={1}
compare.11 = pred[3,4]{1,0} compare(broadcast.8, broadcast.10), direction=GT
parameter.3 = s32[3,4]{1,0} parameter(0)
constant.12 = s32[] constant(0)
broadcast.13 = s32[3,4]{1,0} broadcast(constant.12), dimensions={}
select.14 = s32[3,4]{1,0} select(compare.11, parameter.3, broadcast.13)
ROOT tuple.15 = (s32[3,4]{1,0}) tuple(select.14)
}
What issues can arise when omnistaging is switched on?#
As a consequence of staging out all jax.numpy
operations from Python to XLA
when in the dynamic context of a jit
or pmap
, some code that worked
previously can start raising loud errors. As explained below, these behaviors
were already buggy before omnistaging, but omnistaging makes them into hard
errors.
Using jax.numpy
for shape computations#
Example#
from jax import jit
import jax.numpy as jnp
@jit
def ex1(x):
size = jnp.prod(jnp.array(x.shape))
return x.reshape((size,))
ex1(jnp.ones((3, 4)))
Error message#
[... full traceback ...]
File "/home/mattjj/packages/jax/jax/core.py", line 862, in raise_concretization_error
raise ConcretizationTypeError(msg)
jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.
The error arose in jax.numpy.reshape.
While tracing the function ex1 at ex1.py:4, this value became a tracer due to JAX operations on these lines:
operation c:int32[] = reduce_prod[ axes=(0,) ] b:int32[2]
from line ex1.py:6 (ex1)
You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.
See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.
Encountered tracer value: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>
Explanation#
With omnistaging, we can’t use jax.numpy
for shape computations as in the use
of jnp.prod
above because in the dynamic context of a jit function those
operations will be staged out of Python as values to be computed at execution
time, yet we need them to be compile-time (and hence trace-time) constants.
Before omnistaging, this code wouldn’t have raised an error, but it was a common
performance bug: the jnp.prod
computation would have been executed on the
device at tracing time, meaning extra compilation, transfers, synchronization,
allocations, and potentially memory fragmentation.
Solution#
The solution is simply to use the original numpy
for shape calculations like
these. Not only do we avoid the error, but also we keep the computations on the
host (and with lower overheads).
This issue was common enough in code that we tried to make the error
message especially good. In addition to the stack trace showing where an
abstract tracer value caused a problem (the jnp.reshape
line in the full stack
trace, on omni.py:10), we also explain why this value became a tracer in the
first place by pointing to the upstream primitive operation that caused it to
become an abstract tracer (the reduce_prod
from jnp.prod
on omni.py:9) and to
which jit
-decorated function the tracer belongs (ex1
on omni.py:6).
Side-effects#
Example#
from jax import jit
from jax import random
key = random.PRNGKey(0)
def init():
global key
key, subkey = random.split(key)
return random.normal(subkey, ())
print(init()) # -1.2515389
print(init()) # -0.58665067
init = jit(init)
print(init()) # 0.48648298
print(init()) # 0.48648298 !!
That last call has repeated randomness but no hard error, because we aren’t
re-executing the Python. But if we look at key
, we see an escaped tracer when
omnistaging is on:
print(key) # Traced<ShapedArray(uint32[2])>with<DynamicJaxprTrace(level=0/1)>
Before omnistaging, the random.split
call would not be staged out and so we
wouldn’t get an escaped tracer. The code would still be buggy in that the jitted
function wouldn’t be reproducing the semantics of the original function (because
of the repeated use of the same PRNG key), ultimately due to the side effect.
With omnistaging on, if we touch key
again, we’ll get an escaped tracer error:
random.normal(key, ())
Error message#
[... full stack trace …]
File "/home/mattjj/packages/jax/jax/interpreters/partial_eval.py", line 836, in _assert_live
raise core.escaped_tracer_error(msg)
jax.core.UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state. Detail: tracer created on line example.py:8 (init).
Explanation#
The second largest category of omnistaging issues we found had to do with side-effecting code. This code already voided the JAX warranty by transforming effectful functions, but due to pre-omnistaging “trace-time constant folding” behavior, some side effecting functions could nevertheless behave correctly. Omnistaging catches more of these errors.
Solution#
The solution is to identify JAX-transformed functions that rely on side effects, and to rewrite them not to be effectful.
Small numerical differences based on XLA optimizations#
Because with omnistaging more computations are being staged out to XLA, rather than some being executed at trace time, that can have the effect of reordering floating point operations. As a result, we’ve seen numerical behaviors change in a way that causes tests with overly tight tolerances to fail when omnistaging is switched on.
Dependence on JAX internal APIs that changed#
Omnistaging involved some big revisions to JAX’s core code, including removing or changing internal functions. Any code that relies on such internal JAX APIs can break when omnistaging is switched on, either with build errors (from pytype) or runtime errors.
Triggering XLA compile time bugs#
Because omnistaging involves staging out more code to XLA, we’ve seen it trigger pre-existing XLA compile-time bugs on some backends. The best thing to do with these is to report them so we can work with the XLA teams on fixes.
JEP 9263: Typed keys & pluggable RNGs#
Jake VanderPlas, Roy Frostig
August 2023
Overview#
Going forward, RNG keys in JAX will be more type-safe and customizable.
Rather than representing a single PRNG key by a length-2 uint32
array,
it will be represented as a scalar array with a special RNG dtype that
satisfies jnp.issubdtype(key.dtype, jax.dtypes.prng_key)
.
For now, old-style RNG keys can still be created with
jax.random.PRNGKey()
:
>>> key = jax.random.PRNGKey(0)
>>> key
Array([0, 0], dtype=uint32)
>>> key.shape
(2,)
>>> key.dtype
dtype('uint32')
Starting now, new-style RNG keys can be created with
jax.random.key()
:
>>> key = jax.random.key(0)
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]
>>> key.shape
()
>>> key.dtype
key<fry>
This (scalar-shaped) array behaves the same as any other JAX array, except
that its element type is a key (and associated metadata). We can make
non-scalar key arrays as well, for example by applying jax.vmap()
to
jax.random.key()
:
>>> key_arr = jax.vmap(jax.random.key)(jnp.arange(4))
>>> key_arr
Array((4,), dtype=key<fry>) overlaying:
[[0 0]
[0 1]
[0 2]
[0 3]]
>>> key_arr.shape
(4,)
Aside from switching to a new constructor, most PRNG-related code should
continue to work as expected. You can continue to use keys in
jax.random
APIs as before; for example:
# split
new_key, subkey = jax.random.split(key)
# random number generation
data = jax.random.uniform(key, shape=(5,))
However, not all numerical operations work on key arrays. They now intentionally raise errors:
>>> key = key + 1
ValueError: dtype=key<fry> is not a valid dtype for JAX type promotion.
If for some reason you need to recover the underlying buffer
(the old-style key), you can do so with jax.random.key_data()
:
>>> jax.random.key_data(key)
Array([0, 0], dtype=uint32)
For old-style keys, key_data()
is an identity operation.
What does this mean for users?#
For JAX users, this change does not require any code changes now, but we hope that you will find the upgrade worthwhile and switch to using typed keys. To try this out, replace uses of jax.random.PRNGKey() with jax.random.key(). This may introduce breakages in your code that fall into one of a few categories:
If your code performs unsafe/unsupported operations on keys (such as indexing, arithmetic, transposition, etc; see Type Safety section below), this change will catch it. You can update your code to avoid such unsupported operations, or use
jax.random.key_data()
andjax.random.wrap_key_data()
to manipulate raw key buffers in an unsafe way.If your code includes explicit logic about
key.shape
, you may need to update this logic to account for the fact that the trailing key buffer dimension is no longer an explicit part of the shape.If your code includes explicit logic about
key.dtype
, you will need to upgrade it to use the new public APIs for reasoning about RNG dtypes, such asdtypes.issubdtype(dtype, dtypes.prng_key)
.If you call a JAX-based library which does not yet handle typed PRNG keys, you can use
raw_key = jax.random.key_data(key)
for now to recover the raw buffer, but please keep a TODO to remove this once the downstream library supports typed RNG keys.
At some point in the future, we plan to deprecate jax.random.PRNGKey()
and
require the use of jax.random.key()
.
Detecting new-style typed keys#
To check whether an object is a new-style typed PRNG key, you can use
jax.dtypes.issubdtype
or jax.numpy.issubdtype
:
>>> typed_key = jax.random.key(0)
>>> jax.dtypes.issubdtype(typed_key.dtype, jax.dtypes.prng_key)
True
>>> raw_key = jax.random.PRNGKey(0)
>>> jax.dtypes.issubdtype(raw_key.dtype, jax.dtypes.prng_key)
False
Type annotations for PRNG Keys#
The recommended type annotation for both old and new-style PRNG keys is jax.Array
.
A PRNG key is distinguished from other arrays based on its dtype
, and it is not
currently possible to specify dtypes of JAX arrays within a type annotation.
Previously it was possible to use jax.random.KeyArray
or jax.random.PRNGKeyArray
as type annotations, but these have always been aliased to Any
under type checking,
and so jax.Array
has much more specificity.
Note: jax.random.KeyArray
and jax.random.PRNGKeyArray
were deprecated in JAX
version 0.4.16, and removed in JAX version 0.4.24.
Motivation#
Two major motivating factors for this change are customizability and safety.
Customizing PRNG implementations#
JAX currently operates with a single, globally configured PRNG algorithm. A PRNG key is a vector of unsigned 32-bit integers, which jax.random APIs consume to produce pseudorandom streams. Any higher-rank uint32 array is interpreted as an array of such key buffers, where the trailing dimension represents keys.
The drawbacks of this design became clearer as we introduced alternative PRNG implementations, which must be selected by setting a global or local configuration flag. Different PRNG implementations have different size key buffers, and different algorithms for generating random bits. Determining this behavior with a global flag is error-prone, especially when there is more than one key implementation in use process-wide.
Our new approach is to carry the implementation as part of the PRNG key type, i.e. with the element type of the key array. Using the new key API, here is an example of generating pseudorandom values under the default threefry2x32 implementation (which is implemented in pure Python and compiled with JAX), and under the non-default rbg implementation (which corresponds to a single XLA random-bit generation operation):
>>> key = jax.random.key(0, impl='threefry2x32') # this is the default impl
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]
>>> jax.random.uniform(key, shape=(3,))
Array([0.9653214 , 0.31468165, 0.63302994], dtype=float32)
>>> key = jax.random.key(0, impl='rbg')
>>> key
Array((), dtype=key<rbg>) overlaying:
[0 0 0 0]
>>> jax.random.uniform(key, shape=(3,))
Array([0.39904642, 0.8805201 , 0.73571277], dtype=float32)
Safe PRNG key use#
PRNG keys are really only meant to support a few operations in principle, namely key derivation (e.g. splitting) and random number generation. The PRNG is designed to generate independent pseudorandom numbers, provided keys are properly split and that every key is consumed once.
Code that manipulates or consumes key data in other ways often indicates an accidental bug, and representing key arrays as raw uint32 buffers has allowed for easy misuse along these lines. Here are a few example misuses that we’ve encountered in the wild:
Key buffer indexing#
Access to the underlying integer buffers makes it easy to try and derive keys in non-standard ways, sometimes with unexpectedly bad consequences:
# Incorrect
key = random.PRNGKey(999)
new_key = random.PRNGKey(key[1]) # identical to the original key!
# Correct
key = random.PRNGKey(999)
key, new_key = random.split(key)
If this key were a new-style typed key made with random.key(999)
, indexing
into the key buffer would error instead.
Key arithmetic#
Key arithmetic is a similarly treacherous way to derive keys from other keys.
Deriving keys in a way that avoids jax.random.split()
or
jax.random.fold_in()
by manipulating key data directly produces a batch
of keys that—depending on the PRNG implementation—might then generate
correlated random numbers within the batch:
# Incorrect
key = random.PRNGKey(0)
batched_keys = key + jnp.arange(10, dtype=key.dtype)[:, None]
# Correct
key = random.PRNGKey(0)
batched_keys = random.split(key, 10)
New-style typed keys created with random.key(0)
address this by disallowing
arithmetic operations on keys.
Inadvertent transposing of key buffers#
With “raw” old-style key arrays, it’s easy to accidentally swap batch (leading) dimensions and key buffer (trailing) dimensions. Again this possibly results in keys that produce correlated pseudorandomness. A pattern that we’ve seen over time boils down to this:
# Incorrect
keys = random.split(random.PRNGKey(0))
data = jax.vmap(random.uniform, in_axes=1)(keys)
# Correct
keys = random.split(random.PRNGKey(0))
data = jax.vmap(random.uniform, in_axes=0)(keys)
The bug here is subtle. By mapping over in_axes=1
, this code makes new keys by
combining a single element from each key buffer in the batch. The resulting
keys are different from one another, but are effectively “derived” in a
non-standard way. Again, the PRNG is not designed or tested to produce
independent random streams from such a key batch.
New-style typed keys created with random.key(0)
address this by hiding the
buffer representation of individual keys, instead treating keys as opaque
elements of a key array. Key arrays have no trailing “buffer” dimension to
index, transpose, or map over.
Key reuse#
Unlike state-based PRNG APIs like numpy.random
, JAX’s functional PRNG
does not implicitly update a key when it has been used.
# Incorrect
key = random.PRNGKey(0)
x = random.uniform(key, (100,))
y = random.uniform(key, (100,)) # Identical values!
# Correct
key = random.PRNGKey(0)
key1, key2 = random.split(random.key(0))
x = random.uniform(key1, (100,))
y = random.uniform(key2, (100,))
We’re actively working on tools to detect and prevent unintended key reuse. This is still work in progress, but it relies on typed key arrays. Upgrading to typed keys now sets us up to introduce these safety features as we build them out.
Design of typed PRNG keys#
Typed PRNG keys are implemented as an instance of extended dtypes within JAX, of which the new PRNG dtypes are a sub-dtype.
Extended dtypes#
From the user perspective, an extended dtype dt has the following user-visible properties:
jax.dtypes.issubdtype(dt, jax.dtypes.extended)
returnsTrue
: this is the public API that should be used to detect whether a dtype is an extended dtype.It has a class-level attribute
dt.type
, which returns a typeclass in the hierarchy ofnumpy.generic
. This is analogous to hownp.dtype('int32').type
returnsnumpy.int32
, which is not a dtype but rather a scalar type, and a subclass ofnumpy.generic
.Unlike numpy scalar types, we do not allow instantiation of
dt.type
scalar objects: this is in accordance with JAX’s decision to represent scalar values as zero-dimensional arrays.
From a non-public implementation perspective, an extended dtype has the following properties:
Its type is a subclass of the private base class
jax._src.dtypes.ExtendedDtype
, the non-public base class used for extended dtypes. An instance ofExtendedDtype
is analogous to an instance ofnp.dtype
, likenp.dtype('int32')
.It has a private
_rules
attribute which allows the dtype to define how it behaves under particular operations. For example,jax.lax.full(shape, fill_value, dtype)
will delegate todtype._rules.full(shape, fill_value, dtype)
whendtype
is an extended dtype.
Why introduce extended dtypes in generality, beyond PRNGs? We reuse this same
extended dtype mechanism elsewhere internally. For example, the
jax._src.core.bint
object, a bounded integer type used for experimental work
on dynamic shapes, is another extended dtype. In recent JAX versions it satisfies
the properties above (See jax/_src/core.py#L1789-L1802).
PRNG dtypes#
PRNG dtypes are defined as a particular case of extended dtypes. Specifically, this change introduces a new public scalar type class jax.dtypes.prng_key, which has the following property:
>>> jax.dtypes.issubdtype(jax.dtypes.prng_key, jax.dtypes.extended)
True
PRNG key arrays then have a dtype with the following properties:
>>> key = jax.random.key(0)
>>> jax.dtypes.issubdtype(key.dtype, jax.dtypes.extended)
True
>>> jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key)
True
And in addition to key.dtype._rules
as outlined for extended dtypes in
general, PRNG dtypes define key.dtype._impl
, which contains the metadata
that defines the PRNG implementation. The PRNG implementation is currently
defined by the non-public jax._src.prng.PRNGImpl
class. For now, PRNGImpl
isn’t meant to be a public API, but we might revisit this soon to allow for
fully custom PRNG implementations.
Progress#
Following is a non-comprehensive list of key Pull Requests implementing the above design. The main tracking issue is #9263.
Implement pluggable PRNG via
PRNGImpl
: #6899Implement
PRNGKeyArray
, without dtype: #11952Add a “custom element” dtype property to
PRNGKeyArray
with_rules
attribute: #12167Rename “custom element type” to “opaque dtype”: #12170
Refactor
bint
to use the opaque dtype infrastructure: #12707Add
jax.random.key
to create typed keys directly: #16086Add
impl
argument tokey
andPRNGKey
: #16589Rename “opaque dtype” to “extended dtype” & define
jax.dtypes.extended
: #16824Introduce
jax.dtypes.prng_key
and unify PRNG dtype with Extended dtype: #16781Add a
jax_legacy_prng_key
flag to support warning or erroring when using legacy (raw) PRNG keys: #17225
Design of Type Promotion Semantics for JAX#
Jake VanderPlas, December 2021
One of the challenges faced in the design of any numerical computing library is the choice of how to handle operations between values of different types. This document outlines the thought process behind the promotion semantics used by JAX, summarized in JAX Type Promotion Semantics.
Goals of JAX Type Promotion#
JAX’s numerical computing API is modeled after that of NumPy, with a few enhancements including the ability to target accelerators like GPU and TPU. This makes adoption of NumPy’s type promotion system disadvantageous for JAX users: NumPy’s type promotion rules heavily favor 64-bit outputs, which is problematic for computation on accelerators. Devices such as GPUs and TPUs often pay a significant performance penalty to use 64-bit floating point types, and in some cases do not support native 64-bit floating point types at all.
A simple example of this problematic type promotion semantics can be seen in binary operations between 32-bit integers and floats:
import numpy as np
np.dtype(np.int32(1) + np.float32(1))
dtype('float64')
NumPy’s tendency to produce 64-bit values is a long-standing issue with using NumPy’s API for accelerator computations, for which there isn’t yet a good solution. For this reason, JAX has sought to re-think NumPy-style type promotion with accelerators in mind.
Stepping Back: Tables and Lattices#
Before we dive into the details, let’s take a moment to step back and think about how to think about the problem of type promotion. Consider arithmetic operations between built-in numerical types in Python, namely those of type int
, float
, and complex
. With a few lines of code we can generate the type promotion table used by Python for addition between values of these types:
import pandas as pd
types = [int, float, complex]
name = lambda t: t.__name__
pd.DataFrame([[name(type(t1(1) + t2(1))) for t1 in types] for t2 in types],
index=[name(t) for t in types], columns=[name(t) for t in types])
int | float | complex | |
---|---|---|---|
int | int | float | complex |
float | float | float | complex |
complex | complex | complex | complex |
This table enumerates Python’s numerical type promotion behavior, but it turns out there is a complementary representation that is much more compact: a Lattice representation, where the supremum between any two nodes is the type that they promote to. The lattice representation of Python’s promotion table is much simpler:
Show code cell source
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {'int': ['float'], 'float': ['complex']}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {'int': [0, 0], 'float': [1, 0], 'complex': [2, 0]}
fig, ax = plt.subplots(figsize=(8, 2))
nx.draw(graph, with_labels=True, node_size=4000, node_color='lightgray', pos=pos, ax=ax, arrowsize=20)

This lattice is a compact encoding of the information in the promotion table above. You can find the result of a type promotion for two inputs by tracing the graph to the first common child of the two nodes (including the nodes themselves); mathematically, this common child is known as the supremum, or least upper bound, or join of the pair on the lattice; here we will refer to this operation as the join.
Conceptually, an arrow means that implicit type promotion is allowed between the source and the destination: for example, implicit promotion from integer to float is allowed, but implicit promotion from float to integer is not.
Keep in mind that in general not every directed acyclic graph (DAG) will satisfy the properties of a lattice. A lattice requires the existence of a unique least upper bound for every pair of nodes; so, for example the following two DAGs are not lattices:
Show code cell source
#@title
import networkx as nx
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 2, figsize=(10, 2))
lattice = {'A': ['B', 'C']}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {'A': [0, 0], 'B': [1, 0.5], 'C': [1, -0.5]}
nx.draw(graph, with_labels=True, node_size=2000, node_color='lightgray', pos=pos, ax=ax[0], arrowsize=20)
ax[0].set(xlim=[-0.5, 1.5], ylim=[-1, 1])
lattice = {'A': ['C', 'D'], 'B': ['C', 'D']}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {'A': [0, 0.5], 'B': [0, -0.5], 'C': [1, 0.5], 'D': [1, -0.5]}
nx.draw(graph, with_labels=True, node_size=2000, node_color='lightgray', pos=pos, ax=ax[1], arrowsize=20)
ax[1].set(xlim=[-0.5, 1.5], ylim=[-1, 1]);

The left DAG is not a lattice because there exists no upper bound for nodes B
and C
; the right DAG fails on two counts: first, there exists no upper bound for nodes C
and D
, and for nodes A
and B
the least upper bound cannot be uniquely determined: both C
and D
are candidates, but they are unorderable.
Properties of a Type Promotion Lattice#
Specifying type promotions in terms of a lattice ensures a number of useful properties. Denoting the join on the lattice with the \(\vee\) operator, we have:
Existence: A lattice by definition requires that a unique lattice join exists for every pair of elements: \(\forall (a, b): \exists !(a \vee b)\)
Commutativity: The lattice join is commutative: \(\forall (a, b): a\vee b = b \vee a\).
Associativity: The lattice join is associative: \(\forall (a, b, c): a \vee (b \vee c) = (a \vee b) \vee c\).
On the other hand, these properties imply restrictions on the type promotion systems they can represent; in particular not every type promotion table can be represented by a lattice. A ready example of this is NumPy’s full type promotion table; this can be shown quickly by counterexample: here are three scalar types whose promotion behavior in NumPy is non-associative:
import numpy as np
a, b, c = np.int8(1), np.uint8(1), np.float16(1)
print(np.dtype((a + b) + c))
print(np.dtype(a + (b + c)))
float32
float16
Such a result may come as a surprise to users: we generally expect mathematical expressions to map to mathematical concepts, so, for example, a + b + c
should be equivalent to c + b + a
; x * (y + z)
should be equivalent to x * y + x * z
. If type promotion is non-associative or non-commutative, these properties no longer apply.
Further, a lattice-based type promotion system is simpler to conceptualize and understand when compared to a table-based system. For example, JAX recognizes 18 distinct types: a promotion lattice consisting of 18 nodes and sparse, well-motivated connections between them is far easier to hold in one’s mind than a table of 324 entries.
For this reason, we opt to use a lattice-based type promotion system for JAX.
Type Promotion within Categories#
Numerical computing libraries generally provide more than just int
, float
, and complex
; within each of these categories there are a variety of possible precisions, denoted by the number of bits used in the numerical representation. The categories we will consider here are:
unsigned integers which include
uint8
,uint16
,uint32
&uint64
(we’ll useu8
,u16
,u32
,u64
for short)signed integers which include
int8
,int16
,int32
&int64
(we’ll usei8
,i16
,i32
,i64
for short)floating point, which include
float16
,float32
&float64
(we’ll usef16
,f32
,f64
for short)complex floating point, which include
complex64
&complex128
(we’ll usec64
,c128
for short)
Numpy’s type promotion semantics within each of these four categories is relatively straightforward: the ordered hierarchy of types translates directly to four separate lattices representing in-category type promotion rules:
Show code cell source
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
'u8': ['u16'], 'u16': ['u32'], 'u32': ['u64'],
'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
'f16': ['f32'], 'f32': ['f64'],
'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
'u8': [0, 0], 'u16': [1, 0], 'u32': [2, 0], 'u64': [3, 0],
'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
'f16': [1, 2], 'f32': [2, 2], 'f64': [3, 2],
'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 4))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)

In terms of promotion of values to 64-bit that JAX seeks to avoid, these same-kind promotion semantics within each type category are unproblematic: the only way to produce a 64-bit output is to have a 64-bit input.
Enter Python Scalars#
Let’s now think about where Python scalars fit into the mix.
In NumPy, promotion behavior differs depending on whether the inputs are arrays or scalars. For example, when operating on two scalars, normal promotion rules apply:
x = np.int8(0) # int8 scalar
y = 1 # Python int = int64 scalar
(x + y).dtype
dtype('int64')
Here the Python value 1
is treated as an int64
, and straightforward within-category rules lead to an int64
result.
In operations between Python scalars and NumPy arrays, however, scalars defer to the dtype of the array. For example:
x = np.zeros(1, dtype='int8') # int8 array
y = 1 # Python int = int64 scalar
(x + y).dtype
dtype('int8')
Here the bit width of the int64
scalar is ignored, deferring to the bit width of the array.
There is another detail here: when NumPy type promotion involves a scalar, the output dtype is value-dependent: if the Python scalar is too large for the given dtype, it is promoted to a compatible type:
x = np.zeros(1, dtype='int8') # int8 array
y = 1000 # int64 scalar
(x + y).dtype
dtype('int16')
For the purposes of JAX, value-dependent promotion is a non-starter because of the nature of JIT compilation and other transformations, which act on abstract representations of data without reference to their value.
Ignoring value-dependent effects, the signed integer branch of NumPy’s type promotion can be represented in the following lattice, where we’ll use *
to mark scalar dtypes:
Show code cell source
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
'i8*': ['i16*'], 'i16*': ['i32*'], 'i32*': ['i64*'], 'i64*': ['i8'],
'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
'i8*': [0, 1], 'i16*': [2, 1], 'i32*': [4, 1], 'i64*': [6, 1],
'i8': [9, 1], 'i16': [11, 1], 'i32': [13, 1], 'i64': [15, 1],
}
fig, ax = plt.subplots(figsize=(12, 4))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
ax.text(3, 1.6, "Scalar Types", ha='center', fontsize=14)
ax.text(12, 1.6, "Array Types", ha='center', fontsize=14)
ax.set_ylim(-1, 3);

A similar pattern holds within the uint
, float
, and complex
lattices.
For the sake of simplicity, let’s collapse each category of scalar types into a single node, denoted by u*
, i*
, f*
, and c*
respectively. Our set of in-category lattices can now be represented like this:
Show code cell source
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
'u*': ['u8'], 'u8': ['u16'], 'u16': ['u32'], 'u32': ['u64'],
'i*': ['i8'], 'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
'f*': ['f16'], 'f16': ['f32'], 'f32': ['f64'],
'c*': ['c64'], 'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
'u*': [0, 0], 'u8': [3, 0], 'u16': [5, 0], 'u32': [7, 0], 'u64': [9, 0],
'i*': [0, 1], 'i8': [3, 1], 'i16': [5, 1], 'i32': [7, 1], 'i64': [9, 1],
'f*': [0, 2], 'f16': [5, 2], 'f32': [7, 2], 'f64': [9, 2],
'c*': [0, 3], 'c64': [7, 3], 'c128': [9, 3],
}
fig, ax = plt.subplots(figsize=(6, 4))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)

In some senses, putting scalars at the left is a strange choice: the scalar types may contain values of any width, but when interacting with an array of a given type, the promotion result defers to the array type.
The benefit of this is that when you perform an operation like x + 2
for an array x
, the type of x
will carry to the result no matter its width:
for dtype in [np.int8, np.int16, np.int32, np.int64]:
x = np.arange(10, dtype=dtype)
assert (x + 2).dtype == dtype
This behavior gives motivation to our *
notation for scalar values: the *
is reminiscent of a wildcard that can take on any desired value.
The benefit of these semantics is that you can readily express sequences of operations with clean Python code, without having to explicitly cast scalars to the appropriate type. Imagine if rather than writing this:
3 * (x + 1) ** 2
you had to write this:
np.int32(3) * (x + np.int32(1)) ** np.int32(2)
Although it is explicit, numerical code would become tedious to read or write. With the scalar promotion semantics described above, given an array x
of type int32
, the types in the second statement are implicit within the first.
Combining Lattices#
Recall that we began our discussion by introducing the lattice representing type promotion within Python: int -> float -> complex
. Let’s rewrite this as i* -> f* -> c*
, and let’s further allow i*
to subsume u*
(after all, there is no unsigned integer scalar type in Python).
Putting these all together, we get the following partial lattice representing type promotion between Python scalars and numpy arrays:
Show code cell source
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
'u8': ['u16'], 'u16': ['u32'], 'u32': ['u64'],
'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
'f16': ['f32'], 'f32': ['f64'],
'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],
'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)

Notice that this is not (yet) a true lattice: there are many pairs of nodes for which a join does not exist. However, we can think of this as a partial lattice, in which some pairs of nodes do not have a defined promotion behavior, and the defined portion of this partial lattice does correctly describe NumPy’s array promotion behavior (leaving aside value-dependent semantics mentioned above).
This sets up a nice framework by which we can think about filling-out these undefined promotion rules, by adding connections on this graph. But which connections to add? Broadly speaking, we want any additional connections to satisfy a few properties:
Promotion should satisfy the commutative and associative properties: in other words, the graph should remain a (partial) lattice.
Promotion should never allow for dropping entire components of data: for example, we should never promote
complex
tofloat
, as it would discard any imaginary parts.Promotion should never lead to an unhandled overflow. For example, the maximum possible
uint32
is twice as large as the maximum possibleint32
, so we should not implicitly promoteuint32
toint32
.Wherever possible, promotion should avoid loss of precision. For example, an
int64
value may have 64 bits of mantissa, so promotingint64
tofloat64
represents a possible loss of precision. However, the maximum representable float64 is larger than the maximum representable int64, so in this case criterion #3 is still satisfied.Wherever possible, binary promotion should avoid resulting in types that are wider than the inputs. This is to ensure that JAX’s implicit promotions remain friendly to accelerator-based workflows, in which users often want to restrict types to 32-bit (or in some cases 16-bit) values.
Each new connection on the lattice introduces some level of convenience to the user (a new set of types that can interact without explicit casting), but the convenience may become too costly if any of the above criteria are violated. Developing a full promotion lattice involves striking a balance between this convenience and this cost.
Mixed Promotion: Float and Complex#
Let’s begin with what is perhaps the easiest case, that of promotion between float and complex values.
Complex numbers are made up of pairs of floating point numbers, and so we have a natural path of promotion between them: cast float to complex while maintaining the width of the real part. In terms of our partial lattice representation, it would look like this:
Show code cell source
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
'u8': ['u16'], 'u16': ['u32'], 'u32': ['u64'],
'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],
'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)

This turns out to represent exactly the semantics used by Numpy in mixed float/complex type promotion.
Mixed Promotion: Signed & Unsigned Integers#
For the next case, let’s consider something a bit more difficult: promotion between signed and unsigned integers. For example, when promoting uint8
to a signed integer, how many bits do we need?
At first glance, you might think it natural to promote uint8
to int8
; but the largest uint8
numbers are not representable in int8
. For this reason, it makes more sense to promote unsigned integers to integers with twice the number of bits; this promotion behavior can be represented by adding the following connections to the promotion lattice:
Show code cell source
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],
'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],
'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)

Again, the connections added here are precisely the promotion semantics implemented by Numpy for mixed-integer promotion.
How to handle uint64
?#
The approach to mixed signed/unsigned integer promotion leaves out one type: uint64
. Following the pattern above, the output of a mixed-integer operation involving uint64
should result in int128
, but this is not a standard available dtype.
Numpy’s choice here is to promote to float64
:
(np.uint64(1) + np.int64(1)).dtype
dtype('float64')
However, this may be a surprising convention: it’s the only case in which promotion of integer types does not result in an integer.
For now, we will leave uint64
promotion undefined, and return to it later.
Mixed Promotion: Integer and Floating#
When promoting integers to floating point, we might start with the same thought process as mixed promotion between signed and unsigned integers. A 16-bit signed or unsigned integer cannot be represented at full precision by a 16-bit float, which has only 10 bits of mantissa. Therefore, it might make sense to promote integers to floats represented by twice the number of bits:
Show code cell source
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
'u8': ['u16', 'i16', 'f16'], 'u16': ['u32', 'i32', 'f32'], 'u32': ['u64', 'i64', 'f64'],
'i8': ['i16', 'f16'], 'i16': ['i32', 'f32'], 'i32': ['i64', 'f64'],
'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],
'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)

This is effectively what Numpy type promotion does, but in doing so it breaks the lattice property of the graph: for example, the pair {i8, u8} no longer has a unique least upper bound: the possibilities are i16 and f16, which are unorderable on the graph. This turns out to be the source of NumPy’s non-associative type promotion highlighted above.
Can we come up with a modification of NumPy’s promotion rules, such that it will satisfy the lattice property, while also giving sensible results for mixed type promotion? There are a few approaches we could take here.
Option 0: Leave integer/floating mixed precision undefined#
To make behavior utterly predictable (at some cost to user convenience), a defensible choice would be to leave as undefined any mixed integer/float promotion beyond Python scalars, stopping with the partial lattice from the previous section. The downside would be the requirement for users to explicitly type-cast when operating between integer and floating-point quantities.
Option 1: Avoiding All Precision Loss#
If our focus is on avoiding precision loss at all costs, we can restore the lattice property by promoting unsigned integers to float via their existing signed integer paths:
Show code cell source
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],
'i8': ['i16', 'f16'], 'i16': ['i32', 'f32'], 'i32': ['i64', 'f64'],
'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],
'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)

A disadvantage of this approach is that it still leaves int64
and uint64
promotion undefined, because there is no standard floating point type with enough bits of mantissa to represent their full range of values. We could relax the precision constraint and complete the lattice by drawing connections from i64->f64
and u64->f64
, but those links would run counter to the motivation for this promotion scheme.
A second disadvantage is that this lattice makes it difficult to find a sensible place to insert bfloat16
(see below) while maintaining the lattice property.
A third disadvantage of this approach, more important for JAX’s accelerator backends, is that some operations result in types that are much wider than necessary; for example mixed operations between uint16
and float16
would promote all the way to float64
, which is not ideal.
Option 2: Avoid most wider-than-necessary promotions#
To address the unnecessary promotions to wider types, we could accept the possibility of some precision loss in integer/float promotion, promoting signed integers to floats of the same width:
Show code cell source
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],
'i8': ['i16'], 'i16': ['f16', 'i32'], 'i32': ['f32', 'i64'], 'i64': ['f64'],
'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
'f16': [1.5, 2], 'f32': [2.5, 2], 'f64': [3.5, 2],
'c64': [3, 3], 'c128': [4, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)

While this does allow for precision-losing promotions between integers and floats, these promotions will not mis-represent the magnitude of the result: though the floating point mantissa is not wide enough to represent all values, the exponent is wide enough to approximate them.
This approach also allows a natural promotion path from int64
to float64
, though uint64
remains unpromotable in this scheme. That said, a connection from u64
to f64
could be justified more readily here than before.
This promotion scheme still results in some wider than necessary promotion paths; for example operations between float32
and uint32
result in float64
. Additionally, this lattice makes it difficult to find a sensible place to insert bfloat16
(see below) while maintaining the lattice property.
Option 3: Avoid all wider-than-necessary promotions#
We can avoid all non-ideal 64-bit promotions if we’re willing to fundamentally change our thinking around integer and float promotions. Just as scalars always defer to the widths of array types, we can make integers always defer to the width of float types:
Show code cell source
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
'i*': ['u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],
'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'], 'i64': ['f*'],
'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
'f16': [1.5, 2], 'f32': [2.5, 2], 'f64': [3.5, 2],
'c64': [3, 3], 'c128': [4, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)

This involves a small sleight of hand: previously we had used f*
to refer to a scalar type. In this lattice, f*
might be applied to the array output of a mixed computation. Instead of thinking of f*
as a scalar, we could think of it as a special kind of float
value with distinct promotion rules: in JAX we refer to this as a weak float; see below.
The advantage of this approach is that, outside unsigned ints, it avoids all wider-than-necessary promotions: you can never get an f64 output without a 64-bit input, and you can never get an f32 output without a 32-bit input: this results in convenient semantics for working on accelerators while avoiding inadvertent 64-bit values.
This feature of giving primacy to floating point types resembles the type promotion behavior of PyTorch. This lattice also happens to generate a promotion table that very closely resembles JAX’s original ad hoc type promotion scheme, which was not based on a lattice but had the property of giving primacy to floating point types.
This lattice additionally offers a natural location to insert bfloat16
, without the need to impose an ordering between bf16
and f16
:
Show code cell source
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
'i*': ['u8', 'i8'], 'f*': ['c*', 'f16', 'bf16'], 'c*': ['c64'],
'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],
'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'], 'i64': ['f*'],
'f16': ['f32'], 'bf16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
'f16': [1.8, 1.7], 'bf16': [1.8, 2.3], 'f32': [3.0, 2], 'f64': [4.0, 2],
'c64': [3.5, 3], 'c128': [4.5, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)

This is important because f16
and bf16
are not comparable because they utilize their bits differently: bf16
represents a larger range at lower precision, while f16
represents a smaller range at higher precision.
However, these advantages comes with a few tradeoffs:
mixed float/integer promotion is very prone to precision loss: for example,
int64
(with a maximum value of \(9.2 \times 10^{18}\)) can be promoted tofloat16
(with a maximum value of \(6.5 \times 10^4\)), meaning most representable values will becomeinf
.as mentioned above,
f*
can no longer be thought of as a “scalar type”, but as a different flavor of float64. In JAX’s parlance, this is referred to as a weak type, in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values.
Note that also, this approach still leaves the uint64
promotion question unanswered, although it is perhaps reasonable to close the lattice by connecting u64
to f*
.
Type Promotion in JAX#
In designing the type promotion semantics of JAX, we kept in mind many of these ideas, and leaned heavily on a few things:
We chose to constrain JAX’s type promotion semantics to graphs that satisfy the lattice property: this is to ensure associativity and commutativity, but also to allow the semantics to be compactly described in a DAG, rather than requiring a large table.
We leaned toward semantics that avoid inadvertent promotion to wider types, particularly when it comes to float values, in order to benefit computation on accelerators.
We were fine accepting potential loss of precision (but not loss of magnitude) in mixed type promotion if it were required to maintain (1) and (2)
With this in mind, JAX has adopted Option 3. Or rather, a slightly modified version of Option 3 that draws the connection between u64
and f*
, in order to create a true lattice.
Rearranging the nodes for clarity, JAX’s type promotion lattice then looks like this:
Show code cell source
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
'i*': ['u8', 'i8'], 'f*': ['c*', 'f16', 'bf16'], 'c*': ['c64'],
'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'], 'u64': ['f*'],
'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'], 'i64': ['f*'],
'f16': ['f32'], 'bf16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
'i*': [-1.25, 0.5], 'f*': [4.5, 0.5], 'c*': [5, 1.5],
'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
'f16': [5.75, 0.8], 'bf16': [5.75, 0.2], 'f32': [7, 0.5], 'f64': [8, 0.5],
'c64': [7.5, 1.5], 'c128': [8.5, 1.5],
}
fig, ax = plt.subplots(figsize=(10, 4))
ax.set_ylim(-0.5, 2)
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
# ax.patches[12].set_linestyle((0, (2, 4)))

The behavior resulting from this choice is summarized in JAX Type Promotion Semantics. Notably, aside from the inclusion of larger unsigned types (u16
, u32
, u64
) and some details about the behavior of scalar/weak types (i*
, f*
, c*
), this type promotion scheme turns out to be very close to that chosen by PyTorch.
For those interested, the appendix below prints the full promotion tables used by NumPy, Tensorflow, PyTorch, and JAX.
Appendix: Example Type Promotion Tables#
The following are some examples of implicit type promotion tables implemented by various Python array computing libraries.
NumPy Type Promotion#
Note that NumPy does not include the bfloat16
dtype, and that the table below ignores value-dependent effects.
Show code cell source
# @title
import numpy as np
import pandas as pd
from IPython import display
np_dtypes = {
'b': np.bool_,
'u8': np.uint8, 'u16': np.uint16, 'u32': np.uint32, 'u64': np.uint64,
'i8': np.int8, 'i16': np.int16, 'i32': np.int32, 'i64': np.int64,
'bf16': 'invalid', 'f16': np.float16, 'f32': np.float32, 'f64': np.float64,
'c64': np.complex64, 'c128': np.complex128,
'i*': int, 'f*': float, 'c*': complex}
np_dtype_to_code = {val: key for key, val in np_dtypes.items()}
def make_np_zero(dtype):
if dtype in {int, float, complex}:
return dtype(0)
else:
return np.zeros(1, dtype=dtype)
def np_result_code(dtype1, dtype2):
try:
out = np.add(make_np_zero(dtype1), make_np_zero(dtype2))
except TypeError:
return '-'
else:
if type(out) in {int, float, complex}:
return np_dtype_to_code[type(out)]
else:
return np_dtype_to_code[out.dtype.type]
grid = [[np_result_code(dtype1, dtype2)
for dtype2 in np_dtypes.values()]
for dtype1 in np_dtypes.values()]
table = pd.DataFrame(grid, index=np_dtypes.keys(), columns=np_dtypes.keys())
display.HTML(table.to_html())
b | u8 | u16 | u32 | u64 | i8 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i* | f* | c* | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
b | b | u8 | u16 | u32 | u64 | i8 | i16 | i32 | i64 | - | f16 | f32 | f64 | c64 | c128 | i64 | f64 | c128 |
u8 | u8 | u8 | u16 | u32 | u64 | i16 | i16 | i32 | i64 | - | f16 | f32 | f64 | c64 | c128 | u8 | f64 | c128 |
u16 | u16 | u16 | u16 | u32 | u64 | i32 | i32 | i32 | i64 | - | f32 | f32 | f64 | c64 | c128 | u16 | f64 | c128 |
u32 | u32 | u32 | u32 | u32 | u64 | i64 | i64 | i64 | i64 | - | f64 | f64 | f64 | c128 | c128 | u32 | f64 | c128 |
u64 | u64 | u64 | u64 | u64 | u64 | f64 | f64 | f64 | f64 | - | f64 | f64 | f64 | c128 | c128 | u64 | f64 | c128 |
i8 | i8 | i16 | i32 | i64 | f64 | i8 | i16 | i32 | i64 | - | f16 | f32 | f64 | c64 | c128 | i8 | f64 | c128 |
i16 | i16 | i16 | i32 | i64 | f64 | i16 | i16 | i32 | i64 | - | f32 | f32 | f64 | c64 | c128 | i16 | f64 | c128 |
i32 | i32 | i32 | i32 | i64 | f64 | i32 | i32 | i32 | i64 | - | f64 | f64 | f64 | c128 | c128 | i32 | f64 | c128 |
i64 | i64 | i64 | i64 | i64 | f64 | i64 | i64 | i64 | i64 | - | f64 | f64 | f64 | c128 | c128 | i64 | f64 | c128 |
bf16 | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - |
f16 | f16 | f16 | f32 | f64 | f64 | f16 | f32 | f64 | f64 | - | f16 | f32 | f64 | c64 | c128 | f16 | f16 | c64 |
f32 | f32 | f32 | f32 | f64 | f64 | f32 | f32 | f64 | f64 | - | f32 | f32 | f64 | c64 | c128 | f32 | f32 | c64 |
f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | - | f64 | f64 | f64 | c128 | c128 | f64 | f64 | c128 |
c64 | c64 | c64 | c64 | c128 | c128 | c64 | c64 | c128 | c128 | - | c64 | c64 | c128 | c64 | c128 | c64 | c64 | c64 |
c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | - | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 |
i* | i64 | u8 | u16 | u32 | u64 | i8 | i16 | i32 | i64 | - | f16 | f32 | f64 | c64 | c128 | i64 | f64 | c128 |
f* | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | - | f16 | f32 | f64 | c64 | c128 | f64 | f64 | c128 |
c* | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | - | c64 | c64 | c128 | c64 | c128 | c128 | c128 | c128 |
Tensorflow Type Promotion#
Tensorflow avoids defining implicit type promotion, except for Python scalars in limited cases. The table is asymmetric because in tf.add(x, y)
, the type of y
must be coercible to the type of x
.
Show code cell source
# @title
import tensorflow as tf
import pandas as pd
from IPython import display
tf_dtypes = {
'b': tf.bool,
'u8': tf.uint8, 'u16': tf.uint16, 'u32': tf.uint32, 'u64': tf.uint64,
'i8': tf.int8, 'i16': tf.int16, 'i32': tf.int32, 'i64': tf.int64,
'bf16': tf.bfloat16, 'f16': tf.float16, 'f32': tf.float32, 'f64': tf.float64,
'c64': tf.complex64, 'c128': tf.complex128,
'i*': int, 'f*': float, 'c*': complex}
tf_dtype_to_code = {val: key for key, val in tf_dtypes.items()}
def make_tf_zero(dtype):
if dtype in {int, float, complex}:
return dtype(0)
else:
return tf.zeros(1, dtype=dtype)
def result_code(dtype1, dtype2):
try:
out = tf.add(make_tf_zero(dtype1), make_tf_zero(dtype2))
except (TypeError, tf.errors.InvalidArgumentError):
return '-'
else:
if type(out) in {int, float, complex}:
return tf_dtype_to_code[type(out)]
else:
return tf_dtype_to_code[out.dtype]
grid = [[result_code(dtype1, dtype2)
for dtype2 in tf_dtypes.values()]
for dtype1 in tf_dtypes.values()]
table = pd.DataFrame(grid, index=tf_dtypes.keys(), columns=tf_dtypes.keys())
display.HTML(table.to_html())
b | u8 | u16 | u32 | u64 | i8 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i* | f* | c* | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
b | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - |
u8 | - | u8 | - | - | - | - | - | - | - | - | - | - | - | - | - | u8 | - | - |
u16 | - | - | u16 | - | - | - | - | - | - | - | - | - | - | - | - | u16 | - | - |
u32 | - | - | - | u32 | - | - | - | - | - | - | - | - | - | - | - | u32 | - | - |
u64 | - | - | - | - | u64 | - | - | - | - | - | - | - | - | - | - | u64 | - | - |
i8 | - | - | - | - | - | i8 | - | - | - | - | - | - | - | - | - | i8 | - | - |
i16 | - | - | - | - | - | - | i16 | - | - | - | - | - | - | - | - | i16 | - | - |
i32 | - | - | - | - | - | - | - | i32 | - | - | - | - | - | - | - | i32 | - | - |
i64 | - | - | - | - | - | - | - | - | i64 | - | - | - | - | - | - | i64 | - | - |
bf16 | - | - | - | - | - | - | - | - | - | bf16 | - | - | - | - | - | bf16 | bf16 | - |
f16 | - | - | - | - | - | - | - | - | - | - | f16 | - | - | - | - | f16 | f16 | - |
f32 | - | - | - | - | - | - | - | - | - | - | - | f32 | - | - | - | f32 | f32 | - |
f64 | - | - | - | - | - | - | - | - | - | - | - | - | f64 | - | - | f64 | f64 | - |
c64 | - | - | - | - | - | - | - | - | - | - | - | - | - | c64 | - | c64 | c64 | c64 |
c128 | - | - | - | - | - | - | - | - | - | - | - | - | - | - | c128 | c128 | c128 | c128 |
i* | - | - | - | - | - | - | - | i32 | - | - | - | - | - | - | - | i32 | - | - |
f* | - | - | - | - | - | - | - | - | - | - | - | f32 | - | - | - | f32 | f32 | - |
c* | - | - | - | - | - | - | - | - | - | - | - | - | - | - | c128 | c128 | c128 | c128 |
PyTorch Type Promotion#
Notice that torch does not include unsigned integer types larger than uint8
.
Aside from this and some details about promotion with scalar/weak types, the table is close to that used by jax.numpy
.
Show code cell source
# @title
import torch
import pandas as pd
from IPython import display
torch_dtypes = {
'b': torch.bool,
'u8': torch.uint8, 'u16': 'invalid', 'u32': 'invalid', 'u64': 'invalid',
'i8': torch.int8, 'i16': torch.int16, 'i32': torch.int32, 'i64': torch.int64,
'bf16': torch.bfloat16, 'f16': torch.float16, 'f32': torch.float32, 'f64': torch.float64,
'c64': torch.complex64, 'c128': torch.complex128,
'i*': int, 'f*': float, 'c*': complex}
torch_dtype_to_code = {val: key for key, val in torch_dtypes.items()}
def make_torch_zero(dtype):
if dtype in {int, float, complex}:
return dtype(0)
else:
return torch.zeros(1, dtype=dtype)
def torch_result_code(dtype1, dtype2):
try:
out = torch.add(make_torch_zero(dtype1), make_torch_zero(dtype2))
except TypeError:
return '-'
else:
if type(out) in {int, float, complex}:
return torch_dtype_to_code[type(out)]
else:
return torch_dtype_to_code[out.dtype]
grid = [[torch_result_code(dtype1, dtype2)
for dtype2 in torch_dtypes.values()]
for dtype1 in torch_dtypes.values()]
table = pd.DataFrame(grid, index=torch_dtypes.keys(), columns=torch_dtypes.keys())
display.HTML(table.to_html())
b | u8 | u16 | u32 | u64 | i8 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i* | f* | c* | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
b | b | u8 | - | - | - | i8 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i64 | f32 | c64 |
u8 | u8 | u8 | - | - | - | i16 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | u8 | f32 | c64 |
u16 | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - |
u32 | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - |
u64 | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - |
i8 | i8 | i16 | - | - | - | i8 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i8 | f32 | c64 |
i16 | i16 | i16 | - | - | - | i16 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i16 | f32 | c64 |
i32 | i32 | i32 | - | - | - | i32 | i32 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i32 | f32 | c64 |
i64 | i64 | i64 | - | - | - | i64 | i64 | i64 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i64 | f32 | c64 |
bf16 | bf16 | bf16 | - | - | - | bf16 | bf16 | bf16 | bf16 | bf16 | f32 | f32 | f64 | c64 | c128 | bf16 | bf16 | c64 |
f16 | f16 | f16 | - | - | - | f16 | f16 | f16 | f16 | f32 | f16 | f32 | f64 | c64 | c128 | f16 | f16 | c64 |
f32 | f32 | f32 | - | - | - | f32 | f32 | f32 | f32 | f32 | f32 | f32 | f64 | c64 | c128 | f32 | f32 | c64 |
f64 | f64 | f64 | - | - | - | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | c128 | c128 | f64 | f64 | c128 |
c64 | c64 | c64 | - | - | - | c64 | c64 | c64 | c64 | c64 | c64 | c64 | c128 | c64 | c128 | c64 | c64 | c64 |
c128 | c128 | c128 | - | - | - | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 |
i* | i64 | u8 | - | - | - | i8 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i64 | f32 | c64 |
f* | f32 | f32 | - | - | - | f32 | f32 | f32 | f32 | bf16 | f16 | f32 | f64 | c64 | c128 | f32 | f64 | c64 |
c* | c64 | c64 | - | - | - | c64 | c64 | c64 | c64 | c64 | c64 | c64 | c128 | c64 | c128 | c64 | c64 | c128 |
JAX Type Promotion: jax.numpy
#
jax.numpy
follows type promotion rules laid out at https://jax.readthedocs.io/en/latest/type_promotion.html. Here we use i*
, f*
, c*
to indicate both Python scalars and weakly-typed arrays.
Show code cell source
# @title
from jax import dtypes
import jax
import jax.numpy as jnp
import pandas as pd
from IPython import display
jax.config.update('jax_enable_x64', True)
jnp_dtypes = {
'b': jnp.bool_.dtype,
'u8': jnp.uint8.dtype, 'u16': jnp.uint16.dtype, 'u32': jnp.uint32.dtype, 'u64': jnp.uint64.dtype,
'i8': jnp.int8.dtype, 'i16': jnp.int16.dtype, 'i32': jnp.int32.dtype, 'i64': jnp.int64.dtype,
'bf16': jnp.bfloat16.dtype, 'f16': jnp.float16.dtype, 'f32': jnp.float32.dtype, 'f64': jnp.float64.dtype,
'c64': jnp.complex64.dtype, 'c128': jnp.complex128.dtype,
'i*': int, 'f*': float, 'c*': complex}
jnp_dtype_to_code = {val: key for key, val in jnp_dtypes.items()}
def make_jnp_zero(dtype):
if dtype in {int, float, complex}:
return dtype(0)
else:
return jnp.zeros((), dtype=dtype)
def jnp_result_code(dtype1, dtype2):
try:
out = jnp.add(make_jnp_zero(dtype1), make_jnp_zero(dtype2))
except TypeError:
return '-'
else:
if hasattr(out, 'aval') and out.aval.weak_type:
return out.dtype.kind + '*'
elif type(out) in {int, float, complex}:
return jnp_dtype_to_code[type(out)]
else:
return jnp_dtype_to_code[out.dtype]
grid = [[jnp_result_code(dtype1, dtype2)
for dtype2 in jnp_dtypes.values()]
for dtype1 in jnp_dtypes.values()]
table = pd.DataFrame(grid, index=jnp_dtypes.keys(), columns=jnp_dtypes.keys())
display.HTML(table.to_html())
b | u8 | u16 | u32 | u64 | i8 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i* | f* | c* | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
b | b | u8 | u16 | u32 | u64 | i8 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i* | f* | c* |
u8 | u8 | u8 | u16 | u32 | u64 | i16 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | u8 | f* | c* |
u16 | u16 | u16 | u16 | u32 | u64 | i32 | i32 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | u16 | f* | c* |
u32 | u32 | u32 | u32 | u32 | u64 | i64 | i64 | i64 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | u32 | f* | c* |
u64 | u64 | u64 | u64 | u64 | u64 | f* | f* | f* | f* | bf16 | f16 | f32 | f64 | c64 | c128 | u64 | f* | c* |
i8 | i8 | i16 | i32 | i64 | f* | i8 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i8 | f* | c* |
i16 | i16 | i16 | i32 | i64 | f* | i16 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i16 | f* | c* |
i32 | i32 | i32 | i32 | i64 | f* | i32 | i32 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i32 | f* | c* |
i64 | i64 | i64 | i64 | i64 | f* | i64 | i64 | i64 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i64 | f* | c* |
bf16 | bf16 | bf16 | bf16 | bf16 | bf16 | bf16 | bf16 | bf16 | bf16 | bf16 | f32 | f32 | f64 | c64 | c128 | bf16 | bf16 | c64 |
f16 | f16 | f16 | f16 | f16 | f16 | f16 | f16 | f16 | f16 | f32 | f16 | f32 | f64 | c64 | c128 | f16 | f16 | c64 |
f32 | f32 | f32 | f32 | f32 | f32 | f32 | f32 | f32 | f32 | f32 | f32 | f32 | f64 | c64 | c128 | f32 | f32 | c64 |
f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | c128 | c128 | f64 | f64 | c128 |
c64 | c64 | c64 | c64 | c64 | c64 | c64 | c64 | c64 | c64 | c64 | c64 | c64 | c128 | c64 | c128 | c64 | c64 | c64 |
c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 |
i* | i* | u8 | u16 | u32 | u64 | i8 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i* | f* | c* |
f* | f* | f* | f* | f* | f* | f* | f* | f* | f* | bf16 | f16 | f32 | f64 | c64 | c128 | f* | f* | c* |
c* | c* | c* | c* | c* | c* | c* | c* | c* | c* | c64 | c64 | c64 | c128 | c64 | c128 | c* | c* | c* |
JAX Type Promotion: jax.lax
#
jax.lax
is lower-level, and does not do any implicit promotion. Here we use i*
, f*
, c*
to indicate both Python scalars and weakly-typed arrays.
Show code cell source
# @title
from jax import dtypes
import jax
import jax.numpy as jnp
import pandas as pd
from IPython import display
jax.config.update('jax_enable_x64', True)
jnp_dtypes = {
'b': jnp.bool_.dtype,
'u8': jnp.uint8.dtype, 'u16': jnp.uint16.dtype, 'u32': jnp.uint32.dtype, 'u64': jnp.uint64.dtype,
'i8': jnp.int8.dtype, 'i16': jnp.int16.dtype, 'i32': jnp.int32.dtype, 'i64': jnp.int64.dtype,
'bf16': jnp.bfloat16.dtype, 'f16': jnp.float16.dtype, 'f32': jnp.float32.dtype, 'f64': jnp.float64.dtype,
'c64': jnp.complex64.dtype, 'c128': jnp.complex128.dtype,
'i*': int, 'f*': float, 'c*': complex}
jnp_dtype_to_code = {val: key for key, val in jnp_dtypes.items()}
def make_jnp_zero(dtype):
if dtype in {int, float, complex}:
return dtype(0)
else:
return jnp.zeros((), dtype=dtype)
def jnp_result_code(dtype1, dtype2):
try:
out = jax.lax.add(make_jnp_zero(dtype1), make_jnp_zero(dtype2))
except TypeError:
return '-'
else:
if hasattr(out, 'aval') and out.aval.weak_type:
return out.dtype.kind + '*'
elif type(out) in {int, float, complex}:
return jnp_dtype_to_code[type(out)]
else:
return jnp_dtype_to_code[out.dtype]
grid = [[jnp_result_code(dtype1, dtype2)
for dtype2 in jnp_dtypes.values()]
for dtype1 in jnp_dtypes.values()]
table = pd.DataFrame(grid, index=jnp_dtypes.keys(), columns=jnp_dtypes.keys())
display.HTML(table.to_html())
b | u8 | u16 | u32 | u64 | i8 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i* | f* | c* | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
b | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - |
u8 | - | u8 | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - |
u16 | - | - | u16 | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - |
u32 | - | - | - | u32 | - | - | - | - | - | - | - | - | - | - | - | - | - | - |
u64 | - | - | - | - | u64 | - | - | - | - | - | - | - | - | - | - | - | - | - |
i8 | - | - | - | - | - | i8 | - | - | - | - | - | - | - | - | - | - | - | - |
i16 | - | - | - | - | - | - | i16 | - | - | - | - | - | - | - | - | - | - | - |
i32 | - | - | - | - | - | - | - | i32 | - | - | - | - | - | - | - | - | - | - |
i64 | - | - | - | - | - | - | - | - | i64 | - | - | - | - | - | - | i64 | - | - |
bf16 | - | - | - | - | - | - | - | - | - | bf16 | - | - | - | - | - | - | - | - |
f16 | - | - | - | - | - | - | - | - | - | - | f16 | - | - | - | - | - | - | - |
f32 | - | - | - | - | - | - | - | - | - | - | - | f32 | - | - | - | - | - | - |
f64 | - | - | - | - | - | - | - | - | - | - | - | - | f64 | - | - | - | f64 | - |
c64 | - | - | - | - | - | - | - | - | - | - | - | - | - | c64 | - | - | - | - |
c128 | - | - | - | - | - | - | - | - | - | - | - | - | - | - | c128 | - | - | c128 |
i* | - | - | - | - | - | - | - | - | i64 | - | - | - | - | - | - | i* | - | - |
f* | - | - | - | - | - | - | - | - | - | - | - | - | f64 | - | - | - | f* | - |
c* | - | - | - | - | - | - | - | - | - | - | - | - | - | - | c128 | - | - | c* |
Jax and Jaxlib versioning#
Why are jax
and jaxlib
separate packages?#
We publish JAX as two separate Python wheels, namely jax
, which is a pure
Python wheel, and jaxlib
, which is a mostly-C++ wheel that contains libraries
such as:
XLA,
pieces of LLVM used by XLA,
MLIR infrastructure, such as the StableHLO Python bindings.
JAX-specific C++ libraries for fast JIT and PyTree manipulation.
We distribute separate jax
and jaxlib
packages because it makes it easy to
work on the Python parts of JAX without having to build C++ code or even having
a C++ toolchain installed. jaxlib
is a large library that is not easy for
many users to build, but most changes to JAX only touch Python code. By
allowing the Python pieces to be updated independently of the C++ pieces, we
improve the development velocity for Python changes.
In addition jaxlib
is not cheap to build, but we want to be able to iterate on
and run the JAX tests in environments without a lot of CPU, for example in
Github Actions or on a laptop. Many of our CI builds simply use a prebuilt
jaxlib
, rather than needing to rebuild the C++ pieces of JAX on each PR.
As we will see, distributing jax
and jaxlib
separately comes with a cost, in
that it requires that changes to jaxlib
maintain a backward compatible API.
However, we believe that on balance it is preferable to make Python changes
easy, even if at the cost of making C++ changes slightly harder.
How are jax
and jaxlib
versioned?#
Summary: jax
and jaxlib
share the same version number in the JAX source tree, but are released as separate Python packages.
When installed, the jax
package version must be greater than or equal to jaxlib
’s version,
and jaxlib
’s version must be greater than or equal to the minimum jaxlib
version specified by jax
.
Both jax
and jaxlib
releases are numbered x.y.z
, where x
is the major
version, and y
is the minor version, and z
is an optional patch release.
Version numbers must follow
PEP 440. Version number comparisons
are lexicographic comparisons on tuples of integers.
Each jax
release has an associated minimum jaxlib
version mx.my.mz
. The
minimum jaxlib
version for jax
version x.y.z
must be no greater than
x.y.z
.
For jax
version x.y.z
and jaxlib
version lx.ly.lz
to be compatible,
the following must hold:
The jaxlib version (
lx.ly.lz
) must be greater than or equal to the minimum jaxlib version (mx.my.mz
).The jax version (
x.y.z
) must be greater than or equal to the jaxlib version (lx.ly.lz
).
These constraints imply the following rules for releases:
jax
may be released on its own at any time, without updatingjaxlib
.If a new
jaxlib
is released, ajax
release must be made at the same time.
These
version constraints
are currently checked by jax
at import time, instead of being expressed as
Python package version constraints. jax
checks the jaxlib
version at
runtime rather than using a pip
package version constraint because we
provide separate jaxlib
wheels
for a variety of hardware and software versions (e.g, GPU, TPU, etc.). Since we
do not know which is the right choice for any given user, we do not want pip
to install a jaxlib
package for us automatically.
In the future, we hope to separate out the hardware-specific pieces of jaxlib
into separate plugins, at which point the minimum version could be expressed as
a Python package dependency. For now, we do provide
platform-specific extra requirements that install a compatible jaxlib version,
e.g., jax[cuda]
.
How can I safely make changes to the API of jaxlib
?#
jax
may drop compatibility with olderjaxlib
releases at any time, so long as the minimumjaxlib
version is increased to a compatible version. However, note that the minimumjaxlib
, even for unreleased versions ofjax
, must be a released version! This allows us to use releasedjaxlib
wheels in our CI builds, and allows Python developers to work onjax
at HEAD without ever needing to buildjaxlib
.For example, to remove an old backwards compatibility path in the
jax
Python code, it is sufficient to bump the minimum jaxlib version and then delete the compatibility path.jaxlib
may drop compatibility with olderjax
releases lower than its own release version number. The version constraints enforced byjax
would forbid the use of an incompatiblejaxlib
.For example, for
jaxlib
to drop a Python binding API used by an olderjax
version, thejaxlib
minor or major version number must be incremented.If possible, changes to the
jaxlib
should be made in a backwards-compatible way.In general
jaxlib
may freely change its API, so long as the rules aboutjax
being compatible with alljaxlib
s at least as new as the minimum version are followed. This implies thatjax
must always be compatible with at least two versions ofjaxlib
, namely, the last release, and the tip-of-tree version, effectively the next release. This is easier to do if compatibility is maintained, although incompatible changes can be made using version tests fromjax
; see below.For example, it is usually safe to add a new function to
jaxlib
, but unsafe to remove an existing function or to change its signature if currentjax
is still using it. Changes tojax
must work or degrade gracefully for alljaxlib
releases greater than the minimum up to HEAD.
Note that the compatibility rules here only apply to released versions of
jax
and jaxlib
. They do not apply to unreleased versions; that is, it is ok
to introduce and then remove an API from jaxlib
if it is never released, or if
no released jax
version uses that API.
How is the source to jaxlib
laid out?#
jaxlib
is split across two main repositories, namely the
jaxlib/
subdirectory in the main JAX repository
and in the
XLA source tree, which lives inside the XLA repository.
The JAX-specific pieces inside XLA are primarily in the
xla/python
subdirectory.
The reason that C++ pieces of JAX, such as Python bindings and runtime components, are inside the XLA tree is partially historical and partially technical.
The historical reason is that originally the
xla/python
bindings were envisaged as general purpose Python bindings that
might be shared with other frameworks. In practice this is increasingly less
true, and xla/python
incorporates a number of JAX-specific pieces and is
likely to incorporate more. So it is probably best to simply think of
xla/python
as part of JAX.
The technical reason is that the XLA C++ API is not stable. By keeping the
XLA:Python bindings in the XLA tree, their C++ implementation can be updated
atomically with the C++ API of XLA. It is easier to maintain backward and forward
compatibility of Python APIs than C++ ones, so xla/python
exposes Python APIs
and is responsible for maintaining backward compatibility at the Python
level.
jaxlib
is built using Bazel out of the jax
repository. The pieces of
jaxlib
from the XLA repository are incorporated into the build
as a Bazel submodule.
To update the version of XLA used during the build, one must update the pinned
version in the Bazel WORKSPACE
. This is done manually on an
as-needed basis, but can be overridden on a build-by-build basis.
How do we make changes across the jax
and jaxlib
boundary between releases?#
The jaxlib version is a coarse instrument: it only lets us reason about releases.
However, since the jax
and jaxlib
code is split across repositories that
cannot be updated atomically in a single change, we need to manage compatibility
at a finer granularity than our release cycle. To manage fine-grained
compatibility, we have additional versioning that is independent of the jaxlib
release version numbers.
We maintain an additional version number (_version
) in
xla_client.py
in the XLA repository.
The idea is that this version number, is defined in xla/python
together with the C++ parts of JAX, is also accessible to JAX Python as
jax._src.lib.xla_extension_version
, and must
be incremented every time that a change is made to the XLA/Python code that has
backwards compatibility implications for jax
. The JAX Python code can then use
this version number to maintain backwards compatibility, e.g.:
from jax._src.lib import xla_extension_version
# 123 is the new version number for _version in xla_client.py
if xla_extension_version >= 123:
# Use new code path
...
else:
# Use old code path.
Note that this version number is in addition to the constraints on the released version numbers, that is, this version number exists to help manage compatibility during development for unreleased code. Releases must also follow the compatibility rules given above.
Sequencing side-effects in JAX#
sharadmv@ May 9 2022
Overview#
When we write JAX code, we can usually pretend we’re writing single-threaded, eagerly-executed Python even though underneath the hood, JAX and its runtime may execute it asynchronously in the background. As long as we write pure (side-effect-free) code, these performance optimizations are usually invisible to us and don’t interfere with our single-threaded mental model. Asynchronous execution is great – we get performant, parallel code without having to think about it at all!
However, in the presence of side-effects, the illusion begins to break down and the cracks in our mental model start to show. Specifically, these differences show up when we think about the order in which side-effects happen.
In this design note, we explore the interaction between JAX’s execution model, and the ordering of side-effects. We also provide a way of enforcing a “single-threaded” ordering of effects.
Background#
When we write the following Python code
def f():
print("hello")
return 2
def g():
print("world")
return 3
f()
g()
we expect "hello"
to be printed before "world"
. This might seem obvious
but consider the following JAX code:
@partial(jax.jit, device=<device 0>)
def f():
return 2
@partial(jax.jit, device=<device 1>)
def g():
return 3
f()
g()
In many cases, JAX will execute f
and g
in parallel, dispatching
the computations onto different threads – g
might actually be executed
before f
. Parallel execution is a nice performance optimization, especially if copying
to and from a device is expensive (see the asynchronous dispatch note for more details).
In practice, however, we often don’t need to
think about asynchronous dispatch because we’re writing pure functions and only
care about the inputs and outputs of functions – we’ll naturally block on future
values.
However, now imagine that we have a jax.print
function that works inside of
JIT-ted JAX functions (host_callback.id_print
is an example of this). Let’s
return to the previous example except with prints in the mix.
@partial(jax.jit, device=<device 0>)
def f():
jax.print("hello")
return 2
@partial(jax.jit, device=<device 1>)
def g():
jax.print("world")
return 3
f()
g()
Thanks to asynchronous dispatch, we could actually see "world"
being printed
before "hello"
. The reordering of the print side-effects breaks the illusion
of a single-threaded execution model.
Another example of where side-effects can “reveal” out-of-order execution is when we compile JAX programs. Consider the following JAX code:
@jax.jit
def f(x):
jax.print("hello")
jax.print("world")
return x
Even though in Python, we’ve written the "hello"
print before the "world"
print,
a compiler like XLA is free to reorder them because there’s no explicit data-dependence between the prints.
Motivation#
We’d like to support “ordered” effects. When we say ordered, we mean that the effects
occur in the same order as we would if we were executing a single-threaded Python program.
This is our main desideratum. In the presence of explicit parallelism like pmap
or
user threads, we don’t need to maintain this behavior but at least if the user is not
explicitly requesting parallelism, we’d like to preserve a single-threaded ordering.
Before we dive in more, let’s first step back and ask ourselves if it is okay if we reorder effects in the name of performance, and conversely, do we need to enforce an ordering on effects at all? In some cases, we don’t need ordering. Maybe some side-effects shouldn’t adversely affect the performance of a JAX program. However, for other side-effects, we may want to enforce a single-threaded program order so users don’t get counterintuitive behavior. Consider a logging effect.
@jax.jit
def f(x, y):
log_value(x)
log_value(y)
f(1, 2)
If log
is mutating a global list, we might expect that we add x
before adding y
.
For a more strict effect, we may want the option to order the effects.
Enforcing ordered effects#
The main tool we have to enforce the ordering of computations is data-dependence.
Simply put, if a function g
has an input that is the output of a function f
,
f
must be executed before g
.
However, we may have side effects like prints that have no inputs at all so naively we couldn’t sequence them. We thus use tokens as a means of injecting artificial data-dependence into a computation.
What is a token? A token is just a dummy value that can be threaded in and out of a computation. By threading the same token in and out and several computations, we enforce that they have to happen in a certain order. Let’s take the previous print example and see what it would look like with tokens in the mix:
@jax.jit
def f(token, x):
token = jax.print(token, "hello")
token = jax.print(token, "world")
return token, x
If we rewrite jax.print
to take in and return a token, we have now sequenced
the two prints since the input to the second print depends on the output of the first print.
The actual value of token
can be anything really, but we’ll see in practice
that the tokens are invisible to users.
Runtime tokens vs. compiler tokens#
Here we will actually start talking about implementation details. In practice, we’ll need two separate types of tokens to sequence effects: one for each of the aforementioned sources of reordering. We’ll need runtime tokens to sequence asynchronously dispatched side-effecting computations and we’ll need compiler tokens to sequence effects within computations.
In practice, our computation will be rewritten to look like this:
@jax.jit
def f(runtime_token, x):
compiler_token = new_compiler_token()
compiler_token = jax.print(compiler_token, "hello")
compiler_token = jax.print(compiler_token, "world")
return runtime_token, x
Notice how the runtime tokens are only used at the JIT boundary and the compiler tokens are only within the compiled code. Compiler tokens are created during “lowering” (we convert Python code to a lower level representation like HLO or StableHLO) but runtime tokens need to be managed in Python since they’re being threaded in and out of JIT-ted functions.
Furthermore, notice that the runtime tokens are “disconnected” from the compiler tokens meaning there’s no data dependency between them. This could potentially be dangerous as if we will lose the data dependence between the bodies of two dispatched function calls. However, if we assume “strict execution” – i.e. a dispatched function will only start execution when all of its inputs are ready and all of it outputs will become ready at the same time – we are safe to create a fresh compiler token and return a non-output-dependent runtime token.
Managing runtime tokens#
To manage runtime tokens on behalf of the user, we’ll need to hook into JAX’s dispatch machinery. Whenever we call a JIT-ted function, we eventually bottom out in a function that looks like this:
def _execute(compiled_computation, *args):
outputs = compiled_computation.execute(*args)
return outputs
At this point we need to “inject” the runtime tokens into the computation and “extract” them from the computation’s outputs:
def _execute(compiled_computation, *args):
runtime_token = get_runtime_token() # Grab global token
runtime_token, *outputs = compiled_computation.execute(runtime_token, *args)
update_runtime_token(runtime_token) # Update global token
return outputs
What is runtime_token
exactly? Well we need to be able to pass it into a compiled_computation
,
which means it needs to be some sort of array (for now, since there’s no shared token representation
inside and outside compiled JAX code). In practice we can use a (0,)
-shaped array to minimize overheads.
We also need to think about the multiple device use case, e.g. the first example where we first call a JIT-ted function on device 0 and then one on device 1. In that case, we need to also copy the runtime token returned from the first computation (which lives on device 0) to device 1 so we can pass it into the second computation. If two subsequent computations share the same device, this copy is not necessary.
Adding compiler tokens#
When we lower Python code to HLO or StableHLO we need to create a token at the start of the computation and ensure they are available when we have side-effecting computations that need to be ordered. The side-effecting computations will take the token as input and return it as an output.
The implementation of this token threading involves upgrading the JAX lowering machinery to do this bookkeeping automatically. The main challenges involve dealing with higher-order primitives like call primitives and control-flow primitives. We won’t go into details on how to handle those in this design note.
Blocking on output tokens#
Adding support for runtime and compiler tokens for side-effecting computations is important for sequencing
but there’s also another subtle use-case for tokens, which is blocking on side-effecting computations.
Even if we don’t want a side-effecting computation to be ordered we may still want to wait on its
completion. Currently we have jax.block_until_ready
, which waits until a future value has its
result ready. However, with side-effecting computations, we may have functions that don’t have a return
value but are still executing a side-effect. Take the simple example here:
@jax.jit
def f():
jax.print("hello world")
return
f() # Executed asynchronously
This compiled computation takes no explicit inputs and has no explicit outputs. If it was an ordered print effect,
we could block on the returned runtime token, However,
when this is an unordered computation we don’t do any token threading. How do we wait for f()
to
finish executing when we have no output value to call block_until_ready
on? Well, we could apply our same
token strategy except we only return runtime tokens and don’t take them as inputs. This will give us
a value to block on that will only be ready once f()
is done being executed. We’ll call these tokens
output tokens. We end up with a function that looks like this:
@jax.jit
def f():
jax.print("hello world")
return new_runtime_token()
f() # Executed asynchronously
Underneath the hood, we’ll manage the output tokens in the same way we manage the runtime tokens but provide a method for users to block on the current set of output tokens. Unlike runtime tokens, output tokens need to be device-specific. Consider a single device use-case:
@jax.jit
def f():
jax.print("hello")
@jax.jit
def g():
jax.print("world")
f()
g()
Since f()
and g()
are executed on the same device, blocking on g()
’s output token
effectively blocks on f()
since (as of now!), the JAX runtime does not interleave computations
executed on the same device. We’ll have to revise this entire design if that changes, of course.
However, consider the two device use-case:
@partial(jax.jit, device=<device 0>)
def f():
jax.print("hello")
@partial(jax.jit, device=<device 1>)
def g():
jax.print("world")
f()
g()
Here we don’t want to explicitly sequence f()
and g()
but want to wait for both of them to finish.
We’ll need one output token for f()
and one for g()
and we’ll block on both of those tokens:
@partial(jax.jit, device=<device 0>)
def f():
jax.print("hello")
return new_runtime_token()
@partial(jax.jit, device=<device 1>)
def g():
jax.print("world")
return new_runtime_token()
t0 = f()
t1 = g()
block_until_ready((t0, t1))
We’ll thus need a per-device output token so we can avoid sequencing computations on different devices while offering the ability to block on side-effecting computations. We end up with the following (approximate) change to the JAX dispatch machinery:
def _execute(compiled_computation, *args):
output_token, *outputs = compiled_computation.execute(runtime_token, *args)
update_output_token(output_token, compiled_computation.device)
return outputs
We’ll also need to expose a function to that blocks on the output token:
def effects_barrier():
output_token.block_until_ready()
Note that blocking on output tokens may not be fairly common since most JAX computations will return a value to block on. However, output tokens are helpful for testing and profiling, and are good to support so that we have a consistent and cohesive effect system.
Some more details#
All of the aforementioned token management infrastructure will be thread-local. This means that each user thread will have their own independent stream of runtime tokens. Sequencing is only promised at a user thread level.
In practice, we have one runtime token per effect. Different instances of that effect will be sequenced. This is to avoid sequencing effectul computations that may not have any relation to each other. Technically this goes against our original goal though of enforcing a single-threaded Python program ordering, but this is a tradeoff that could be modulated by having both “effect”-specific tokens and “global” tokens.
jax.remat
/ jax.checkpoint
changes: what you need to know#
Contents#
What’s going on?#
As of #11830 we’re switching on a new implementation of jax.checkpoint()
, aka jax.remat()
(the two names are aliases of one another). For most code, there will be no changes. But there may be some observable differences in edge cases; see What are the possible issues after the upgrade?
How can I disable the change, and go back to the old behavior for now?#
In case you have a problem with this change, through version jax==0.3.16
it is possible to switch off the new implementation by setting the jax_new_checkpoint
config option to be False, in any one of these ways:
set the shell environment variable
JAX_NEW_CHECKPOINT=0
;execute
jax.config.update('jax_new_checkpoint', False)
;if you parse flags with
absl
, pass the--jax_new_checkpoint=False
option.
If you need to revert to the old implementation, please reach out on a GitHub issue so that we can make the new implementation work for you.
As of jax==0.3.17
the jax_new_checkpoint
config option is no longer
available. If you have an issue, please reach out on the issue
tracker so we can help fix it!
Why are we doing this?#
At the time of writing, JAX has two parallel implementations of jax.checkpoint
. The new one has been used for months (e.g. by Pax and Flaxformer/T5X) on an opt-in basis. But it hasn’t been on-by-default.
We want to switch the new implementation to be on-by-default, and then delete the old implementation. Using the new implementation, and removing the old implementation, gives users several benefits.
User-customizable rematerialization policies#
The main upside of the new implementation is a new feature corresponding to the policy
argument. The idea is to give precise user control over what intermediates get saved (versus rematerialized) during the forward pass of automatic differentiation. By exercising this control over the memory-usage vs recomputation tradeoff, users can get significant performance wins, especially in large models and in our LLM MLPerf submission!
The full documentation for this feature is still forthcoming, but here’s a quick example:
from functools import partial
import jax
def apply_layer(W, x):
return jnp.sin(jnp.dot(W, x))
@partial(jax.checkpoint, policy=jax.checkpoint_policies.checkpoint_dots)
def predict(params, x):
for W in params[:-1]:
x = apply_layer(W, x)
return jnp.dot(params[-1], x)
By applying jax.checkpoint
with policy=jax.checkpoint_policies.checkpoint_dots
here, we ensure that only the results of matrix multiplies are allowed to be saved during the forward pass. The Jacobian coefficient values from cos
applications, and the values of sin
applications needed to compute them, are not saved from the forward pass and are instead recomputed during the backward pass. (Policies like this one can be effective on TPUs, where elementwise computations are effectively free but results from the matrix unit are worth saving.)
Ability to rematerialize constants, not just operations with data dependence on arguments#
The old jax.checkpoint
implementation couldn’t actually rematerialize computations without a data dependence on arguments to the decorated function. Consider this toy example:
@jax.checkpoint
def f(x):
a = some_function(jnp.arange(10_000_000)) # `a` does not depend on `x`
return a * x
The old jax.checkpoint
implementation was forced to save the value of a
, which could require a lot of memory. The new jax.checkpoint
implementation can rematerialize rather than save the value of a
.
Significantly less Python overhead in some cases#
The new jax.checkpoint
incurs significantly less Python overhead in some cases. Simple overhead benchmarks got 10x faster. These overheads only arise in eager op-by-op execution, so in the common case of using a jax.checkpoint
under a jax.jit
or similar the speedups aren’t relevant. But still, nice!
Enabling new JAX features by simplifying internals#
This change unlocks big future user benefits too, like custom batching rules (the vmap
analogue of custom_vjp
) and a forward-differentiable upgrade to custom_vjp
. It also significantly reduces complexity in parts of the JAX codebase, which will be good for maintainability and bug-fixing in general.
What are the possible issues after the upgrade?#
Innocuous numerical changes#
Because the new implementation can rematerialize more computations, including those of potentially large constants, some code may see small numerical changes. The magnitude of any numerical changes should be within the range we expect from changing compiler optimizations, like reordering of floating point operations. But some overly tight test tolerances may need to be slightly relaxed.
The concrete=True
option is removed.#
The old jax.checkpoint
implementation had a boolean concrete
option, which allowed tracing on concrete Python values (rather than delaying all computations and only tracing on abstracted values). That option was seldom used, and in the cases where it was used there were much simpler alternatives. So we removed the option in the new jax.checkpoint
.
For example, the overwhelmingly common use of concrete=True
in Google code was to support passing an argument like is_training
:
@partial(jax.checkpoint, concrete=True) # OLD jax.checkpoint API
def foo(x, is_training):
if is_training:
return g(x)
else:
return h(x)
With the new jax.checkpoint
implementation, we can accomplish the same using the static\_argnums
option:
@partial(jax.checkpoint, static_argnums=(1,)) # NEW jax.checkpoint API
def foo(x, is_training):
if is_training:
...
If jax.numpy
operations need to be performed on static arguments, with their numerical results computed during Python tracing rather than delayed, we can use static_argnums
with jax.ensure_compile_time_eval()
. But it seems unlikely that you’d need this!
Type Annotation Roadmap for JAX#
Author: jakevdp
Date: August 2022
Background#
Python 3.0 introduced optional function annotations (PEP 3107), which were later codified for use in static type checking around the release of Python 3.5 (PEP 484). To some degree, type annotations and static type checking have become an integral part of many Python development workflows, and to this end we have added annotations in a number of places throughout the JAX API. The current state of type annotations in JAX is a bit patchwork, and efforts to add more have been hampered by more fundamental design questions. This doc attempts to summarize those issues and generate a roadmap for the goals and non-goals of type annotations in JAX.
Why do we need such a roadmap? Better/more comprehensive type annotations are a frequent request from users, both internally and externally. In addition, we frequently receive pull requests from external users (for example, PR #9917, PR #10322) seeking to improve JAX’s type annotations: it’s not always clear to the JAX team member reviewing the code whether such contributions are beneficial, particularly when they introduce complex Protocols to address the challenges inherent to full-fledged annotation of JAX’s use of Python. This document details JAX’s goals and recommendations for type annotations within the package.
Why type annotations?#
There are a number of reasons that a Python project might wish to annotate their code-base; we’ll summarize them in this document as Level 1, Level 2, and Level 3.
Level 1: Annotations as documentation#
When originally introduced in PEP 3107, type annotations were motivated partly by the ability to use them as concise, inline documentation of function parameter types and return types. JAX has long utilized annotations in this manner; an example is the common pattern of creating type names aliased to Any
. An example can be found in lax/slicing.py
[source]:
Array = Any
Shape = core.Shape
def slice(operand: Array, start_indices: Sequence[int],
limit_indices: Sequence[int],
strides: Optional[Sequence[int]] = None) -> Array:
...
For the purposes of static type checking, this use of Array = Any
for array type annotations puts no constraint on the argument values (Any
is equivalent to no annotation at all), but it does serve as a form of useful in-code documentation for the developer.
For the sake of generated documentation, the name of the alias gets lost (the HTML docs for jax.lax.slice
report operand as type Any
), so the documentation benefit does not go beyond the source code (though we could enable some sphinx-autodoc
options to improve this: See autodoc_type_aliases).
A benefit of this level of type annotation is that it is never wrong to annotate a value with Any
, so it will provide a concrete benefit to developers and users in the form of documentation, without added complexity of satisfying the stricter needs of any particular static type checker.
Level 2: Annotations for intelligent autocomplete#
Many modern IDEs take advantage of type annotations as inputs to intelligent code completion systems. One example of this is the Pylance extension for VSCode, which uses Microsoft’s pyright static type checker as a source of information for VSCode’s IntelliSense completions.
This use of type checking requires going further than the simple aliases used above; for example, knowing that the slice
function returns an alias of Any
named Array
does not add any useful information to the code completion engine. However, were we to annotate the function with a DeviceArray
return type, the autocomplete would know how to populate the namespace of the result, and thus be able to suggest more relevant autocompletions during the course of development.
JAX has begun to add this level of type annotation in a few places; one example is the jnp.ndarray
return type within the jax.random
package [source]:
def shuffle(key: KeyArray, x: Array, axis: int = 0) -> jnp.ndarray:
...
In this case jnp.ndarray
is an abstract base class that forward-declares the attributes and methods of JAX arrays (see source), and so Pylance in VSCode can offer the full set of autocompletions on results from this function. Here is a screenshot showing the result:
Listed in the autocomplete field are all methods and attributes declared by the abstract ndarray
class.
We’ll discuss further below why it was necessary to create this abstract class rather than annotating with DeviceArray
directly.
Level 3: Annotations for static type-checking#
These days, static type-checking often is the first thing people think of when considering the purpose of type annotations in Python code. While Python does not do any runtime checking of types, several mature static type checking tools exist that can do this as part of a CI test suite. The most important ones for JAX are the following:
python/mypy is more or less the standard in the open Python world. JAX currently runs mypy on a subset of source files within the Github Actions CI checks.
google/pytype is Google’s static type checker, and projects which depend on JAX within Google frequently use this.
microsoft/pyright is important as the static type checker used within VSCode for the Pylance completions mentioned previously.
Full static type checking is the strictest of all the type annotation applications, because it will surface an error any time your type annotations are not precisely correct.
On the one hand, this is nice because your static type analysis may catch faulty type annotations (for example, a case where a DeviceArray
method is missing from the jnp.ndarray
abstract class).
On the other hand, this strictness can make the type checking process very brittle in packages that often rely on duck-typing rather than strict type-safe APIs.
You’ll currently find code comments like #type: ignore
(for mypy) or #pytype: disable
(for pytype) peppered throughout the JAX codebase in several hundred places.
These typically represent cases where typing problems have arisen; they may be inaccuracies in JAX type annotations, or inaccuracies in the static type checker’s ability to correctly follow the control flow in the code.
On occasion, they are due to real & subtle bugs in the behavior of pytype or mypy.
In rare cases, they may be due to the fact that JAX uses Python patterns that are difficult or even impossible to express in terms of Python’s static type annotation syntax.
Type annotation challenges for JAX#
JAX currently has type annotations that are a mixture of different styles, and aimed at all three levels of type annotation discussed above. Partly, this comes from the fact that JAX’s source code poses a number of unique challenges for Python’s type annotation system. We’ll outline them here.
Challenge 1: pytype, mypy and developer friction#
One challenge JAX currently faces is that package development must satisfy the constraints of two different static type checking systems, pytype
(used by internal CI and internal Google projects) and mypy
(used by external CI and external dependencies).
Although the two type checkers have broad overlap in their behavior, each presents its own unique corner cases, as evidenced by the numerous #type: ignore
and #pytype: disable
statements throughout the JAX codebase.
This creates friction in development: internal contributors may iterate until tests pass, only to find that on export their pytype-approved code falls afoul of mypy.
For external contributors, it’s often the opposite: a recent example is #9596 which had to be rolled-back after it failed internal Google pytype checks.
Each time we move a type annotation from Level 1 (Any
everywhere) to Level 2 or 3 (stricter annotations), it introduces more potential for such frustrating developer experiences.
Challenge 2: array duck-typing#
One particular challenge for annotating JAX code is its heavy use of duck-typing. An input to a function marked Array
in general could be one of many different types: a JAX DeviceArray
, a NumPy np.ndarray
, a NumPy scalar, a Python scalar, a Python sequence, an object with an __array__
attribute, an object with a __jax_array__
attribute, or any flavor of jax.Tracer
.
For this reason, simple annotations like def func(x: DeviceArray)
will not be sufficient, and will lead to false positives for many valid uses.
This means that type annotations for JAX functions will not be short or trivial, but we would have to effectively develop a set of JAX-specific typing extensions similar to those in the numpy.typing
package.
Challenge 3: transformations and decorators#
JAX’s Python API relies heavily on function transformations (jit()
, vmap()
, grad()
, etc.), and this type of API poses a particular challenge for static type analysis.
Flexible annotation for decorators has been a long-standing issue in the mypy package, which was only recently resolved by the introduction of ParamSpec
, discussed in PEP 612 and added in Python 3.10.
Because JAX follows NEP 29, it cannot rely on Python 3.10 features until sometime after mid-2024.
In the meantime, Protocols can be used as a partial solution to this (JAX added this for jit and other methods in #9950) and ParamSpec is possible to use via the typing_extensions
package (a prototype is in #9999) though this currently reveals fundamental bugs in mypy (see python/mypy#12593).
All that to say: it’s not yet clear that the API of JAX’s function transforms can be suitably annotated within the current constraints of Python type annotation tools.
Challenge 4: array annotation lack of granularity#
Another challenge here is common to all array-oriented APIs in Python, and has been part of the JAX discussion for several years (see #943). Type annotations have to do with the Python class or type of an object, whereas in array-based languages often the attributes of the class are more important. In the case of NumPy, JAX, and similar packages, often we would wish to annotate particular array shapes and data types.
For example, the arguments to the jnp.linspace
function must be scalar values, but in JAX scalars are represented by zero-dimensional arrays.
So in order for annotations to not raise false positives, we must allow these arguments to be arbitrary arrays.
Another example is the second argument to jax.random.choice
, which must have dtype=int
when shape=()
.
Python has a plan to enable type annotations with this level of granularity via Variadic Type Generics (see PEP 646, slated for Python 3.11) but like ParamSpec
, support for this feature will take a while to stabilize.
There are some third-party projects that may help in the meantime, in particular google/jaxtyping, but this uses non-standard annotations and may not be suitable for annotating the core JAX library itself. All told, the array-type-granularity challenge is less of an issue than the other challenges, because the main effect is that array-like annotations will be less specific than they otherwise could be.
Challenge 5: imprecise APIs inherited from NumPy#
A large part of JAX’s user-facing API is inherited from NumPy within the jax.numpy
submodule.
NumPy’s API was developed years before static type checking became part of the Python language, and follows Python’s historic recommendations to use a duck-typing/EAFP coding style, in which strict type-checking at runtime is discouraged. As a concrete example of this, consider the numpy.tile()
function, which is defined like this:
def tile(A, reps):
try:
tup = tuple(reps)
except TypeError:
tup = (reps,)
d = len(tup)
...
Here the intent is that reps
would contain either an int
or a sequence of int
values, but the implementation allows tup
to be any iterable.
When adding annotations to this kind of duck-typed code, we could take one of two routes:
We may choose to annotate the intent of the function’s API, which here might be something like
reps: Union[int, Sequence[int]]
.Conversely, we may choose to annotate the implementation of the function, which here might look something like
reps: Union[ConvertibleToInt, Iterable[ConvertibleToInt]]
whereConvertibleToInt
is a special protocol that covers the exact mechanism by which our function converts the inputs to integers (i.e. via__int__
, via__index__
, via__array__
, etc.). Note also here that in a strict sense,Iterable
is not sufficient here because there are objects in Python that duck-type as iterables but do not satisfy a static type check againstIterable
(namely, an object that is iterable via__getitem__
rather than__iter__
.)
The advantage of #1, annotating intent, is that the annotations are more useful to the user in communicating the API contract; while for the developer the flexibility leaves room for refactoring when necessary. The down-side (particularly for gradually-typed APIs like JAX’s) is that it’s quite likely that user code exists which runs correctly, but would be flagged as incorrect by a type checker.
Gradual typing of an existing duck-typed API means that the current annotation is implicitly Any
, so changing this to a stricter type may present to users as a breaking change.
Broadly speaking, annotating intent better serves Level 1 type checking, while annotating implementation better serves Level 3, while Level 2 is more of a mixed bag (both intent and implementation are important when it comes to annotations in IDEs).
JAX type annotation roadmap#
With this framing (Level 1/2/3) and JAX-specific challenges in mind, we can begin to develop our roadmap for implementing consistent type annotations across the JAX project.
Guiding Principles#
For JAX type annotation, we will be guided by the following principles:
Purpose of type annotations#
We would like to support full, Level 1, 2, and 3 type annotation as far as possible. In particular, this means that we should have restrictive type annotations on both inputs and outputs to public API functions.
Annotate for intent#
JAX type annotations should in general indicate the intent of APIs, rather than the implementation, so that the annotations become useful to communicate the contract of the API. This means that at times inputs that are valid at runtime may not be recognized as valid by the static type checker (one example might be an arbitrary iterator passed in place of a shape that is annotated as Shape = Sequence[int]
).
Inputs should be permissively-typed#
Inputs to JAX functions and methods should be typed as permissively as is reasonable: for example, while shapes are typically tuples, functions that accept a shape should accept arbitrary sequences. Similarly, functions that accept a dtype need not require an instance of class np.dtype
, but rather any dtype-convertible object. This might include strings, built-in scalar types, or scalar object constructors such as np.float64
and jnp.float64
. In order to make this as uniform as possible across the package, we will add a jax.typing
module with common type specifications, starting with broad categories such as:
ArrayLike
would be a union of anything that can be implicitly converted into an array: for example, jax arrays, numpy arrays, JAX tracers, and python or numpy scalarsDTypeLike
would be a union of anything that can be implicitly converted into a dtype: for example, numpy dtypes, numpy dtype objects, jax dtype objects, strings, and built-in types.ShapeLike
would be a union of anything that could be converted into a shape: for example, sequences of integer or integer-like objects.etc.
Note that these will in general be simpler than the equivalent protocols used in numpy.typing
. For example, in the case of DTypeLike
, JAX does not support structured dtypes, so JAX can use a simpler implementation. Similarly, in ArrayLike
, JAX generally does not support list or tuple inputs in place of arrays, so the type definition will be simpler than the NumPy analog.
Outputs should be strictly-typed#
Conversely, outputs of functions and methods should be typed as strictly as possible: for example, for a JAX function that returns an array, the output should be annotated with something similar to jnp.ndarray
rather than ArrayLike
. Functions returning a dtype should always be annotated np.dtype
, and functions returning a shape should always be Tuple[int]
or a strictly-typed NamedShape equivalent. For this purpose, we will implement in jax.typing
several strictly-typed analogs of the permissive types mentioned above, namely:
Array
orNDArray
(see below) for type annotation purposes is effectively equivalent toUnion[Tracer, jnp.ndarray]
and should be used to annotate array outputs.DType
is an alias ofnp.dtype
, perhaps with the ability to also represent key types and other generalizations used within JAX.Shape
is essentiallyTuple[int, ...]
, perhaps with some additional flexibility to account for dynamic shapes.NamedShape
is an extension ofShape
that allows for named shapes as used internally in JAX.etc.
We will also explore whether the current implementation of jax.numpy.ndarray
can be dropped in favor of making ndarray
an alias of Array
or similar.
Err toward simplicity#
Aside from common typing protocols gathered in jax.typing
, we should err on the side of simplicity. We should avoid constructing overly-complex protocols for arguments passed to API functions, and instead use simple unions such as Union[simple_type, Any]
in the case that the full type specification of the API cannot be succinctly specified. This is a compromise that achieves the goals of Level 1 and 2 annotations, while punting on Level 3 in favor of avoiding unnecessary complexity.
Avoid unstable typing mechanisms#
In order to not add undue development friction (due to the internal/external CI differences), we would like to be conservative in the type annotation constructs we use: in particular, when it comes to recently-introduced mechanisms such as ParamSpec
(PEP 612) and Variadic Type Generics (PEP 646), we would like to wait until support in mypy and other tools matures and stabilizes before relying on them.
One impact of this is that for the time being, when functions are decorated by JAX transformations like jit
, vmap
, grad
, etc. JAX will effectively strip all annotations from the decorated function.
While this is unfortunate, at the time of this writing mypy has a laundry-list of incompatibilities with the potential solution offered by ParamSpec
(see ParamSpec
mypy bug tracker), and we therefore judge it as not ready for full adoption in JAX at this time.
We will revisit this question in the future once support for such features stabilizes.
Similarly, for the time being we will avoid adding the more complex & granular array type annotations offered by the jaxtyping project. This is a decision we could revisit at a future date.
Array
Type Design Considerations#
As mentioned above, type annotation of arrays in JAX poses a unique challenge because of JAX’s extensive use of duck-typing, i.e. passing and returning Tracer
objects in place actual arrays within jax transformations.
This becomes increasingly confusing because objects used for type annotation often overlap with objects used for runtime instance checking, and may or may not correspond to the actual type hierarchy of the objects in question.
For JAX, we need to provide duck-typed objects for use in two contexts: static type annotations and runtime instance checks.
The following discussion will assume that jax.Array
is the runtime type on-device arrays, which is not yet the case but will be once the work in #12016 is complete.
Static type annotations#
We need to provide an object that can be used for duck-typed type annotations.
Assuming for the moment that we call this object ArrayAnnotation
, we need a solution which satisfies mypy
and pytype
for a case like the following:
@jit
def f(x: ArrayAnnotation) -> ArrayAnnotation:
assert isinstance(x, core.Tracer)
return x
This could be accomplished via a number of approaches, for example:
Use a type union:
ArrayAnnotation = Union[Array, Tracer]
Create an interface file that declares
Tracer
andArray
should be treated as subclasses ofArrayAnnotation
.Restructure
Array
andTracer
so thatArrayAnnotation
is a true base class of both.
Runtime instance checks#
We also must provide an object that can be used for duck-typed runtime isinstance
checks.
Assuming for the moment that we call this object ArrayInstance
, we need a solution that passes the following runtime check:
def f(x):
return isinstance(x, ArrayInstance)
x = jnp.array([1, 2, 3])
assert f(x) # x will be an array
assert jit(f)(x) # x will be a tracer
Again, there are a couple mechanisms that could be used for this:
override
type(ArrayInstance).__instancecheck__
to returnTrue
for bothArray
andTracer
objects; this is howjnp.ndarray
is currently implemented (source).define
ArrayInstance
as an abstract base class and dynamically register it toArray
andTracer
restructure
Array
andTracer
so thatArrayInstance
is a true base class of bothArray
andTracer
A decision we need to make is whether ArrayAnnotation
and ArrayInstance
should be the same or different objects. There is some precedent here; for example in the core Python language spec, typing.Dict
and typing.List
exist for the sake of annotation, while the built-in dict
and list
serve the purposes of instance checks.
However, Dict
and List
are deprecated in newer Python versions in favor of using dict
and list
for both annotation and instance checks.
Following NumPy’s lead#
In NumPy’s case, np.typing.NDArray
serves the purpose of type annotations, while np.ndarray
serves the purpose of instance checks (as well as array type identity).
Given this, it may be reasonable to conform to NumPy’s precedent and implement the following:
jax.Array
is the actual type of on-device arrays.jax.typing.NDArray
is the object used for duck-typed array annotations.jax.numpy.ndarray
is the object used for duck-typed array instance checks.
This might feel somewhat natural to NumPy power-users, however this trifurcation would likely be a source of confusion: the choice of which to use for instance checks and annotations is not immediately clear.
Unifying instance checks and annotation#
Another approach would be to unify type checking and annotation via override mechanisms mentioned above.
Option 1: Partial unification#
A partial unification might look like this:
jax.Array
is the actual type of on-device arrays.jax.typing.Array
is the object used for duck-typed array annotations (via.pyi
interfaces onArray
andTracer
).jax.typing.Array
is also the object used duck-typed instance checks (via an__isinstance__
override in its metaclass)
In this approach, jax.numpy.ndarray
would become a simple alias jax.typing.Array
for backward compatibility.
Option 2: Full unification via overrides#
Alternatively, we could opt for full unification via overrides:
jax.Array
is the actual type of on-device arrays.jax.Array
is also the object used for duck-typed array annotations (via a.pyi
interface onTracer
)jax.Array
is also the object used for duck-typed instance checks (via an__isinstance__
override in its metaclass)
Here, jax.numpy.ndarray
would become a simple alias jax.Array
for backward compatibility.
Option 3: Full unification via class hierarchy#
Finally, we could opt for full unification via restructuring of the class hierarchy and replacing duck-typing with OOP object hierarchies:
jax.Array
is the actual type of on-device arraysjax.Array
is also the object used for array type annotations, by ensuring thatTracer
inherits fromjax.Array
jax.Array
is also the object used for instance checks, via the same mechanism
Here jnp.ndarray
could be an alias for jax.Array
.
This final approach is in some senses the most pure, but it is somewhat forced from an OOP design standpoint (Tracer
is an Array
?).
Option 4: Partial unification via class hierarchy#
We could make the class hierarchy more sensible by making Tracer
and the class for
on-device arrays inherit from a common base class. So, for example:
jax.Array
is a base class forTracer
as well as the actual type of on-device arrays, which might bejax._src.ArrayImpl
or similar.jax.Array
is the object used for array type annotationsjax.Array
is also the object used for instance checks
Here jnp.ndarray
would be an alias for Array
.
This may be purer from an OOP perspective, but compared to Options 2 and 3 it drops the notion
that type(x) is jax.Array
will evaluate to True.
Evaluation#
Considering the overall strengths and weaknesses of each potential approach:
From a user perspective, the unified approaches (options 2 and 3) are arguably best, because they remove the cognitive overhead involved in remembering which objects to use for instance checks or annotations:
jax.Array
is all you need to know.However, both options 2 and 3 introduce some strange and/or confusing behavior. Option 2 depends on potentially confusing overrides of instance checks, which are not well supported for classes defined in pybind11. Option 3 requires
Tracer
to be a subclass array. This breaks the inheritance model, because it would requireTracer
objects to carry all the baggage ofArray
objects (data buffers, sharding, devices, etc.)Option 4 is purer in an OOP sense, and avoids the need for any overrides of typical instance check or type annotation behavior. The tradeoff is that the actual type of on-device arrays becomes something separate (here
jax._src.ArrayImpl
). But the vast majority of users would never have to touch this private implementation directly.
There are different tradeoffs here, but after discussion we’ve landed on Option 4 as our way forward.
Implementation Plan#
To move forward with type annotations, we will do the following:
Iterate on this JEP doc until developers and stakeholders are bought-in.
Create a private
jax._src.typing
(not providing any public APIs for now) and put in it the first level of simple types mentioned above:Alias
Array = Any
for the time being, as this will take a bit more thought.ArrayLike
: a Union of types valid as inputs to normaljax.numpy
functionsDType
/DTypeLike
(Note: numpy uses camel-casedDType
; we should follow this convention for ease of use)Shape
/NamedShape
/ShapeLike
The beginnings of this are done in #12300.
Begin work on a
jax.Array
base class that follows Option 4 from the previous section. Initially this will be defined in Python, and use the dynamic registration mechanism currently found in thejnp.ndarray
implementation to ensure correct behavior ofisinstance
checks. Apyi
override for each tracer and array-like class would ensure correct behavior for type annotations.jnp.ndarray
could then be make into an alias ofjax.Array
As a test, use these new typing definitions to comprehensively annotate functions within
jax.lax
according to the guidelines above.Continue adding additional annotations one module at a time, focusing on public API functions.
In parallel, begin re-implementing a
jax.Array
base class in pybind11, so thatArrayImpl
andTracer
can inherit from it. Use apyi
definition to ensure static type checkers recognize the appropriate attributes of the class.Once
jax.Array
andjax._src.ArrayImpl
have fully landed, remove these temporary Python implementations.When all is finalized, create a public
jax.typing
module that makes the above types available to users, along with documentation of annotation best practices for code using JAX.
We will track this work in #12049, from which this JEP gets its number.
shmap
(shard_map
) for simple per-device code#
sholto@, sharadmv@, jekbradbury@, zhangqiaorjc@, mattjj@
January 2023
Motivation#
JAX supports two schools of thought for multi-device programming:
Compiler, take the wheel! Let the compiler automatically partition bulk array functions over devices.
Just let me write what I mean, damnit! Give me per-device code and explicit communication collectives.
We need great APIs for both, and rather than being mutually exclusive alternatives, they need to compose with each other.
With pjit
(now just jit
) we have a next-gen
API
for the first school. But we haven’t quite leveled-up the second school. pmap
follows the second school, but over time we found it has fatal
flaws. xmap
solved those flaws,
but it doesn’t quite give us per-device shapes, and it includes several other
big ideas too. Meanwhile, new demands for per-device explicit-collectives
programming have emerged, like in Efficiently Scaling Transformer
Inference.
We can level-up the second school with shmap
. shmap
is:
a simple multi-device parallelism API which lets us write per-device code with explicit collectives, where logical shapes match per-device physical buffer shapes and collectives correspond exactly to cross-device communication;
a specialization of
xmap
with scaled-back features and a few tweaks;a fairly direct surfacing of the XLA SPMD Partitioner’s ‘manual’ mode;
a fun-to-say Seussian name which could stand for
shard_map
,shpecialized_xmap
,sholto_map
, orsharad_map
.
For pjit
users, shmap
is a complementary tool. It can be used inside a
pjit
computation to drop temporarily into a “manual collectives” mode, like an
escape hatch from the compiler’s automatic partitioning. That way, users get the
convenience and familiar just-NumPy programming model of pjit
for most of their
code, along with the ability to hand-optimize collective communication with
shmap
wherever it’s needed. It’s the best of both worlds!
For pmap
users, shmap
is a strict upgrade. It’s more expressive,
performant, and composable with other JAX APIs, without making basic batch data
parallelism any harder.
For more on practical use, you can jump to When should you use shmap
and when
should you use pjit
?.
If you’re wondering why we need a new thing at all, or what
the problems with pmap
are, jump to Why don’t pmap
or xmap
already solve
this?.
Or keep reading the next section to see some shmap
examples and the API spec.
So, let’s see shmap
!#
TL;DR example (with a more detailed explanation to follow)#
Sho shick:
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('i', 'j'))
a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 * 32.).reshape(16, 32)
@partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P('j', None)),
out_specs=P('i', None))
def matmul_basic(a_block, b_block):
# a_block: f32[2, 8]
# b_block: f32[8, 32]
z_partialsum = jnp.dot(a_block, b_block)
z_block = jax.lax.psum(z_partialsum, 'j')
return z_block
c = matmul_basic(a, b) # c: f32[8, 32]
Notice:
no nesting needed (or
axis_index_groups
) for multiple axes of parallelism, unlikepmap
;no reshapes in the caller, unlike
pmap
and hard-xmap
, and logical shapes correspond to per-device physical shapes, unlike (non-hard)xmap
;precise device placement control by using
mesh
, unlikepmap
;there’s only one set of axis names for logical and physical, unlike
xmap
;the result is a
jax.Array
which could be efficiently passed to apjit
, unlikepmap
;this same code works efficiently inside a
pjit
/jit
, unlikepmap
;this code works eagerly, so we can
pdb
in the middle and print values, unlikexmap
’s current implementation (though by designxmap
without the sequential schedule can in principle work eagerly too).
Here’s another matmul variant with a fully sharded result:
@partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P('j', None)),
out_specs=P('i', 'j'))
def matmul_reduce_scatter(a_block, b_block):
# c_partialsum: f32[8/X, 32]
c_partialsum = jnp.matmul(a_block, b_block)
# c_block: f32[8/X, 32/Y]
c_block = jax.lax.psum_scatter(c_partialsum, 'j', scatter_dimension=1, tiled=True)
return c_block
c = matmul_reduce_scatter(a, b)
Slow down, start with the basics!#
Rank-reducing vs rank-preserving maps over array axes#
We can think of pmap
(and vmap
and xmap
) as unstacking each array input
along an axis (e.g. unpacking a 2D matrix into its 1D rows), applying its body
function to each piece, and stacking the results back together, at least when
collectives aren’t involved:
pmap(f, in_axes=[0], out_axes=0)(xs) == jnp.stack([f(x) for x in xs])
For example, if xs
had shape f32[8,5]
then each x
has shape f32[5]
, and
if each f(x)
has shape f32[3,7]
then the final stacked result pmap(f)(xs)
has shape f32[8,3,7]
. That is, each application of the body function f
takes
as argument inputs with one fewer axis than the corresponding argument to
pmap(f)
. We can say these are rank-reducing maps with unstacking/stacking of
inputs/outputs.
The number of logical applications of f
is determined by the size of the input
axis being mapped over: for example, if we map over an input axis of size 8,
semantically we get 8 logical applications of the function, which for pmap
always correspond to 8 devices physically computing them.
In contrast, shmap
does not have this rank-reducing behavior. Instead, we can
think of it as slicing (or “unconcatenating”) along input axes into blocks,
applying the body function, and concatenating the results back together (again
when collectives aren’t involved):
devices = np.array(jax.devices()[:4])
m = Mesh(devices, ('i',)) # mesh.shape['i'] = 4
shard_map(f, m, in_specs=P('i'), out_specs=P('i'))(y)
==
jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, 4)])
Recall that jnp.split
slices its input into equally-sized blocks with the same
rank, so that if in the above example y
has shape f32[8,5]
then each y_blk
has shape f32[2,5]
, and if each f(y_blk)
has shape f32[3,7]
then the final
concatenated result shard_map(f, ...)(y)
has shape f32[12,7]
. So shmap
(shard_map
) maps over shards, or blocks, of its inputs. We can say it’s a
rank-preserving map with unconcatenating/concatenating of its inputs/outputs.
The number of logical applications of f
is determined by the mesh size, not by
any input axis size: for example, if we have a mesh of total size 4 (i.e. over 4
devices) then semantically we get 4 logical applications of the function,
corresponding to the 4 devices physically computing them.
Controlling how each input is split (unconcatenated) and tiled with in_specs
#
Each of the in_specs
identifies some of the corresponding input array’s axes
with mesh axes by name using PartitionSpec
s, representing how to split (or
unconcatenate) that input into the blocks to which the body function is applied.
That identification determines the shard sizes; when an input axis is identified
with a mesh axis, the input is split (unconcatenated) along that logical axis
into a number of pieces equal to the corresponding mesh axis size. (It’s an
error if the corresponding mesh axis size does not evenly divide the input array
axis size.) If an input’s pspec does not mention a mesh axis name, then there’s
no splitting over that mesh axis. For example:
devices = np.array(jax.devices())
m = Mesh(devices.reshape(4, 2), ('i', 'j'))
@partial(shard_map, mesh=m, in_specs=P('i', None), out_specs=P('i', 'j'))
def f1(x_block):
print(x_block.shape)
return x_block
x1 = np.arange(12 * 12).reshape(12, 12)
y = f1(x1) # prints (3,12)
Here, because the input pspec did not mention the mesh axis name 'j'
, no input
array axis is split over that mesh axis; similarly, because the second axis of
the input array is not identified with (and hence split over) any mesh axis,
application of f1
gets a full view of the input along that axis.
When a mesh axis is not mentioned in an input pspec, we can always rewrite to a
less efficient program where all mesh axes are mentioned but the caller performs
a jnp.tile
, for example:
@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P('i', 'j'))
def f2(x_block):
print(x_block.shape)
return x_block
x = np.arange(12 * 12).reshape(12, 12)
x_ = jnp.tile(x, (1, mesh.axis_size['j'])) # x_ has shape (12, 24)
y = f2(x_) # prints (3,12), and f1(x) == f2(x_)
In other words, because each input pspec can mention each mesh axis name zero or
one times, rather than having to mention each name exactly once, we can say that
in addition to the jnp.split
built into its input, shard_map
also has a
jnp.tile
built into its input, at least logically (though the tiling may not
need to be carried out physically, depending on the arguments’ physical sharding
layout). The tiling to use is not unique; we could also have tiled along the
first axis, and used the pspec P(('j', 'i'), None)
.
Physical data movement is possible on inputs, as each device needs to have a copy of the appropriate data.
Controlling how each output assembled by concatenation, block transposition, and untiling using out_specs
#
Analogously to the input side, each of the out_specs
identifies some of the
corresponding output array’s axes with mesh axes by name, representing how the
output blocks (one for each application of the body function, or equivalently
one for each physical device) should be assembled back together to form the
final output value. For example, in both the f1
and f2
examples above the
out_specs
indicate we should form the final output by concatenating together
the block results along both axes, resulting in both cases an array y
of shape
(12,24)
. (It’s an error if an output shape of the body function, i.e. an
output block shape, has a rank too small for the concatenation described by the
corresponding output pspec.)
When a mesh axis name is not mentioned in an output pspec, it represents an un-tiling: when the user writes an output pspec which does not mention one of the mesh axis names, they promise that the output blocks are equal along that mesh axis, and so only one block along that axis is used in the output (rather than concatenating all the blocks together along that mesh axis). For example, using the same mesh as above:
x = jnp.array([[3.]])
z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P('i', 'j'))()
print(z) # prints the same as jnp.tile(x, (4, 2))
z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P('i', None))()
print(z) # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))
z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P(None, None))()
print(z) # prints the same as jnp.tile(x, (1, 1)), or just x
Notice that the body function closing over an array value is equivalent to
passing it as an augment with a corresponding input pspec of P(None, None)
. As
another example, following more closely to the other examples above:
@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P('i', None))
def f3(x_block):
return jax.lax.psum(x_block, 'j')
x = np.arange(12 * 12).reshape(12, 12)
y3 = f3(x)
print(y3.shape) # (12,6)
Notice that the result has a second axis size of 6, half the size of the input’s
second axis. In this case, the un-tile expressed by not mentioning the mesh axis
name 'j'
in the output pspec was safe because of the collective psum
, which
ensures each output block is equal along the corresponding mesh axis. Here are
two more examples where we vary which mesh axes are mentioned in the output
pspec:
@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f4(x_block):
return jax.lax.psum(x_block, 'i')
x = np.arange(12 * 12).reshape(12, 12)
y4 = f4(x)
print(y4.shape) # (3,12)
@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P(None, None))
def f5(x_block):
return jax.lax.psum(x_block, ('i', 'j'))
y5 = f5(x)
print(y5.shape) # (3,6)
On the physical side, not mentioning a mesh axis name in an output pspec
assembles an Array
from the output device buffers with replicated layout along
that mesh axis.
There is no runtime check that the output blocks are actually equal along a mesh axis to be un-tiled along, or equivalently that the corresponding physical buffers have equal values and thus can be interpreted as a replicated layout for a single logical array. But we can provide a static check mechanism which raises an error on all potentially-incorrect programs.
Because the out_specs
can mention mesh axis names zero or one times, and
because they can be mentioned in any order, we can say that in addition to the
jnp.concatenate
built into its output, shard_map
also has both an untile and
a block transpose built into its output.
Physical data movement is not possible on outputs, no matter the output pspec.
Instead, out_specs
just encodes how to assemble the block outputs into
Array
s, or physically how to interpret the buffers across devices as the
physical layout of a single logical Array
.
API Specification#
from jax.sharding import Mesh
Specs = PyTree[PartitionSpec]
def shard_map(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs
) -> Callable:
...
where:
mesh
encodes devices arranged in an array and with associated axis names, just like it does forxmap
and forsharding.NamedSharding
;in_specs
andout_specs
arePartitionSpec
s which can affinely mention axis names frommesh
(not separate logical names as inxmap
) to express slicing/unconcatenation and concatenation of inputs and outputs, respectively (not unstacking and stacking likepmap
andxmap
do), with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively;the shapes of the arguments passed to
f
have the same ranks as the arguments passed toshard_map
-of-f
(unlikepmap
andxmap
where the ranks are reduced), and the shape of an argument tof
is computed from the shapeshape
of the corresponding argument toshard_map
-of-f
and the correspondingPartitionSpec
spec as roughlytuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))
;the body of
f
can apply collectives using names frommesh
.
shmap
is eager by default, meaning that we dispatch computations
primitive-by-primitive, so that the user can employ Python control flow on fully
replicated values and interactive pdb
debugging to print any values. To stage
out and end-to-end compile a shmap
ped function, just put a jit
around it. A
consequence is that shmap
doesn’t have its own dispatch and compilation paths
like xmap
and pmap
currently do; it’s just the jit
path.
When it’s staged out by e.g. an enclosing jit
, the lowering of shmap
to
StableHLO is trivial: it just involves switching into ‘manual SPMD mode’ on the
inputs, and switching back on the outputs. (We don’t currently plan to support
partially-manual-partially-automatic modes.)
The interaction with effects is the same as with pmap
.
The interaction with autodiff is also just like pmap
(rather than attempting
the new semantics that xmap
did, corresponding to having unmapped
intermediates and hence grad
’s reduce_axes
as well as making psum
transpose to pbroadcast
rather than psum
). But it thus inherits an unsolved
problem from pmap
: in some cases, instead of transposing psum
to psum
, and
thus performing a backward pass psum
corresponding to the forward pass psum
,
it can be beneficial to move the backward pass psum
to elsewhere in the
backward pass, exploiting linearity. Many advanced pmap
users addressed this
challenge by using custom_vjp
to implement psum_idrev
and id_psumrev
functions, but since it’s easy to accidentally leave those imbalanced, that
technique is a foot-cannon. We have some ideas on how to provide this
functionality in a safer way.
When should you use shmap
and when should you use pjit
?#
One philosophy is: it is almost always simpler to write a program in jit==pjit
— but if a given part of the program is less optimized by the compiler than it
could be, drop into shmap
!
A realistic transformer example#
In fact, we can implement a simple version of the “collective
matmul” algorithm
recently introduced in XLA to overlap communication and computation using shmap
and 30 lines of Python. The basic idea of the algorithm can be grasped with a
simple example.
Suppose we want to compute C = A @ B
where A
is sharded by a 1D mesh on the
0-th dimension while B
and C
are replicated.
M, K, N = 4096, 2048, 1024
A = jnp.arange(np.prod((M, K))).reshape((M, K))
B = jnp.arange(np.prod((K, N))).reshape((K, N))
mesh = Mesh(np.array(jax.devices()), axis_names=('i'))
A_x = jax.device_put(A, NamedSharding(mesh, P('i', None)))
@jax.jit
def f(lhs, rhs):
return lhs @ rhs
C = f(A_x, B)
A profile shows the blocking all-gather across 8 devices before the matmul can
start. This is suboptimal because A
is sharded on a non-contracting dimension,
and each shard of A
can be matmul’ed with B
independently and this chunked
computation can be overlapped with fetching of the next shard of A
from
another device.

This overlap can be implemented using shmap
and explicit collectives.
def collective_matmul_allgather_lhs_non_contracting(lhs, rhs):
# lhs is the looped operand; rhs is the local operand
axis_size = jax.lax.psum(1, axis_name='i')
axis_index = jax.lax.axis_index(axis_name='i')
chunk_size = lhs.shape[0]
def f(i, carrys):
accum, lhs = carrys
# matmul for a chunk
update = lhs @ rhs
# circular shift to the left
lhs = jax.lax.ppermute(
lhs,
axis_name='i',
perm=[(j, (j - 1) % axis_size) for j in range(axis_size)]
)
# device 0 computes chunks 0, 1, ...
# device 1 computes chunks 1, 2, ...
update_index = (((axis_index + i) % axis_size) * chunk_size, 0)
accum = jax.lax.dynamic_update_slice(accum, update, update_index)
return accum, lhs
accum = jnp.zeros((lhs.shape[0] * axis_size, rhs.shape[1]), dtype=lhs.dtype)
# fori_loop cause a crash: hlo_sharding.cc:817 Check failed: !IsManual()
# accum, lhs = jax.lax.fori_loop(0, axis_size - 1, f, (accum, lhs))
for i in range(0, axis_size - 1):
accum, lhs = f(i, (accum, lhs))
# compute the last chunk, without the ppermute
update = lhs @ rhs
i = axis_size - 1
update_index = (((axis_index + i) % axis_size) * chunk_size, 0)
accum = jax.lax.dynamic_update_slice(accum, update, update_index)
return accum
jit_sharded_f = jax.jit(shard_map(
collective_matmul_allgather_lhs_non_contracting, mesh,
in_specs=(P('i', None), P()), out_specs=P()))
C = jit_sharded_f(A_x, B)
A profile shows that the all-gather is gone, and replaced with overlapped matmul with async collective permute. This profile matches very closely with the collective matmul paper result.

This collective matmul technique can be used to speed up feedforward blocks in
transformer layers. This typically consists of two matrix multiplications
followed by a ReduceScatter
(to resolve partial sums from a parallelized
matrix multiplication) and preceded by an AllGather
(to collect the sharded
dimensions along some axes and allow partial sum computation). Together, the
ReduceScatter
from one layer and the AllGather
for the next amount to an
AllReduce
.
In a typical profile, the two matmuls will be followed by an AllReduce
, and
they will not be overlapped. Collective matmul can be used to achieve the
overlap, but is difficult to trigger, has a minimum slice size and does not yet
cover all topologies, tensor shapes and variants of collective matmul (i.e
latency and throughput optimized variants). In a recent
paper, we found a ~40% gain in many
circumstances from manually implementing collective matmul variants in shmap
style.
But it isn’t always more complex! We expect this to be a much more natural way to think about pipelined computation, and plan to do some demos of that soon!
Another realistic example#
Here’s how shmap
might look in a transformer layer pass with a 2D weight
gathered pattern (paper, Sec 3.2.3 on p. 5):
def matmul_2D_wg_manual(xnorm, q_wi, layer):
'''Calls a custom manual implementation of matmul_reducescatter'''
# [batch, maxlen, embed.X] @ [heads.YZ, embed.X, q_wi_per_head]
# -> (matmul)
# -> [batch, maxlen, heads.YZ, q_wi_per_head]{x unreduced}
# -> (reducescatter over x into X heads, B batches)
# -> [batch, maxlen, heads.YZX, q_wi_per_head]
with jax.named_scope('q_wi'):
xnorm = intermediate_dtype(xnorm)
q_wi = matmul_reducescatter(
'bte,hed->bthd',
xnorm,
params.q_wi,
scatter_dimension=(0, 2),
axis_name='i',
layer=layer)
return q_wi
import partitioning.logical_to_physical as l2phys
def pjit_transformer_layer(
hparams: HParams, layer: int, params: weights.Layer, sin: jnp.ndarray,
cos: jnp.ndarray, kv_caches: Sequence[attention.KVCache],
x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Forward pass through a single layer, returning output, K, V."""
def my_layer(t, axis=0):
"""Gets the parameters corresponding to a given layer."""
return lax.dynamic_index_in_dim(t, layer, axis=axis, keepdims=False)
# 2D: [batch.Z, time, embed.XY]
x = _with_sharding_constraint(
x, ('residual_batch', 'residual_time', 'residual_embed'))
xnorm = _layernorm(x)
# 2D: [batch, time, embed.X]
xnorm = _with_sharding_constraint(
xnorm, ('post_norm_batch', 'time', 'post_norm_embed'))
# jump into manual mode where you want to optimise
if manual:
q_wi = shard_map(matmul_2D_wg_manual, mesh
in_specs=(l2phys('post_norm_batch', 'time', 'post_norm_embed'),
l2phys('layers', 'heads', 'embed', 'q_wi_per_head')),
out_specs=l2phys('post_norm_batch', 'time', 'heads', 'q_wi_per_head'))(xnorm, q_wi, layer)
else:
q_wi = jnp.einsum('bte,hed->bthd', xnorm, my_layer(params.q_wi))
# 2D: [batch, time, heads.YZX, None]
q_wi = _with_sharding_constraint(q_wi,
('post_norm_batch', 'time', 'heads', 'qkv'))
q = q_wi[:, :, :, :hparams.qkv]
q = _rope(sin, cos, q)
# unlike in https://arxiv.org/pdf/2002.05202.pdf, PaLM implements
# swiGLU with full d_ff dimension, rather than 2/3 scaled
wi0 = q_wi[:, :, :, hparams.qkv:hparams.qkv + (hparams.ff // hparams.heads)]
wi1 = q_wi[:, :, :, hparams.qkv + (hparams.ff // hparams.heads):]
kv = jnp.einsum('bte,ezd->btzd', xnorm, my_layer(params.kv))
k = kv[:, :, 0, :hparams.qkv]
v = kv[:, :, 0, hparams.qkv:]
k = _rope(sin, cos, k)
y_att = jnp.bfloat16(attention.attend(q, k, v, kv_caches, layer))
y_mlp = special2.swish2(wi0) * wi1
# 2D: [batch, time, heads.YZX, None]
y_mlp = _with_sharding_constraint(y_mlp,
('post_norm_batch', 'time', 'heads', None))
y_fused = jnp.concatenate([y_att, y_mlp], axis=-1)
# do the second half of the mlp and the self-attn projection in parallel
y_out = jnp.einsum('bthd,hde->bte', y_fused, my_layer(params.o_wo))
# 2D: [batch.Z, time, embed.XY]
y_out = _with_sharding_constraint(
y_out, ('residual_batch', 'residual_time', 'residual_embed'))
z = y_out + x
z = _with_sharding_constraint(
z, ('residual_batch', 'residual_time', 'residual_embed'))
return z, k, v
In the profile below, both the first and second matmul were replaced by manually lowered versions, where the compute (fusions) are fully overlapped with the communication (ppermute)! One fun hint that we are using a latency optimised variant is that the ppmerute pixels are jittered — because there are two overlapping ppermutes using opposite ICI axes at the same time!
All-to-all is much harder to overlap, so was left on the table.

Why don’t pmap
or xmap
already solve this?#
pmap
was our first multi-device parallelism API. It follows the
per-device-code-and-explicit-collectives school. But it had major shortcomings
which make it unsuitable for today’s programs:
Mapping multiple axes required nested
pmap
s. Not only are nestedpmap
s cumbersome to write, but also they make it difficult to control (or even predict) the device placement of data and computation, and difficult to preserve data sharding (see the next two bullets). Today’s programs require multiple axes of parallelism.Controlling device placement was impossible. Especially with multiple axes of parallelism, programmers need to control how those axes are aligned with hardware resources and their communication topologies. But (nested)
pmap
doesn’t offer control over how mapped program instances are placed on hardware; there’s just an automatic device order which the user can’t control. (Gopher’s use ofaxis_index_groups
and a single un-nestedpmap
was essentially a hack to get around this by flattening multiple axes of parallelism down to one.)jit
/pjit
composability.jit
-of-pmap
is a performance footgun, as is nestingpmap
s, as is e.g.scan
-of-pmap
, because sharding is not preserved when returning from an innerpmap
. To preserve sharding we would need pattern matching on jaxprs to ensure we’re working with perfectly nested pmaps, or a pmap just inside ajit
. Moreover,pjit
was no help here becausepmap
targets XLA replicas whilepjit
targets the XLA SPMD Partitioner, and composing those two is hard.jax.Array
compatibility (and hencepjit
compatibility). Because the sharding ofpmap
outputs can’t be expressed asShardings
/OpShardings
, due topmap
’s stacking rather than concatenative semantics, the output of apmap
computation can’t currently be passed to apjit
computation without bouncing to host (or dispatching a reshaping computation).Multi-controller semantics (and hence
pjit
compatibility). Multi-controllerpmap
concatenates values across controllers, which works well but differs from single-controllerpmap
’s stacking semantics. More practically, it precludes the use of non-fully-addressablejax.Array
inputs and outputs as we use with multi-controllerpjit
.Eager mode. We didn’t make
pmap
eager-first, and though we eventually (after 4+ years!) added eager operation withdisable_jit()
, the fact thatpmap
hasjit
fused into it means it has its own compilation and dispatch path (actually two dispatch paths: in Python for handlingTracer
s, and in C++ for performance on rawArray
inputs!), a heavy implementation burden.Reshapes needed in the caller. A typical use case with
pmap
on 8 devices might look like starting with a batch axis of size 128, reshaping it to split into two axes with sizes (8, 16), and thenpmap
ping over the first. These reshapes are awkward and the compiler often interprets them as copies instead of view — increasing memory and time usage.
These shortcomings aren’t so bad when only doing batch data parallelism. But
when more parallelism is involved, pmap
just can’t cut it!
xmap
paved the way as a next-gen evolution of pmap
and solved (almost) all these
issues. shmap
follows in xmap
’s footsteps and solves these problems in
essentially the same ways; indeed, shmap
is like a specialized subset of xmap
(what some call the “hard xmap
” subset), with a few tweaks.
For the initial prototype, we chose to implement shmap
as a separate primitive
from xmap
, because limiting the set of features it supports makes it easier to
focus on the core functionality. For example, shmap
doesn’t allow unmapped
intermediates, making it easier not to worry about the interactions between
named axes and autodiff. Furthermore, not having to reason about interactions of
all pairs of features makes it easier to add capabilities beyond what’s
implemented in xmap
today, such as support for eager mode.
Both shmap
and xmap
share significant portions of the lowering code. We
could consider merging both in the future, or even focusing solely on shmap
,
depending on how the usage will evolve.
jax.extend
: a module for extensions#
@froystig, @sharadmv, @jakevdp, @yashk2810
May 2023
import jax.extend as jex
Several projects depend on JAX’s codebase internals, often to use its core machinery (e.g. to write a transformation over its IR) or to extend it (e.g. to define new primitives). Two challenges for these dependencies are (a) that our internals aren’t all solidly designed for external use, and (b) that circumventing JAX’s public API is unsupported. In other words, our internals are often used like a library, but are neither structured nor updated like one.
This proposal considers introducing a jax.extend
module that
defines a library view of some of JAX’s internal components. We would
treat this as a second-tier API, still guaranteeing essentially no
compatibility policy, but hopefully making
it easier to spot changes when they happen.
The audience for jax.extend
includes JAX-adjacent Python libraries
like Oryx,
jax-triton, and many others,
as well as projects experimenting with function transformations,
autodiff systems, compiler frontends for numerical programming, etc.
This note gives an overview of how jax.extend
might look, now and
eventually. It doesn’t lay things out in great detail, instead
proposing that we begin iteratively developing
the module.
Note that jax.extend
differs from jax.experimental
, which is a
staging ground for new features and ideas in progress. Typically, work
in jax.experimental
eventually makes into another JAX module or is
removed altogether.
No compatibility policy#
To keep development overhead low, jax.extend
would not follow the
public
API compatibility
policy. It would promise no deprecation windows nor backwards
compatibility between releases. Every release may break existing
callers without simple recourse (e.g. without a flag reintroducing
prior behavior). We would rely on the
changelog
to call out such changes.
Callers of jax.extend
that need to upgrade their code regularly
alongside JAX releases might find it useful to pin JAX versions as an
intermediate step between releases. This is a common habit among
projects that rely on JAX’s internals today. The difference is that it
would now come with the help of changelog announcements and better
intentions regarding library design and naming.
Iterative development#
Having no compatibility policy makes it easier to get started on
implementation: on day one, we can move a handful of symbols over from
internal packages such as jax._src
and today’s jax.core
and
jax.interpreters
. Then we can iterate to improve things from there.
Possible module overview#
We can imagine that eventually jax.extend
would include the
following modules:
core
– primitives, the Jaxpr IR, etc.interpreters
– core transformations (e.g. autodiff, batching) and lowerings.random
– random bit generation, key splitting and folding, key arrays.sharding
– extra functionality around distributed arrays.
We might also have other symbols in the module at first, such as
jex.api_util
, as we work to remove or replace them. Others will be
decided in time. For instance, jex.lib
could offer an entry point to
jaxlib (and would do so in the immediate term), but it’s not clear
whether we want to keep it for long.
Some preliminary thoughts on what each of these might comprise follow.
jax.extend.core
#
This should enable callers at least to define new JAX primitives and
to process the Jaxpr IR (the output of
jax.make_jaxpr(...)
). Supporting this might involve providing:
Access to existing core system primitives, such as today’s
jax._src.lax.add_p
.Access to IR types, such as the current
jax._src.core.ShapedArray
.Functions for checking and pretty-printing jaxprs.
Functions for building jaxprs explicitly, rather than by staging Python functions via
jax.make_jaxpr
(or not!).
At initialization, this module will contain many more symbols than
what’s needed to define primitives and rules, including various names
used in setting up
“final-style transformations”,
such as the current jax._src.core.Trace
and Tracer
classes. We can
revisit whether jex.core
should also support final-style extensions
alongside initial style approaches, and whether it can do so by a more
narrow API than exposing Trace
and Tracer
entirely.
Oryx might help guide these decisions.
We can also consider relocating make_jaxpr
itself to jex.core
.
jax.extend.interpreters
#
This module would provide a means of registering various transformation rules for primitives—defining their behavior under AD, batching, lowering, etc.
It would initially reflect jax._src.interpreters
in providing
the modules ad
, batching
, partial_eval
(for staging Python to
Jaxpr, and for linearization in AD), mlir
, pxla
, and xla
. The
first three might be replaceable by a single primitive extension API
in jex.core
. The latter three, used for lowering, could be
simplified into one module, maybe.
Today, to write transformation rules, e.g. for AD and batching,
callers may need symbols relating to tracers, e.g. JVPTracer
and
BatchTracer
. This may be avoidable later on, and allow us to remove
tracer types from jex
.
This module plus jex.core
ought to suffice for replicating today’s
custom primitive tutorials (e.g.
ours
and
dfm’s).
For instance, defining a primitive and its behavior under jax.jit
would be possible as follows (in the immediate term):
from jax.extend import core # Previously: from jax import core
from jax.extend.interpreters import mlir # ... and similarly
mul_add_p = core.Primitive('mul_add')
mul_add_p.def_impl(lambda x, y, z: x * y + z)
@mul_add_p.def_abstract_eval
def mul_add_abstract(x_sa, y_sa, z_sa):
return core.ShapedArray(x_sa.shape, x_sa.dtype)
def mul_add_mlir(ctx, xc, yc, zc):
add = mlir.hlo.AddOp
mul = mlir.hlo.MulOp
return add(mul(xc, yc), zc).results
mlir.register_lowering(mul_add_p, mul_add_mlir)
import jax
print(mul_add_p.bind(2, 3, 4)) # -> 10
print(jax.jit(mul_add_p.bind)(2, 3, 4)) # -> Array(10, dtype=int32)
jax.extend.random
#
This module could expose our mechanism for defining new RNG
implementations, and functions for working with PRNG key internals
(see issue #9263),
such as the current jax._src.prng.random_wrap
and
random_unwrap
.
It could also expose the keyed hash functions that underlie the
built-in RNG implementations, such as jax._src.prng.threefry_2x32
.
jax.extend.sharding
#
This module could expose low-level utilities for sharding distributed arrays.
We have only one item in mind for now. The XLA compiler’s
array sharding format is more expressive than those provided by
JAX. We could
provide this as jex.sharding.XlaOpShardingProto
, corresponding to
today’s jax._src.lib.xla_client.OpSharding
internally.
Efficient transposition of replication-inducing collectives#
mattjj@, dougalm@
August 2023
Motivation#
We have an efficiency problem in automatically transposing shmap
s containing
certain collectives. The issue arises with psum
and all_gather
, specifically
when the output of the collective is returned to the caller as an unmapped
output. And it’s not an edge case: for example, it arises when applying grad
to a shmap
-based batch data parallel neural network loss function which uses
psum
to compute the total loss.
We’ve known about this problem for some time. An analogous issue exists with
pmap
, though it’s been worked around by keeping grad
inside pmap
rather than
outside. A primary goal of the incomplete avals-with-names work was to address a
version of this transpose efficiency problem. This doc draws on those ideas,
while extending and revising them to handle more cases and to be much easier to
land. Indeed the solution proposed here only affects the shmap
implementation.
The rest of the system need not be changed (yet).
The main purpose of this doc is to define this transpose efficiency problem and propose an easy-to-land solution.
This doc is not about:
logical axis names on arrays (the only axis names here are just like in
shmap
and OGpmap
);changing autodiff semantics (all the numbers and (non)errors are staying the same, we’re just making things more efficient);
allowing user code to reflect on any new information, or really affecting user code at all.
Problem: efficient transpose of psum
or all_gather
depends on whether cotangents are invariant across devices#
Consider this semi-realistic example, meant to resemble a replicated-parameter batch data parallel loss function:
devices = jax.devices() # 8 devices
@partial(shmap, mesh=Mesh(devices, ('batch',)),
in_specs=(P(None, None), P('batch', None)),
out_specs=P())
def loss(params, batch):
inputs, targets = batch
predictions = predict(params, inputs)
local_loss = jnp.mean(jnp.sum(predictions - targets, -1))
global_loss = lax.pmean(local_loss, 'batch'))
return global_loss
Notice the out_specs=P()
, which indicates an unmapped output. If you’re not
familiar with the notion of unmapped outputs, see the appendix at the bottom of
this document.
Most of the details in the loss
example aren’t important. All that matters for
our purposes is that we’re applying psum
(or rather pmean = lambda x, name: psum(x, name) / psum(1, name)
) at the end. So a distilled version looks like
this:
# Example 1: shmap involving psum and unmapped output with inefficient transpose
f1 = shmap(lambda x: psum(g(x), 'i'),
in_specs=P('i'), out_specs=P())
We even simplified notation by suppressing the mesh
argument. In the examples to
follow it can be inferred from context.
What does the transpose look like? Writing t
to mean function transpose, we
could evaluate t(f1)(ybar)
for any ybar
efficiently by applying the function
¿f1_transpose?
below:
# An efficient "transpose" of Example 1 (but don't transpose this again!)
¿f1_transpose? = shmap(t(g), in_specs=P(), out_specs=P('i'))
But that’s not the transpose we currently get as t(f1).
Instead, the current recipe for transposition is roughly that we switch
in_specs
and out_specs
, do some division rescaling for unmapped outputs, and
transpose the body. Because psum
is its own transpose (as an all-reduce sum),
we end up producing this transpose:
# The transpose we currently get for Example 1 (which is fine to transpose again)
t(f1) = shmap(lambda ybar: t(g)(psum(ybar / 8, 'i')),
in_specs=P(), out_specs=P('i'))
This transpose gets the numbers right, but it’s wasteful. We know statically
from the transpose’s in_specs=P()
that ybar
has the same value for each function
instance, i.e. that its value is device-invariant for devices along the mesh
axis named i
, and yet we apply a psum
to it! That uses expensive communication
just to multiply the value on each device by 8. (Here 8 refers to the size of
axis i. The division by 8 comes from the original function’s out_specs=P()
; it
and the trivial psum
basically cancel each other out.)
What are we doing wrong? We’re not exploiting the fact that cotangents ybar
corresponding to f1
’s unmapped outputs are guaranteed to be device-invariant;
instead, we’re defensively psum
ming them as if they weren’t because psum
’s
transpose can’t be sure given the local information it has. Sometimes the psum
is necessary, as in transposing f2
with respect to its first argument:
# Example 2: shmap involving psum and *mapped* output with efficient transpose
f2 = shmap(lambda x, y: psum(g(x), 'i') * y,
in_specs=(P('i'), P('i')), out_specs=P('i'))
# The transpose we currently get for Example 2 is efficient
t(f2, 0) = shmap(lambda y, zbar: t(g)(psum(zbar * y, 'i')),
in_specs=(P('i'), P('i')), out_specs=P('i'))
Intuitively, if our transpose machinery could tell the difference between Example 1 and Example 2, we could do better by avoiding the psum and division where possible.
The inefficient examples can be even smaller. Consider transposing this cursed identity function:
# Example 3: cursed identity
cursed_identity = shmap(lambda x: x, P(), P())
# Currently we get these inefficient transposes
t(cursed_identity) = shmap(lambda x: psum(x / 8, 'i'), P(), P())
t(t(cursed_identity)) = shmap(lambda x: psum(psum(x / 8 / 8, 'i'), 'i')), P(), P())
...
It keeps getting bigger the more we transpose. How embarrassing!
And psum
isn’t the only culprit. Something analogous holds true for
all_gather
:
# Example 4: all_gather to an unmapped output
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())
# Currently we get this inefficient transpose
t(f4) = shmap(lambda ybar: psum_scatter(ybar / 8, 'i'), P(), P('i'))
This program is a bit artificial. Why do an all_gather
and feed the result into
an unmapped output, rather than skipping the all_gather
in the body and just
using out_specs=P('i')
to collect the results? But even though it’s cooked-up,
this example nevertheless exhibits a transpose which unnecessarily performs
communication (we could have just performed a non-communicating slice),
analogous to Example 1 for psum
.
Also analogously to the psum
examples, the defensive psum_scatter
is
necessary in some cases:
# Example 5: all_gather to a mapped output
f5 = shmap(lambda x, y: all_gather(x, 'i') * y,
in_specs=(P('i'), P('i')), out_specs=P('i'))
# Currently we get this efficient transpose
t(f5, 0) = shmap(lambda y, zbar: psum_scatter(zbar * y, 'i'),
in_specs=(P('i'), P('i')), out_specs=P('i'))
So how do we avoid these inefficient transposes?
Solutions#
Here are two solution ideas. They aren’t mutually exclusive. But (spoilers) the second one is better, and it’s all we need.
Partial solution “P-sum”: build the ability to express a psum
into out_specs
#
This solution is a bit of a strawperson because it would offer only an awkward way to write programs. And it wouldn’t even fix everything! But it’s worth considering, if only to motivate a more complete solution.
Example 4 above is artificial because we could have just used out_specs
instead
of an all_gather
in the body:
# Example 4 again
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())
# Why didn't we just write it like this?
f4_better = shmap(lambda x: x, P('i'), P('i'))
The f4_better
version doesn’t have any transposition problems, since the
transpose problems arise from collectives in the body.
Analogously, we could fix Example 1 by extending out_specs
so that they can
express summing:
# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
in_specs=P('i'), out_specs=P())
# What if we could write an output sum like this?
f1_better = shmap(g, in_specs=P('i'), out_specs=P(sum='i')) # sum='i' means sum over that axis
# Then it could transpose like this:
t(f1_better) = shmap(t(g), in_specs=P(), out_specs=P('i'))
t(t(f1_better)) = shmap(t(t(g)), in_specs=P('i'), P(sum='i'))
So offering psum
s built into out_specs
fixes the transpose problem of
Example 1. But it doesn’t fully fix the cursed identity transpose in Example 3:
# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())
# How it would transpose with the P-sum partial solution:
t(cursed_identity) = shmap(lambda x: x / 8, P(), P(sum='i'))
t(t(cursed_identity)) = shmap(lambda x: x / 8, P(), P(sum='i'))
It’s an improvement since the program doesn’t continue to get bigger as we keep transposing, but we’re still doing wasteful communication.
Full solution: statically track device-varying vs device-invariant intermediates, plus new primitives#
This solution has two components:
track when values are guaranteed to be device-invariant vs device-varying over particular mesh axes, and
decompose
psum
into a two-step process, introducing a newpbroadcast
primitive, and introduce new primitives forall_gather
and its transposes.
Morally, the tracking of device-invariant vs device-varying information is a type-level consideration. But for the expedience of our first implementation, we don’t need to literally add the information to abstract values or jaxpr types. Before we get to implementation, we’ll first introduce the idea using types.
Also to follow is a discussion of making the user API convenient and backward compatible. But to first introduce the idea, we’ll ignore convenience and instead write code that is as explicit as possible.
Tracking device invariance in avals (a.k.a. avals-with-names, revived)#
We can sometimes tell from static information alone that the values of some
intermediate variables in the body of a shmap
are guaranteed to be invariant
along a mesh axis, in the sense that the function instances (and their
corresponding devices) along the mesh axis must all be computing with the same
value. We’ll call such values device-invariant. For values that are not
device-invariant, we’ll say they’re device-varying, though really we mean
potentially device-varying from the point of view of the type system.
To encode device variance in types, we’ll extend the syntax of types for arrays.
We’ll write things like x:f32[3,4]{i}
to indicate that x
is (potentially)
device-varying along mesh axis i
(and device-invariant over any other mesh
axes of the shmap
). More generally, we’ll say the grammar for array type
syntax is something like
shaped_array ::= <dtype>[<int_literal>, ...]<device_variance_type>
device_variance_type ::= {<axis_name>, ...}
We’ll also update the typing rules to handle device variance types:
for first-order primitives other than collectives
for multi-arity primitives, the operand device variance types must be equal where shapes must be equal, e.g.
mul x:f32[s1]{r1} y:f32[s2][r2]
requiresr1 == r2
in addition tos1 == s2
the output device variance type must be the same as the operand(s)
for higher-order primitives
we just instantiate any type variables including the device variance type (and checking types for equality checks their device variance types are equal)
(when performing type inference, e.g. for branches of a
cond
, we take the union of the sets of axis names in device variance types)
for first-order collectives
a collective can either accept a device-varying or device-invariant input (along a mesh axis corresponding to its axis name parameter); it’s an error to pass a device-invariant operand to a collective which accepts device-varying operands and vice-versa
a collective can either produce a device-varying or device-invariant output
see the table below As a side benefit, whatever logic implements this type checking can subsume
shmap
’s “static analysis” check for whether ashmap
body function is compatible with any unmappedout_specs
.
Here’s a table summarizing the device variance typing for collective primitives:
Name |
Device variance type |
Example |
Lowers to HLO |
Transpose |
---|---|---|---|---|
|
|
|
|
|
|
|
|
no-op (no communication) |
|
|
|
|
|
|
|
|
|
|
n/a |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
There are some surprising things here!
We introduced several new primitives, including
pbroadcast
, which interestingly lowers to a no-opall_gather_invariant
, which lowers to the same thing asall_gather
but has a different device variance type (essentiallyall_gather
has apbroadcast
fused into it, whereasall_gather_invariant
does not)pscatter
which is the dual (transpose) ofall_gather_invariant
all_gather has a device-varying result
Intuitively, the reason to introduce pbroadcast
(other than to make the typing
rules work) is so that psum
can transpose to a physical no-op. The reason we
need all_gather
to have a device-varying result is so that we can transpose it
to psum_scatter
; if we instead left it with a device-invariant result, we
might need a downstream pbroadcast
, and that composition would transpose to an
inefficient psum
followed by slicing / pscatter
. So instead we have a
pbroadcast
“fused into” the all_gather
, thus allowing for an efficient
transpose into psum_scatter
. We provide all_gather_invariant
and its
transpose pscatter
mainly for completeness; it’s unlikely users will need it
(it corresponds to the situation in Example 4, which is easy to write
differently using out_specs
).
Interestingly, the psum
and pbroadcast
transpose pair correspond to the
psum_idrev
and id_psumrev
that users introduced while training LLMs with
pmap
.
How this system solves the inefficient transpose examples#
Consider again the simplified motivating example:
# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
in_specs=P('i'), out_specs=P())
# Example 1 with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1(x: f32[3,4]{i}):
w:f32[]{i} = g(x)
y:f32[]{} = psum(w, 'i')
return y
With these new rules, the transpose is:
# Example 1 transpose using device variance types (go ahead and transpose this again!)
t(f1) = shmap(lambda ybar: t(g)(pbroadcast(ybar, 'i')),
in_specs=P(), out_specs=P('i'))
# Example 1 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1_transpose(ybar: f32[]):
wbar:f32[]{i} = pbroadcast(ybar, 'i')
xbar:f32[3,4]{i} = transpose(g)(wbar)
return xbar
where evaluating the pbroadcast
application involves no communication or FLOPs
at all; it’s a no-op. Notice that if we keep transposing the body does not grow
in size; indeed t(t(f1)) == f1
. Efficiency achieved!
And we wouldn’t mess up the other examples either, so long as we pbroadcast
to
make the types check where needed:
# Example 2 rewritten with explicit pbroadcast
f2 = shmap(lambda x, y: pbroadcast(psum(g(x), 'i'), 'i') * y,
in_specs=(P('i'), P('i')), out_specs=P('i'))
# Example 2 transpose using device variance types
t(f2, 0) = shmap(lambda y, zbar: t(g)(pbroadcast(psum(zbar * y, 'i'), 'i')),
in_specs=(P('i'), P('i')), out_specs=P('i'))
# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())
# Notice here the body is `f32[...] -> f32[...]`, i.e. no device varying type.
# Example 3 transpose using device variance types
t(cursed_identity) = shmap(lambda x: x, P(), P())
t(t(cursed_identity)) = shmap(lambda x: x, P(), P())
Intuitively, in Example 1 we now only have “half the original psum”, whereas in Example 2 we get both “halves”. For Example 3 we never need any operations in the body at all.
For the all_gather
examples, Example 4 would need to use
all_reduce_invariant
to have an efficient transpose (though it’d be better to
instead use out_specs
instead of the collective in the body):
# Example 4 rewritten with explicit all_reduce_invariant
f4 = shmap(lambda x: all_gather_invariant(x, 'i'), P('i'), P())
# Example 4 with intermediate device variance types annotated
@partial(shmap, P('i'), P())
def f4(x:f32[1]{i}):
y:f32[8]{} = all_gather_invariant(x, 'i')
return y
# Example 4 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P(), out_specs=P('i'))
def f4_transpose(ybar:f32[8]):
xbar:f32[1]{i} = pscatter(ybar, 'i')
return xbar
For Example 5, using the device-varying all_gather
works as we’d want:
# Example 5 with intermediate device variance types annotated
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5(x:f32[1]{i}, y:f32[8]{i}):
z:f32[8]{i} = all_gather(x, 'i')
w:f32[8]{i} = z * y
return w
# Transpose with respect to first argument
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5_transpose(y:f32[8]{i}, wbar:f32[8]{i}):
zbar:f32[8]{i} = wbar * y
xbar:f32[1]{i} = psum_scatter(zbar, 'i')
return xbar
How to make the API convenient for users (and backward compatible)#
But what user wants to write pbroadcast
s? And what developer wants to break
lots of existing user code involving psum
s which are not fed into unmapped
outputs? Not me!
Instead we can automatically insert the pbroadcast
s. It’s a bit analogous to how
we do automatic rank promotion at the jax.numpy
layer, inserting broadcasts to
avoid rank mismatch errors in binary operators. But it’s much simpler since we
don’t need to contend with shape tuples. The typical rule is: whenever we see a
multi-arity operation where the operands disagree in their device variance
types, take the union of operands’ device variance types’ axis name sets and
insert pbroadcast
s to lift each operand to the resulting device variance type.
Automatically inserting pbroadcast
s just before they’re needed may mean we apply
the same pbroadcast
to the same operand multiple times, creating common
subexpressions. When we transpose, those could turn into a sum-of-psum
s rather
than a psum
-of-sum. We’ll rely on the compiler to clean that up as appropriate.
If it’s a problem then we could add some simple memoization to the
pbroadcast
-insertion pass.
The user API for all_gather
will mean all_gather_p
by default (not
all_gather_invariant_p
), covering the common case and meaning no pbroadcast
s
must be inserted.
We can provide an option on shmap
to disable this automatic insertion of
pbroadcast
s, in which case it’ll be up to the user to ensure type-correctness.
This explicit option may be appealing to some who want to be explicit about
where the psum
s occur in the backward pass.
How to implement the solution#
The key to making the implementation lightweight is that we aren’t going to add these types to avals or jaxprs. At least, not at first. That can be expensive because it requires updating the rest of JAX, e.g. all consumers of avals and jaxprs may need to handle the new types. We’re not falling for that again!
Instead we’re going to keep these extended types as metadata internal to
shmap
, just like the current “replication checking for out_specs
” machinery
is internal to shmap
. Indeed this solution amounts to a relatively small
extension to that existing machinery: it was already tracking the same
information; now we’re just adding the pbroadcast
s.
We have at least two options for where to perform the pbroadcast
insertion:
just before transposition, in the transpose rule, where we have a jaxpr of the computation to be transposed;
in every
shmap
body, whether eagerly executed or staged out, like the current “replication checking forout_specs
” machinery. The former may end up being easier since we only have to handle the jaxpr case, and only linear primitives. But we’ll start by trying the latter so the implementation here is a strict revision/extension to the existing replication-checking logic.
Appendix: defining and motivating maps with unmapped inputs and outputs#
For concreteness, we’ll mostly focus on shmap
, though these same ideas apply
to e.g. pmap
and probably xmap
.
An argument/input is unmapped along a mesh axis when the corresponding entry
of in_specs
doesn’t mention that mesh axis’s name. Logically it means that
each function instance along that mesh axis gets the same value for the
argument. To the caller, each operand is sliced according to the mesh axes over
which the operand is mapped, whereas there is no slicing for mesh axes over
which the operand is unmapped.
An output is unmapped along a mesh axis when the corresponding entry of
out_specs
doesn’t mention that mesh axis’s name. Logically it means each
function instance along that mesh axis must return the same value. To the
caller, each result of the shmap
is formed by concatenating the return values
of every function instance along which the outputs are mapped, whereas for mesh
axes over which the output is unmapped only one copy of the value is used.
See the shmap
JEP for examples
of unmapped inputs and outputs. For comparison, in vmap
unmapped
inputs/outputs are indicated by using in_axes
/ out_axes
of None
(rather
than an int
).
Here are reasons we like unmapped inputs and outputs for shmap
:
Same expressiveness as
pjit
. Anythingpjit
can do, theshmap
escape hatch should be able to do too. Or else we’d have a lacking escape hatch! If we didn’t have unmapped outputs inshmap
then we couldn’t express the same batch-parallel loss function computations aspjit
.Closed-over inputs. Closed-over inputs essentially correspond to unmapped inputs, and…
Closure under transposition. Once we have unmapped inputs, it’s natural to be able to transpose to unmapped outputs.
So unmapped outputs are both canonical and useful!
JEP 18137: Scope of JAX NumPy & SciPy Wrappers#
Jake VanderPlas
October 2023
Until now, the intended scope of jax.numpy
and jax.scipy
has been relatively
ill-defined. This document proposes a well-defined scope for these packages to better guide
and evaluate future contributions, and to motivate the removal of some out-of-scope code.
Background#
From the beginning, JAX has aimed to provide a NumPy-like API for executing code in XLA,
and a big part of the project’s development has been building out the jax.numpy
and
jax.scipy
namespaces as JAX-based implementations of NumPy and SciPy APIs. There has always
been an implicit understanding that some parts of numpy
and scipy
are out-of-scope
for JAX, but this scope has not been well defined. This can lead to confusion and frustration
for contributors, because there’s no clear answer to whether potential jax.numpy
and
jax.scipy
contributions will be accepted into JAX.
Why Limit the Scope?#
To avoid leaving this unsaid, we should be explicit: it is a fact that any code included in a project like JAX incurs a small but nonzero ongoing maintenance burden for the developers. The success of a project over time directly relates to the ability of maintainers to continue this maintenance for the sum of all the project’s parts: documenting functionality, responding to questions, fixing bugs, etc. For long-term success and sustainability of any software tool, it’s vital that maintainers carefully weigh whether any particular contribution will be a net positive for the project given its goals and resources.
Evaluation Rubric#
This document proposes a rubric of six axes along which the scope of any particular numpy
or scipy
API can be judged for inclusion into JAX. An API which is strong along all axes
is an excellent candidate for inclusion in the JAX package; a strong weakness along any of
the six axes is a good argument against inclusion in JAX.
Axis 1: XLA alignment#
The first axis we consider is the degree to which the proposed API aligns with native XLA
operations. For example, jax.numpy.exp()
is a function that more-or-less directly mirrors
jax.lax.exp
. A large number of functions in numpy
, scipy.special
, numpy.linalg
,
scipy.linalg
, and others meet this criteria: such functions pass the XLA-alignment check
when considering their inclusion into JAX.
On the other end, there are functions like numpy.unique()
, which do not directly correspond
to any XLA operation, and in some cases are fundamentally incompatible with JAX’s current
computational model, which requires statically-shaped arrays (e.g. unique
returns a
value-dependent dynamic array shape). Such functions do not pass the XLA alignment check
when considering their inclusion into JAX.
We also consider as part of this axis the need for pure function semantics. For example,
numpy.random
is built on an implicitly-updated state-based RNG, which is fundamentally
incompatible with JAX’s computational model built on XLA.
Axis 2: Array API Alignment#
The second axis we consider focuses on the
Python Array API Standard: this is in some
senses a community-driven outline of which array operations are central to array-oriented
programming across a wide range of user communities. If an API in numpy
or scipy
is
listed within the Array API standard, it is a strong signal that JAX should include it.
Using the example from above, the Array API standard includes several variants of
numpy.unique()
(unique_all
, unique_counts
, unique_inverse
, unique_values
) which
suggests that, despite the function not being precisely aligned with XLA, it is important
enough to the Python user community that JAX should perhaps implement it.
Axis 3: Existence of Downstream Implementations#
For functionality that does not align with Axis 1 or 2, an important consideration for
inclusion into JAX is whether there exist well-supported downstream packages that supply
the functionality in question. A good example of this is scipy.optimize
: while JAX does
include a minimal set of wrappers of scipy.optimize
functionality, a much more complete
treatment exists in the JAXopt package, which is actively
maintained by JAX collaborators. In cases like this, we should lean toward pointing users
and contributors to these specialized packages rather than re-implementing such APIs in
JAX itself.
Axis 4: Complexity & Robustness of Implementation#
For functionality that does not align with XLA, one consideration is the degree of
complexity of the proposed implementation. This aligns to some degree with Axis 1,
but nevertheless is important to call out. A number of functions have been contributed
to JAX which have relatively complex implementations which are difficult to validate
and introduce outsized maintenance burdens; an example is jax.scipy.special.bessel_jn()
:
as of the writing of this JEP, its current implementation is a non-straightforward
iterative approximation that has
convergence issues in some domains,
and proposed fixes introduce further
complexity. Had we more carefully weighed the complexity and robustness of the
implementation when accepting the contribution, we may have chosen not to accept this
contribution to the package.
Axis 5: Functional vs. Object-Oriented APIs#
JAX works best with functional APIs rather than object-oriented APIs. Object-oriented APIs can often hide impure semantics, making them often difficult to implement well. NumPy and SciPy generally stick to functional APIs, but sometimes provide object-oriented convenience wrappers.
Examples of this are numpy.polynomial.Polynomial
, which wraps lower-level operations
like numpy.polyadd()
, numpy.polydiv()
, etc. In general, when there are both functional
and object-oriented APIs available, JAX should avoid providing wrappers for the
object-oriented APIs and instead provide wrappers for the functional APIs.
In cases where only the object-oriented APIs exist, JAX should avoid providing wrappers unless the case is strong along other axes.
Axis 6: General “Importance” to JAX Users & Stakeholders#
The decision to include a NumPy/SciPy API in JAX should also take into account the importance of the algorithm to the general user community. It is admittedly difficult to quantify who is a “stakeholder” and how this importance should be measured; but we include this to make clear that any decision about what to include in JAX’s NumPy and SciPy wrappers will involve some amount of discretion that cannot be easily quantified.
For existing APIs, searches for usage in github may be useful in establishing importance
or lack thereof; as an example, we might return to jax.scipy.special.bessel_jn()
discussed above: a search shows that this function has only a
handful of uses
on github, probably partly to do with the previously mentioned accuracy issues.
Evaluation: what’s in scope?#
In this section, we’ll attempt to evaluate the NumPy and SciPy APIs, including some examples from the current JAX API, in light of the above rubric. This will not be a comprehensive listing of all existing functions and classes, but rather a more general discussion by submodule and topic, with relevant examples.
NumPy APIs#
✅ numpy
namespace#
We consider the functions in the main numpy
namespace to be essentially all in-scope
for JAX, due to its general alignment with XLA (Axis 1) and the Python Array API
(Axis 2), as well as its general importance to the JAX user community (Axis 6).
Some functions are perhaps borderline (functions like numpy.intersect1d()
,
np.setdiff1d()
, np.union1d()
arguably fail parts of the rubric) but for
simplicity we declare that all array functions in the main numpy namespace are in-scope
for JAX.
✅ numpy.linalg
& numpy.fft
#
The numpy.linalg
and numpy.fft
submodules contain many functions that
broadly align with functionality provided by XLA. Others have complicated device-specific
lowerings, but represent a case where importance to stakeholders (Axis 6) outweighs complexity.
For this reason, we deem both of these submodules in-scope for JAX.
❌ numpy.random
#
numpy.random
is out-of-scope for JAX, because state-based RNGs are fundamentally
incompatible with JAX’s computation model. We instead focus on jax.random
,
which offers similar functionality using a counter-based PRNG.
❌ numpy.ma
& numpy.polynomial
#
The numpy.ma
and numpy.polynomial
submodules are mostly concerned with
providing object-oriented interfaces to computations that can be expressed via other
functional means (Axis 5); for this reason, we deem them out-of-scope for JAX.
❌ numpy.testing
#
NumPy’s testing functionality only really makes sense for host-side computation,
and so we don’t include any wrappers for it in JAX. That said, JAX arrays are
compatible with numpy.testing
, and JAX makes frequent use of it throughout
the JAX test suite.
SciPy APIs#
SciPy has no functions in the top-level namespace, but includes a number of submodules. We consider each below, leaving out modules which have been deprecated.
❌ scipy.cluster
#
The scipy.cluster
module includes tools for hierarchical clustering, k-means,
and related algorithms. These are weak along several axes, and would be better
served by a downstream package. One function already exists within JAX
(jax.scipy.cluster.vq.vq()
) but has
no obvious usage
on github: this suggests that clustering is not broadly important to JAX users.
Recommendation: deprecate and remove jax.scipy.cluster.vq()
.
❌ scipy.constants
#
The scipy.constants
module includes mathematical and physical constants.
These constants can be used directly with JAX, and so there is no reason to
re-implement this in JAX.
❌ scipy.datasets
#
The scipy.datasets
module includes tools to fetch and load datasets.
These fetched datasets can be used directly with JAX, and so there is no
reason to re-implement this in JAX.
✅ scipy.fft
#
The scipy.fft
module contains functions that broadly align with functionality
provided by XLA, and fare well along other axes as well. For this reason,
we deem them in-scope for JAX.
❌ scipy.integrate
#
The scipy.integrate
module contains functions for numerical integration. The
more sophisticated of these (quad
, dblquad
, ode
) are out-of-scope for JAX by
axes 1 & 4, since they tend to be loopy algorithms based on dynamic numbers of
evaluations. jax.experimental.ode.odeint()
is related, but rather limited and not
under any active development.
JAX does currently include jax.scipy.integrate.trapezoid()
, but this is only because
numpy.trapz()
was recently deprecated in favor of this. For any particular input,
its implementation could be replaced with one line of jax.numpy
expressions, so
it’s not a particularly useful API to provide.
Based on Axes 1, 2, 4, and 6, scipy.integrate
should be considered out-of-scope for JAX.
Recommendation: remove jax.scipy.integrate.trapezoid()
, which was added in JAX 0.4.14.
❌ scipy.interpolate
#
The scipy.interpolate
module provides both low-level and object-oriented routines
for interpolating in one or more dimensions. These APIs rate poorly along a number
of the axes above: they are class-based rather than low-level, and none but the
simplest methods can be expressed efficiently in terms of XLA operations.
JAX does currently have wrappers for scipy.interpolate.RegularGridInterpolator
.
Were we considering this contribution today, we would probably reject it by the
above criteria. But this code has been fairly stable so there is not much downside
to continuing to maintain it.
Going forward, we should consider other members of scipy.interpolate
to be
out-of-scope for JAX.
❌ scipy.io
#
The scipy.io
submodule has to do with file input/output. There is no reason
to re-implement this in JAX.
✅ scipy.linalg
#
The scipy.linalg
submodule contains functions that broadly align with functionality
provided by XLA, and fast linear algebra is broadly important to the JAX user community.
For this reason, we deem it in-scope for JAX.
❌ scipy.ndimage
#
The scipy.ndimage
submodule contains a set of tools for working on image data. Many
of these overlap with tools in scipy.signal
(e.g. convolutions and filtering). JAX
currently provides one scipy.ndimage
API, in jax.scipy.ndimage.map_coordinates()
.
Additionally, JAX provides some image-related tools in the jax.image
module. The
deepmind ecosystem includes dm-pix, a
more full-featured set of tools for image manipulation in JAX. Given all these factors,
I’d suggest that scipy.ndimage
should be considered out-of-scope for JAX core; we can
point interested users and contributors to dm-pix. We can consider moving map_coordinates
to dm-pix
or to another appropriate package.
❌ scipy.odr
#
The scipy.odr
module provides an object-oriented wrapper around ODRPACK
for
performing orthogonal distance regressions. It is not clear that this could be cleanly
expressed using existing JAX primitives, and so we deem it out of scope for JAX itself.
❌ scipy.optimize
#
The scipy.optimize
module provides high-level and low-level interfaces for optimization.
Such functionality is important to a lot of JAX users, and very early on JAX created
jax.scipy.optimize
wrappers. However, developers of these routines soon realized that
the scipy.optimize
API was too constraining, and different teams began working on the
JAXopt package and the
Optimistix package, each of which contain
a much more comprehensive and better-tested set of optimization routines in JAX.
Because of these well-supported external packages, we now consider scipy.optimize
to be out-of-scope for JAX.
Recommendation: deprecate jax.scipy.optimize
and/or make it a lightweight wrapper
around an optional JAXopt or Optimistix dependency.
🟡 scipy.signal
#
The scipy.signal
module is mixed: some functions are squarely in-scope for JAX
(e.g. correlate
and convolve
, which are more user-friendly wrappers of
lax.conv_general_dilated
), while many others are squarely out-of-scope (domain-specific
tools with no viable lowering path to XLA). Potential contributions to jax.scipy.signal
will have to be weighed on a case-by-case basis.
🟡 scipy.sparse
#
The scipy.sparse
submodule mainly contains data structures for storing and operating
on sparse matrices and arrays in a variety of formats. Additionally, scipy.sparse.linalg
contains a number of matrix-free solvers, suitable for use with sparse matrices,
dense matrices, and linear operators.
The scipy.sparse
array and matrix data structures are out-of-scope for JAX, because
they do not align with JAX’s computational model (e.g. many operations depend on
dynamically-sized buffers). JAX has developed the jax.experimental.sparse
module
as an alternative set of data structures that are more in-line with JAX’s computational
constraints. For these reasons, we consider the data structures in scipy.sparse
to
be out-of-scope for JAX.
On the other hand, scipy.sparse.linalg
has proven to be an interesting area, and
jax.scipy.sparse.linalg
includes the bicgstab
, cg
, and gmres
solvers. These
are useful to the JAX user community (Axis 6) but aside from this do not fare well
along other axes. They would be very suitable for moving into a downstream library;
one potential option may be Lineax, which features
a number of linear solvers built on JAX.
Recommendation: explore moving sparse solvers into Lineax, and otherwise treat `scipy.sparse`` as out-of-scope for JAX.
❌ scipy.spatial
#
The scipy.spatial
module contains mainly object-oriented interfaces to spatial/distance
computations and nearest neighbor searches. It is mostly out-of-scope for JAX
The scipy.spatial.transform
submodule provides tools for manipulating three-dimensional
spatial rotations. It is a relatively complicated object-oriented interface, and could
perhaps be better served by a downstream project. JAX currently contains partial
implementations of Rotation
and
Slerp
within jax.scipy.spatial.transform
;
these are object-oriented wrappers of otherwise basic
functions, which introduce a very large API surface and have very few users. It is our
judgment that they are out-of-scope for JAX itself, with users better-served by a
hypothetical downstream project.
The scipy.spatial.distance
submodule contains a useful collection of distance metrics,
and it might be tempting to provide JAX wrappers for these. That said, with jit and vmap
it would be straightforward for a user to define efficient versions of most of these from
scratch if needed, so adding them to JAX is not particularly beneficial.
Recommendation: consider deprecating and removing the Rotation
and Slerp
APIs, and
consider scipy.spatial
as a whole out-of-scope for future contributions.
✅ scipy.special
#
The scipy.special
module includes implementations of a number of more specialized
functions. In many cases, these functions are squarely in scope: for example, functions
like gammaln
, betainc
, digamma
, and many others correspond directly to available
XLA primitives, and are clearly in-scope by Axis 1 and others.
Other functions require more complicated implementations; one example mentioned above
is bessel_jn
. Despite not aligning on Axes 1 and 2, these functions tend to be very
strong along Axis 6: scipy.special
provides fundamental functions necessary for
computation in a variety of domains, so even functions with complicated implementations
should lean toward in-scope, so long as the implementations are well-designed and robust.
There are a few existing function wrappers that we should take a closer look at; for example:
jax.scipy.special.lpmn()
: this generates legendre polynomials via a complicated fori_loop, in a way that does not match the scipy API (e.g. forscipy
,z
must be a scalar, while for JAX,z
must be a 1D array). The function has few discoverable uses making it a weak candidate along Axes 1, 2, 4, and 6.jax.scipy.special.lpmn_values()
: this has similar weaknesses tolmpn
above.jax.scipy.special.sph_harm()
: this is built on lpmn, and similarly has an API that diverges from the correspondingscipy
function.jax.scipy.special.bessel_jn()
: as discussed under Axis 4 above, this has weaknesses in terms of robustness of implementation and little usage. We might consider replacing it with a new, more robust implementation (e.g. #17038).
Recommendation: refactor and improve robustness & test coverage for bessel_jn
. Consider deprecating lpmn
, lpmn_values
, and sph_harm
if they cannot be modified to more closely match the scipy
APIs.
✅ scipy.stats
#
The scipy.stats
module contains a wide range of statistical functions, including discrete
and continuous distributions, summary statistics, and hypothesis testing. JAX currently wraps
a number of these in jax.scipy.stats
, primarily including 20 or so statistical distributions,
along with a few other functions (mode
, rankdata
, gaussian_kde
). In general these are
well-aligned with JAX: distributions usually are expressible in terms of efficient XLA operations,
and the APIs are clean and functional.
We don’t currently have any wrappers for hypothesis testing functions, probably because these are less useful to the primary user-base of JAX.
Regarding distributions, in some cases, tensorflow_probability
provides similar functionality,
and in the future we might consider whether to deprecate the scipy.stats distributions in favor
of that implementation.
Recommendation: going forward, we should treat statistical distributions and summary statistics as in-scope, and consider hypothesis tests and related functionality generally out-of-scope.
Several early JEPs were converted in hindsight from other documentation, issues, and pull requests, so they might not exactly reflect the process outlined above.
Investigating a regression#
So you updated JAX and you hit a speed regression? You have a little bit of time and are ready to investigate this? Let’s first make a JAX issue. But if you can pinpoint the commit that triggered the regression, it will really help us.
This document explains how we identified the commit that caused a 15% performance regression.
Steps#
This can be done easily if the reproducer is quick enough. This is a brute force method and not a bisection, but if the reproducer is quick enough, it works well. This makes sure that you always test XLA and JAX commits that are compatible. It also limits XLA recompilation.
Here is a suggested investigation strategy:
You can do a brute force test of nightly containers between the 2 releases.
Hourly recompilation while keeping XLA and JAX in sync.
Final verification: maybe a manual check of a few commits (or a git bisect).
Nightly investigation.#
This can be done by using JAX-Toolbox nightly containers.
Some days, bugs prevent the container from being built, or there are temporary regressions. Just discard those days.
So you should end up with a specific day or a few days where the regression happens.
To automate this, you need 2 python scripts:
test_runner.sh: will start the containers and the test.
test.sh: will install missing dependencies and run the test
Here are real example scripts used for the issue: https://github.com/google/jax/issues/17686
test_runner.sh:
for m in 7 8 9; do
for d in `seq -w 1 30`; do
docker run -v $PWD:/dir --gpus=all ghcr.io/nvidia/jax:nightly-2023-0${m}-${d} /bin/bash /dir/test.sh &> OUT-0${m}-${d}
done
Done
test.sh:
pip install jmp pyvista numpy matplotlib Rtree trimesh jmp termcolor orbax
git clone https://github.com/Autodesk/XLB
cd XLB
export PYTHONPATH=.
export CUDA_VISIBLE_DEVICES=0 # only 1 GPU is needed
python3 examples/performance/MLUPS3d.py 256 200
Then you can grep each output to see when the regression happens:
grep MLUPS OUT*
. Here are the results we got:
OUT-07-06:MLUPS: 587.9240990200157
OUT-07-07:MLUPS: 587.8907972116419
OUT-07-08:MLUPS: 587.3186499464459
OUT-07-09:MLUPS: 587.3130127722537
OUT-07-10:MLUPS: 587.8526619429658
OUT-07-17:MLUPS: 570.1631097290182
OUT-07-18:MLUPS: 570.2819775617064
OUT-07-19:MLUPS: 570.1672213357352
OUT-07-20:MLUPS: 587.437153685251
OUT-07-21:MLUPS: 587.6702557143142
OUT-07-25:MLUPS: 577.3063618431178
OUT-07-26:MLUPS: 577.2362978080912
OUT-07-27:MLUPS: 577.2101850145785
OUT-07-28:MLUPS: 577.0716349809895
OUT-07-29:MLUPS: 577.4223280707176
OUT-07-30:MLUPS: 577.2255967221336
OUT-08-01:MLUPS: 577.277685388252
OUT-08-02:MLUPS: 577.0137874289354
OUT-08-03:MLUPS: 577.1333281553946
OUT-08-04:MLUPS: 577.305012020407
OUT-08-05:MLUPS: 577.2143988866626
OUT-08-06:MLUPS: 577.2409145495443
OUT-08-07:MLUPS: 577.2602819927345
OUT-08-08:MLUPS: 577.2823738293221
OUT-08-09:MLUPS: 577.3453199728248
OUT-08-11:MLUPS: 577.3161423260563
OUT-08-12:MLUPS: 577.1697775786824
OUT-08-13:MLUPS: 577.3049883393633
OUT-08-14:MLUPS: 576.9051978525331
OUT-08-15:MLUPS: 577.5331743016213
OUT-08-16:MLUPS: 577.5117505070573
OUT-08-18:MLUPS: 577.5930698237612
OUT-08-19:MLUPS: 577.3539885757353
OUT-08-20:MLUPS: 577.4190113959127
OUT-08-21:MLUPS: 577.300394253605
OUT-08-22:MLUPS: 577.4263792037783
OUT-08-23:MLUPS: 577.4087536357031
OUT-08-24:MLUPS: 577.1094728438082
OUT-08-25: File "/XLB/examples/performance/MLUPS3d.py", line 5, in <module>
OUT-08-26:MLUPS: 537.0164618489928
OUT-08-27:MLUPS: 536.9545448661609
OUT-08-28:MLUPS: 536.2887650464874
OUT-08-29:MLUPS: 536.7178471720636
OUT-08-30:MLUPS: 536.6978912984252
OUT-09-01:MLUPS: 536.7030899164106
OUT-09-04:MLUPS: 536.5339818238837
OUT-09-05:MLUPS: 536.6507808565617
OUT-09-06:MLUPS: 536.7144494518315
OUT-09-08:MLUPS: 536.7376612408998
OUT-09-09:MLUPS: 536.7798324141778
OUT-09-10:MLUPS: 536.726157440174
OUT-09-11:MLUPS: 536.7446210750584
OUT-09-12:MLUPS: 536.6707332269023
OUT-09-13:MLUPS: 536.6777936517823
OUT-09-14:MLUPS: 536.7581523280307
OUT-09-15:MLUPS: 536.6156273667873
OUT-09-16:MLUPS: 536.7320935035265
OUT-09-17:MLUPS: 536.7104991444398
OUT-09-18:MLUPS: 536.7492269469092
OUT-09-19:MLUPS: 536.6760131792959
OUT-09-20:MLUPS: 536.7361260076634
This found that 8-24 was good, but 8-26 was bad. On 8-25 there was another issue that prevented from getting results. So we need to investigate hourly between 8-24 and 8-26. There was a smaller slowdown earlier, lets ignore it for this example. It would be only another hourly investigation between those dates.
Hourly investigation.#
This does a checkout of JAX and XLA at each hour between the 2 dates, rebuilds everything and runs the test. The scripts are structured differently. We start the working container and keep it. Then inside it, we only trigger incremental XLA builds except for the first build. So it is much faster after the first iteration.
test_runner2.sh:
# Execute this script inside the container:
# docker run -v $PWD:/dir --gpus=all ghcr.io/nvidia/jax:nightly-2023-08-24 /bin/bash
cd /opt/xla-source
git remote update
cd /opt/jax-source
git remote update
pip install jmp pyvista numpy matplotlib Rtree trimesh jmp termcolor orbax
cd /tmp
git clone https://github.com/Autodesk/XLB
cd XLB
for d in `seq -w 24 26`; do
for h in `seq -w 0 24`; do
echo $m $d $h
/bin/bash /dir/test2.sh Aug $d 2023 $h:00:00 &> OUT-08-${d}-$h
done
done
test2.sh:
echo "param: $@"
cd /opt/xla-source
git checkout `git rev-list -1 --before="$*" origin/main`
git show -q
cd /opt/jax-source
git checkout `git rev-list -1 --before="$*" origin/main`
git show -q
rm /opt/jax-source/dist/jax*.whl
build-jax.sh # The script is in the nightly container
export PYTHONPATH=.
export CUDA_VISIBLE_DEVICES=0 # only 1 GPU is needed
python3 examples/performance/MLUPS3d.py 256 200
Now, you can execute the grep command on the new output files to see which hours the issue appeared between.
Final verification#
With this, you need to check the JAX and XLA history between those hours. Maybe there are a few commits to test. You can use git bisect if you want to be fancy.
Can this be improved?#
Yes! If it was a crash regression, being able to do a bisect would be useful. But it would be more complicated. If someone want to contribute such instructions, please submit a PR ;)
For speed regressions, a bisect can hide some information. We wouldn’t see as easily that there were two regressions here.
Building on JAX#
A great way to learn advanced JAX usage is to see how other libraries are using JAX, both how they integrate the library into their API, what functionality it adds mathematically, and how it’s used for computational speedup in other libraries.
Below are examples of how JAX’s features can be used to define accelerated computation across numerous domains and software packages.
Gradient Computation#
Easy gradient calculation is a key feature of JAX. In the JaxOpt library value and grad is directly utilized for users in multiple optimization algorithms in its source code.
Similarly the same Dynamax Optax pairing mentioned above is an example of gradients enabling estimation methods that were challenging historically Maximum Likelihood Expectation using Optax.
Computational Speedup on a Single Core across Multiple Devices#
Models defined in JAX can then be compiled to enable single computation speedup through JIT compiling. The same compiled code can then be sent to a CPU device, to a GPU or TPU device for additional speedup, typically with no additional changes needed. This allows for a smooth workflow from development into production. In Dynamax the computationally expensive portion of a Linear State Space Model solver has been jitted. A more complex example comes from PyTensor which compiles a JAX function dynamically and then jits the constructed function.
Single and Multi Computer Speedup Using Parallelization#
Another benefit of JAX is the simplicity of parallelizing computation using
pmap
and vmap
function calls or decorators.
In Dynamax state space models are parallelized with a VMAP decorator
a practical example of this use case being multi object tracking.
Incorporating JAX code into your, or your users, workflows#
JAX is quite composable and can be used in multiple ways. JAX can be used with a standalone pattern, where the user defines all the calculations themselves. However other patterns, such as using libraries built on jax that provide specific functionality. These can be libraries that define specific types of models, such as Neural Networks or State Space models or others, or provide specific functionality such as optimization. Here are more specific examples of each pattern.
Direct Usage#
Jax can be directly imported and utilized to build models “from scratch” as shown across this website, for example in JAX Tutorials or Neural Network with JAX. This may be the best option if you are unable to find prebuilt code for your particular challenge, or if you’re looking to reduce the number of dependencies in your codebase.
Composable Domain Specific Libraries with JAX exposed#
Another common approach are packages that provide prebuilt functionality, whether it be model definition, or computation of some type. Combinations of these packages can then be mixed and matched for a full end to end workflow where a model is defined and its parameters are estimated.
One example is Flax which simplifies the construction of Neural Networks. Flax is then typically paired with Optax where Flax defines the neural network architecture and Optax supplies the optimization & model-fitting capabilities.
Another is Dynamax which allows easy definition of state space models. With Dynamax parameters can be estimated using Maximum Likelihood using Optax or full Bayesian Posterior can be estimating using MCMC from Blackjax
Notes#
This section contains shorter notes on topics relevant to using JAX; see also the longer design discussions in JAX Enhancement Proposals (JEPs).
- Dependencies and version compatibility:
API compatibility outlines JAX’s policies with regard to API compatibility across releases.
Python and NumPy version support policy outlines JAX’s policies with regard to compatibility with Python and NumPy.
- Migrations and deprecations:
jax.Array migration summarizes the changes to the default array type in jax v 0.4.1
- Memory and computation usage:
Asynchronous dispatch describes JAX’s asynchronous dispatch model.
Concurrency describes how JAX interacts with other Python concurrency.
GPU memory allocation describes how JAX interacts with memory allocation on GPU.
- Programmer guardrails:
Rank promotion warning describes how to configure
jax.numpy
to avoid implicit rank promotion.
API compatibility#
JAX is constantly evolving, and we want to be able to make improvements to its APIs. That said, we want to minimize churn for the JAX user community, and we try to make breaking changes rarely.
JAX follows a 3 month deprecation policy. When an incompatible change is made to an API, we will make our best effort to obey the following procedure:
the change will be announced in
CHANGELOG.md
and in the doc string for the deprecated API, and the old API will issue aDeprecationWarning
.three months after the
jax
release that deprecated an API, we may remove the deprecated API at any time. Note that three months is a lower bound, and is intentionally chosen to be faster than that of many more mature projects. In practice, deprecations may take considerably longer, particularly if there are many users of a feature. If a three month deprecation period becomes problematic, please raise this with us.
We reserve the right to change this policy at any time.
What is covered?#
Only public JAX APIs are covered, which includes the following modules:
jax
jax.dlpack
jax.image
jax.lax
jax.nn
jax.numpy
jax.ops
jax.profiler
jax.random
(see details below)jax.scipy
jax.tree_util
jax.test_util
Not everything in these modules is public. Over time, we are working to separate public and private APIs. Public APIs are documented in the JAX documentation. Additionally, our goal is that all non-public APIs should have names prefixed with underscores, although we do not entirely comply with this yet.
What is not covered?#
anything prefixed with an underscore.
jax._src
jax.core
jax.linear_util
jax.lib
jax.prng
jax.interpreters
jax.experimental
jax.example_libraries
jax.extend
(see details)
This list is not exhaustive.
Numerics and randomness#
The exact values of numerical operations are not guaranteed to be
stable across JAX releases. In fact, exact numerics are not
necessarily stable at a given JAX version, across accelerator
platforms, within or without jax.jit
, and more.
For a fixed PRNG key input, the outputs of pseudorandom functions in
jax.random
may vary across JAX versions. The compatibility policy
applies only to the output distribution. For example, the expression
jax.random.gumbel(jax.random.key(72))
may return a different value
across JAX releases, but jax.random.gumbel
will remain a
pseudorandom generator for the Gumbel distribution.
We try to make such changes to pseudorandom values infrequently. When they happen, the changes are announced in the changelog, but do not follow a deprecation cycle. In some situations, JAX might expose a transient configuration flag that reverts the new behavior, to help users diagnose and update affected code. Such flags will last a deprecation window’s amount of time.
Python and NumPy version support policy#
For NumPy and SciPy version support, JAX follows the Python scientific community’s SPEC 0.
For Python version support, we have heard from users that a 36-month support window can be too short, for example due to the delays in propagation of new CPython releases to Linux vendor releases. For this reason JAX supports Python versions for at least nine months longer than SPEC-0 recommends.
This means we support at least:
All minor Python releases in the 45 months prior to each JAX release. For example:
Python 3.9 was released October 2020, and will be supported in new JAX releases at least until July 2024.
Python 3.10 was released October 2021, and will be supported in new JAX releases at least until July 2025.
Python 3.11 was released October 2022, and will be supported in new JAX releases at least until July 2026.
All minor NumPy releases in the 24 months prior to each JAX release. For example:
NumPy 1.22 was released December 2021, and will be supported in new JAX releases at least until December 2023.
NumPy 1.23 was released June 2022, and will be supported in new JAX releases at least until June 2024.
NumPy 1.24 was released December 2022, and will be supported in new JAX releases at least until December 2024.
All minor SciPy releases in the 24 months prior to each JAX release, starting with SciPy version 1.9. For example:
Scipy 1.9 was released July 2022, and will be supported in new JAX releases at least until July 2024.
Scipy 1.10 was released January 2023, and will be supported in new JAX releases at least until January 2025.
Scipy 1.11 was released June 2023, and will be supported in new JAX releases at least until June 2025.
JAX releases may support older versions of Python, NumPy, and SciPy than strictly required by this policy, but support for older versions may be dropped at any time beyond the listed dates.
jax.Array migration#
yashkatariya@
TL;DR#
JAX switched its default array implementation to the new jax.Array
as of version 0.4.1.
This guide explains the reasoning behind this, the impact it might have on your code,
and how to (temporarily) switch back to the old behavior.
What’s going on?#
jax.Array
is a unified array type that subsumes DeviceArray
, ShardedDeviceArray
,
and GlobalDeviceArray
types in JAX. The jax.Array
type helps make parallelism a
core feature of JAX, simplifies and unifies JAX internals, and allows us to
unify jit and pjit. If your code doesn’t mention DeviceArray
vs
ShardedDeviceArray
vs GlobalDeviceArray
, no changes are needed. But code that
depends on details of these separate classes may need to be tweaked to work with
the unified jax.Array
After the migration is complete jax.Array
will be the only type of array in
JAX.
This doc explains how to migrate existing codebases to jax.Array
. For more information on using jax.Array
and JAX parallelism APIs, see the Distributed arrays and automatic parallelization tutorial.
How to enable jax.Array?#
You can enable jax.Array
by:
setting the shell environment variable
JAX_ARRAY
to something true-like (e.g.,1
);setting the boolean flag
jax_array
to something true-like if your code parses flags with absl;using this statement at the top of your main file:
import jax jax.config.update('jax_array', True)
How do I know if jax.Array broke my code?#
The easiest way to tell if jax.Array
is responsible for any problems is to
disable jax.Array
and see if the issues go away.
How can I disable jax.Array for now?#
Through March 15, 2023 it will be possible to disable jax.Array by:
setting the shell environment variable
JAX_ARRAY
to something falsey (e.g.,0
);setting the boolean flag
jax_array
to something falsey if your code parses flags with absl;using this statement at the top of your main file:
import jax jax.config.update('jax_array', False)
Why create jax.Array?#
Currently JAX has three types; DeviceArray
, ShardedDeviceArray
and
GlobalDeviceArray
. jax.Array
merges these three types and cleans up JAX’s
internals while adding new parallelism features.
We also introduce a new Sharding
abstraction that describes how a logical
Array is physically sharded out across one or more devices, such as TPUs or
GPUs. The change also upgrades, simplifies and merges the parallelism features
of pjit
into jit
. Functions decorated with jit
will be able to operate
over sharded arrays without copying data onto a single device.
Features you get with jax.Array
:
C++
pjit
dispatch pathOp-by-op parallelism (even if the array distributed across multiple devices across multiple hosts)
Simpler batch data parallelism with
pjit
/jit
.Ways to create
Sharding
s that are not necessarily consisting of a mesh and partition spec. Can fully utilize the flexibility of OpSharding if you want or any other Sharding that you want.and many more
Example:
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
import numpy as np
x = jnp.arange(8)
# Let's say there are 8 devices in jax.devices()
mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x'))
sharded_x = jax.device_put(x, sharding)
# `matmul_sharded_x` and `sin_sharded_x` are sharded. `jit` is able to operate over a
# sharded array without copying data to a single device.
matmul_sharded_x = sharded_x @ sharded_x.T
sin_sharded_x = jnp.sin(sharded_x)
# Even jnp.copy preserves the sharding on the output.
copy_sharded_x = jnp.copy(sharded_x)
# double_out is also sharded
double_out = jax.jit(lambda x: x * 2)(sharded_x)
What issues can arise when jax.Array is switched on?#
New public type named jax.Array#
All isinstance(..., jnp.DeviceArray)
or isinstance(.., jax.xla.DeviceArray)
and other variants of DeviceArray
should be switched to using isinstance(..., jax.Array)
.
Since jax.Array
can represent DA, SDA and GDA, you can differentiate those 3
types in jax.Array
via:
x.is_fully_addressable and len(x.sharding.device_set) == 1
– this means thatjax.Array
is like a DAx.is_fully_addressable and (len(x.sharding.device_set) > 1
– this means thatjax.Array
is like a SDAnot x.is_fully_addressable
– this means thatjax.Array
is like a GDA and spans across multiple processes
For ShardedDeviceArray
, you can move isinstance(..., pxla.ShardedDeviceArray)
to isinstance(..., jax.Array) and x.is_fully_addressable and len(x.sharding.device_set) > 1
.
In general it is not possible to differentiate a ShardedDeviceArray
on 1
device from any other kind of single-device Array.
GDA’s API name changes#
GDA’s local_shards
and local_data
have been deprecated.
Please use addressable_shards
and addressable_data
which are compatible with
jax.Array
and GDA
.
Creating jax.Array#
All JAX functions will output jax.Array
when the jax_array
flag is True. If
you were using GlobalDeviceArray.from_callback
or make_sharded_device_array
or make_device_array
functions to explicitly create the respective JAX data
types, you will need to switch them to use jax.make_array_from_callback()
or jax.make_array_from_single_device_arrays()
.
For GDA:
GlobalDeviceArray.from_callback(shape, mesh, pspec, callback)
can become
jax.make_array_from_callback(shape, jax.sharding.NamedSharding(mesh, pspec), callback)
in a 1:1 switch.
If you were using the raw GDA constructor to create GDAs, then do this:
GlobalDeviceArray(shape, mesh, pspec, buffers)
can become
jax.make_array_from_single_device_arrays(shape, jax.sharding.NamedSharding(mesh, pspec), buffers)
For SDA:
make_sharded_device_array(aval, sharding_spec, device_buffers, indices)
can
become jax.make_array_from_single_device_arrays(shape, sharding, device_buffers)
.
To decide what the sharding should be, it depends on why you were creating the SDAs:
If it was created to give as an input to pmap
, then sharding can be:
jax.sharding.PmapSharding(devices, sharding_spec)
.
If it was created to give as an input
to pjit
, then sharding can be jax.sharding.NamedSharding(mesh, pspec)
.
Breaking change for pjit after switching to jax.Array for host local inputs#
If you are exclusively using GDA arguments to pjit, you can skip this section! 🎉
With jax.Array
enabled, all inputs to pjit
must be globally shaped. This is
a breaking change from the previous behavior where pjit
would concatenate
process-local arguments into a global value; this concatenation no longer
occurs.
Why are we making this breaking change? Each array now says explicitly how its
local shards fit into a global whole, rather than leaving it implicit. The more
explicit representation also unlocks additional flexibility, for example the use
of non-contiguous meshes with pjit
which can improve efficiency on some TPU
models.
Running multi-process pjit computation and passing host-local inputs when
jax.Array
is enabled can lead to an error similar to this:
Example:
Mesh = {'x': 2, 'y': 2, 'z': 2}
and host local input shape == (4,)
and
pspec = P(('x', 'y', 'z'))
Since pjit
doesn’t lift host local shapes to global shapes with jax.Array
,
you get the following error:
Note: You will only see this error if your host local shape is smaller than the shape of the mesh.
ValueError: One of pjit arguments was given the sharding of
NamedSharding(mesh={'x': 2, 'y': 2, 'chips': 2}, partition_spec=PartitionSpec(('x', 'y', 'chips'),)),
which implies that the global size of its dimension 0 should be divisible by 8,
but it is equal to 4
The error makes sense because you can’t shard dimension 0, 8 ways when the value
on dimension 0
is 4
.
How can you migrate if you still pass host local inputs to pjit
? We are
providing transitional APIs to help you migrate:
Note: You don’t need these utilities if you run your pjitted computation on a single process.
from jax.experimental import multihost_utils
global_inps = multihost_utils.host_local_array_to_global_array(
local_inputs, mesh, in_pspecs)
global_outputs = pjit(f, in_shardings=in_pspecs,
out_shardings=out_pspecs)(global_inps)
local_outs = multihost_utils.global_array_to_host_local_array(
global_outputs, mesh, out_pspecs)
host_local_array_to_global_array
is a type cast that looks at a value with
only local shards and changes its local shape to the shape that pjit
would
have previously assumed if that value was passed before the change.
Passing in fully replicated inputs i.e. same shape on each process with
P(None)
as in_axis_resources
is still supported. In this case you do not
have to use host_local_array_to_global_array
because the shape is already
global.
key = jax.random.PRNGKey(1)
# As you can see, using host_local_array_to_global_array is not required since in_axis_resources says
# that the input is fully replicated via P(None)
pjit(f, in_shardings=None, out_shardings=None)(key)
# Mixing inputs
global_inp = multihost_utils.host_local_array_to_global_array(
local_inp, mesh, P('data'))
global_out = pjit(f, in_shardings=(P(None), P('data')),
out_shardings=...)(key, global_inp)
FROM_GDA and jax.Array#
If you were using FROM_GDA
in in_axis_resources
argument to pjit
, then
with jax.Array
there is no need to pass anything to in_axis_resources
as
jax.Array
will follow computation follows sharding semantics.
For example:
pjit(f, in_shardings=FROM_GDA, out_shardings=...) can be replaced by pjit(f, out_shardings=...)
If you have PartitionSpecs mixed in with FROM_GDA
for inputs like numpy
arrays, etc, then use host_local_array_to_global_array
to convert them to
jax.Array
.
For example:
If you had this:
pjitted_f = pjit(
f, in_shardings=(FROM_GDA, P('x'), FROM_GDA, P(None)),
out_shardings=...)
pjitted_f(gda1, np_array1, gda2, np_array2)
then you can replace it with:
pjitted_f = pjit(f, out_shardings=...)
array2, array3 = multihost_utils.host_local_array_to_global_array(
(np_array1, np_array2), mesh, (P('x'), P(None)))
pjitted_f(array1, array2, array3, array4)
live_buffers replaced with live_arrays#
live_buffers
attribute on jax Device
has been deprecated. Please use jax.live_arrays()
instead which is compatible
with jax.Array
.
Handling of host local inputs to pjit like batch, etc#
If you are passing host local inputs to pjit
in a multi-process
environment, then please use
multihost_utils.host_local_array_to_global_array
to convert the batch to a
global jax.Array
and then pass that to pjit
.
The most common example of such a host local input is a batch of input data.
This will work for any host local input (not just a batch of input data).
from jax.experimental import multihost_utils
batch = multihost_utils.host_local_array_to_global_array(
batch, mesh, batch_partition_spec)
See the pjit section above for more details about this change and more examples.
RecursionError: Recursively calling jit#
This happens when some part of your code has jax.Array
disabled and then you
enable it only for some other part. For example, if you use some third_party
code which has jax.Array
disabled and you get a DeviceArray
from that
library and then you enable jax.Array
in your library and pass that
DeviceArray
to JAX functions, it will lead to a RecursionError.
This error should go away when jax.Array
is enabled by default so that all
libraries return jax.Array
unless they explicitly disable it.
Asynchronous dispatch#
JAX uses asynchronous dispatch to hide Python overheads. Consider the following program:
>>> import numpy as np
>>> import jax.numpy as jnp
>>> from jax import random
>>> x = random.uniform(random.key(0), (1000, 1000))
>>> # Printing the result (i.e. evaluating `repr(result)` or `str(result)`)
>>> # will block until the value is ready.
>>> jnp.dot(x, x) + 3.
Array([[258.01971436, 249.64862061, 257.13372803, ...,
236.67948914, 250.68939209, 241.36853027],
[265.65979004, 256.28912354, 262.18252563, ...,
242.03181458, 256.16757202, 252.44122314],
[262.38916016, 255.72747803, 261.23059082, ...,
240.83563232, 255.41094971, 249.62471008],
...,
[259.15814209, 253.09197998, 257.72174072, ...,
242.23876953, 250.72680664, 247.16642761],
[271.22662354, 261.91204834, 265.33398438, ...,
248.26651001, 262.05389404, 261.33700562],
[257.16134644, 254.7543335, 259.08300781, ..., 241.59848022,
248.62597656, 243.22348022]], dtype=float32)
When an operation such as jnp.dot(x, x)
is executed, JAX does not wait
for the operation to complete before returning control to the Python program.
Instead, JAX returns a jax.Array
value, which is a future,
i.e., a value that will be produced in the future on an accelerator device but
isn’t necessarily available immediately. We can inspect the shape or type of a
jax.Array
without waiting for the computation that produced it to
complete, and we can even pass it to another JAX computation, as we do with the
addition operation here. Only if we actually inspect the value of the array from
the host, for example by printing it or by converting it into a plain old
numpy.ndarray
will JAX force the Python code to wait for the
computation to complete.
Asynchronous dispatch is useful since it allows Python code to “run ahead” of an accelerator device, keeping Python code out of the critical path. Provided the Python code enqueues work on the device faster than it can be executed, and provided that the Python code does not actually need to inspect the output of a computation on the host, then a Python program can enqueue arbitrary amounts of work and avoid having the accelerator wait.
Asynchronous dispatch has a slightly surprising consequence for microbenchmarks.
>>> %time jnp.dot(x, x)
CPU times: user 267 µs, sys: 93 µs, total: 360 µs
Wall time: 269 µs
Array([[255.01972961, 246.64862061, 254.13371277, ...,
233.67948914, 247.68939209, 238.36853027],
[262.65979004, 253.28910828, 259.18252563, ...,
239.03181458, 253.16757202, 249.44122314],
[259.38916016, 252.72747803, 258.23059082, ...,
237.83563232, 252.41094971, 246.62471008],
...,
[256.15814209, 250.09197998, 254.72172546, ...,
239.23876953, 247.72680664, 244.16642761],
[268.22662354, 258.91204834, 262.33398438, ...,
245.26651001, 259.05389404, 258.33700562],
[254.16134644, 251.7543335, 256.08300781, ..., 238.59848022,
245.62597656, 240.22348022]], dtype=float32)
269µs is a surprisingly small time for a 1000x1000 matrix multiplication on CPU!
However it turns out that asynchronous dispatch is misleading us and we are not
timing the execution of the matrix multiplication, only the time to dispatch
the work. To measure the true cost of the operation we must either read the
value on the host (e.g., convert it to a plain old host-side numpy array), or
use the block_until_ready()
method on a
jax.Array
value to wait for the computation that produced it to
complete.
>>> %time np.asarray(jnp.dot(x, x))
CPU times: user 61.1 ms, sys: 0 ns, total: 61.1 ms
Wall time: 8.09 ms
Out[16]:
array([[255.01973, 246.64862, 254.13371, ..., 233.67949, 247.68939,
238.36853],
[262.6598 , 253.28911, 259.18253, ..., 239.03181, 253.16757,
249.44122],
[259.38916, 252.72748, 258.2306 , ..., 237.83563, 252.41095,
246.62471],
...,
[256.15814, 250.09198, 254.72173, ..., 239.23877, 247.7268 ,
244.16643],
[268.22662, 258.91205, 262.33398, ..., 245.26651, 259.0539 ,
258.337 ],
[254.16135, 251.75433, 256.083 , ..., 238.59848, 245.62598,
240.22348]], dtype=float32)
>>> %time jnp.dot(x, x).block_until_ready()
CPU times: user 50.3 ms, sys: 928 µs, total: 51.2 ms
Wall time: 4.92 ms
Array([[255.01972961, 246.64862061, 254.13371277, ...,
233.67948914, 247.68939209, 238.36853027],
[262.65979004, 253.28910828, 259.18252563, ...,
239.03181458, 253.16757202, 249.44122314],
[259.38916016, 252.72747803, 258.23059082, ...,
237.83563232, 252.41094971, 246.62471008],
...,
[256.15814209, 250.09197998, 254.72172546, ...,
239.23876953, 247.72680664, 244.16642761],
[268.22662354, 258.91204834, 262.33398438, ...,
245.26651001, 259.05389404, 258.33700562],
[254.16134644, 251.7543335, 256.08300781, ..., 238.59848022,
245.62597656, 240.22348022]], dtype=float32)
Blocking without transferring the result back to Python is usually faster, and is often the best choice when writing microbenchmarks of computation times.
Concurrency#
JAX has limited support for Python concurrency.
Clients may call JAX APIs (e.g., jit()
or grad()
)
concurrently from separate Python threads.
It is not permitted to manipulate JAX trace values concurrently from multiple
threads. In other words, while it is permissible to call functions that use JAX
tracing (e.g., jit()
) from multiple threads, you must not use
threading to manipulate JAX values inside the implementation of the function
f that is passed to jit()
. The most likely outcome if you do this
is a mysterious error from JAX.
GPU memory allocation#
JAX will preallocate 75% of the total GPU memory when the first JAX operation is run. Preallocating minimizes allocation overhead and memory fragmentation, but can sometimes cause out-of-memory (OOM) errors. If your JAX process fails with OOM, the following environment variables can be used to override the default behavior:
XLA_PYTHON_CLIENT_PREALLOCATE=false
This disables the preallocation behavior. JAX will instead allocate GPU memory as needed, potentially decreasing the overall memory usage. However, this behavior is more prone to GPU memory fragmentation, meaning a JAX program that uses most of the available GPU memory may OOM with preallocation disabled.
XLA_PYTHON_CLIENT_MEM_FRACTION=.XX
If preallocation is enabled, this makes JAX preallocate XX% of the total GPU memory, instead of the default 75%. Lowering the amount preallocated can fix OOMs that occur when the JAX program starts.
XLA_PYTHON_CLIENT_ALLOCATOR=platform
This makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed (note that this is the only configuration that will deallocate GPU memory, instead of reusing it). This is very slow, so is not recommended for general use, but may be useful for running with the minimal possible GPU memory footprint or debugging OOM failures.
Common causes of OOM failures#
- Running multiple JAX processes concurrently.
Either use
XLA_PYTHON_CLIENT_MEM_FRACTION
to give each process an appropriate amount of memory, or setXLA_PYTHON_CLIENT_PREALLOCATE=false
.- Running JAX and GPU TensorFlow concurrently.
TensorFlow also preallocates by default, so this is similar to running multiple JAX processes concurrently.
One solution is to use CPU-only TensorFlow (e.g. if you’re only doing data loading with TF). You can prevent TensorFlow from using the GPU with the command
tf.config.experimental.set_visible_devices([], "GPU")
Alternatively, use
XLA_PYTHON_CLIENT_MEM_FRACTION
orXLA_PYTHON_CLIENT_PREALLOCATE
. There are also similar options to configure TensorFlow’s GPU memory allocation (gpu_memory_fraction and allow_growth in TF1, which should be set in atf.ConfigProto
passed totf.Session
. See Using GPUs: Limiting GPU memory growth for TF2).- Running JAX on the display GPU.
Use
XLA_PYTHON_CLIENT_MEM_FRACTION
orXLA_PYTHON_CLIENT_PREALLOCATE
.
Rank promotion warning#
NumPy broadcasting rules allow the automatic promotion of arguments from one rank (number of array axes) to another. This behavior can be convenient when intended but can also lead to surprising bugs where a silent rank promotion masks an underlying shape error.
Here’s an example of rank promotion:
>>> import numpy as np
>>> x = np.arange(12).reshape(4, 3)
>>> y = np.array([0, 1, 0])
>>> x + y
array([[ 0, 2, 2],
[ 3, 5, 5],
[ 6, 8, 8],
[ 9, 11, 11]])
To avoid potential surprises, jax.numpy
is configurable so that
expressions requiring rank promotion can lead to a warning, error, or can be
allowed just like regular NumPy. The configuration option is named
jax_numpy_rank_promotion
and it can take on string values
allow
, warn
, and raise
. The default setting is
allow
, which allows rank promotion without warning or error.
The raise
setting raises an error on rank promotion, and warn
raises a warning on the first occurrence of rank promotion.
Rank promotion can be enabled or disabled locally with the jax.numpy_rank_promotion()
context manager:
with jax.numpy_rank_promotion("warn"):
z = x + y
This configuration can also be set globally in several ways.
One is by using jax.config
in your code:
import jax
jax.config.update("jax_numpy_rank_promotion", "warn")
You can also set the option using the environment variable
JAX_NUMPY_RANK_PROMOTION
, for example as
JAX_NUMPY_RANK_PROMOTION='warn'
. Finally, when using absl-py
the option can be set with a command-line flag.
Public API: jax package#
Subpackages#
jax.numpy
module#
Implements the NumPy API, using the primitives in jax.lax
.
While JAX tries to follow the NumPy API as closely as possible, sometimes JAX cannot follow NumPy exactly.
Notably, since JAX arrays are immutable, NumPy APIs that mutate arrays in-place cannot be implemented in JAX. However, often JAX is able to provide an alternative API that is purely functional. For example, instead of in-place array updates (
x[i] = y
), JAX provides an alternative pure indexed update functionx.at[i].set(y)
(seendarray.at
).Relatedly, some NumPy functions often return views of arrays when possible (examples are
transpose()
andreshape()
). JAX versions of such functions will return copies instead, although such are often optimized away by XLA when sequences of operations are compiled usingjax.jit()
.NumPy is very aggressive at promoting values to
float64
type. JAX sometimes is less aggressive about type promotion (See Type promotion semantics).Some NumPy routines have data-dependent output shapes (examples include
unique()
andnonzero()
). Because the XLA compiler requires array shapes to be known at compile time, such operations are not compatible with JIT. For this reason, JAX adds an optionalsize
argument to such functions which may be specified statically in order to use them with JIT.
Nearly all applicable NumPy functions are implemented in the jax.numpy
namespace; they are listed below.
Helper property for index update functionality. |
|
|
Calculate the absolute value element-wise. |
|
Calculate the absolute value element-wise. |
|
Trigonometric inverse cosine, element-wise. |
|
Inverse hyperbolic cosine, element-wise. |
|
Add arguments element-wise. |
|
Test whether all array elements along a given axis evaluate to True. |
|
Returns True if two arrays are element-wise equal within a tolerance. |
|
Return the maximum of an array or maximum along an axis. |
|
Return the minimum of an array or minimum along an axis. |
|
Return the angle of the complex argument. |
|
Test whether any array element along a given axis evaluates to True. |
|
Append values to the end of an array. |
|
Apply a function to 1-D slices along the given axis. |
|
Apply a function repeatedly over multiple axes. |
|
Return evenly spaced values within a given interval. |
|
Trigonometric inverse cosine, element-wise. |
|
Inverse hyperbolic cosine, element-wise. |
|
Inverse sine, element-wise. |
|
Inverse hyperbolic sine element-wise. |
|
Trigonometric inverse tangent, element-wise. |
|
Element-wise arc tangent of |
|
Inverse hyperbolic tangent element-wise. |
|
Returns the indices of the maximum values along an axis. |
|
Returns the indices of the minimum values along an axis. |
|
Perform an indirect partition along the given axis using the |
|
Returns the indices that would sort an array. |
|
Find the indices of nonzero array elements |
|
Round an array to the given number of decimals. |
|
Create an array. |
|
True if two arrays have the same shape and elements, False otherwise. |
|
Returns True if input arrays are shape consistent and all elements equal. |
|
Return the string representation of an array. |
|
Split an array into multiple sub-arrays. |
|
Return a string representation of the data in an array. |
|
Convert the input to an array. |
|
Inverse sine, element-wise. |
|
Inverse hyperbolic sine element-wise. |
|
This is implemented via |
|
Trigonometric inverse tangent, element-wise. |
|
Inverse hyperbolic tangent element-wise. |
|
Element-wise arc tangent of |
Convert inputs to arrays with at least one dimension. |
|
View inputs as arrays with at least two dimensions. |
|
View inputs as arrays with at least three dimensions. |
|
|
Compute the weighted average along the specified axis. |
|
Return the Bartlett window. |
|
Count number of occurrences of each value in array of non-negative ints. |
|
Compute the bit-wise AND of two arrays element-wise. |
|
|
|
Compute bit-wise inversion, or bit-wise NOT, element-wise. |
|
Shift the bits of an integer to the left. |
|
Compute bit-wise inversion, or bit-wise NOT, element-wise. |
|
Compute the bit-wise OR of two arrays element-wise. |
|
Shift the bits of an integer to the right. |
|
Compute the bit-wise XOR of two arrays element-wise. |
|
Return the Blackman window. |
|
Assemble an nd-array from nested lists of blocks. |
|
|
|
Broadcast any number of arrays against each other. |
Broadcast the input shapes into a single shape. |
|
|
Broadcast an array to a new shape. |
Concatenate slices, scalars and array-like objects along the last axis. |
|
|
Returns True if cast between data types can occur according to the casting rule. |
|
Return the cube-root of an array, element-wise. |
alias of |
|
|
Return the ceiling of the input, element-wise. |
Abstract base class of all character string scalar types. |
|
|
Construct an array from an index array and a list of arrays to choose from. |
|
Clip (limit) the values in an array. |
|
Stack 1-D arrays as columns into a 2-D array. |
alias of |
|
|
|
|
|
Abstract base class of all complex number scalar types that are made up of floating-point numbers. |
|
The warning raised when casting a complex dtype to a real dtype. |
|
|
Compress an array along a given axis using a boolean condition. |
|
|
|
Join a sequence of arrays along an existing axis. |
|
Return the complex conjugate, element-wise. |
|
Return the complex conjugate, element-wise. |
|
Returns the discrete, linear convolution of two one-dimensional sequences. |
|
Return an array copy of the given object. |
|
Change the sign of x1 to that of x2, element-wise. |
|
Return Pearson product-moment correlation coefficients. |
|
Cross-correlation of two 1-dimensional sequences. |
|
Cosine element-wise. |
|
Hyperbolic cosine, element-wise. |
|
Counts the number of non-zero values in the array |
|
Estimate a covariance matrix, given data and weights. |
|
Return the cross product of two (arrays of) vectors. |
alias of |
|
|
Return the cumulative product of elements along a given axis. |
|
Return the cumulative sum of the elements along a given axis. |
|
|
|
Convert angles from degrees to radians. |
|
Convert angles from radians to degrees. |
|
Delete entry or entries from an array. |
|
Extract a diagonal or construct a diagonal array. |
|
Return the indices to access the main diagonal of an array. |
|
Return the indices to access the main diagonal of an n-dimensional array. |
|
Create a two-dimensional array with the flattened input as a diagonal. |
|
Return specified diagonals. |
|
Calculate the n-th discrete difference along the given axis. |
|
Return the indices of the bins to which each value in input array belongs. |
|
Divide arguments element-wise. |
|
Return element-wise quotient and remainder simultaneously. |
|
Compute the dot product of two arrays. |
alias of |
|
|
Split array into multiple sub-arrays along the 3rd axis (depth). |
|
Stack arrays in sequence depth wise (along third axis). |
|
Create a data type object. |
|
The differences between consecutive elements of an array. |
|
Evaluates the Einstein summation convention on the operands. |
|
Evaluates the lowest cost contraction order for an einsum expression by |
|
Return a new array of given shape and type, without initializing entries. |
|
Return a new array with the same shape and type as a given array. |
|
Return (x1 == x2) element-wise. |
|
Calculate the exponential of all elements in the input array. |
|
Calculate 2**p for all p in the input array. |
|
Expand the shape of an array. |
|
Calculate |
|
Return the elements of an array that satisfy a condition. |
|
Return a 2-D array with ones on the diagonal and zeros elsewhere. |
|
Compute the absolute values element-wise. |
|
Fill the main diagonal of the given array of any dimensionality. |
|
Machine limits for floating point types. |
|
Round to nearest integer towards zero. |
|
Return indices of nonzero elements in a flattened array |
|
Abstract base class of all scalar types without predefined length. |
|
Reverse the order of elements in an array along the given axis. |
|
Reverse the order of elements along axis 1 (left/right). |
|
Reverse the order of elements along axis 0 (up/down). |
alias of |
|
|
First array elements raised to powers from second array, element-wise. |
|
|
|
|
|
|
|
Abstract base class of all floating-point scalar types. |
|
Return the floor of the input, element-wise. |
|
Return the largest integer smaller or equal to the division of the inputs. |
|
Element-wise maximum of array elements. |
|
Element-wise minimum of array elements. |
|
Returns the element-wise remainder of division. |
|
Decompose the elements of x into mantissa and twos exponent. |
|
Interpret a buffer as a 1-dimensional array. |
|
Unimplemented JAX wrapper for jnp.fromfile. |
|
Construct an array by executing a function over each coordinate. |
|
Unimplemented JAX wrapper for jnp.fromiter. |
|
Create a JAX ufunc from an arbitrary JAX-compatible scalar function. |
|
A new 1-D array initialized from text data in a string. |
|
Create a NumPy array from an object implementing the |
|
Return a new array of given shape and type, filled with fill_value. |
|
Return a full array with the same shape and type as a given array. |
|
Returns the greatest common divisor of |
|
Base class for numpy scalar types. |
|
Return numbers spaced evenly on a log scale (a geometric progression). |
Return the current print options. |
|
|
Return the gradient of an N-dimensional array. |
|
Return the truth value of (x1 > x2) element-wise. |
|
Return the truth value of (x1 >= x2) element-wise. |
|
Return the Hamming window. |
|
Return the Hanning window. |
|
Compute the Heaviside step function. |
|
Compute the histogram of a dataset. |
|
Function to calculate only the edges of the bins used by the histogram |
|
Compute the bi-dimensional histogram of two data samples. |
|
Compute the multidimensional histogram of some data. |
|
Split an array into multiple sub-arrays horizontally (column-wise). |
|
Stack arrays in sequence horizontally (column wise). |
|
Given the "legs" of a right triangle, return its hypotenuse. |
Modified Bessel function of the first kind, order 0. |
|
|
Return the identity array. |
|
|
|
Return the imaginary part of the complex argument. |
A nicer way to build up index tuples for arrays. |
|
|
Return an array representing the indices of a grid. |
|
Abstract base class of all numeric scalar types with a (potentially) inexact representation of the values in its range, such as floating-point numbers. |
|
Compute the inner product of two arrays. |
|
Insert values along the given axis before the given indices. |
alias of |
|
|
|
|
|
|
|
|
|
|
Abstract base class of all integer scalar types. |
|
One-dimensional linear interpolation for monotonically increasing sample points. |
|
Find the intersection of two arrays. |
|
Compute bit-wise inversion, or bit-wise NOT, element-wise. |
|
Returns a boolean array where two arrays are element-wise equal within a |
|
Returns a bool array, where True if input element is complex. |
|
Check for a complex type or an array of complex numbers. |
|
Returns a boolean indicating whether a provided dtype is of a specified kind. |
|
Test element-wise for finiteness (not infinity and not Not a Number). |
|
Calculates |
|
Test element-wise for positive or negative infinity. |
|
Test element-wise for NaN and return result as a boolean array. |
|
Test element-wise for negative infinity, return result as bool array. |
|
Test element-wise for positive infinity, return result as bool array. |
|
Returns a bool array, where True if input element is real. |
|
Return True if x is a not complex type or an array of complex numbers. |
|
Returns True if the type of element is a scalar type. |
|
Returns True if first argument is a typecode lower/equal in type hierarchy. |
|
Check whether or not an object can be iterated over. |
|
Return a multi-dimensional grid (open mesh) from N one-dimensional sequences. |
|
Return the Kaiser window. |
|
Kronecker product of two arrays. |
|
Returns the lowest common multiple of |
|
Returns x1 * 2**x2, element-wise. |
|
Shift the bits of an integer to the left. |
|
Return the truth value of (x1 < x2) element-wise. |
|
Return the truth value of (x1 <= x2) element-wise. |
|
Perform an indirect stable sort using a sequence of keys. |
|
Return evenly spaced numbers over a specified interval. |
|
Load arrays or pickled objects from |
|
Natural logarithm, element-wise. |
|
Return the base 10 logarithm of the input array, element-wise. |
|
Return the natural logarithm of one plus the input array, element-wise. |
|
Base-2 logarithm of x. |
Logarithm of the sum of exponentiations of the inputs. |
|
Logarithm of the sum of exponentiations of the inputs in base-2. |
|
|
Compute the truth value of x1 AND x2 element-wise. |
|
Compute the truth value of NOT x element-wise. |
|
Compute the truth value of x1 OR x2 element-wise. |
|
Compute the truth value of x1 XOR x2, element-wise. |
|
Return numbers spaced evenly on a log scale. |
|
Return the indices to access (n, n) arrays, given a masking function. |
|
Perform a matrix multiplication. |
|
Transpose the last two dimensions of an array. |
|
Return the maximum of an array or maximum along an axis. |
|
Element-wise maximum of array elements. |
|
Compute the arithmetic mean along the specified axis. |
|
Compute the median along the specified axis. |
|
Return a list of coordinate matrices from coordinate vectors. |
Return dense multi-dimensional "meshgrid". |
|
|
Return the minimum of an array or minimum along an axis. |
|
Element-wise minimum of array elements. |
|
Returns the element-wise remainder of division. |
|
Return the fractional and integral parts of an array, element-wise. |
|
Move axes of an array to new positions. |
|
Multiply arguments element-wise. |
|
Replace NaN with zero and infinity with large finite numbers (default |
|
Return the indices of the maximum values in the specified axis ignoring |
|
Return the indices of the minimum values in the specified axis ignoring |
|
Return the cumulative product of array elements over a given axis treating Not a |
|
Return the cumulative sum of array elements over a given axis treating Not a |
|
Return the maximum of an array or maximum along an axis, ignoring any |
|
Compute the arithmetic mean along the specified axis, ignoring NaNs. |
|
Compute the median along the specified axis, while ignoring NaNs. |
|
Return minimum of an array or minimum along an axis, ignoring any NaNs. |
|
Compute the qth percentile of the data along the specified axis, |
|
Return the product of array elements over a given axis treating Not a |
|
Compute the qth quantile of the data along the specified axis, |
|
Compute the standard deviation along the specified axis, while |
|
Return the sum of array elements over a given axis treating Not a |
|
Compute the variance along the specified axis, while ignoring NaNs. |
alias of |
|
|
Return the number of dimensions of an array. |
|
Numerical negative, element-wise. |
|
Return the next floating-point value after x1 towards x2, element-wise. |
|
Return indices of nonzero elements of an array. |
|
Return (x1 != x2) element-wise. |
|
Abstract base class of all numeric scalar types. |
Any Python object. |
|
Return open multi-dimensional "meshgrid". |
|
|
Return a new array of given shape and type, filled with ones. |
|
Return an array of ones with the same shape and type as a given array. |
|
Compute the outer product of two vectors. |
|
Packs the elements of a binary-valued array into bits in a uint8 array. |
|
Pad an array. |
|
Return a partitioned copy of an array. |
|
Compute the q-th percentile of the data along the specified axis. |
|
|
|
Evaluate a piecewise-defined function. |
|
Change elements of an array based on conditional and input values. |
|
Find the coefficients of a polynomial with the given sequence of roots. |
|
Find the sum of two polynomials. |
|
Return the derivative of the specified order of a polynomial. |
|
Returns the quotient and remainder of polynomial division. |
|
Least squares polynomial fit. |
|
Return an antiderivative (indefinite integral) of a polynomial. |
|
Find the product of two polynomials. |
|
Difference (subtraction) of two polynomials. |
|
Evaluate a polynomial at specific values. |
|
Numerical positive, element-wise. |
|
First array elements raised to powers from second array, element-wise. |
|
First array elements raised to powers from second array, element-wise. |
|
Context manager for setting print options. |
|
Return the product of array elements over a given axis. |
|
Returns the type to which a binary operation should cast its arguments. |
|
Range of values (maximum - minimum) along an axis. |
|
Replaces specified elements of an array with given values. |
|
Compute the q-th quantile of the data along the specified axis. |
Concatenate slices, scalars and array-like objects along the first axis. |
|
|
Convert angles from radians to degrees. |
|
Convert angles from degrees to radians. |
|
Flatten array into a 1-dimensional shape. |
|
Converts a tuple of index arrays into an array of flat |
|
Return the real part of the complex argument. |
|
Return the reciprocal of the argument, element-wise. |
|
Returns the element-wise remainder of division. |
|
Repeat each element of an array after themselves |
|
Return a reshaped copy of an array. |
|
Return a new array with the specified shape. |
|
Returns the type that results from applying the NumPy |
|
Shift the bits of an integer to the right. |
|
Round elements of the array to the nearest integer. |
|
Roll array elements along a given axis. |
|
Roll the specified axis backwards, until it lies in a given position. |
|
Return the roots of a polynomial with coefficients given in p. |
|
Rotate an array by 90 degrees in the plane specified by axes. |
|
Round an array to the given number of decimals. |
|
Round an array to the given number of decimals. |
A nicer way to build up index tuples for arrays. |
|
|
Save an array to a binary file in NumPy |
|
Save several arrays into a single file in uncompressed |
|
Find indices where elements should be inserted to maintain order. |
|
Return an array drawn from elements in choicelist, depending on conditions. |
|
Set printing options. |
|
Find the set difference of two arrays. |
|
Find the set exclusive-or of two arrays. |
|
Return the shape of an array. |
|
Returns an element-wise indication of the sign of a number. |
|
Returns element-wise True where signbit is set (less than zero). |
Abstract base class of all signed integer scalar types. |
|
|
Trigonometric sine, element-wise. |
|
Return the normalized sinc function. |
alias of |
|
|
Hyperbolic sine, element-wise. |
|
Return the number of elements along a given axis. |
|
Return a sorted copy of an array. |
|
Sort a complex array using the real part first, then the imaginary part. |
|
Split an array into multiple sub-arrays as views into ary. |
|
Return the non-negative square-root of an array, element-wise. |
|
Return the element-wise square of the input. |
|
Remove axes of length one from a. |
|
Join a sequence of arrays along a new axis. |
|
Compute the standard deviation along the specified axis. |
|
Subtract arguments, element-wise. |
|
Sum of array elements over a given axis. |
|
Interchange two axes of an array. |
|
Take elements from an array along an axis. |
|
Take values from the input array by matching 1d index and data slices. |
|
Compute tangent element-wise. |
|
Compute hyperbolic tangent element-wise. |
|
Compute the tensor dot product of two N-dimensional arrays. |
|
Construct an array by repeating A the number of times given by reps. |
|
Return the sum along diagonals of the array. |
|
Integrate along the given axis using the composite trapezoidal rule. |
|
Return a transposed version of an N-dimensional array. |
|
An array with ones at and below the given diagonal and zeros elsewhere. |
|
Lower triangle of an array. |
|
Return the indices for the lower-triangle of an (n, m) array. |
|
Return the indices for the lower-triangle of arr. |
|
Trim the leading and/or trailing zeros from a 1-D array or sequence. |
|
Upper triangle of an array. |
|
Return the indices for the upper-triangle of an (n, m) array. |
|
Return the indices for the upper-triangle of arr. |
|
Divide arguments element-wise. |
|
Return the truncated value of the input, element-wise. |
|
Functions that operate element-by-element on whole arrays. |
alias of |
|
|
|
|
|
|
|
|
|
|
Find the union of two arrays. |
|
Find the unique elements of an array. |
|
|
|
|
|
|
|
|
|
Unpacks elements of a uint8 array into a binary-valued output array. |
|
Converts a flat index or array of flat indices into a tuple |
|
|
Abstract base class of all unsigned integer scalar types. |
|
|
Unwrap by taking the complement of large deltas with respect to the period. |
|
Generate a Vandermonde matrix. |
|
Compute the variance along the specified axis. |
|
Perform a conjugate multiplication of two 1D vectors. |
|
Perform a conjugate multiplication of two batched vectors. |
|
Define a vectorized function with broadcasting. |
|
Split an array into multiple sub-arrays vertically (row-wise). |
|
Stack arrays in sequence vertically (row wise). |
|
Select elements from two arrays based on a condition. |
|
Return a new array of given shape and type, filled with zeros. |
|
Return an array of zeros with the same shape and type as a given array. |
jax.numpy.fft#
|
Compute the one-dimensional discrete Fourier Transform. |
|
Compute the 2-dimensional discrete Fourier Transform. |
|
Return the Discrete Fourier Transform sample frequencies. |
|
Compute the N-dimensional discrete Fourier Transform. |
|
Shift the zero-frequency component to the center of the spectrum. |
|
Compute the FFT of a signal that has Hermitian symmetry, i.e., a real |
|
Compute the one-dimensional inverse discrete Fourier Transform. |
|
Compute the 2-dimensional inverse discrete Fourier Transform. |
|
Compute the N-dimensional inverse discrete Fourier Transform. |
|
The inverse of fftshift. |
|
Compute the inverse FFT of a signal that has Hermitian symmetry. |
|
Computes the inverse of rfft. |
|
Computes the inverse of rfft2. |
|
Computes the inverse of rfftn. |
|
Compute the one-dimensional discrete Fourier Transform for real input. |
|
Compute the 2-dimensional FFT of a real array. |
|
Return the Discrete Fourier Transform sample frequencies |
|
Compute the N-dimensional discrete Fourier Transform for real input. |
jax.numpy.linalg#
|
Compute the Cholesky decomposition of a matrix. |
|
Compute the condition number of a matrix. |
|
Compute the corss-product of two 3D vectors |
Computes the determinant of an array. |
|
|
Extract the diagonal of an matrix or stack of matrices. |
|
Computes the eigenvalues and eigenvectors of a square array. |
|
Computes the eigenvalues and eigenvectors of a Hermitian matrix. |
|
Computes the eigenvalues of a general matrix. |
|
Computes the eigenvalues of a Hermitian matrix. |
|
Return the inverse of a square matrix |
|
Return the least-squares solution to a linear equation. |
|
Perform a matrix multiplication. |
|
Compute the norm of a matrix or stack of matrices. |
|
Raise a square matrix to an integer power. |
|
Compute the rank of a matrix. |
|
Transpose a matrix or stack of matrices. |
|
Efficiently compute matrix products between a sequence of arrays. |
|
Compute the norm of a matrix or vector. |
|
Compute the outer product of two 1-dimensional arrays. |
Compute the (Moore-Penrose) pseudo-inverse of a matrix. |
|
|
Compute the QR decomposition of an array |
|
Computes the sign and (natural) logarithm of the determinant of an array. |
|
Solve a linear system of equations |
|
Compute the singular value decomposition. |
|
Compute the singular values of a matrix. |
|
Compute the tensor dot product of two N-dimensional arrays. |
|
Compute the tensor inverse of an array. |
|
Solve the tensor equation a x = b for x. |
|
Computes the vector norm of a vector or batch of vectors. |
|
Compute the (batched) vector conjugate dot product of two arrays. |
JAX Array#
The JAX Array
(along with its alias, jax.numpy.ndarray
) is
the core array object in JAX: you can think of it as JAX’s equivalent of a
numpy.ndarray
. Like numpy.ndarray
, most users will not need to
instantiate Array
objects manually, but rather will create them via
jax.numpy
functions like array()
, arange()
,
linspace()
, and others listed above.
Copying and Serialization#
JAX Array
objects are designed to work seamlessly with Python
standard library tools where appropriate.
With the built-in copy
module, when copy.copy()
or copy.deepcopy()
encounder an Array
, it is equivalent to calling the
copy()
method, which will create a copy of
the buffer on the same device as the original array. This will work correctly within
traced/JIT-compiled code, though copy operations may be elided by the compiler
in this context.
When the built-in pickle
module encounters an Array
,
it will be serialized via a compact bit representation in a similar manner to pickled
numpy.ndarray
objects. When unpickled, the result will be a new
Array
object on the default device.
This is because in general, pickling and unpickling may take place in different runtime
environments, and there is no general way to map the device IDs of one runtime
to the device IDs of another. If pickle
is used in traced/JIT-compiled code,
it will result in a ConcretizationTypeError
.
jax.scipy
module#
jax.scipy.cluster#
|
Assign codes from a code book to a set of observations. |
jax.scipy.fft#
|
Computes the discrete cosine transform of the input |
|
Computes the multidimensional discrete cosine transform of the input |
|
Computes the inverse discrete cosine transform of the input |
|
Computes the multidimensional inverse discrete cosine transform of the input |
jax.scipy.integrate#
|
Integrate along the given axis using the composite trapezoidal rule. |
jax.scipy.linalg#
|
Create a block diagonal matrix from input arrays. |
|
Factorization for Cholesky-based linear solves |
|
Solve a linear system using a Cholesky factorization |
|
Compute the Cholesky decomposition of a matrix. |
|
Compute the determinant of a matrix |
|
Compute eigenvalues and eigenvectors for a Hermitian matrix |
|
Solve the eigenvalue problem for a symmetric real tridiagonal matrix |
|
Compute the matrix exponential |
Compute the Frechet derivative of the matrix exponential. |
|
|
Evaluate a matrix-valued function |
Compute the Hessenberg form of the matrix |
|
|
Create a Hilbert matrix of order n. |
|
Return the inverse of a square matrix |
|
Compute the LU decomposition |
|
Factorization for LU-based linear solves |
|
Solve a linear system using an LU factorization |
|
Computes the polar decomposition. |
|
Compute the QR decomposition of an array |
|
Convert real Schur form to complex Schur form. |
|
Compute the Schur decomposition |
|
Solve a linear system of equations |
|
Solve a triangular linear system of equations |
|
Compute the matrix square root |
|
Compute the singular value decomposition. |
|
Construct a Toeplitz matrix |
jax.scipy.ndimage#
|
Map the input array to new coordinates using interpolation. |
jax.scipy.optimize#
|
Minimization of scalar function of one or more variables. |
|
Object holding optimization results. |
jax.scipy.signal#
|
Convolve two N-dimensional arrays using Fast Fourier Transform (FFT). |
|
Convolution of two N-dimensional arrays. |
|
Convolution of two 2-dimensional arrays. |
|
Cross-correlation of two N-dimensional arrays. |
|
Cross-correlation of two 2-dimensional arrays. |
|
Estimate cross power spectral density (CSD) using Welch's method. |
|
Remove linear or piecewise linear trends from data. |
|
Perform the inverse short-time Fourier transform (ISTFT). |
|
Compute the short-time Fourier transform (STFT). |
|
Estimate power spectral density (PSD) using Welch's method. |
jax.scipy.spatial.transform#
|
Rotation in 3 dimensions. |
|
Spherical Linear Interpolation of Rotations. |
jax.scipy.sparse.linalg#
|
Use Bi-Conjugate Gradient Stable iteration to solve |
|
Use Conjugate Gradient iteration to solve |
|
GMRES solves the linear system A x = b for x, given A and b. |
jax.scipy.special#
|
Generate the first N Bernoulli numbers. |
|
The beta function |
|
The regularized incomplete beta function. |
|
Natural log of the absolute value of the beta function |
|
The digamma function |
|
The entropy function |
|
The error function |
|
The complement of the error function |
|
The inverse of the error function |
|
Exponential integral function. |
Exponential integral function. |
|
|
The logistic sigmoid (expit) function |
Generalized exponential integral function. |
|
|
Factorial function |
|
The gamma function. |
|
The regularized lower incomplete gamma function. |
|
The regularized upper incomplete gamma function. |
|
Natural log of the absolute value of the gamma function. |
|
Sign of the gamma function. |
The 1F1 hypergeometric function. |
|
|
Modified bessel function of zeroth order. |
|
Exponentially scaled modified bessel function of zeroth order. |
|
Modified bessel function of first order. |
|
Exponentially scaled modified bessel function of first order. |
Log Normal distribution function. |
|
The logit function |
|
Log-sum-exp reduction. |
|
|
The associated Legendre functions (ALFs) of the first kind. |
|
The associated Legendre functions (ALFs) of the first kind. |
|
The natural log of the multivariate gamma function. |
|
Normal distribution function. |
|
The inverse of the CDF of the Normal distribution function. |
The Pochammer symbol. |
|
|
The polygamma function. |
|
Spence's function, also known as the dilogarithm for real values. |
|
Computes the spherical harmonics. |
Compute x*log(1 + y), returning 0 for x=0. |
|
Compute x*log(y), returning 0 for x=0. |
|
The Hurwitz zeta function. |
|
|
The Kullback-Leibler divergence. |
|
The relative entropy function. |
jax.scipy.stats#
|
Compute the mode (most common value) along an axis of an array. |
|
Compute the rank of data along an array axis. |
|
Compute the standard error of the mean. |
jax.scipy.stats.bernoulli#
|
Bernoulli log probability mass function. |
|
Bernoulli probability mass function. |
|
Bernoulli cumulative distribution function. |
|
Bernoulli percent point function. |
jax.scipy.stats.beta#
|
Beta log probability distribution function. |
|
Beta probability distribution function. |
|
Beta cumulative distribution function |
|
Beta log cumulative distribution function. |
|
Beta distribution survival function. |
|
Beta distribution log survival function. |
jax.scipy.stats.betabinom#
|
Beta-binomial log probability mass function. |
|
Beta-binomial probability mass function. |
jax.scipy.stats.binom#
|
Binomial log probability mass function. |
|
Binomial probability mass function. |
jax.scipy.stats.cauchy#
|
Cauchy log probability distribution function. |
|
Cauchy probability distribution function. |
|
Cauchy cumulative distribution function. |
|
Cauchy log cumulative distribution function. |
|
Cauchy distribution log survival function. |
|
Cauchy distribution log survival function. |
|
Cauchy distribution inverse survival function. |
|
Cauchy distribution percent point function. |
jax.scipy.stats.chi2#
|
Chi-square log probability distribution function. |
|
Chi-square probability distribution function. |
|
Chi-square cumulative distribution function. |
|
Chi-square log cumulative distribution function. |
|
Chi-square survival function. |
|
Chi-square log survival function. |
jax.scipy.stats.dirichlet#
|
Dirichlet log probability distribution function. |
|
Dirichlet probability distribution function. |
jax.scipy.stats.expon#
|
Exponential log probability distribution function. |
|
Exponential probability distribution function. |
jax.scipy.stats.gamma#
|
Gamma log probability distribution function. |
|
Gamma probability distribution function. |
|
Gamma cumulative distribution function. |
|
Gamma log cumulative distribution function. |
|
Gamma survival function. |
|
Gamma log survival function. |
jax.scipy.stats.gennorm#
|
Generalized normal cumulative distribution function. |
|
Generalized normal log probability distribution function. |
|
Generalized normal probability distribution function. |
jax.scipy.stats.geom#
|
Geometric log probability mass function. |
|
Geometric probability mass function. |
jax.scipy.stats.laplace#
|
Laplace cumulative distribution function. |
|
Laplace log probability distribution function. |
|
Laplace probability distribution function. |
jax.scipy.stats.logistic#
|
Logistic cumulative distribution function. |
|
Logistic distribution inverse survival function. |
|
Logistic log probability distribution function. |
|
Logistic probability distribution function. |
|
Logistic distribution percent point function. |
|
Logistic distribution survival function. |
jax.scipy.stats.multinomial#
|
Multinomial log probability mass function. |
|
Multinomial probability mass function. |
jax.scipy.stats.multivariate_normal#
|
Multivariate normal log probability distribution function. |
|
Multivariate normal probability distribution function. |
jax.scipy.stats.nbinom#
|
Negative-binomial log probability mass function. |
|
Negative-binomial probability mass function. |
jax.scipy.stats.norm#
|
Normal log probability distribution function. |
|
Normal probability distribution function. |
|
Normal cumulative distribution function. |
|
Normal log cumulative distribution function. |
|
Normal distribution percent point function. |
|
Normal distribution survival function. |
|
Normal distribution log survival function. |
|
Normal distribution inverse survival function. |
jax.scipy.stats.pareto#
|
Pareto log probability distribution function. |
|
Pareto probability distribution function. |
jax.scipy.stats.poisson#
|
Poisson log probability mass function. |
|
Poisson probability mass function. |
|
Poisson cumulative distribution function. |
jax.scipy.stats.t#
|
Student's T log probability distribution function. |
|
Student's T probability distribution function. |
jax.scipy.stats.truncnorm#
|
Truncated normal cumulative distribution function. |
|
Truncated normal log cumulative distribution function. |
|
Truncated normal log probability distribution function. |
|
Truncated normal distribution log survival function. |
|
Truncated normal probability distribution function. |
|
Truncated normal distribution log survival function. |
jax.scipy.stats.uniform#
|
Uniform log probability distribution function. |
|
Uniform probability distribution function. |
|
Uniform cumulative distribution function. |
|
Uniform distribution percent point function. |
jax.scipy.stats.gaussian_kde#
|
Gaussian Kernel Density Estimator |
|
Evaluate the Gaussian KDE on the given points. |
|
Integrate the distribution weighted by a Gaussian. |
|
Integrate the distribution over the given limits. |
|
Integrate the product of two Gaussian KDE distributions. |
|
Randomly sample a dataset from the estimated pdf |
Probability density function |
|
Log probability density function |
jax.scipy.stats.vonmises#
|
von Mises log probability distribution function. |
|
von Mises probability distribution function. |
jax.scipy.stats.wrapcauchy#
|
Wrapped Cauchy log probability distribution function. |
|
Wrapped Cauchy probability distribution function. |
jax.lax
module#
jax.lax
is a library of primitives operations that underpins libraries
such as jax.numpy
. Transformation rules, such as JVP and batching rules,
are typically defined as transformations on jax.lax
primitives.
Many of the primitives are thin wrappers around equivalent XLA operations, described by the XLA operation semantics documentation. In a few cases JAX diverges from XLA, usually to ensure that the set of operations is closed under the operation of JVP and transpose rules.
Where possible, prefer to use libraries such as jax.numpy
instead of
using jax.lax
directly. The jax.numpy
API follows NumPy, and is
therefore more stable and less likely to change than the jax.lax
API.
Operators#
|
Elementwise absolute value: \(|x|\). |
|
Elementwise arc cosine: \(\mathrm{acos}(x)\). |
|
Elementwise inverse hyperbolic cosine: \(\mathrm{acosh}(x)\). |
|
Elementwise addition: \(x + y\). |
|
Merges one or more XLA token values. |
|
Returns max |
|
Returns min |
|
Computes the index of the maximum element along |
|
Computes the index of the minimum element along |
|
Elementwise arc sine: \(\mathrm{asin}(x)\). |
|
Elementwise inverse hyperbolic sine: \(\mathrm{asinh}(x)\). |
|
Elementwise arc tangent: \(\mathrm{atan}(x)\). |
|
Elementwise arc tangent of two variables: \(\mathrm{atan}({x \over y})\). |
|
Elementwise inverse hyperbolic tangent: \(\mathrm{atanh}(x)\). |
|
Batch matrix multiplication. |
|
Exponentially scaled modified Bessel function of order 0: \(\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)\) |
|
Exponentially scaled modified Bessel function of order 1: \(\mathrm{i1e}(x) = e^{-|x|} \mathrm{i1}(x)\) |
|
Elementwise regularized incomplete beta integral. |
|
Elementwise bitcast. |
|
Elementwise AND: \(x \wedge y\). |
|
Elementwise NOT: \(\neg x\). |
|
Elementwise OR: \(x \vee y\). |
|
Elementwise exclusive OR: \(x \oplus y\). |
Elementwise popcount, count the number of set bits in each element. |
|
|
Broadcasts an array, adding new leading dimensions |
|
Wraps XLA's BroadcastInDim operator. |
Returns the shape that results from NumPy broadcasting of shapes. |
|
|
Adds leading dimensions of |
|
Convenience wrapper around |
|
Elementwise cube root: \(\sqrt[3]{x}\). |
|
Elementwise ceiling: \(\left\lceil x \right\rceil\). |
|
Elementwise clamp. |
|
Elementwise count-leading-zeros. |
|
Collapses dimensions of an array into a single dimension. |
|
Elementwise make complex number: \(x + jy\). |
|
Concatenates a sequence of arrays along dimension. |
|
Elementwise complex conjugate function: \(\overline{x}\). |
|
Convenience wrapper around conv_general_dilated. |
|
Elementwise cast. |
|
Converts convolution dimension_numbers to a ConvDimensionNumbers. |
|
General n-dimensional convolution operator, with optional dilation. |
|
General n-dimensional unshared convolution operator with optional dilation. |
|
Extract patches subject to the receptive field of conv_general_dilated. |
|
Convenience wrapper for calculating the N-d convolution "transpose". |
|
Convenience wrapper around conv_general_dilated. |
|
Elementwise cosine: \(\mathrm{cos}(x)\). |
|
Elementwise hyperbolic cosine: \(\mathrm{cosh}(x)\). |
|
Computes a cumulative logsumexp along axis. |
|
Computes a cumulative maximum along axis. |
|
Computes a cumulative minimum along axis. |
|
Computes a cumulative product along axis. |
|
Computes a cumulative sum along axis. |
|
Elementwise digamma: \(\psi(x)\). |
|
Elementwise division: \(x \over y\). |
|
Vector/vector, matrix/vector, and matrix/matrix multiplication. |
|
General dot product/contraction operator. |
|
Convenience wrapper around dynamic_slice to perform int indexing. |
|
Wraps XLA's DynamicSlice operator. |
|
Convenience wrapper around |
|
Convenience wrapper around |
|
Wraps XLA's DynamicUpdateSlice operator. |
|
Convenience wrapper around |
|
Elementwise equals: \(x = y\). |
|
Elementwise error function: \(\mathrm{erf}(x)\). |
|
Elementwise complementary error function: \(\mathrm{erfc}(x) = 1 - \mathrm{erf}(x)\). |
|
Elementwise inverse error function: \(\mathrm{erf}^{-1}(x)\). |
|
Elementwise exponential: \(e^x\). |
|
Insert any number of size 1 dimensions into an array. |
|
Elementwise \(e^{x} - 1\). |
|
|
|
Elementwise floor: \(\left\lfloor x \right\rfloor\). |
|
Returns an array of shape filled with fill_value. |
|
Create a full array like np.full based on the example array x. |
|
Gather operator. |
|
Elementwise greater-than-or-equals: \(x \geq y\). |
|
Elementwise greater-than: \(x > y\). |
|
Elementwise regularized incomplete gamma function. |
|
Elementwise complementary regularized incomplete gamma function. |
|
Elementwise extract imaginary part: \(\mathrm{Im}(x)\). |
|
Convenience wrapper around |
|
|
|
Elementwise power: \(x^y\), where \(y\) is a fixed integer. |
|
Wraps XLA's Iota operator. |
|
Elementwise \(\mathrm{isfinite}\). |
|
Elementwise less-than-or-equals: \(x \leq y\). |
|
Elementwise log gamma: \(\mathrm{log}(\Gamma(x))\). |
|
Elementwise natural logarithm: \(\mathrm{log}(x)\). |
|
Elementwise \(\mathrm{log}(1 + x)\). |
|
Elementwise logistic (sigmoid) function: \(\frac{1}{1 + e^{-x}}\). |
|
Elementwise less-than: \(x < y\). |
|
Elementwise maximum: \(\mathrm{max}(x, y)\) |
|
Elementwise minimum: \(\mathrm{min}(x, y)\) |
|
Elementwise multiplication: \(x \times y\). |
|
Elementwise not-equals: \(x \neq y\). |
|
Elementwise negation: \(-x\). |
|
Returns the next representable value after x1 in the direction of x2. |
|
Applies low, high, and/or interior padding to an array. |
|
Elementwise polygamma: \(\psi^{(m)}(x)\). |
Elementwise popcount, count the number of set bits in each element. |
|
|
Elementwise power: \(x^y\). |
|
Elementwise derivative of samples from Gamma(a, 1). |
|
Elementwise extract real part: \(\mathrm{Re}(x)\). |
|
Elementwise reciprocal: \(1 \over x\). |
|
Wraps XLA's Reduce operator. |
|
Wraps XLA's ReducePrecision operator. |
|
Wraps XLA's ReduceWindowWithGeneralPadding operator. |
|
Elementwise remainder: \(x \bmod y\). |
|
Wraps XLA's Reshape operator. |
|
Wraps XLA's Rev operator. |
|
Stateless PRNG bit generator. |
|
Stateful PRNG generator. |
|
Elementwise round. |
|
Elementwise reciprocal square root: \(1 \over \sqrt{x}\). |
|
Scatter-update operator. |
|
Scatter-add operator. |
|
Scatter-apply operator. |
|
Scatter-max operator. |
|
Scatter-min operator. |
|
Scatter-multiply operator. |
|
Elementwise left shift: \(x \ll y\). |
|
Elementwise arithmetic right shift: \(x \gg y\). |
|
Elementwise logical right shift: \(x \gg y\). |
|
Elementwise sign. |
|
Elementwise sine: \(\mathrm{sin}(x)\). |
|
Elementwise hyperbolic sine: \(\mathrm{sinh}(x)\). |
|
Wraps XLA's Slice operator. |
|
Convenience wrapper around |
|
Wraps XLA's Sort operator. |
|
Sorts |
|
Elementwise square root: \(\sqrt{x}\). |
|
Elementwise square: \(x^2\). |
|
Squeeze any number of size 1 dimensions from an array. |
|
Elementwise subtraction: \(x - y\). |
|
Elementwise tangent: \(\mathrm{tan}(x)\). |
|
Elementwise hyperbolic tangent: \(\mathrm{tanh}(x)\). |
|
Deprecated. |
|
Returns top |
|
Wraps XLA's Transpose operator. |
|
Elementwise Hurwitz zeta function: \(\zeta(x, q)\) |
Control flow operators#
|
Performs a scan with an associative binary operation, in parallel. |
|
Conditionally apply |
|
Loop from |
|
Map a function over leading array axes. |
|
Scan a function over leading array axes while carrying along state. |
|
Selects between two branches based on a boolean predicate. |
|
Selects array values from multiple cases. |
|
Apply exactly one of the |
|
Call |
Custom gradient operators#
Stops gradient computation. |
|
|
Perform a matrix-free linear solve with implicitly defined gradients. |
|
Differentiably solve for the roots of a function. |
Parallel operators#
|
Gather values of x across all replicas. |
|
Materialize the mapped axis and map a different axis. |
|
|
|
Compute an all-reduce sum on |
|
Like |
|
Compute an all-reduce max on |
|
Compute an all-reduce min on |
|
Compute an all-reduce mean on |
|
Perform a collective permutation according to the permutation |
|
Convenience wrapper of jax.lax.ppermute with alternate permutation encoding |
|
Swap the pmapped axis |
|
Return the index along the mapped axis |
Linear algebra operators (jax.lax.linalg)#
|
Cholesky decomposition. |
|
Eigendecomposition of a general matrix. |
|
Eigendecomposition of a Hermitian matrix. |
|
Reduces a square matrix to upper Hessenberg form. |
|
LU decomposition with partial pivoting. |
|
Product of elementary Householder reflectors. |
|
QR-based dynamically weighted Halley iteration for polar decomposition. |
|
QR decomposition. |
|
|
|
Singular value decomposition. |
|
Triangular solve. |
|
Reduces a symmetric/Hermitian matrix to tridiagonal form. |
|
Computes the solution of a tridiagonal linear system. |
Argument classes#
- class jax.lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)[source]#
Describes batch, spatial, and feature dimensions of a convolution.
- Parameters:
lhs_spec (Sequence[int]) – a tuple of nonnegative integer dimension numbers containing (batch dimension, feature dimension, spatial dimensions…).
rhs_spec (Sequence[int]) – a tuple of nonnegative integer dimension numbers containing (out feature dimension, in feature dimension, spatial dimensions…).
out_spec (Sequence[int]) – a tuple of nonnegative integer dimension numbers containing (batch dimension, feature dimension, spatial dimensions…).
- jax.lax.ConvGeneralDilatedDimensionNumbers#
- class jax.lax.GatherDimensionNumbers(offset_dims, collapsed_slice_dims, start_index_map)[source]#
Describes the dimension number arguments to an XLA’s Gather operator. See the XLA documentation for more details of what the dimension numbers mean.
- Parameters:
offset_dims (tuple[int, ...]) – the set of dimensions in the gather output that offset into an array sliced from operand. Must be a tuple of integers in ascending order, each representing a dimension number of the output.
collapsed_slice_dims (tuple[int, ...]) – the set of dimensions i in operand that have slice_sizes[i] == 1 and that should not have a corresponding dimension in the output of the gather. Must be a tuple of integers in ascending order.
start_index_map (tuple[int, ...]) – for each dimension in start_indices, gives the corresponding dimension in the operand that is to be sliced. Must be a tuple of integers with size equal to start_indices.shape[-1].
Unlike XLA’s GatherDimensionNumbers structure, index_vector_dim is implicit; there is always an index vector dimension and it must always be the last dimension. To gather scalar indices, add a trailing dimension of size 1.
- class jax.lax.GatherScatterMode(value)[source]#
Describes how to handle out-of-bounds indices in a gather or scatter.
Possible values are:
- CLIP:
Indices will be clamped to the nearest in-range value, i.e., such that the entire window to be gathered is in-range.
- FILL_OR_DROP:
If any part of a gathered window is out of bounds, the entire window that is returned, even those elements that were otherwise in-bounds, will be filled with a constant. If any part of a scattered window is out of bounds, the entire window will be discarded.
- PROMISE_IN_BOUNDS:
The user promises that indices are in bounds. No additional checking will be performed. In practice, with the current XLA implementation this means that out-of-bounds gathers will be clamped but out-of-bounds scatters will be discarded. Gradients will not be correct if indices are out-of-bounds.
- class jax.lax.Precision(value)[source]#
Precision enum for lax functions
The precision argument to JAX functions generally controls the tradeoff between speed and accuracy for array computations on accelerator backends, (i.e. TPU and GPU). Members are:
- DEFAULT:
Fastest mode, but least accurate. Performs computations in bfloat16. Aliases:
'default'
,'fastest'
,'bfloat16'
.- HIGH:
Slower but more accurate. Performs float32 computations in 3 bfloat16 passes, or using tensorfloat32 where available. Aliases:
'high'
,'bfloat16_3x'
,'tensorfloat32'
.- HIGHEST:
Slowest but most accurate. Performs computations in float32 or float64 as applicable. Aliases:
'highest'
,'float32'
.
- jax.lax.PrecisionLike#
alias of
str
|Precision
|tuple
[str
,str
] |tuple
[Precision
,Precision
] |None
- class jax.lax.ScatterDimensionNumbers(update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims)[source]#
Describes the dimension number arguments to an XLA’s Scatter operator. See the XLA documentation for more details of what the dimension numbers mean.
- Parameters:
update_window_dims (Sequence[int]) – the set of dimensions in the updates that are window dimensions. Must be a tuple of integers in ascending order, each representing a dimension number.
inserted_window_dims (Sequence[int]) – the set of size 1 window dimensions that must be inserted into the shape of updates. Must be a tuple of integers in ascending order, each representing a dimension number of the output. These are the mirror image of collapsed_slice_dims in the case of gather.
scatter_dims_to_operand_dims (Sequence[int]) – for each dimension in scatter_indices, gives the corresponding dimension in operand. Must be a sequence of integers with size equal to scatter_indices.shape[-1].
Unlike XLA’s ScatterDimensionNumbers structure, index_vector_dim is implicit; there is always an index vector dimension and it must always be the last dimension. To scatter scalar indices, add a trailing dimension of size 1.
jax.random
module#
Utilities for pseudo-random number generation.
The jax.random
package provides a number of routines for deterministic
generation of sequences of pseudorandom numbers.
Basic usage#
>>> seed = 1701
>>> num_steps = 100
>>> key = jax.random.key(seed)
>>> for i in range(num_steps):
... key, subkey = jax.random.split(key)
... params = compiled_update(subkey, params, next(batches))
PRNG keys#
Unlike the stateful pseudorandom number generators (PRNGs) that users of NumPy and
SciPy may be accustomed to, JAX random functions all require an explicit PRNG state to
be passed as a first argument.
The random state is described by a special array element type that we call a key,
usually generated by the jax.random.key()
function:
>>> from jax import random
>>> key = random.key(0)
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]
This key can then be used in any of JAX’s random number generation routines:
>>> random.uniform(key)
Array(0.41845703, dtype=float32)
Note that using a key does not modify it, so reusing the same key will lead to the same result:
>>> random.uniform(key)
Array(0.41845703, dtype=float32)
If you need a new random number, you can use jax.random.split()
to generate new subkeys:
>>> key, subkey = random.split(key)
>>> random.uniform(subkey)
Array(0.10536897, dtype=float32)
Note
Typed key arrays, with element types such as key<fry>
above,
were introduced in JAX v0.4.16. Before then, keys were
conventionally represented in uint32
arrays, whose final
dimension represented the key’s bit-level representation.
Both forms of key array can still be created and used with the
jax.random
module. New-style typed key arrays are made with
jax.random.key()
. Legacy uint32
key arrays are made
with jax.random.PRNGKey()
.
To convert between the two, use jax.random.key_data()
and
jax.random.wrap_key_data()
. The legacy key format may be
needed when interfacing with systems outside of JAX (e.g. exporting
arrays to a serializable format), or when passing keys to JAX-based
libraries that assume the legacy format.
Otherwise, typed keys are recommended. Caveats of legacy keys relative to typed ones include:
They have an extra trailing dimension.
They have a numeric dtype (
uint32
), allowing for operations that are typically not meant to be carried out over keys, such as integer arithmetic.They do not carry information about the RNG implementation. When legacy keys are passed to
jax.random
functions, a global configuration setting determines the RNG implementation (see “Advanced RNG configuration” below).
To learn more about this upgrade, and the design of key types, see JEP 9263.
Advanced#
Design and background#
TLDR: JAX PRNG = Threefry counter PRNG + a functional array-oriented splitting model
See docs/jep/263-prng.md for more details.
To summarize, among other requirements, the JAX PRNG aims to:
ensure reproducibility,
parallelize well, both in terms of vectorization (generating array values) and multi-replica, multi-core computation. In particular it should not use sequencing constraints between random function calls.
Advanced RNG configuration#
JAX provides several PRNG implementations. A specific one can be selected with the optional impl keyword argument to jax.random.key. When no impl option is passed to the key constructor, the implementation is determined by the global jax_default_prng_impl configuration flag.
default, “threefry2x32”: A counter-based PRNG built around the Threefry hash function.
experimental A PRNG that thinly wraps the XLA Random Bit Generator (RBG) algorithm. See TF doc.
“rbg” uses ThreeFry for splitting, and XLA RBG for data generation.
“unsafe_rbg” exists only for demonstration purposes, using RBG both for splitting (using an untested made up algorithm) and generating.
The random streams generated by these experimental implementations haven’t been subject to any empirical randomness testing (e.g. Big Crush). The random bits generated may change between JAX versions.
The possible reasons not use the default RNG are:
it may be slow to compile (specifically for Google Cloud TPUs)
it’s slower to execute on TPUs
it doesn’t support efficient automatic sharding / partitioning
Here is a short summary:
Property |
Threefry |
Threefry* |
rbg |
unsafe_rbg |
rbg** |
unsafe_rbg** |
---|---|---|---|---|---|---|
Fastest on TPU |
✅ |
✅ |
✅ |
✅ |
||
efficiently shardable (w/ pjit) |
✅ |
✅ |
✅ |
|||
identical across shardings |
✅ |
✅ |
✅ |
✅ |
||
identical across CPU/GPU/TPU |
✅ |
✅ |
||||
identical across JAX/XLA versions |
✅ |
✅ |
(*): with jax_threefry_partitionable=1
set
(**): with XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1
set
The difference between “rbg” and “unsafe_rbg” is that while “rbg” uses a less robust/studied hash function for random value generation (but not for jax.random.split or jax.random.fold_in), “unsafe_rbg” additionally uses less robust hash functions for jax.random.split and jax.random.fold_in. Therefore less safe in the sense that the quality of random streams it generates from different keys is less well understood.
For more about jax_threefry_partitionable, see https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
API Reference#
Key Creation & Manipulation#
|
Create a pseudo-random number generator (PRNG) key given an integer seed. |
|
Create a pseudo-random number generator (PRNG) key given an integer seed. |
|
Recover the bits of key data underlying a PRNG key array. |
|
Wrap an array of key data bits into a PRNG key array. |
|
Folds in data to a PRNG key to form a new PRNG key. |
|
Splits a PRNG key into num new keys by adding a leading axis. |
|
Clone a key for reuse |
Random Samplers#
|
Sample uniformly from the unit Lp ball. |
|
Sample Bernoulli random values with given shape and mean. |
|
Sample Beta random values with given shape and float dtype. |
|
Sample Binomial random values with given shape and float dtype. |
|
Sample uniform bits in the form of unsigned integers. |
|
Sample random values from categorical distributions. |
|
Sample Cauchy random values with given shape and float dtype. |
|
Sample Chisquare random values with given shape and float dtype. |
|
Generates a random sample from a given array. |
|
Sample Dirichlet random values with given shape and float dtype. |
|
Sample from a double sided Maxwell distribution. |
|
Sample Exponential random values with given shape and float dtype. |
|
Sample F-distribution random values with given shape and float dtype. |
|
Sample Gamma random values with given shape and float dtype. |
|
Sample from the generalized normal distribution. |
|
Sample Geometric random values with given shape and float dtype. |
|
Sample Gumbel random values with given shape and float dtype. |
|
Sample Laplace random values with given shape and float dtype. |
|
Sample log-gamma random values with given shape and float dtype. |
|
Sample logistic random values with given shape and float dtype. |
|
Sample lognormal random values with given shape and float dtype. |
|
Sample from a one sided Maxwell distribution. |
|
Sample multivariate normal random values with given mean and covariance. |
|
Sample standard normal random values with given shape and float dtype. |
|
Sample uniformly from the orthogonal group O(n). |
|
Sample Pareto random values with given shape and float dtype. |
|
Returns a randomly permuted array or range. |
|
Sample Poisson random values with given shape and integer dtype. |
|
Sample from a Rademacher distribution. |
|
Sample uniform random values in [minval, maxval) with given shape/dtype. |
|
Sample Rayleigh random values with given shape and float dtype. |
|
Sample Student's t random values with given shape and float dtype. |
|
Sample Triangular random values with given shape and float dtype. |
|
Sample truncated standard normal random values with given shape and dtype. |
|
Sample uniform random values in [minval, maxval) with given shape/dtype. |
|
Sample Wald random values with given shape and float dtype. |
|
Sample from a Weibull distribution. |
jax.sharding
module#
Classes#
- class jax.sharding.Sharding#
Describes how a
jax.Array
is laid out across devices.- property addressable_devices: set[Device]#
The set of devices in the
Sharding
that are addressable by the current process.
- addressable_devices_indices_map(global_shape)[source]#
A mapping from addressable devices to the slice of array data each contains.
addressable_devices_indices_map
contains that part ofdevice_indices_map
that applies to the addressable devices.
- property device_set: set[Device][source]#
The set of devices that this
Sharding
spans.In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.
- devices_indices_map(global_shape)[source]#
Returns a mapping from devices to the array slices each contains.
The mapping includes all global devices, i.e., including non-addressable devices from other processes.
- is_equivalent_to(other, ndim)[source]#
Returns
True
if two shardings are equivalent.Two shardings are equivalent if they place the same logical array shards on the same devices.
For example, a
NamedSharding
may be equivalent to aPositionalSharding
if both place the same shards of the array on the same devices.
- property is_fully_addressable: bool[source]#
Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all of the devices named in the
Sharding
.is_fully_addressable
is equivalent to “is_local” in multi-process JAX.
- property is_fully_replicated: bool[source]#
Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of the entire data.
- class jax.sharding.XLACompatibleSharding#
Bases:
Sharding
A
Sharding
that describes shardings expressible to XLA.Subclasses of
XLACompatibleSharding
work with all JAX APIs and transformations that use XLA.- devices_indices_map(global_shape)[source]#
Returns a mapping from devices to the array slices each contains.
The mapping includes all global devices, i.e., including non-addressable devices from other processes.
- is_equivalent_to(other, ndim)[source]#
Returns
True
if two shardings are equivalent.Two shardings are equivalent if they place the same logical array shards on the same devices.
For example, a
NamedSharding
may be equivalent to aPositionalSharding
if both place the same shards of the array on the same devices.- Parameters:
self (XLACompatibleSharding)
other (XLACompatibleSharding)
ndim (int)
- Return type:
- class jax.sharding.SingleDeviceSharding#
Bases:
XLACompatibleSharding
A
Sharding
that places its data on a single device.- Parameters:
device – A single
Device
.
Example
>>> single_device_sharding = jax.sharding.SingleDeviceSharding( ... jax.devices()[0])
- property device_set: set[Device][source]#
The set of devices that this
Sharding
spans.In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.
- devices_indices_map(global_shape)[source]#
Returns a mapping from devices to the array slices each contains.
The mapping includes all global devices, i.e., including non-addressable devices from other processes.
- property is_fully_addressable: bool[source]#
Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all of the devices named in the
Sharding
.is_fully_addressable
is equivalent to “is_local” in multi-process JAX.
- property is_fully_replicated: bool[source]#
Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of the entire data.
- class jax.sharding.NamedSharding#
Bases:
XLACompatibleSharding
A
NamedSharding
expresses sharding using named axes.A
NamedSharding
is a pair of aMesh
of devices andPartitionSpec
which describes how to shard an array across that mesh.A
Mesh
is a multidimensional NumPy array of JAX devices, where each axis of the mesh has a name, e.g.'x'
or'y'
.A
PartitionSpec
is a tuple, whose elements can be aNone
, a mesh axis, or a tuple of mesh axes. Each element describes how an input dimension is partitioned across zero or more mesh dimensions. For example,PartitionSpec('x', 'y')
says that the first dimension of data is sharded acrossx
axis of the mesh, and the second dimension is sharded acrossy
axis of the mesh.The Distributed arrays and automatic parallelization (https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names) tutorial has more details and diagrams that explain how
Mesh
andPartitionSpec
are used.- Parameters:
mesh – A
jax.sharding.Mesh
object.spec – A
jax.sharding.PartitionSpec
object.
Example
>>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y')) >>> spec = P('x', 'y') >>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
- property addressable_devices: set[Device][source]#
The set of devices in the
Sharding
that are addressable by the current process.
- property device_set: set[Device][source]#
The set of devices that this
Sharding
spans.In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.
- property is_fully_addressable: bool[source]#
Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all of the devices named in the
Sharding
.is_fully_addressable
is equivalent to “is_local” in multi-process JAX.
- property is_fully_replicated: bool#
Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of the entire data.
- property mesh#
(self) -> object
- property spec#
(self) -> object
- class jax.sharding.PositionalSharding(devices, *, memory_kind=None)[source]#
Bases:
XLACompatibleSharding
- Parameters:
devices (Sequence[xc.Device] | np.ndarray)
memory_kind (str | None)
- property device_set: set[Device]#
The set of devices that this
Sharding
spans.In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.
- property is_fully_addressable: bool#
Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all of the devices named in the
Sharding
.is_fully_addressable
is equivalent to “is_local” in multi-process JAX.
- property is_fully_replicated: bool#
Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of the entire data.
- class jax.sharding.PmapSharding#
Bases:
XLACompatibleSharding
Describes a sharding used by
jax.pmap()
.- classmethod default(shape, sharded_dim=0, devices=None)[source]#
Creates a
PmapSharding
which matches the default placement used byjax.pmap()
.- Parameters:
sharded_dim (int) – Dimension the input array is sharded on. Defaults to 0.
devices (Sequence[Device] | None) – Optional sequence of devices to use. If omitted, the implicit
used (device order used by pmap is) –
jax.local_devices()
.of (which is the order) –
jax.local_devices()
.
- Return type:
- property device_set: set[Device]#
The set of devices that this
Sharding
spans.In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.
- property devices#
(self) -> ndarray
- devices_indices_map(global_shape)[source]#
Returns a mapping from devices to the array slices each contains.
The mapping includes all global devices, i.e., including non-addressable devices from other processes.
- is_equivalent_to(other, ndim)[source]#
Returns
True
if two shardings are equivalent.Two shardings are equivalent if they place the same logical array shards on the same devices.
For example, a
NamedSharding
may be equivalent to aPositionalSharding
if both place the same shards of the array on the same devices.- Parameters:
self (PmapSharding)
other (PmapSharding)
ndim (int)
- Return type:
- property is_fully_addressable: bool#
Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all of the devices named in the
Sharding
.is_fully_addressable
is equivalent to “is_local” in multi-process JAX.
- property is_fully_replicated: bool#
Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of the entire data.
- shard_shape(global_shape)[source]#
Returns the shape of the data on each device.
The shard shape returned by this function is calculated from
global_shape
and the properties of the sharding.
- property sharding_spec#
(self) -> jax::ShardingSpec
- class jax.sharding.GSPMDSharding#
Bases:
XLACompatibleSharding
- property device_set: set[Device]#
The set of devices that this
Sharding
spans.In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.
- devices_indices_map(global_shape)[source]#
Returns a mapping from devices to the array slices each contains.
The mapping includes all global devices, i.e., including non-addressable devices from other processes.
- property is_fully_addressable: bool#
Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all of the devices named in the
Sharding
.is_fully_addressable
is equivalent to “is_local” in multi-process JAX.
- property is_fully_replicated: bool#
Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of the entire data.
- class jax.sharding.PartitionSpec(*partitions)[source]#
Tuple describing how to partition an array across a mesh of devices.
Each element is either
None
, a string, or a tuple of strings. See the documentation ofjax.sharding.NamedSharding
for more details.This class exists so JAX’s pytree utilities can distinguish a partition specifications from tuples that should be treated as pytrees.
- class jax.sharding.Mesh(devices, axis_names)[source]#
Declare the hardware resources available in the scope of this manager.
In particular, all
axis_names
become valid resource names inside the managed block and can be used e.g. in thein_axis_resources
argument ofjax.experimental.pjit.pjit()
. Also see JAX’s multi-process programming model (https://jax.readthedocs.io/en/latest/multi_process.html) and the Distributed arrays and automatic parallelization tutorial (https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)If you are compiling in multiple threads, make sure that the
with Mesh
context manager is inside the function that the threads will execute.- Parameters:
devices (ndarray) – A NumPy ndarray object containing JAX device objects (as obtained e.g. from
jax.devices()
).axis_names (tuple[Any, ...]) – A sequence of resource axis names to be assigned to the dimensions of the
devices
argument. Its length should match the rank ofdevices
.
Example
>>> from jax.experimental.pjit import pjit >>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> import numpy as np ... >>> inp = np.arange(16).reshape((8, 2)) >>> devices = np.array(jax.devices()).reshape(4, 2) ... >>> # Declare a 2D mesh with axes `x` and `y`. >>> global_mesh = Mesh(devices, ('x', 'y')) >>> # Use the mesh object directly as a context manager. >>> with global_mesh: ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Initialize the Mesh and use the mesh as the context manager. >>> with Mesh(devices, ('x', 'y')) as global_mesh: ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Also you can use it as `with ... as ...`. >>> global_mesh = Mesh(devices, ('x', 'y')) >>> with global_mesh as m: ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # You can also use it as `with Mesh(...)`. >>> with Mesh(devices, ('x', 'y')): ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
jax.debug
module#
Runtime value debugging utilities#
jax.debug.print and jax.debug.breakpoint describes how to make use of JAX’s runtime value debugging features.
|
Calls a stageable Python callback. |
|
Prints values and works in staged out JAX functions. |
|
Enters a breakpoint at a point in a program. |
jax.dlpack
module#
|
Returns a |
|
Returns a DLPack tensor that encapsulates a |
jax.distributed
module#
|
Initializes the JAX distributed system. |
|
Shuts down the distributed system. |
jax.dtypes
module#
bfloat16 floating-point values |
|
|
Convert from a dtype to a canonical dtype based on config.x64_enabled. |
DType class corresponding to the scalar type and dtype of the same name. |
|
|
Returns True if first argument is a typecode lower/equal in type hierarchy. |
|
Scalar class for PRNG Key dtypes. |
|
Convenience function to apply JAX argument dtype promotion. |
Return the scalar type associated with a JAX value. |
jax.flatten_util
module#
List of Functions#
|
Ravel (flatten) a pytree of arrays down to a 1D array. |
jax.image
module#
Image manipulation functions.
More image manipulation functions can be found in libraries built on top of JAX, such as PIX.
Image manipulation functions#
|
Image resize. |
|
Apply a scale and translation to an image. |
Argument classes#
- class jax.image.ResizeMethod(value)[source]#
Image resize method.
Possible values are:
- NEAREST:
Nearest-neighbor interpolation.
- LINEAR:
- LANCZOS3:
Lanczos resampling, using a kernel of radius 3.
- LANCZOS5:
Lanczos resampling, using a kernel of radius 5.
- CUBIC:
Cubic interpolation, using the Keys cubic kernel.
jax.nn
module#
jax.nn.initializers
module#
Common neural network layer initializers, consistent with definitions used in Keras and Sonnet.
Initializers#
This module provides common neural network layer initializers, consistent with definitions used in Keras and Sonnet.
An initializer is a function that takes three arguments:
(key, shape, dtype)
and returns an array with dimensions shape
and
data type dtype
. Argument key
is a PRNG key (e.g. from
jax.random.key()
), used to generate random numbers to initialize the array.
|
Builds an initializer that returns arrays full of a constant |
|
Builds an initializer for delta orthogonal kernels. |
|
Builds a Glorot normal initializer (aka Xavier normal initializer). |
|
Builds a Glorot uniform initializer (aka Xavier uniform initializer). |
|
Builds a He normal initializer (aka Kaiming normal initializer). |
|
Builds a He uniform initializer (aka Kaiming uniform initializer). |
|
Builds a Lecun normal initializer. |
|
Builds a Lecun uniform initializer. |
|
Builds an initializer that returns real normally-distributed random arrays. |
|
An initializer that returns a constant array full of ones. |
|
Builds an initializer that returns uniformly distributed orthogonal matrices. |
|
Builds an initializer that returns truncated-normal random arrays. |
|
Builds an initializer that returns real uniformly-distributed random arrays. |
|
Initializer that adapts its scale to the shape of the weights tensor. |
|
An initializer that returns a constant array full of zeros. |
Common functions for neural network libraries.
Activation functions#
Rectified linear unit activation function. |
|
Rectified Linear Unit 6 activation function. |
|
|
Sigmoid activation function. |
|
Softplus activation function. |
|
Sparse plus function. |
|
Soft-sign activation function. |
|
SiLU (aka swish) activation function. |
|
SiLU (aka swish) activation function. |
|
Log-sigmoid activation function. |
|
Leaky rectified linear unit activation function. |
|
Hard Sigmoid activation function. |
|
Hard SiLU (swish) activation function |
|
Hard SiLU (swish) activation function |
|
Hard \(\mathrm{tanh}\) activation function. |
|
Exponential linear unit activation function. |
|
Continuously-differentiable exponential linear unit activation. |
|
Scaled exponential linear unit activation. |
|
Gaussian error linear unit activation function. |
|
Gated linear unit activation function. |
|
Squareplus activation function. |
|
Mish activation function. |
Other functions#
|
Softmax function. |
|
Log-Softmax function. |
Log-sum-exp reduction. |
|
|
Normalizes an array by subtracting |
|
One-hot encodes the given indices. |
jax.ops
module#
The functions jax.ops.index_update
, jax.ops.index_add
, etc., which were
deprecated in JAX 0.2.22, have been removed. Please use the
jax.numpy.ndarray.at
property on JAX arrays instead.
Segment reduction operators#
|
Computes the maximum within segments of an array. |
|
Computes the minimum within segments of an array. |
|
Computes the product within segments of an array. |
|
Computes the sum within segments of an array. |
jax.profiler
module#
Tracing and time profiling#
Profiling JAX programs describes how to make use of JAX’s tracing and time profiling features.
|
Starts the profiler server on port port. |
|
Starts a profiler trace. |
Stops the currently-running profiler trace. |
|
|
Context manager to take a profiler trace. |
|
Decorator that generates a trace event for the execution of a function. |
Context manager that generates a trace event in the profiler. |
|
|
Context manager that generates a step trace event in the profiler. |
Device memory profiling#
See Device Memory Profiling for an introduction to JAX’s device memory profiling features.
|
Captures a JAX device memory profile as |
|
Collects a device memory profile and writes it to a file. |
jax.stages
module#
Interfaces to stages of the compiled execution process.
JAX transformations that compile just in time for execution, such as
jax.jit
and jax.pmap
, also support a common means of explicit
lowering and compilation ahead of time. This module defines types
that represent the stages of this process.
For more, see the AOT walkthrough.
Classes#
- class jax.stages.Wrapped(*args, **kwargs)[source]#
A function ready to be specialized, lowered, and compiled.
This protocol reflects the output of functions such as
jax.jit
. Calling it results in JIT (just-in-time) lowering, compilation, and execution. It can also be explicitly lowered prior to compilation, and the result compiled prior to execution.- lower(*args, **kwargs)[source]#
Lower this function explicitly for the given arguments.
A lowered function is staged out of Python and translated to a compiler’s input language, possibly in a backend-dependent manner. It is ready for compilation but not yet compiled.
- Returns:
A
Lowered
instance representing the lowering.- Return type:
- class jax.stages.Lowered(lowering, args_info, out_tree, no_kwargs=False)[source]#
Lowering of a function specialized to argument types and values.
A lowering is a computation ready for compilation. This class carries a lowering together with the remaining information needed to later compile and execute it. It also provides a common API for querying properties of lowered computations across JAX’s various lowering paths (
jit()
,pmap()
, etc.).- as_text(dialect=None)[source]#
A human-readable text representation of this lowering.
Intended for visualization and debugging purposes. This need not be a valid nor reliable serialization. It is relayed directly to external callers.
- compiler_ir(dialect=None)[source]#
An arbitrary object representation of this lowering.
Intended for debugging purposes. This is not a valid nor reliable serialization. The output has no guarantee of consistency across invocations.
Returns
None
if unavailable, e.g. based on backend, compiler, or runtime.
- cost_analysis()[source]#
A summary of execution cost estimates.
Intended for visualization and debugging purposes. The object output by this is some simple data structure that can easily be printed or serialized (e.g. nested dicts, lists, and tuples with numeric leaves). However, its structure can be arbitrary: it may be inconsistent across versions of JAX and jaxlib, or even across invocations.
Returns
None
if unavailable, e.g. based on backend, compiler, or runtime.- Return type:
Any | None
- class jax.stages.Compiled(executable, args_info, out_tree, no_kwargs=False)[source]#
Compiled representation of a function specialized to types/values.
A compiled computation is associated with an executable and the remaining information needed to execute it. It also provides a common API for querying properties of compiled computations across JAX’s various compilation paths and backends.
- Parameters:
args_info (Any)
out_tree (PyTreeDef)
- as_text()[source]#
A human-readable text representation of this executable.
Intended for visualization and debugging purposes. This is not a valid nor reliable serialization.
Returns
None
if unavailable, e.g. based on backend, compiler, or runtime.- Return type:
str | None
- cost_analysis()[source]#
A summary of execution cost estimates.
Intended for visualization and debugging purposes. The object output by this is some simple data structure that can easily be printed or serialized (e.g. nested dicts, lists, and tuples with numeric leaves). However, its structure can be arbitrary: it may be inconsistent across versions of JAX and jaxlib, or even across invocations.
Returns
None
if unavailable, e.g. based on backend, compiler, or runtime.- Return type:
Any | None
- property in_tree: PyTreeDef[source]#
Tree structure of the pair (positional arguments, keyword arguments).
- memory_analysis()[source]#
A summary of estimated memory requirements.
Intended for visualization and debugging purposes. The object output by this is some simple data structure that can easily be printed or serialized (e.g. nested dicts, lists, and tuples with numeric leaves). However, its structure can be arbitrary: it may be inconsistent across versions of JAX and jaxlib, or even across invocations.
Returns
None
if unavailable, e.g. based on backend, compiler, or runtime.- Return type:
Any | None
- runtime_executable()[source]#
An arbitrary object representation of this executable.
Intended for debugging purposes. This is not valid nor reliable serialization. The output has no guarantee of consistency across invocations.
Returns
None
if unavailable, e.g. based on backend, compiler, or runtime.- Return type:
Any | None
jax.tree
module#
Utilities for working with tree-like container data structures.
The jax.tree
namespace contains aliases of utilities from jax.tree_util
.
List of Functions#
|
Call all() over the leaves of a tree. |
|
Flattens a pytree. |
|
Gets the leaves of a pytree. |
|
Maps a multi-input function over pytree args to produce a new pytree. |
|
Call reduce() over the leaves of a tree. |
|
Gets the treedef for a pytree. |
|
Transform a tree having tree structure (outer, inner) into one having structure (inner, outer). |
|
Reconstructs a pytree from the treedef and the leaves. |
jax.tree_util
module#
Utilities for working with tree-like container data structures.
This module provides a small set of utility functions for working with tree-like data structures, such as nested tuples, lists, and dicts. We call these structures pytrees. They are trees in that they are defined recursively (any non-pytree is a pytree, i.e. a leaf, and any pytree of pytrees is a pytree) and can be operated on recursively (object identity equivalence is not preserved by mapping operations, and the structures cannot contain reference cycles).
The set of Python types that are considered pytree nodes (e.g. that can be mapped over, rather than treated as leaves) is extensible. There is a single module-level registry of types, and class hierarchy is ignored. By registering a new pytree node type, that type in effect becomes transparent to the utility functions in this file.
The primary purpose of this module is to enable the interoperability between user defined data structures and JAX transformations (e.g. jit). This is not meant to be a general purpose tree-like data structure handling library.
See the JAX pytrees note for examples.
List of Functions#
|
A version of functools.partial that works in pytrees. |
|
Tests whether all elements in the given iterable are all leaves. |
|
|
|
Extends the set of types that are considered internal nodes in pytrees. |
Extends the set of types that are considered internal nodes in pytrees. |
|
|
Extends the set of types that are considered internal nodes in pytrees. |
Extends the set of types that are considered internal nodes in pytrees. |
|
|
Call all() over the leaves of a tree. |
|
Flattens a pytree. |
|
Flattens a pytree like |
|
Gets the leaves of a pytree. |
|
Gets the leaves of a pytree like |
|
Maps a multi-input function over pytree args to produce a new pytree. |
|
Maps a multi-input function over pytree key path and args to produce a new pytree. |
Call reduce() over the leaves of a tree. |
|
|
Gets the treedef for a pytree. |
|
Transform a tree having tree structure (outer, inner) into one having structure (inner, outer). |
|
Reconstructs a pytree from the treedef and the leaves. |
|
|
|
|
|
Makes a tuple treedef from an iterable of child treedefs. |
|
Helper to pretty-print a tuple of keys. |
jax.typing
module#
The JAX typing module is where JAX-specific static type annotations live. This submodule is a work in progress; to see the proposal behind the types exported here, see https://jax.readthedocs.io/en/latest/jep/12049-type-annotations.html.
The currently-available types are:
jax.Array
: annotation for any JAX array or tracer (i.e. representations of arrays within JAX transforms).jax.typing.ArrayLike
: annotation for any value that is safe to implicitly cast to a JAX array; this includesjax.Array
,numpy.ndarray
, as well as Python builtin numeric values (e.g.int
,float
, etc.) and numpy scalar values (e.g.numpy.int32
,numpy.flota64
, etc.)jax.typing.DTypeLike
: annotation for any value that can be cast to a JAX-compatible dtype; this includes strings (e.g. ‘float32’, ‘int32’), scalar types (e.g. float, np.float32), dtypes (e.g. np.dtype(‘float32’)), or objects with a dtype attribute (e.g. jnp.float32, jnp.int32).
We may add additional types here in future releases.
JAX Typing Best Practices#
When annotating JAX arrays in public API functions, we recommend using ArrayLike
for array inputs, and Array
for array outputs.
For example, your function might look like this:
import numpy as np
import jax.numpy as jnp
from jax import Array
from jax.typing import ArrayLike
def my_function(x: ArrayLike) -> Array:
# Runtime type validation, Python 3.10 or newer:
if not isinstance(x, ArrayLike):
raise TypeError(f"Expected arraylike input; got {x}")
# Runtime type validation, any Python version:
if not (isinstance(x, (np.ndarray, Array)) or np.isscalar(x)):
raise TypeError(f"Expected arraylike input; got {x}")
# Convert input to jax.Array:
x_arr = jnp.asarray(x)
# ... do some computation; JAX functions will return Array types:
result = x_arr.sum(0) / x_arr.shape[0]
# return an Array
return result
Most of JAX’s public APIs follow this pattern. Note in particular that we recommend JAX functions
to not accept sequences such as list
or tuple
in place of arrays, as this can
cause extra overhead in JAX transforms like jit()
and can behave in unexpected ways with
batch-wise transforms like vmap()
or jax.pmap()
. For more information on this,
see Non-array inputs NumPy vs JAX
List of Members#
Type annotation for JAX array-like objects. |
|
jax.extend
module#
Modules for JAX extensions.
The jax.extend
package provides modules for access to JAX
internal machinery. See
JEP #15856.
API policy#
Unlike the public API, this package offers no compatibility guarantee across releases. Breaking changes will be announced via the JAX project changelog.
Modules#
jax.extend.linear_util
module#
|
Represents a function f to which transforms are to be applied. |
|
Memoization decorator for functions taking a WrappedFun as first argument. |
|
|
Adds one more transformation to a WrappedFun. |
|
Adds one more transformation with auxiliary output to a WrappedFun. |
|
|
Wraps function f as a WrappedFun, suitable for transformation. |
jax.extend.mlir
module#
jax.extend.random
module#
|
|
|
|
|
Apply the Threefry 2x32 hash. |
Specifies PRNG key shape and operations. |
|
Specifies PRNG key shape and operations. |
|
Specifies PRNG key shape and operations. |
jax.example_libraries
module#
JAX provides some small, experimental libraries for machine learning. These libraries are in part about providing tools and in part about serving as examples for how to build such libraries using JAX. Each one is only <300 source lines of code, so take a look inside and adapt them as you need!
Note
Each mini-library is meant to be an inspiration, but not a prescription.
To serve that purpose, it is best to keep their code samples minimal; so we generally will not merge PRs adding new features. Instead, please send your lovely pull requests and design ideas to more fully-featured libraries like Haiku or Flax.
jax.example_libraries.optimizers
module#
Examples of how to write optimizers with JAX.
You likely do not mean to import this module! The optimizers in this library are intended as examples only. If you are looking for a fully featured optimizer library, two good options are JAXopt and Optax.
This module contains some convenient optimizer definitions, specifically initialization and update functions, which can be used with ndarrays or arbitrarily-nested tuple/list/dicts of ndarrays.
An optimizer is modeled as an (init_fun, update_fun, get_params)
triple of
functions, where the component functions have these signatures:
init_fun(params)
Args:
params: pytree representing the initial parameters.
Returns:
A pytree representing the initial optimizer state, which includes the
initial parameters and may also include auxiliary values like initial
momentum. The optimizer state pytree structure generally differs from that
of `params`.
update_fun(step, grads, opt_state)
Args:
step: integer representing the step index.
grads: a pytree with the same structure as `get_params(opt_state)`
representing the gradients to be used in updating the optimizer state.
opt_state: a pytree representing the optimizer state to be updated.
Returns:
A pytree with the same structure as the `opt_state` argument representing
the updated optimizer state.
get_params(opt_state)
Args:
opt_state: pytree representing an optimizer state.
Returns:
A pytree representing the parameters extracted from `opt_state`, such that
the invariant `params == get_params(init_fun(params))` holds true.
Notice that an optimizer implementation has a lot of flexibility in the form of opt_state: it just has to be a pytree of JaxTypes (so that it can be passed to the JAX transforms defined in api.py) and it has to be consumable by update_fun and get_params.
Example Usage:
opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
opt_state = opt_init(params)
def step(step, opt_state):
value, grads = jax.value_and_grad(loss_fn)(get_params(opt_state))
opt_state = opt_update(step, grads, opt_state)
return value, opt_state
for i in range(num_steps):
value, opt_state = step(i, opt_state)
- class jax.example_libraries.optimizers.JoinPoint(subtree)[source]#
Bases:
object
Marks the boundary between two joined (nested) pytrees.
- class jax.example_libraries.optimizers.Optimizer(init_fn, update_fn, params_fn)[source]#
Bases:
NamedTuple
- Parameters:
init_fn (Callable[[Any], OptimizerState])
update_fn (Callable[[int, Any, OptimizerState], OptimizerState])
params_fn (Callable[[OptimizerState], Any])
- init_fn: Callable[[Any], OptimizerState]#
Alias for field number 0
- params_fn: Callable[[OptimizerState], Any]#
Alias for field number 2
- update_fn: Callable[[int, Any, OptimizerState], OptimizerState]#
Alias for field number 1
- class jax.example_libraries.optimizers.OptimizerState(packed_state, tree_def, subtree_defs)#
Bases:
tuple
- packed_state#
Alias for field number 0
- subtree_defs#
Alias for field number 2
- tree_def#
Alias for field number 1
- jax.example_libraries.optimizers.adagrad(step_size, momentum=0.9)[source]#
Construct optimizer triple for Adagrad.
Adaptive Subgradient Methods for Online Learning and Stochastic Optimization: http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf
- Parameters:
step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to a positive scalar.
momentum – optional, a positive scalar value for momentum
- Returns:
An (init_fun, update_fun, get_params) triple.
- jax.example_libraries.optimizers.adam(step_size, b1=0.9, b2=0.999, eps=1e-08)[source]#
Construct optimizer triple for Adam.
- Parameters:
step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to a positive scalar.
b1 – optional, a positive scalar value for beta_1, the exponential decay rate for the first moment estimates (default 0.9).
b2 – optional, a positive scalar value for beta_2, the exponential decay rate for the second moment estimates (default 0.999).
eps – optional, a positive scalar value for epsilon, a small constant for numerical stability (default 1e-8).
- Returns:
An (init_fun, update_fun, get_params) triple.
- jax.example_libraries.optimizers.adamax(step_size, b1=0.9, b2=0.999, eps=1e-08)[source]#
Construct optimizer triple for AdaMax (a variant of Adam based on infinity norm).
- Parameters:
step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to a positive scalar.
b1 – optional, a positive scalar value for beta_1, the exponential decay rate for the first moment estimates (default 0.9).
b2 – optional, a positive scalar value for beta_2, the exponential decay rate for the second moment estimates (default 0.999).
eps – optional, a positive scalar value for epsilon, a small constant for numerical stability (default 1e-8).
- Returns:
An (init_fun, update_fun, get_params) triple.
- jax.example_libraries.optimizers.clip_grads(grad_tree, max_norm)[source]#
Clip gradients stored as a pytree of arrays to maximum norm max_norm.
- jax.example_libraries.optimizers.inverse_time_decay(step_size, decay_steps, decay_rate, staircase=False)[source]#
- jax.example_libraries.optimizers.l2_norm(tree)[source]#
Compute the l2 norm of a pytree of arrays. Useful for weight decay.
- jax.example_libraries.optimizers.momentum(step_size, mass)[source]#
Construct optimizer triple for SGD with momentum.
- jax.example_libraries.optimizers.nesterov(step_size, mass)[source]#
Construct optimizer triple for SGD with Nesterov momentum.
- jax.example_libraries.optimizers.optimizer(opt_maker)[source]#
Decorator to make an optimizer defined for arrays generalize to containers.
With this decorator, you can write init, update, and get_params functions that each operate only on single arrays, and convert them to corresponding functions that operate on pytrees of parameters. See the optimizers defined in optimizers.py for examples.
- Parameters:
opt_maker (Callable[[...], tuple[Callable[[Any], Any], Callable[[int, Any, Any], Any], Callable[[Any], Any]]]) –
a function that returns an
(init_fun, update_fun, get_params)
triple of functions that might only work with ndarrays, as perinit_fun :: ndarray -> OptStatePytree ndarray update_fun :: OptStatePytree ndarray -> OptStatePytree ndarray get_params :: OptStatePytree ndarray -> ndarray
- Returns:
An
(init_fun, update_fun, get_params)
triple of functions that work on arbitrary pytrees, as perinit_fun :: ParameterPytree ndarray -> OptimizerState update_fun :: OptimizerState -> OptimizerState get_params :: OptimizerState -> ParameterPytree ndarray
The OptimizerState pytree type used by the returned functions is isomorphic to
ParameterPytree (OptStatePytree ndarray)
, but may store the state instead as e.g. a partially-flattened data structure for performance.- Return type:
- jax.example_libraries.optimizers.pack_optimizer_state(marked_pytree)[source]#
Converts a marked pytree to an OptimizerState.
The inverse of unpack_optimizer_state. Converts a marked pytree with the leaves of the outer pytree represented as JoinPoints back into an OptimizerState. This function is intended to be useful when deserializing optimizer states.
- Parameters:
marked_pytree – A pytree containing JoinPoint leaves that hold more pytrees.
- Returns:
An equivalent OptimizerState to the input argument.
- jax.example_libraries.optimizers.polynomial_decay(step_size, decay_steps, final_step_size, power=1.0)[source]#
- jax.example_libraries.optimizers.rmsprop(step_size, gamma=0.9, eps=1e-08)[source]#
Construct optimizer triple for RMSProp.
- Parameters:
step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to a positive scalar. gamma: Decay parameter. eps: Epsilon parameter.
- Returns:
An (init_fun, update_fun, get_params) triple.
- jax.example_libraries.optimizers.rmsprop_momentum(step_size, gamma=0.9, eps=1e-08, momentum=0.9)[source]#
Construct optimizer triple for RMSProp with momentum.
This optimizer is separate from the rmsprop optimizer because it needs to keep track of additional parameters.
- Parameters:
step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to a positive scalar.
gamma – Decay parameter.
eps – Epsilon parameter.
momentum – Momentum parameter.
- Returns:
An (init_fun, update_fun, get_params) triple.
- jax.example_libraries.optimizers.sgd(step_size)[source]#
Construct optimizer triple for stochastic gradient descent.
- Parameters:
step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to a positive scalar.
- Returns:
An (init_fun, update_fun, get_params) triple.
- jax.example_libraries.optimizers.sm3(step_size, momentum=0.9)[source]#
Construct optimizer triple for SM3.
Memory-Efficient Adaptive Optimization for Large-Scale Learning. https://arxiv.org/abs/1901.11150
- Parameters:
step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to a positive scalar.
momentum – optional, a positive scalar value for momentum
- Returns:
An (init_fun, update_fun, get_params) triple.
- jax.example_libraries.optimizers.unpack_optimizer_state(opt_state)[source]#
Converts an OptimizerState to a marked pytree.
Converts an OptimizerState to a marked pytree with the leaves of the outer pytree represented as JoinPoints to avoid losing information. This function is intended to be useful when serializing optimizer states.
- Parameters:
opt_state – An OptimizerState
- Returns:
A pytree with JoinPoint leaves that contain a second level of pytrees.
jax.example_libraries.stax
module#
Stax is a small but flexible neural net specification library from scratch.
You likely do not mean to import this module! Stax is intended as an example library only. There are a number of other much more fully-featured neural network libraries for JAX, including Flax from Google, and Haiku from DeepMind.
- jax.example_libraries.stax.AvgPool(window_shape, strides=None, padding='VALID', spec=None)[source]#
Layer construction function for a pooling layer.
- jax.example_libraries.stax.BatchNorm(axis=(0, 1, 2), epsilon=1e-05, center=True, scale=True, beta_init=<function zeros>, gamma_init=<function ones>)[source]#
Layer construction function for a batch normalization layer.
- jax.example_libraries.stax.Conv(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)#
Layer construction function for a general convolution layer.
- jax.example_libraries.stax.Conv1DTranspose(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)#
Layer construction function for a general transposed-convolution layer.
- jax.example_libraries.stax.ConvTranspose(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)#
Layer construction function for a general transposed-convolution layer.
- jax.example_libraries.stax.Dense(out_dim, W_init=<function variance_scaling.<locals>.init>, b_init=<function normal.<locals>.init>)[source]#
Layer constructor function for a dense (fully-connected) layer.
- jax.example_libraries.stax.Dropout(rate, mode='train')[source]#
Layer construction function for a dropout layer with given rate.
- jax.example_libraries.stax.FanInConcat(axis=-1)[source]#
Layer construction function for a fan-in concatenation layer.
- jax.example_libraries.stax.GeneralConv(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)[source]#
Layer construction function for a general convolution layer.
- jax.example_libraries.stax.GeneralConvTranspose(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)[source]#
Layer construction function for a general transposed-convolution layer.
- jax.example_libraries.stax.MaxPool(window_shape, strides=None, padding='VALID', spec=None)[source]#
Layer construction function for a pooling layer.
- jax.example_libraries.stax.SumPool(window_shape, strides=None, padding='VALID', spec=None)[source]#
Layer construction function for a pooling layer.
- jax.example_libraries.stax.elementwise(fun, **fun_kwargs)[source]#
Layer that applies a scalar function elementwise on its inputs.
- jax.example_libraries.stax.parallel(*layers)[source]#
Combinator for composing layers in parallel.
The layer resulting from this combinator is often used with the FanOut and FanInSum layers.
- Parameters:
*layers – a sequence of layers, each an (init_fun, apply_fun) pair.
- Returns:
A new layer, meaning an (init_fun, apply_fun) pair, representing the parallel composition of the given sequence of layers. In particular, the returned layer takes a sequence of inputs and returns a sequence of outputs with the same length as the argument layers.
- jax.example_libraries.stax.serial(*layers)[source]#
Combinator for composing layers in serial.
- Parameters:
*layers – a sequence of layers, each an (init_fun, apply_fun) pair.
- Returns:
A new layer, meaning an (init_fun, apply_fun) pair, representing the serial composition of the given sequence of layers.
- jax.example_libraries.stax.shape_dependent(make_layer)[source]#
Combinator to delay layer constructor pair until input shapes are known.
- Parameters:
make_layer – a one-argument function that takes an input shape as an argument (a tuple of positive integers) and returns an (init_fun, apply_fun) pair.
- Returns:
A new layer, meaning an (init_fun, apply_fun) pair, representing the same layer as returned by make_layer but with its construction delayed until input shapes are known.
jax.experimental
module#
jax.experimental.optix
has been moved into its own Python package
(deepmind/optax).
jax.experimental.ann
has been moved into jax.lax
.
Experimental Modules#
jax.experimental.array_api
module#
This module includes experimental JAX support for the Python array API standard. Support for this is currently experimental and not fully complete.
Example Usage:
>>> from jax.experimental import array_api as xp
>>> xp.__array_api_version__
'2023.12'
>>> arr = xp.arange(1000)
>>> arr.sum()
Array(499500, dtype=int32)
The xp
namespace is the array API compliant analog of jax.numpy
,
and implements most of the API listed in the standard.
jax.experimental.checkify
module#
API#
|
Functionalize check calls in fun, and optionally add run-time error checks. |
|
Check a predicate, add an error with msg if predicate is False. |
|
Raise an Exception if |
|
|
frozenset() -> empty frozenset object frozenset(iterable) -> frozenset object |
|
frozenset() -> empty frozenset object frozenset(iterable) -> frozenset object |
|
frozenset() -> empty frozenset object frozenset(iterable) -> frozenset object |
|
frozenset() -> empty frozenset object frozenset(iterable) -> frozenset object |
|
frozenset() -> empty frozenset object frozenset(iterable) -> frozenset object |
|
frozenset() -> empty frozenset object frozenset(iterable) -> frozenset object |
|
frozenset() -> empty frozenset object frozenset(iterable) -> frozenset object |
jax.experimental.host_callback
module#
Primitives for calling Python functions on the host from JAX accelerator code.
Warning
The host_callback APIs are deprecated as of March 20, 2024. The functionality is subsumed by the new JAX external callbacks See google/jax#20385.
This module introduces the host callback functions call()
,
id_tap()
, and id_print()
, that send their arguments from the device
to the host and invoke user-defined Python functions on the host, optionally
returning results back to the device computation.
We show below how these functions can be used. We start with call()
,
and we discuss examples of calling from JAX to arbitrary Python functions
on the CPU, e.g., to use NumPy CPU custom kernels. Then we
show uses of id_tap()
and id_print()
, which have the restriction
that they cannot return values from the host to the device.
These primitives are generally faster
because they are executed asynchronously with the device code.
In particular, they can be used to tap into and to debug JAX code.
Using call()
to call a host function and return results to device#
Use call()
to invoke a computation on the host and return
NumPy arrays to the device computation.
Host computation is useful, e.g., when a device computation needs some data
that requires I/O on the host, or it needs a library that is available on the
host and you do not want to code it in JAX.
For example, eigen decomposition for general matrices in JAX does not work on TPU.
We can call the Numpy implementation from any JAX accelerator computation,
using a host computation:
# This function runs on the host
def host_eig(m: np.ndarray) -> np.ndarray:
return np.linalg.eigvals(m)
# This function is used in JAX
def device_fun(m):
# We send "m" to the host, asking it to call "host_eig" and return the result.
# We have to specify the result shape and dtype, either in the form of an
# example return value or any object that has `shape` and `dtype` attributes,
# e.g., a NumPy array or a `jax.ShapeDtypeStruct`.
return hcb.call(host_eig, m,
# Given an input of shape (..., d, d), eig output has shape (..., d)
result_shape=jax.ShapeDtypeStruct(m.shape[:-1], m.dtype))
The call()
function and the Python host function both take a single argument
and return a single result, but those can be pytrees. Note that we must tell
the call()
what shape and dtype to expect from the host invocation, using
the result_shape
keyword argument.
This is important because the device code is compiled with that expectation.
There will be an error raised at runtime if the actual invocation produces a
different result shape. In general, such errors and also exceptions raised
by the host computation may be difficult to debug. See the Debugging section
below.
This is a problem for call()
but not for id_tap()
because for the
latter the device code does not expect a returned value.
The call()
API can be used inside a jit or pmap computation or inside
cond/scan/while control flow. When used inside jax.pmap()
, there will be
separate calls to the host from each of the participating devices:
def host_sin(x, *, device):
# The ``device`` argument is passed due to ``call_with_device=True`` below.
print(f"Invoking host_sin with {x.shape} on {device}")
return np.sin(x)
# Use pmap to run the computation on two devices
jax.pmap(lambda x: hcb.call(host_sin, x,
result_shape=x,
# Ask that the `host_sin` function be passed `device=dev`
call_with_device=True))(
np.ones((2, 4), dtype=np.float32))
# prints (in arbitrary order)
# Invoking host_sin with (4,) on cpu:0
# Invoking host_sin with (4,) on cpu:1
Note that call()
does not support any JAX transformations, but as we
show below one can make use of the
existing support for Custom differentiation in JAX.
Using id_tap()
to call a Python function on the host, with no returned values#
The id_tap()
and id_print()
are special cases of call()
, when
you just want the side effects of your Python callback. These functions have
the advantage that once the arguments have been sent to the host, the device
computation can proceed without waiting for the Python callback to return.
For id_tap()
you can specify your Python callback to be called, while
id_print()
uses a built-in callback that prints the arguments to
stdout on the host.
The Python function passed
to id_tap()
takes two positional arguments (the value tapped
from the device computation along with a transforms
tuple,
described below). Optionally, the function may be passed a keyword argument
device
with the Device from which the value was tapped.
A few examples:
def host_func(arg, transforms):
...do something with arg...
# calls host_func(2x, []) on host
id_tap(host_func, 2 * x)
# calls host_func((2x, 3x), [])
id_tap(host_func, (2 * x, 3 * x)) # The argument can be a pytree
# calls host_func(2x, [], device=jax.devices()[0])
id_tap(host_func, 2 * x, tap_with_device=True) # Pass the device to the tap
# calls host_func(2x, [], what='activation')
id_tap(functools.partial(host_func, what='activation'), 2 * x)
# calls host_func(dict(x=x, y=y), what='data')
id_tap(lambda tap, transforms: host_func(tap, what='data'), dict(x=x, y=y))
The above examples can all be adapted to use id_print()
instead, with
the difference that id_print()
prints on the host the positional argument,
along with any additional kwargs and the automatic kwarg transforms
.
Using barrier_wait()
to wait until all callbacks have executed#
If your Python callbacks have side-effects you may need to wait until the
computation has finished to ensure that the side-effects have been observed.
You can use the barrier_wait()
function for that purpose:
accumulator = []
def host_log(arg, transforms):
# We just record the arguments in a list
accumulator.append(arg)
def device_fun(x):
id_tap(host_log, x)
id_tap(host_log, 2. * x)
jax.jit(device_fun)(1.)
jax.jit(device_fun)(1.)
# At this point, we have started two computations, each with two
# taps, but they may not have yet executed.
barrier_wait()
# Now we know that all the computations started before `barrier_wait`
# on all devices, have finished, and all the callbacks have finished
# executing.
Note that barrier_wait()
will start one
tiny computation with one tap on each of the jax.local_devices() and
will wait for all these taps to be received.
An alternative to using barrier_wait()
is to just wait for the end
of the computation, if all the callbacks are call()
:
accumulator = p[]
def host_log(arg):
# We just record the arguments in a list
accumulator.append(arg)
return 0. # return something
def device_fun(c):
y = call(host_log, x, result_shape=jax.ShapeDtypeStruct((), np.float32))
z = call(host_log, 2. * x, result_shape=jax.ShapeDtypeStruct((), np.float32))
return y + z # return something that uses both results
res1 = jax.jit(device_fun)(1.)
res2 = jax.jit(device_fun)(1.)
res1.block_until_ready()
res2.block_until_ready()
Behavior under parallelization transformations#
In presence of jax.pmap()
the code will run on multiple devices and
each device will tap its values independently.
It may be helpful to use the tap_with_device
option for id_print()
or id_tap()
, so that you see which device is sending which data:
jax.pmap(power3, devices=jax.local_devices()[:2])(np.array([3., 4.])
# device=cpu:0 what=x,x^2: (3., 9.) # from the first device
# device=cpu:1 what=x,x^2: (4., 16.) # from the second device
When using jax.pmap()
with multiple devices on multiple hosts, every
host will receive callbacks from all of its local devices, with an operand
that corresponds to each device slice. For a
call()
, the callback must return to each device only the slice of the
result that pertains to the corresponding device.
When using the experimental pjit.pjit()
the code will run on multiple
devices on different shards of the input. The current implementation of
host callbacks will ensure that a single device will collect and outfeed
the entire operand, in a single callback. The callback function is supposed
to return the entire array, which will then be sent in a single infeed to the
same device that issued the outfeed. This device is then responsible for
sending the required shards to the other devices:
with jax.sharding.Mesh(jax.local_devices()[:2], ["d"]):
pjit.pjit(power3, in_shardings=(P("d"),),
out_shardings=(P("d"),))(np.array([3., 4.]))
# device=TPU:0 what=x,x^2: ( [3., 4.],
# [9., 16.] )
Note that the collection of the operand on one device may result in OOM if the operand was sharded across devices.
When using pjit.pjit()
with multiple devices on multiple hosts, only
the host for the device 0 (w.r.t. the mesh) will receive the callback, with
the operand collected
from all participating devices on all hosts. For a call()
, the callback
must return the entire array for all devices on all hosts.
Behavior under JAX autodiff transformations#
When used under a JAX autodiff transformation, the host callback functions operate on the primal values only. Consider the following example:
def power3(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2")
return y * x
power3(3.)
# what: x,x^2 : (3., 9.)
(You can see these examples tested in host_callback_test.HostCallbackTapTest.test_tap_transforms.)
When used under jax.jvp()
there will be one callback with the primal
values only:
jax.jvp(power3, (3.,), (0.1,))
# what: x,x^2 : (3., 9.)
Similarly for jax.grad()
, we get a callback from the forward computation
only:
jax.grad(power3)(3.)
# what: x,x^2 : (3., 9.)
If you want to invoke the callback on the tangents during a jax.jvp()
,
you can use a custom_jvp. For example, you can define a function that does
nothing interesting except that its custom_jvp will print the tangents:
@jax.custom_jvp
def print_tangents(arg):
return None
@print_tangents.defjvp
def print_tangents_jvp(primals, tangents):
arg_dot, = tangents
hcb.id_print(arg_dot, what="tangents")
return primals, tangents
Then you use this function in the places where you want to tap the tangents:
def power3_with_tangents(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2")
print_tangents((x, y))
return y * x
jax.jvp(power3_with_tangents, (3.,), (0.1,))
# what: x,x^2 : (3., 9.)
# what: tangents : (0.1, 0.6)
You can do a similar thing for the cotangents during jax.grad()
. This
time you must be careful to use in the rest of the computation the values whose
cotangents you want to tap. Hence we make the print_cotangents
return
its argument:
@jax.custom_vjp
def print_cotangents(arg):
# Must return the argument for which we want the cotangent.
return arg
# f_fwd: a -> (b, residual)
def print_cotangents_fwd(arg):
return print_cotangents(arg), None
# f_bwd: (residual, CT b) -> [CT a]
def print_cotangents_bwd(residual, ct_b):
hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream)
return ct_b,
print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd)
def power3_with_cotangents(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
(x1, y1) = print_cotangents((x, y))
# Must use the output of print_cotangents
return y1 * x1
jax.grad(power3_with_cotangents)(3.)
# what: x,x^2 : (3., 9.)
# what: cotangents : (9., 3.)
If you use ad_checkpoint.checkpoint()
to rematerialize the residuals
for the backward pass, then the callbacks from the primal computation will
be called twice:
jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.)
# what: x,x^2 : (3., 9.)
# what: x,x^2 : (27., 729.)
# what: x,x^2 : (3., 9.)
The callbacks are, in order from: the primal computation of the inner power3
,
the primal computation of the outer power3
, and the rematerialization
of the residuals for the inner power3
.
Behavior under jax.vmap#
The host callback functions id_print()
and id_tap()
support the
vectorization transformation jax.vmap()
.
For jax.vmap()
the arguments to the callback are batched,
and the callback function is
passed an additional special transforms
containing a list of transformation descriptors
in the form ("batch", {"batch_dims": ...})
, where ...`
denotes the
batched dimensions for the tapped values (one entry per argument, `
None` denotes an argument that was broadcast).
jax.vmap(power3)(np.array([2., 3.])) # transforms: [(‘batch’, {‘batch_dims’: (0, 0)})] what: x,x^2 : ([2., 3.], [4., 9.])
See documentation for id_tap()
, id_print()
, and call()
.
For more usage example, see tests/host_callback_test.py.
Using call()
to call a TensorFlow function, with reverse-mode autodiff support#
Another possible use for host computation is to invoke a library written for
another framework, such as TensorFlow.
In this case it becomes interesting to support JAX autodiff for host callbacks
by deferring to the autodiff mechanism in TensorFlow,
using the jax.custom_vjp()
mechanism.
This is relatively easy to do, once one understands both the JAX custom VJP
and the TensorFlow autodiff mechanisms.
The code for how this can be done is shown in the call_tf_full_ad
function in host_callback_to_tf_test.py.
This example supports arbitrary higher-order differentiation as well.
Note that if you just want to call TensorFlow functions from JAX, you can also use the jax2tf.call_tf function.
Using call()
to call a JAX function on another device, with reverse-mode autodiff support#
It should not be surprising that we can use host computation to invoke a JAX computation on another device. The arguments are sent from the accelerator to the host, and then to the outside device on which the JAX host computation will run, and then the results are sent back to the original accelerator.
The code for how this can be done is shown in the call_jax_other_device function
in host_callback_test.py.
Low-level details and debugging#
The host callback functions will be executed for each device in the order in which the send operations were performed on the device.
The host callback functions for multiple devices may be interleaved.
The data from the devices is received by separate threads managed by the JAX
runtime (one thread per device). The runtime maintains a buffer of
configurable size (see the flag --jax_host_callback_max_queue_byte_size
).
When the buffer is full, all the receiving threads are paused
which eventually pauses the computation on devices. The runtime has one
additional thread for each device to invoke the Python user functions with the
received data. If the processing of the callbacks is slow, it may actually
lead to the runtime buffer filling up, and eventually pausing the computation
on the devices when they need to send something.
For more details on the outfeed receiver runtime mechanism see
runtime code.
In order to pause the execution until all data from computations already
started on devices has arrived and has been processed, use barrier_wait()
.
Exceptions from the user-defined callback functions are logged along with their
stack traces, but the receiving threads are not stopped. Instead the last
exception is recorded and the subsequent barrier_wait()
will
raise CallbackException
if any exception had occurred
in one of the tap functions. This exception will include the text and the
stack trace of the last exception encountered.
One further complication arises for callback functions that must return
results to the call origin device, such as call()
. This is handled
differently on CPU/GPU devices compared to TPU devices.
On CPU/GPU devices, in order to avoid the device computation
being stuck waiting for a result that will never arrive, in case of any
error during the processing of the callback (whether raised by the user-code
itself or due to a mismatch of the returned value and the expected return_shape)
we send the device a “fake” result of shape int8[12345]
.
This will make the device
computation abort because the received data is different than the one that
it expects. On CPU the runtime will crash with a distinctive error message:
`
Check failed: buffer->length() == buffer_length (12345 vs. ...)
`
On GPU, the failure is more user-friendly and will be surfaced to the Python program as:
`
RET_CHECK failure ... Mismatch between infeed source buffer shape s8[12345] ...
`
To debug the underlying cause for these messages, see the Debugging section.
On TPU devices, there is currently no shape check for infeed, so we take the safer route of not sending this fake result in case of errors. This means that the computation will hang, and no exception will be raised (but any exceptions in the callback functions will still appear in the logs).
The current implementation uses the outfeed mechanism provided by XLA. The mechanism itself is quite primitive in the sense that a receiver must know exactly the shape of each incoming packet, and how many packets are expected. This makes it hard to use for multiple kinds of data in the same computation, and it is practically impossible to use it under conditionals or in loops of non-constant iteration count. Furthermore, code that uses the outfeed mechanism directly cannot be transformed by JAX. All these limitations are addressed by the host callback functions. The tapping API introduced here makes it easy to share the outfeed mechanism for multiple purposes, while supporting all transformations.
Note that after you have used the host callback functions, you cannot
use lax.outfeed directly. You may want to stop_outfeed_receiver()
if you later need to use lax.outfeed.
Since the actual calls to your callback functions are made from the C++
receiver, it may be hard to debug the calls. In particular, the stack trace
will not include the calling code. You can use the flag
jax_host_callback_inline
(or the environment variable
JAX_HOST_CALLBACK_INLINE
) to ensure that the calls to the callbacks are
inlined. This works only if the calls are outside a staging context
(jit()
or a control-flow primitive).
The C++ receiver
is started automatically on the first call to id_tap()
. In order to stop
it properly, upon start an atexit
handler is registered to call
barrier_wait()
with the logging name “at_exit”.
There are a few environment variables that you can use to turn on logging for the C++ outfeed receiver backend.
TF_CPP_MIN_LOG_LEVEL=0
: will turn on INFO logging, needed for all below.
TF_CPP_MIN_VLOG_LEVEL=3
: will make all VLOG logging up to level 3 behave like INFO logs. This may be too much, but you will see which modules are logging relevant info, and then you can select which modules to log from.
TF_CPP_VMODULE=<module_name>=3
(the module name can be either C++ or Python, without the extension).
You should also use the --verbosity=2
flag so that you see the logs
from Python.
For example, you can try to enable logging in the host_callback
module:
TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=host_callback=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple
If you want to enable logging in lower-level implementation modules try:
TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,infeed_thunk=3,cpu_transfer_manager=3,cpu_runtime=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple
(For bazel tests use –test_arg=–vmodule=…
- Still to do:
More performance tests.
Explore implementation with outside compilation for TPU.
Explore implementation with XLA CustomCall for CPU and GPU.
API#
|
Host-callback tap primitive, like identity function with a call to |
|
Like |
|
Make a call to the host, and expect a result. |
|
Blocks the calling thread until all current outfeed is processed. |
Signals that some callback function had exceptions. |
jax.experimental.maps
module#
API#
|
Assign a positional signature to a program that uses named array axes. |
jax.experimental.pjit
module#
API#
- jax.experimental.pjit.pjit(fun, in_shardings=UnspecifiedValue, out_shardings=UnspecifiedValue, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None)[source]#
Makes
fun
compiled and automatically partitioned across multiple devices.NOTE: This function is now equivalent to jax.jit please use that instead. The returned function has semantics equivalent to those of
fun
, but is compiled to an XLA computation that runs across multiple devices (e.g. multiple GPUs or multiple TPU cores). This can be useful if the jitted version offun
would not fit in a single device’s memory, or to speed upfun
by running each operation in parallel across multiple devices.The partitioning over devices happens automatically based on the propagation of the input partitioning specified in
in_shardings
and the output partitioning specified inout_shardings
. The resources specified in those two arguments must refer to mesh axes, as defined by thejax.sharding.Mesh()
context manager. Note that the mesh definition atpjit()
application time is ignored, and the returned function will use the mesh definition available at each call site.Inputs to a
pjit()
’d function will be automatically partitioned across devices if they’re not already correctly partitioned based onin_shardings
. In some scenarios, ensuring that the inputs are already correctly pre-partitioned can increase performance. For example, if passing the output of onepjit()
’d function to anotherpjit()
’d function (or the samepjit()
’d function in a loop), make sure the relevantout_shardings
match the correspondingin_shardings
.Note
Multi-process platforms: On multi-process platforms such as TPU pods,
pjit()
can be used to run computations across all available devices across processes. To achieve this,pjit()
is designed to be used in SPMD Python programs, where every process is running the same Python code such that all processes run the samepjit()
’d function in the same order.When running in this configuration, the mesh should contain devices across all processes. However, any input argument dimensions partitioned over multi-process mesh axes should be of size equal to the corresponding local mesh axis size, and outputs will be similarly sized according to the local mesh.
fun
will still be executed across all devices in the mesh, including those from other processes, and will be given a global view of the data spread across multiple processes as a single array. However, outside ofpjit()
every process only “sees” its local piece of the input and output, corresponding to its local sub-mesh.This means that each process’s participating local devices must form a _contiguous_ local sub-mesh within the full global mesh. A contiguous sub-mesh is one where all of its devices are adjacent within the global mesh, and form a rectangular prism.
The SPMD model also requires that the same multi-process
pjit()
’d functions must be run in the same order on all processes, but they can be interspersed with arbitrary operations running in a single process.- Parameters:
fun (Callable) – Function to be compiled. Should be a pure function, as side-effects may only be executed once. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. Positional arguments indicated by
static_argnums
can be anything at all, provided they are hashable and have an equality operation defined. Static arguments are included as part of a compilation cache key, which is why hash and equality operators must be defined.in_shardings –
Pytree of structure matching that of arguments to
fun
, with all actual arguments replaced by resource assignment specifications. It is also valid to specify a pytree prefix (e.g. one value in place of a whole subtree), in which case the leaves get broadcast to all values in that subtree.The
in_shardings
argument is optional. JAX will infer the shardings from the inputjax.Array
’s, and defaults to replicating the input if the sharding cannot be inferred.The valid resource assignment specifications are:
XLACompatibleSharding
, which will decide how the value will be partitioned. With this, using a mesh context manager is not required.None
is a special case whose semantics are:if the mesh context manager is not provided, JAX has the freedom to choose whatever sharding it wants. For in_shardings, JAX will mark is as replicated but this behavior can change in the future. For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings.
If the mesh context manager is provided, None will imply that the value will be replicated on all devices of the mesh.
For backwards compatibility, in_shardings still supports ingesting
PartitionSpec
. This option can only be used with the mesh context manager.PartitionSpec
, a tuple of length at most equal to the rank of the partitioned value. Each element can be aNone
, a mesh axis or a tuple of mesh axes, and specifies the set of resources assigned to partition the value’s dimension matching its position in the spec.
The size of every dimension has to be a multiple of the total number of resources assigned to it.
out_shardings – Like
in_shardings
, but specifies resource assignment for function outputs. Theout_shardings
argument is optional. If not specified,jax.jit()
will use GSPMD’s sharding propagation to determine how to shard the outputs.static_argnums (int | Sequence[int] | None) –
An optional int or collection of ints that specify which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded in Python (during tracing), and so the corresponding argument values can be any Python object.
Static arguments should be hashable, meaning both
__hash__
and__eq__
are implemented, and immutable. Calling the jitted function with different values for these constants will trigger recompilation. Arguments that are not arrays or containers thereof must be marked as static.If
static_argnums
is not provided, no arguments are treated as static.static_argnames (str | Iterable[str] | None) – An optional string or collection of strings specifying which named arguments to treat as static (compile-time constant). See the comment on
static_argnums
for details. If not provided butstatic_argnums
is set, the default is based on callinginspect.signature(fun)
to find corresponding named arguments.donate_argnums (int | Sequence[int] | None) –
Specify which positional argument buffers are “donated” to the computation. It is safe to donate argument buffers if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to. By default, no argument buffers are donated.
If neither
donate_argnums
nordonate_argnames
is provided, no arguments are donated. Ifdonate_argnums
is not provided butdonate_argnames
is, or vice versa, JAX usesinspect.signature(fun)
to find any positional arguments that correspond todonate_argnames
(or vice versa). If bothdonate_argnums
anddonate_argnames
are provided,inspect.signature
is not used, and only actual parameters listed in eitherdonate_argnums
ordonate_argnames
will be donated.For more details on buffer donation see the FAQ.
donate_argnames (str | Iterable[str] | None) – An optional string or collection of strings specifying which named arguments are donated to the computation. See the comment on
donate_argnums
for details. If not provided butdonate_argnums
is set, the default is based on callinginspect.signature(fun)
to find corresponding named arguments.keep_unused (bool) – If False (the default), arguments that JAX determines to be unused by fun may be dropped from resulting compiled XLA executables. Such arguments will not be transferred to the device nor provided to the underlying executable. If True, unused arguments will not be pruned.
device (Device | None) – This argument is deprecated. Please put your arguments on the device you want before passing them to jit. Optional, the Device the jitted function will run on. (Available devices can be retrieved via
jax.devices()
.) The default is inherited from XLA’s DeviceAssignment logic and is usually to usejax.devices()[0]
.backend (str | None) – This argument is deprecated. Please put your arguments on the backend you want before passing them to jit. Optional, a string representing the XLA backend:
'cpu'
,'gpu'
, or'tpu'
.inline (bool)
abstracted_axes (Any | None)
- Returns:
A wrapped version of
fun
, set up for just-in-time compilation and automatically partitioned by the mesh available at each call site.- Return type:
JitWrapped
For example, a convolution operator can be automatically partitioned over an arbitrary set of devices by a single
pjit()
application:>>> import jax >>> import jax.numpy as jnp >>> import numpy as np >>> from jax.sharding import Mesh, PartitionSpec >>> from jax.experimental.pjit import pjit >>> >>> x = jnp.arange(8, dtype=jnp.float32) >>> f = pjit(lambda x: jax.numpy.convolve(x, jnp.asarray([0.5, 1.0, 0.5]), 'same'), ... in_shardings=None, out_shardings=PartitionSpec('devices')) >>> with Mesh(np.array(jax.devices()), ('devices',)): ... print(f(x)) [ 0.5 2. 4. 6. 8. 10. 12. 10. ]
jax.experimental.sparse
module#
The jax.experimental.sparse
module includes experimental support for sparse matrix
operations in JAX. It is under active development, and the API is subject to change. The
primary interfaces made available are the BCOO
sparse array type, and the
sparsify()
transform.
Batched-coordinate (BCOO) sparse matrices#
The main high-level sparse object currently available in JAX is the BCOO
,
or batched coordinate sparse array, which offers a compressed storage format compatible
with JAX transformations, in particular JIT (e.g. jax.jit()
), batching
(e.g. jax.vmap()
) and autodiff (e.g. jax.grad()
).
Here is an example of creating a sparse array from a dense array:
>>> from jax.experimental import sparse
>>> import jax.numpy as jnp
>>> import numpy as np
>>> M = jnp.array([[0., 1., 0., 2.],
... [3., 0., 0., 0.],
... [0., 0., 4., 0.]])
>>> M_sp = sparse.BCOO.fromdense(M)
>>> M_sp
BCOO(float32[3, 4], nse=4)
Convert back to a dense array with the todense()
method:
>>> M_sp.todense()
Array([[0., 1., 0., 2.],
[3., 0., 0., 0.],
[0., 0., 4., 0.]], dtype=float32)
The BCOO format is a somewhat modified version of the standard COO format, and the dense
representation can be seen in the data
and indices
attributes:
>>> M_sp.data # Explicitly stored data
Array([1., 2., 3., 4.], dtype=float32)
>>> M_sp.indices # Indices of the stored data
Array([[0, 1],
[0, 3],
[1, 0],
[2, 2]], dtype=int32)
BCOO objects have familiar array-like attributes, as well as sparse-specific attributes:
>>> M_sp.ndim
2
>>> M_sp.shape
(3, 4)
>>> M_sp.dtype
dtype('float32')
>>> M_sp.nse # "number of specified elements"
4
BCOO objects also implement a number of array-like methods, to allow you to use them directly within jax programs. For example, here we compute the transposed matrix-vector product:
>>> y = jnp.array([3., 6., 5.])
>>> M_sp.T @ y
Array([18., 3., 20., 6.], dtype=float32)
>>> M.T @ y # Compare to dense version
Array([18., 3., 20., 6.], dtype=float32)
BCOO objects are designed to be compatible with JAX transforms, including jax.jit()
,
jax.vmap()
, jax.grad()
, and others. For example:
>>> from jax import grad, jit
>>> def f(y):
... return (M_sp.T @ y).sum()
...
>>> jit(grad(f))(y)
Array([3., 3., 4.], dtype=float32)
Note, however, that under normal circumstances jax.numpy
and jax.lax
functions
do not know how to handle sparse matrices, so attempting to compute things like
jnp.dot(M_sp.T, y)
will result in an error (however, see the next section).
Sparsify transform#
An overarching goal of the JAX sparse implementation is to provide a means to switch from
dense to sparse computation seamlessly, without having to modify the dense implementation.
This sparse experiment accomplishes this through the sparsify()
transform.
Consider this function, which computes a more complicated result from a matrix and a vector input:
>>> def f(M, v):
... return 2 * jnp.dot(jnp.log1p(M.T), v) + 1
...
>>> f(M, y)
Array([17.635532, 5.158883, 17.09438 , 7.591674], dtype=float32)
Were we to pass a sparse matrix to this directly, it would result in an error, because jnp
functions do not recognize sparse inputs. However, with sparsify()
, we get a version of
this function that does accept sparse matrices:
>>> f_sp = sparse.sparsify(f)
>>> f_sp(M_sp, y)
Array([17.635532, 5.158883, 17.09438 , 7.591674], dtype=float32)
Support for sparsify()
includes a large number of the most common primitives, including:
generalized (batched) matrix products & einstein summations (
dot_general_p
)zero-preserving elementwise binary operations (e.g.
add_p
,mul_p
, etc.)zero-preserving elementwise unary operations (e.g.
abs_p
,jax.lax.neg_p
, etc.)summation reductions (
reduce_sum_p
)general indexing operations (
slice_p
, lax.dynamic_slice_p, lax.gather_p)concatenation and stacking (
concatenate_p
)transposition & reshaping ((
transpose_p
,reshape_p
,squeeze_p
,broadcast_in_dim_p
)some higher-order functions (
cond_p
,while_p
,scan_p
)some simple 1D convolutions (
conv_general_dilated_p
)
Nearly any jax.numpy
function that lowers to these supported primitives can be used
within a sparsify transform to operate on sparse arrays. This set of primitives is enough
to enable relatively sophisticated sparse workflows, as the next section will show.
Example: sparse logistic regression#
As an example of a more complicated sparse workflow, let’s consider a simple logistic regression implemented in JAX. Notice that the following implementation has no reference to sparsity:
>>> import functools
>>> from sklearn.datasets import make_classification
>>> from jax.scipy import optimize
>>> def sigmoid(x):
... return 0.5 * (jnp.tanh(x / 2) + 1)
...
>>> def y_model(params, X):
... return sigmoid(jnp.dot(X, params[1:]) + params[0])
...
>>> def loss(params, X, y):
... y_hat = y_model(params, X)
... return -jnp.mean(y * jnp.log(y_hat) + (1 - y) * jnp.log(1 - y_hat))
...
>>> def fit_logreg(X, y):
... params = jnp.zeros(X.shape[1] + 1)
... result = optimize.minimize(functools.partial(loss, X=X, y=y),
... x0=params, method='BFGS')
... return result.x
>>> X, y = make_classification(n_classes=2, random_state=1701)
>>> params_dense = fit_logreg(X, y)
>>> print(params_dense)
[-0.7298445 0.29893667 1.0248291 -0.44436368 0.8785025 -0.7724008
-0.62893456 0.2934014 0.82974285 0.16838408 -0.39774987 -0.5071844
0.2028872 0.5227761 -0.3739224 -0.7104083 2.4212713 0.6310087
-0.67060554 0.03139788 -0.05359547]
This returns the best-fit parameters of a dense logistic regression problem.
To fit the same model on sparse data, we can apply the sparsify()
transform:
>>> Xsp = sparse.BCOO.fromdense(X) # Sparse version of the input
>>> fit_logreg_sp = sparse.sparsify(fit_logreg) # Sparse-transformed fit function
>>> params_sparse = fit_logreg_sp(Xsp, y)
>>> print(params_sparse)
[-0.72971725 0.29878938 1.0246326 -0.44430563 0.8784217 -0.77225566
-0.6288222 0.29335397 0.8293481 0.16820715 -0.39764675 -0.5069753
0.202579 0.522672 -0.3740134 -0.7102678 2.4209507 0.6310593
-0.670236 0.03132951 -0.05356663]
Sparse API Reference#
|
Experimental sparsification transform. |
|
Sparse-aware version of |
|
Sparse-aware version of |
|
Create an empty sparse array. |
|
Create 2D sparse identity matrix. |
|
Convert input to a dense matrix. |
|
Generate a random BCOO matrix. |
|
Base class for high-level JAX sparse objects. |
BCOO Data Structure#
BCOO
is the Batched COO format, and is the main sparse data structure
implemented in jax.experimental.sparse
.
Its operations are compatible with JAX’s core transformations, including batching
(e.g. jax.vmap()
) and autodiff (e.g. jax.grad()
).
|
Experimental batched COO matrix implemented in JAX |
|
Expand the size and rank of a BCOO array by duplicating the data. |
|
Sparse implementation of |
|
A general contraction operation. |
|
A contraction operation with output computed at given sparse indices. |
|
Sparse implementation of {func}`jax.lax.dynamic_slice`. |
|
Extract values from a dense array according to the sparse array's indices. |
|
Create BCOO-format sparse matrix from a dense matrix. |
|
BCOO version of lax.gather. |
|
An element-wise multiplication between a sparse and a dense array. |
|
An element-wise multiplication of two sparse arrays. |
|
Update the storage layout (i.e. n_batch & n_dense) of a BCOO matrix. |
|
Sum array element over given axes. |
|
Sparse implementation of {func}`jax.lax.reshape`. |
|
Sparse implementation of {func}`jax.lax.slice`. |
|
Sort indices of a BCOO array. |
|
Sparse implementation of {func}`jax.lax.squeeze`. |
|
Sums duplicate indices within a BCOO array, returning an array with sorted indices. |
|
Convert batched sparse matrix to a dense matrix. |
|
Transpose a BCOO-format array. |
BCSR Data Structure#
BCSR
is the Batched Compressed Sparse Row format, and is under development.
Its operations are compatible with JAX’s core transformations, including batching
(e.g. jax.vmap()
) and autodiff (e.g. jax.grad()
).
|
Experimental batched CSR matrix implemented in JAX. |
|
A general contraction operation. |
|
Extract values from a dense matrix at given BCSR (indices, indptr). |
|
Create BCSR-format sparse matrix from a dense matrix. |
|
Convert batched sparse matrix to a dense matrix. |
Other Sparse Data Structures#
Other sparse data structures include COO
, CSR
, and CSC
. These are
reference implementations of simple sparse structures with a few core operations implemented.
Their operations are generally compatible with autodiff transformations such as jax.grad()
,
but not with batching transforms like jax.vmap()
.
|
Experimental COO matrix implemented in JAX. |
|
Experimental CSC matrix implemented in JAX; API subject to change. |
|
Experimental CSR matrix implemented in JAX. |
|
Create a COO-format sparse matrix from a dense matrix. |
|
Product of COO sparse matrix and a dense matrix. |
|
Product of COO sparse matrix and a dense vector. |
|
Convert a COO-format sparse matrix to a dense matrix. |
|
Create a CSR-format sparse matrix from a dense matrix. |
|
Product of CSR sparse matrix and a dense matrix. |
|
Product of CSR sparse matrix and a dense vector. |
|
Convert a CSR-format sparse matrix to a dense matrix. |
jax.experimental.sparse.linalg
#
Sparse linear algebra routines.
|
A sparse direct solver using QR factorization. |
|
Compute the top-k standard eigenvalues using the LOBPCG routine. |
jax.experimental.jet
module#
Jet is an experimental module for higher-order automatic differentiation that does not rely on repeated first-order automatic differentiation.
How? Through the propagation of truncated Taylor polynomials.
Consider a function \(f = g \circ h\), some point \(x\)
and some offset \(v\).
First-order automatic differentiation (such as jax.jvp()
)
computes the pair \((f(x), \partial f(x)[v])\) from the pair
\((h(x), \partial h(x)[v])\).
jet()
implements the higher-order analogue:
Given the tuple
which represents a \(K\)-th order Taylor approximation
of \(h\) at \(x\), jet()
returns a \(K\)-th order
Taylor approximation of \(f\) at \(x\),
More specifically, jet()
computes
and can thus be used for high-order automatic differentiation of \(f\). Details are explained in these notes.
Note
Help improve jet()
by contributing
outstanding primitive rules.
API#
- jax.experimental.jet.jet(fun, primals, series)[source]#
Taylor-mode higher-order automatic differentiation.
- Parameters:
fun – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars.
primals – The primal values at which the Taylor approximation of
fun
should be evaluated. Should be either a tuple or a list of arguments, and its length should be equal to the number of positional parameters offun
.series – Higher order Taylor-series-coefficients. Together, primals and series make up a truncated Taylor polynomial. Should be either a tuple or a list of tuples or lists, and its length dictates the degree of the truncated Taylor polynomial.
- Returns:
A
(primals_out, series_out)
pair, whereprimals_out
isfun(*primals)
, and together,primals_out
andseries_out
are a truncated Taylor polynomial of \(f(h(\cdot))\). Theprimals_out
value has the same Python tree structure asprimals
, and theseries_out
value the same Python tree structure asseries
.
For example:
>>> import jax >>> import jax.numpy as np
Consider the function \(h(z) = z^3\), \(x = 0.5\), and the first few Taylor coefficients \(h_0=x^3\), \(h_1=3x^2\), and \(h_2=6x\). Let \(f(y) = \sin(y)\).
>>> h0, h1, h2 = 0.5**3., 3.*0.5**2., 6.*0.5 >>> f, df, ddf = np.sin, np.cos, lambda *args: -np.sin(*args)
jet()
returns the Taylor coefficients of \(f(h(z)) = \sin(z^3)\) according to Faà di Bruno’s formula:>>> f0, (f1, f2) = jet(f, (h0,), ((h1, h2),)) >>> print(f0, f(h0)) 0.12467473 0.12467473
>>> print(f1, df(h0) * h1) 0.7441479 0.74414825
>>> print(f2, ddf(h0) * h1 ** 2 + df(h0) * h2) 2.9064622 2.9064634
jax.experimental.custom_partitioning
module#
API#
- jax.experimental.custom_partitioning.custom_partitioning(fun, static_argnums=())[source]#
Inserts a CustomCallOp into the XLA graph with custom SPMD lowering rules.
@custom_partitioning def f(*args): return ... def propagate_user_sharding(mesh, user_shape): '''Update the sharding of the op from a user's shape.sharding.''' user_sharding = jax.tree.map(lambda x: x.sharding, user_shape) def partition(mesh, arg_shapes, result_shape): def lower_fn(*args): ... builds computation on per-device shapes ... result_shardings = jax.tree.map(lambda x: x.sharding, result_shape) arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) # result_sharding and arg_shardings may optionally be modified and the # partitioner will insert collectives to reshape. return mesh, lower_fn, result_sharding, arg_shardings def infer_sharding_from_operands(mesh, arg_shapes, shape): '''Compute the result sharding from the sharding of the operands.''' arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) f.def_partition(partition, propagate_user_sharding, infer_sharding_from_operands)
The args to
def_partition
are as follows:propagate_user_sharding
: Callable which takes the sharding of a user (in the dag) and returns a suggestion for a new NamedSharding. The default implementation is just to return the suggested sharding.partition
: Callable which takes the SPMD suggested partition shapes and partition specs and returns the mesh, a per-shard lowering function, and the final input and output sharding specs (the SPMD partitioner will repartition the inputs to match). The mesh is returned to allow configuring axis_names for collectives when no mesh is provided.infer_sharding_from_operands
: Callable which computes an outputNamedSharding
from theNamedSharding
chosen for each argument.decode_shardings
: When set to True, convert inputGSPMDSharding``s to ``NamedSharding
if possible. This may not be possible if the user does not provide a contextual mesh.
Positional arguments can be specified as static using static_argnums. JAX uses
inspect.signature(fun)
to resolve these positional arguments.Example
As an example, assume we want to enhance the existing
jax.numpy.fft.fft
. This function computes the discrete Fourier transform of an N-dimensional input along the last dimension, and is batched along the first N-1 dimensions. By default, however, it will ignore the sharding of the input and gather the input on all devices. However, sincejax.numpy.fft.fft
is batched along the first N-1 dimensions, this is unnecessary. We will create a newmy_fft
op that, instead, does not alter the sharding along the first N-1 dimensions, and only gathers the input along the last dimension if needed.import jax from jax.sharding import NamedSharding from jax.experimental.custom_partitioning import custom_partitioning from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as P from jax.sharding import Mesh from jax.numpy.fft import fft import regex as re import numpy as np # Pattern to detect all-gather or dynamic-slice in the generated HLO _PATTERN = '(dynamic-slice|all-gather)' # For an N-D input, keeps sharding along the first N-1 dimensions # but replicate along the last dimension def supported_sharding(sharding, shape): rank = len(shape.shape) max_shared_dims = min(len(sharding.spec), rank-1) names = tuple(sharding.spec[:max_shared_dims]) + tuple(None for _ in range(rank - max_shared_dims)) return NamedSharding(sharding.mesh, P(*names)) def partition(mesh, arg_shapes, result_shape): result_shardings = jax.tree.map(lambda x: x.sharding, result_shape) arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) return mesh, fft, supported_sharding(arg_shardings[0], arg_shapes[0]), (supported_sharding(arg_shardings[0], arg_shapes[0]),) def infer_sharding_from_operands(mesh, arg_shapes, result_shape): arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) return supported_sharding(arg_shardings[0], arg_shapes[0]) @custom_partitioning def my_fft(x): return fft(x) my_fft.def_partition( infer_sharding_from_operands=infer_sharding_from_operands, partition=partition)
Now create a 2D array sharded along the first axis, pass it through
my_fft
and notice how it is still sharded as expected, and identical to the output offft
. However, inspecting the HLO (usinglower(x).compile().runtime_executable().hlo_modules()
) reveals thatmy_fft
does not create any all-gather or dynamic-slice, whilefft
does.with Mesh(np.array(jax.devices()), ('x',)): x = np.asarray(np.random.randn(32*1024, 1024), dtype=np.complex64) y = pjit(lambda x: x, in_shardings=None, out_shardings=P('x'))(x) pjit_my_fft = pjit(my_fft, in_shardings=P('x'), out_shardings=P('x')) pjit_fft = pjit(fft, in_shardings=P('x'), out_shardings=P('x')) print(pjit_my_fft(y)) print(pjit_fft(y)) # dynamic-slice or all-gather are not present in the HLO for my_fft, because x is a 2D array assert(re.search(_PATTERN, pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is None) # dynamic-slice or all-gather are present in the HLO for fft assert(re.search(_PATTERN, pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is not None)
# my_fft [[-38.840824 +0.j -40.649452 +11.845365j ... -1.6937828 +0.8402481j 15.999859 -4.0156755j]] # jax.numpy.fft.fft [[-38.840824 +0.j -40.649452 +11.845365j ... -1.6937828 +0.8402481j 15.999859 -4.0156755j]]
Because of the logic in
supported_sharding
,my_fft
also works on 1-dimensional arrays. However, in this case, the HLO ofmy_fft
does show a dynamic-slice, since the last dimension is the dimension along which FFTs are calculated and needs to be replicated on all devices before the computation can be done.with Mesh(np.array(jax.devices()), ('x',)): x = np.asarray(np.random.randn(32*1024*1024), dtype=np.complex64) y = pjit(lambda x: x, in_shardings=None, out_shardings=P('x'))(x) pjit_my_fft = pjit(my_fft, in_shardings=P('x'), out_shardings=P('x')) pjit_fft = pjit(fft, in_shardings=P('x'), out_shardings=P('x')) print(pjit_my_fft(y)) print(pjit_fft(y)) # dynamic-slice or all-gather are present in the HLO for my_fft, because x is a 1D array assert(re.search(_PATTERN, pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is None) # dynamic-slice or all-gather are present in the HLO for fft assert(re.search(_PATTERN, pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is not None)
# my_fft [ 7.217285 +0.j -3012.4937 +4287.635j -405.83594 +3042.984j ... 1422.4502 +7271.4297j -405.84033 -3042.983j -3012.4963 -4287.6343j] # jax.numpy.fft.fft [ 7.217285 +0.j -3012.4937 +4287.635j -405.83594 +3042.984j ... 1422.4502 +7271.4297j -405.84033 -3042.983j -3012.4963 -4287.6343j]
jax.experimental.multihost_utils
module#
Utilities for synchronizing and communication across multiple hosts.
Multihost Utils API Reference#
|
Broadcast data from a source host (host 0 by default) to all other hosts. |
|
Creates a barrier across all hosts/devices. |
|
Gather data from across processes. |
|
Verifies that all the hosts have the same tree of values. |
Converts a host local value to a globally sharded jax.Array. |
|
Converts a global jax.Array to a host local jax.Array. |
jax.experimental.compilation_cache
module#
JAX disk compilation cache.
API#
- jax.experimental.compilation_cache.compilation_cache.is_initialized()[source]#
Deprecated.
Return whether the cache is enabled. Initialization can be deferred, so initialized status is not checked. The name is retained for backwards compatibility.
- Return type:
- jax.experimental.compilation_cache.compilation_cache.initialize_cache(path)[source]#
This API is deprecated; use set_cache_dir instead.
Set the path. To take effect, should be called prior to any calls to get_executable_and_time() and put_executable_and_time().
- Return type:
None
- jax.experimental.compilation_cache.compilation_cache.set_cache_dir(path)[source]#
Sets the persistent compilation cache directory.
After calling this, jit-compiled functions are saved to path, so they do not need be recompiled if the process is restarted or otherwise run again. This also tells Jax where to look for compiled functions before compiling.
- Return type:
None
jax.experimental.key_reuse
module#
Experimental Key Reuse Checking#
This module contains experimental functionality for detecting reuse of random keys within JAX programs. It is under active development and the APIs here are likely to change. The usage below requires JAX version 0.4.26 or newer.
Key reuse checking can be enabled using the jax_debug_key_reuse
configuration.
This can be set globally using:
>>> jax.config.update('jax_debug_key_reuse', True)
Or it can be enabled locally with the jax.debug_key_reuse()
context manager.
When enabled, using the same key twice will result in a KeyReuseError
:
>>> import jax
>>> with jax.debug_key_reuse(True):
... key = jax.random.key(0)
... val1 = jax.random.normal(key)
... val2 = jax.random.normal(key)
Traceback (most recent call last):
...
KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0
The key reuse checker is currently experimental, but in the future we will likely enable it by default.
jax.experimental.mesh_utils
module#
Utils for building a device mesh.
API#
|
Creates a performant device mesh for jax.sharding.Mesh. |
|
Creates a device mesh for hybrid (e.g., ICI and DCN) parallelism. |
Experimental APIs#
|
Experimental context manager to temporarily enable X64 mode. |
Experimental context manager to temporarily disable X64 mode. |
|
|
Functionalize check calls in fun, and optionally add run-time error checks. |
|
Check a predicate, add an error with msg if predicate is False. |
Raise an Exception if |
jax.lib
module#
The jax.lib package is a set of internal tools and types for bridging between JAX’s Python frontend and its XLA backend.
jax.lib.xla_bridge#
Returns the platform name of the default XLA backend. |
|
|
|
|
Returns the compile options to use, as derived from flag values. |
jax.lib.xla_client#
Configuration#
Context manager for jax_check_tracer_leaks config option. |
|
Context manager for jax_check_tracer_leaks config option. |
|
Context manager for jax_debug_nans config option. |
|
Context manager for jax_debug_infs config option. |
|
Context manager for jax_default_device config option. |
|
Context manager for jax_default_matmul_precision config option. |
|
Context manager for jax_default_prng_impl config option. |
|
Context manager for jax_enable_checks config option. |
|
Context manager for jax_enable_custom_prng config option (transient). |
|
Context manager for jax_enable_custom_vjp_by_custom_transpose config option (transient). |
|
Context manager for jax_log_compiles config option. |
|
Context manager for jax_numpy_rank_promotion config option. |
|
|
A contextmanager to control the transfer guard level for all transfers. |
Just-in-time compilation (jit
)#
|
Sets up |
|
Context manager that disables |
Context manager to ensure evaluation at trace/compile time (or error). |
|
|
Creates a function that produces its XLA computation given example args. |
|
Creates a function that produces its jaxpr given example args. |
|
Compute the shape/dtype of |
|
A container for the shape, dtype, and other static attributes of an array. |
|
Transfers |
|
Transfer array(s) to each specified device and form Array(s). |
|
Transfer array shards to specified devices and form Array(s). |
|
Transfer |
Returns the platform name of the default XLA backend. |
|
|
Adds a user specified name to a function when staging out JAX computations. |
|
A context manager that adds a user specified name to the JAX name stack. |
Tries to call a |
Automatic differentiation#
|
Creates a function that evaluates the gradient of |
|
Create a function that evaluates both |
|
Jacobian of |
|
Jacobian of |
|
Hessian of |
|
Computes a (forward-mode) Jacobian-vector product of |
Produces a linear approximation to |
|
|
Transpose a function that is promised to be linear. |
|
Compute a (reverse-mode) vector-Jacobian product of |
|
Set up a JAX-transformable function for a custom JVP rule definition. |
|
Set up a JAX-transformable function for a custom VJP rule definition. |
|
Convenience function for defining custom VJP rules (aka custom gradients). |
|
Closure conversion utility, for use with higher-order custom derivatives. |
|
Make |
jax.Array (jax.Array
)#
|
Array base class for JAX |
|
Returns a |
|
Returns a |
Vectorization (vmap
)#
|
Vectorizing map. |
|
Define a vectorized function with broadcasting. |
Parallelization (pmap
)#
|
Parallel map with support for collective operations. |
|
Returns a list of all devices for a given backend. |
|
Like |
|
Returns the integer process index of this process. |
|
Returns the total number of devices. |
|
Returns the number of devices addressable by this process. |
|
Returns the number of JAX processes associated with the backend. |
Callbacks#
|
Calls a pure Python callback. |
|
Calls an impure Python callback. |
|
Calls a stageable Python callback. |
|
Prints values and works in staged out JAX functions. |
Miscellaneous#
A descriptor of an available device. |
|
|
Returns a string containing local environment & JAX installation information. |
|
Return all live arrays in the backend for platform. |
Clear all compilation and staging caches. |
Change log#
Best viewed here.
jax 0.4.28#
Deprecations & removals
The
kind
argument tojax.numpy.sort()
andjax.numpy.argsort()
is now removed. Usestable=True
orstable=False
instead.
jaxlib 0.4.28#
jax 0.4.27 (May 7, 2024)#
New Functionality
Added
jax.numpy.unstack()
andjax.numpy.cumulative_sum()
, following their addition in the array API 2023 standard, soon to be adopted by NumPy.Added a new config option
jax_cpu_collectives_implementation
to select the implementation of cross-process collective operations used by the CPU backend. Choices available are'none'
(default),'gloo'
and'mpi'
(requires jaxlib 0.4.26). If set to'none'
, cross-process collective operations are disabled.
Changes
jax.pure_callback()
,jax.experimental.io_callback()
andjax.debug.callback()
now usejax.Array
instead ofnp.ndarray
. You can recover the old behavior by transforming the arguments viajax.tree.map(np.asarray, args)
before passing them to the callback.complex_arr.astype(bool)
now follows the same semantics as NumPy, returning False wherecomplex_arr
is equal to0 + 0j
, and True otherwise.core.Token
now is a non-trivial class which wraps ajax.Array
. It could be created and threaded in and out of computations to build up dependency. The singleton objectcore.token
has been removed, users now should create and use freshcore.Token
objects instead.On GPU, the Threefry PRNG implementation no longer lowers to a kernel call by default. This choice can improve runtime memory usage at a compile-time cost. Prior behavior, which produces a kernel call, can be recovered with
jax.config.update('jax_threefry_gpu_kernel_lowering', True)
. If the new default causes issues, please file a bug. Otherwise, we intend to remove this flag in a future release.
Deprecations & Removals
Pallas now exclusively uses XLA for compiling kernels on GPU. The old lowering pass via Triton Python APIs has been removed and the
JAX_TRITON_COMPILE_VIA_XLA
environment variable no longer has any effect.jax.numpy.clip()
has a new argument signature:a
,a_min
, anda_max
are deprecated in favor ofx
(positional only),min
, andmax
(#20550).The
device()
method of JAX arrays has been removed, after being deprecated since JAX v0.4.21. Usearr.devices()
instead.The
initial
argument tojax.nn.softmax()
andjax.nn.log_softmax()
is deprecated; empty inputs to softmax are now supported without setting this.In
jax.jit()
, passing invalidstatic_argnums
orstatic_argnames
now leads to an error rather than a warning.The minimum jaxlib version is now 0.4.23.
The
jax.numpy.hypot()
function now issues a deprecation warning when passing complex-valued inputs to it. This will raise an error when the deprecation is completed.Scalar arguments to
jax.numpy.nonzero()
,jax.numpy.where()
, and related functions now raise an error, following a similar change in NumPy.The config option
jax_cpu_enable_gloo_collectives
is deprecated. Usejax.config.update('jax_cpu_collectives_implementation', 'gloo')
instead.The
jax.Array.device_buffer
andjax.Array.device_buffers
methods have been removed after being deprecated in JAX v0.4.22. Instead usejax.Array.addressable_shards
andjax.Array.addressable_data()
.The
condition
,x
, andy
parameters ofjax.numpy.where
are now positional-only, following deprecation of the keywords in JAX v0.4.21.Non-array arguments to functions in
jax.lax.linalg
now must be specified by keyword. Previously, this raised a DeprecationWarning.Array-like arguments are now required in several :func:
jax.numpy
APIs, includingapply_along_axis()
,apply_over_axes()
,inner()
,outer()
,cross()
,kron()
, andlexsort()
.
Bug fixes
jax.numpy.astype()
will now always return a copy whencopy=True
. Previously, no copy would be made when the output array would have the same dtype as the input array. This may result in some increased memory usage. The default value is set tocopy=False
to preserve backwards compatability.
jaxlib 0.4.27 (May 7, 2024)#
jax 0.4.26 (April 3, 2024)#
New Functionality
Added
jax.numpy.trapezoid()
, following the addition of this function in NumPy 2.0.
Changes
Complex-valued
jax.numpy.geomspace()
now chooses the logarithmic spiral branch consistent with that of NumPy 2.0.The behavior of
lax.rng_bit_generator
, and in turn the'rbg'
and'unsafe_rbg'
PRNG implementations, underjax.vmap
has changed so that mapping over keys results in random generation only from the first key in the batch.Docs now use
jax.random.key
for construction of PRNG key arrays rather thanjax.random.PRNGKey
.
Deprecations & Removals
jax.tree_map()
is deprecated; usejax.tree.map
instead, or for backward compatibility with older JAX versions, usejax.tree_util.tree_map()
.jax.clear_backends()
is deprecated as it does not necessarily do what its name suggests and can lead to unexpected consequences, e.g., it will not destroy existing backends and release corresponding owned resources. Usejax.clear_caches()
if you only want to clean up compilation caches. For backward compatibility or you really need to switch/reinitialize the default backend, usejax.extend.backend.clear_backends()
.The
jax.experimental.maps
module andjax.experimental.maps.xmap
are deprecated. Usejax.experimental.shard_map
orjax.vmap
with thespmd_axis_name
argument for expressing SPMD device-parallel computations.The
jax.experimental.host_callback
module is deprecated. Use instead the new JAX external callbacks. AddedJAX_HOST_CALLBACK_LEGACY
flag to assist in the transition to the new callbacks. See #20385 for a discussion.Passing arguments to
jax.numpy.array_equal()
andjax.numpy.array_equiv()
that cannot be converted to a JAX array now results in an exception.The deprecated flag
jax_parallel_functions_output_gda
has been removed. This flag was long deprecated and did nothing; its use was a no-op.The previously-deprecated imports
jax.interpreters.ad.config
andjax.interpreters.ad.source_info_util
have now been removed. Usejax.config
andjax.extend.source_info_util
instead.JAX export does not support older serialization versions anymore. Version 9 has been supported since October 27th, 2023 and has become the default since February 1, 2024. See a description of the versions. This change could break clients that set a specific JAX serialization version lower than 9.
jaxlib 0.4.26 (April 3, 2024)#
Changes
JAX now supports CUDA 12.1 or newer only. Support for CUDA 11.8 has been dropped.
JAX now supports NumPy 2.0.
jax 0.4.25 (Feb 26, 2024)#
New Features
Added CUDA Array Interface import support (requires jaxlib 0.4.24).
JAX arrays now support NumPy-style scalar boolean indexing, e.g.
x[True]
orx[False]
.Added
jax.tree
module, with a more convenient interface for referencing functions injax.tree_util
.jax.tree.transpose()
(i.e.jax.tree_util.tree_transpose()
) now acceptsinner_treedef=None
, in which case the inner treedef will be automatically inferred.
Changes
Pallas now uses XLA instead of the Triton Python APIs to compile Triton kernels. You can revert to the old behavior by setting the
JAX_TRITON_COMPILE_VIA_XLA
environment variable to"0"
.Several deprecated APIs in
jax.interpreters.xla
that were removed in v0.4.24 have been re-added in v0.4.25, includingbackend_specific_translations
,translations
,register_translation
,xla_destructure
,TranslationRule
,TranslationContext
, andXLAOp
. These are still considered deprecated, and will be removed again in the future when better replacements are available. Refer to #19816 for discussion.
Deprecations & Removals
jax.numpy.linalg.solve()
now shows a deprecation warning for batched 1D solves withb.ndim > 1
. In the future these will be treated as batched 2D solves.Conversion of a non-scalar array to a Python scalar now raises an error, regardless of the size of the array. Previously a deprecation warning was raised in the case of non-scalar arrays of size 1. This follows a similar deprecation in NumPy.
The previously deprecated configuration APIs have been removed following a standard 3 months deprecation cycle (see API compatibility). These include
the
jax.config.config
object andthe
define_*_state
andDEFINE_*
methods ofjax.config
.
Importing the
jax.config
submodule viaimport jax.config
is deprecated. To configure JAX useimport jax
and then reference the config object viajax.config
.The minimum jaxlib version is now 0.4.20.
jaxlib 0.4.25 (Feb 26, 2024)#
jax 0.4.24 (Feb 6, 2024)#
Changes
JAX lowering to StableHLO does not depend on physical devices anymore. If your primitive wraps custom_paritioning or JAX callbacks in the lowering rule i.e. function passed to
rule
parameter ofmlir.register_lowering
then add your primitive tojax._src.dispatch.prim_requires_devices_during_lowering
set. This is needed because custom_partitioning and JAX callbacks need physical devices to createSharding
s during lowering. This is a temporary state until we can createSharding
s without physical devices.jax.numpy.argsort()
andjax.numpy.sort()
now support thestable
anddescending
arguments.Several changes to the handling of shape polymorphism (used in
jax.experimental.jax2tf
andjax.experimental.export
):cleaner pretty-printing of symbolic expressions (#19227)
added the ability to specify symbolic constraints on the dimension variables. This makes shape polymorphism more expressive, and gives a way to workaround limitations in the reasoning about inequalities. See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
with the addition of symbolic constraints (#19235) we now consider dimension variables from different scopes to be different, even if they have the same name. Symbolic expressions from different scopes cannot interact, e.g., in arithmetic operations. Scopes are introduced by
jax.experimental.jax2tf.convert()
,jax.experimental.export.symbolic_shape()
,jax.experimental.export.symbolic_args_specs()
. The scope of a symbolic expressione
can be read withe.scope
and passed into the above functions to direct them to construct symbolic expressions in a given scope. See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.simplified and faster equality comparisons, where we consider two symbolic dimensions to be equal if the normalized form of their difference reduces to 0 (#19231; note that this may result in user-visible behavior changes)
improved the error messages for inconclusive inequality comparisons (#19235).
the
core.non_negative_dim
API (introduced recently) was deprecated andcore.max_dim
andcore.min_dim
were introduced (#18953) to expressmax
andmin
for symbolic dimensions. You can usecore.max_dim(d, 0)
instead ofcore.non_negative_dim(d)
.the
shape_poly.is_poly_dim
is deprecated in favor ofexport.is_symbolic_dim
(#19282).the
export.args_specs
is deprecated in favor ofexport.symbolic_args_specs ({jax-issue}
#19283`).the
shape_poly.PolyShape
andjax2tf.PolyShape
are deprecated, use strings for polymorphic shapes specifications (#19284).JAX default native serialization version is now 9. This is relevant for
jax.experimental.jax2tf
andjax.experimental.export
. See description of version numbers.
Refactored the API for
jax.experimental.export
. Instead offrom jax.experimental.export import export
you should use nowfrom jax.experimental import export
. The old way of importing will continue to work for a deprecation period of 3 months.Added
jax.scipy.stats.sem()
.jax.numpy.unique()
withreturn_inverse = True
returns inverse indices reshaped to the dimension of the input, following a similar change tonumpy.unique()
in NumPy 2.0.jax.numpy.sign()
now returnsx / abs(x)
for nonzero complex inputs. This is consistent with the behavior ofnumpy.sign()
in NumPy version 2.0.jax.scipy.special.logsumexp()
withreturn_sign=True
now uses the NumPy 2.0 convention for the complex sign,x / abs(x)
. This is consistent with the behavior ofscipy.special.logsumexp()
in SciPy v1.13.JAX now supports the bool DLPack type for both import and export. Previously bool values could not be imported and were exported as integers.
Deprecations & Removals
A number of previously deprecated functions have been removed, following a standard 3+ month deprecation cycle (see API compatibility). This includes:
From
jax.core
:TracerArrayConversionError
,TracerIntegerConversionError
,UnexpectedTracerError
,as_hashable_function
,collections
,dtypes
,lu
,map
,namedtuple
,partial
,pp
,ref
,safe_zip
,safe_map
,source_info_util
,total_ordering
,traceback_util
,tuple_delete
,tuple_insert
, andzip
.From
jax.lax
:dtypes
,itertools
,naryop
,naryop_dtype_rule
,standard_abstract_eval
,standard_naryop
,standard_primitive
,standard_unop
,unop
, andunop_dtype_rule
.The
jax.linear_util
submodule and all its contents.The
jax.prng
submodule and all its contents.From
jax.random
:PRNGKeyArray
,KeyArray
,default_prng_impl
,threefry_2x32
,threefry2x32_key
,threefry2x32_p
,rbg_key
, andunsafe_rbg_key
.From
jax.tree_util
:register_keypaths
,AttributeKeyPathEntry
, andGetItemKeyPathEntry
.from
jax.interpreters.xla
:backend_specific_translations
,translations
,register_translation
,xla_destructure
,TranslationRule
,TranslationContext
,axis_groups
,ShapedArray
,ConcreteArray
,AxisEnv
,backend_compile
, andXLAOp
.from
jax.numpy
:NINF
,NZERO
,PZERO
,row_stack
,issubsctype
,trapz
, andin1d
.from
jax.scipy.linalg
:tril
andtriu
.
The previously-deprecated method
PRNGKeyArray.unsafe_raw_array
has been removed. Usejax.random.key_data()
instead.bool(empty_array)
now raises an error rather than returningFalse
. This previously raised a deprecation warning, and follows a similar change in NumPy.Support for the mhlo MLIR dialect has been deprecated. JAX no longer uses the mhlo dialect, in favor of stablehlo. APIs that refer to “mhlo” will be removed in the future. Use the “stablehlo” dialect instead.
jax.random
: passing batched keys directly to random number generation functions, such asbits()
,gamma()
, and others, is deprecated and will emit aFutureWarning
. Usejax.vmap
for explicit batching.jax.lax.tie_in()
is deprecated: it has been a no-op since JAX v0.2.0.
jaxlib 0.4.24 (Feb 6, 2024)#
Changes
JAX now supports CUDA 12.3 and CUDA 11.8. Support for CUDA 12.2 has been dropped.
cost_analysis
now works with cross-compiledCompiled
objects (i.e. when using.lower().compile()
with a topology object, e.g., to compile for Cloud TPU from a non-TPU computer).Added CUDA Array Interface import support (requires jax 0.4.25).
jax 0.4.23 (Dec 13, 2023)#
jaxlib 0.4.23 (Dec 13, 2023)#
Fixed a bug that caused verbose logging from the GPU compiler during compilation.
jax 0.4.22 (Dec 13, 2023)#
Deprecations
The
device_buffer
anddevice_buffers
properties of JAX arrays are deprecated. Explicit buffers have been replaced by the more flexible array sharding interface, but the previous outputs can be recovered this way:arr.device_buffer
becomesarr.addressable_data(0)
arr.device_buffers
becomes[x.data for x in arr.addressable_shards]
jaxlib 0.4.22 (Dec 13, 2023)#
jax 0.4.21 (Dec 4 2023)#
New Features
Added
jax.nn.squareplus
.
Changes
The minimum jaxlib version is now 0.4.19.
Released wheels are built now with clang instead of gcc.
Enforce that the device backend has not been initialized prior to calling
jax.distributed.initialize()
.Automate arguments to
jax.distributed.initialize()
in cloud TPU environments.
Deprecations
The previously-deprecated
sym_pos
argument has been removed fromjax.scipy.linalg.solve()
. Useassume_a='pos'
instead.Passing
None
tojax.array()
orjax.asarray()
, either directly or within a list or tuple, is deprecated and now raises aFutureWarning
. It currently is converted to NaN, and in the future will raise aTypeError
.Passing the
condition
,x
, andy
parameters tojax.numpy.where
by keyword arguments has been deprecated, to matchnumpy.where
.Passing arguments to
jax.numpy.array_equal()
andjax.numpy.array_equiv()
that cannot be converted to a JAX array is deprecated and now raises aDeprecationWaning
. Currently the functions return False, in the future this will raise an exception.The
device()
method of JAX arrays is deprecated. Depending on the context, it may be replaced with one of the following:jax.Array.devices()
returns the set of all devices used by the array.jax.Array.sharding
gives the sharding configuration used by the array.
jaxlib 0.4.21 (Dec 4 2023)#
Changes
In preparation for adding distributed CPU support, JAX now treats CPU devices identically to GPU and TPU devices, that is:
jax.devices()
includes all devices present in a distributed job, even those not local to the current process.jax.local_devices()
still only includes devices local to the current process, so if the change tojax.devices()
breaks you, you most likely want to usejax.local_devices()
instead.CPU devices now receive a globally unique ID number within a distributed job; previously CPU devices would receive a process-local ID number.
The
process_index
of each CPU device will now match any GPU or TPU devices within the same process; previously theprocess_index
of a CPU device was always 0.
On NVIDIA GPU, JAX now prefers a Jacobi SVD solver for matrices up to 1024x1024. The Jacobi solver appears faster than the non-Jacobi version.
Bug fixes
Fixed error/hang when an array with non-finite values is passed to a non-symmetric eigendecomposition (#18226). Arrays with non-finite values now produce arrays full of NaNs as outputs.
jax 0.4.20 (Nov 2, 2023)#
jaxlib 0.4.20 (Nov 2, 2023)#
Bug fixes
Fixed some type confusion between E4M3 and E5M2 float8 types.
jax 0.4.19 (Oct 19, 2023)#
New Features
Added
jax.typing.DTypeLike
, which can be used to annotate objects that are convertible to JAX dtypes.Added
jax.numpy.fill_diagonal
.
Changes
JAX now requires SciPy 1.9 or newer.
Bug fixes
Only process 0 in a multicontroller distributed JAX program will write persistent compilation cache entries. This fixes write contention if the cache is placed on a network file system such as GCS.
The version check for cusolver and cufft no longer considers the patch versions when determining if the installed version of these libraries is at least as new as the versions against which JAX was built.
jaxlib 0.4.19 (Oct 19, 2023)#
Changes
jaxlib will now always prefer pip-installed NVIDIA CUDA libraries (nvidia-… packages) over any other CUDA installation if they are installed, including installations named in
LD_LIBRARY_PATH
. If this causes problems and the intent is to use a system-installed CUDA, the fix is to remove the pip installed CUDA library packages.
jax 0.4.18 (Oct 6, 2023)#
jaxlib 0.4.18 (Oct 6, 2023)#
Changes
CUDA jaxlibs now depend on the user to install a compatible NCCL version. If using the recommended
cuda12_pip
installation, NCCL should be installed automatically. Currently, NCCL 2.16 or newer is required.We now provide Linux aarch64 wheels, both with and without NVIDIA GPU support.
jax.Array.item()
now supports optional index arguments.
Deprecations
A number of internal utilities and inadvertent exports in
jax.lax
have been deprecated, and will be removed in a future release.jax.lax.dtypes
: usejax.dtypes
instead.jax.lax.itertools
: useitertools
instead.naryop
,naryop_dtype_rule
,standard_abstract_eval
,standard_naryop
,standard_primitive
,standard_unop
,unop
, andunop_dtype_rule
are internal utilities, now deprecated without replacement.
Bug fixes
Fixed Cloud TPU regression where compilation would OOM due to smem.
jax 0.4.17 (Oct 3, 2023)#
New features
Added new
jax.numpy.bitwise_count()
function, matching the API of the similar function recently added to NumPy.
Deprecations
Removed the deprecated module
jax.abstract_arrays
and all its contents.Named key constructors in
jax.random
are deprecated. Pass theimpl
argument tojax.random.PRNGKey()
orjax.random.key()
instead:random.threefry2x32_key(seed)
becomesrandom.PRNGKey(seed, impl='threefry2x32')
random.rbg_key(seed)
becomesrandom.PRNGKey(seed, impl='rbg')
random.unsafe_rbg_key(seed)
becomesrandom.PRNGKey(seed, impl='unsafe_rbg')
Changes:
CUDA: JAX now verifies that the CUDA libraries it finds are at least as new as the CUDA libraries that JAX was built against. If older libraries are found, JAX raises an exception since that is preferable to mysterious failures and crashes.
Removed the “No GPU/TPU” found warning. Instead warn if, on Linux, an NVIDIA GPU or a Google TPU are found but not used and
--jax_platforms
was not specified.jax.scipy.stats.mode()
now returns a 0 count if the mode is taken across a size-0 axis, matching the behavior ofscipy.stats.mode
in SciPy 1.11.Most
jax.numpy
functions and attributes now have fully-defined type stubs. Previously many of these were treated asAny
by static type checkers likemypy
andpytype
.
jaxlib 0.4.17 (Oct 3, 2023)#
Changes:
Python 3.12 wheels were added in this release.
The CUDA 12 wheels now require CUDA 12.2 or newer and cuDNN 8.9.4 or newer.
Bug fixes:
Fixed log spam from ABSL when the JAX CPU backend was initialized.
jax 0.4.16 (Sept 18, 2023)#
Changes
Added
jax.numpy.ufunc
, as well asjax.numpy.frompyfunc()
, which can convert any scalar-valued function into anumpy.ufunc()
-like object, with methods such asouter()
,reduce()
,accumulate()
,at()
, andreduceat()
(#17054).When not running under IPython: when an exception is raised, JAX now filters out the entirety of its internal frames from tracebacks. (Without the “unfiltered stack trace” that previously appeared.) This should produce much friendlier-looking tracebacks. See here for an example. This behavior can be changed by setting
JAX_TRACEBACK_FILTERING=remove_frames
(for two separate unfiltered/filtered tracebacks, which was the old behavior) orJAX_TRACEBACK_FILTERING=off
(for one unfiltered traceback).jax2tf default serialization version is now 7, which introduces new shape safety assertions.
Devices passed to
jax.sharding.Mesh
should be hashable. This specifically applies to mock devices or user created devices.jax.devices()
are already hashable.
Breaking changes:
jax2tf now uses native serialization by default. See the jax2tf documentation for details and for mechanisms to override the default.
The option
--jax_coordination_service
has been removed. It is now alwaysTrue
.jax.jaxpr_util
has been removed from the public JAX namespace.JAX_USE_PJRT_C_API_ON_TPU
no longer has an effect (i.e. it always defaults to true).The backwards compatibility flag
--jax_host_callback_ad_transforms
introduced in December 2021, has been removed.
Deprecations:
Several
jax.numpy
APIs have been deprecated following NumPy NEP-52:jax.numpy.NINF
has been deprecated. Use-jax.numpy.inf
instead.jax.numpy.PZERO
has been deprecated. Use0.0
instead.jax.numpy.NZERO
has been deprecated. Use-0.0
instead.jax.numpy.issubsctype(x, t)
has been deprecated. Usejax.numpy.issubdtype(x.dtype, t)
.jax.numpy.row_stack
has been deprecated. Usejax.numpy.vstack
instead.jax.numpy.in1d
has been deprecated. Usejax.numpy.isin
instead.jax.numpy.trapz
has been deprecated. Usejax.scipy.integrate.trapezoid
instead.
jax.scipy.linalg.tril
andjax.scipy.linalg.triu
have been deprecated, following SciPy. Usejax.numpy.tril
andjax.numpy.triu
instead.jax.lax.prod
has been removed after being deprecated in JAX v0.4.11. Use the built-inmath.prod
instead.A number of exports from
jax.interpreters.xla
related to defining HLO lowering rules for custom JAX primitives have been deprecated. Custom primitives should be defined using the StableHLO lowering utilities injax.interpreters.mlir
instead.The following previously-deprecated functions have been removed after a three-month deprecation period:
jax.abstract_arrays.ShapedArray
: usejax.core.ShapedArray
.jax.abstract_arrays.raise_to_shaped
: usejax.core.raise_to_shaped
.jax.numpy.alltrue
: usejax.numpy.all
.jax.numpy.sometrue
: usejax.numpy.any
.jax.numpy.product
: usejax.numpy.prod
.jax.numpy.cumproduct
: usejax.numpy.cumprod
.
Deprecations/removals:
The internal submodule
jax.prng
is now deprecated. Its contents are available atjax.extend.random
.The internal submodule path
jax.linear_util
has been deprecated. Usejax.extend.linear_util
instead (Part of jax.extend: a module for extensions)jax.random.PRNGKeyArray
andjax.random.KeyArray
are deprecated. Usejax.Array
for type annotations, andjax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key)
for runtime detection of typed prng keys.The method
PRNGKeyArray.unsafe_raw_array
is deprecated. Usejax.random.key_data()
instead.jax.experimental.pjit.with_sharding_constraint
is deprecated. Usejax.lax.with_sharding_constraint
instead.The internal utilities
jax.core.is_opaque_dtype
andjax.core.has_opaque_dtype
have been removed. Opaque dtypes have been renamed to Extended dtypes; usejnp.issubdtype(dtype, jax.dtypes.extended)
instead (available since jax v0.4.14).The utility
jax.interpreters.xla.register_collective_primitive
has been removed. This utility did nothing useful in recent JAX releases and calls to it can be safely removed.The internal submodule path
jax.linear_util
has been deprecated. Usejax.extend.linear_util
instead (Part of jax.extend: a module for extensions)
jaxlib 0.4.16 (Sept 18, 2023)#
Changes:
Sparse CSR matrix multiplications via the experimental jax sparse APIs no longer uses a deterministic algorithm on NVIDIA GPUs. This change was made to improve compatibility with CUDA 12.2.1.
Bug fixes:
Fixed a crash on Windows due to a fatal LLVM error related to out-of-order sections and IMAGE_REL_AMD64_ADDR32NB relocations (https://github.com/openxla/xla/commit/cb732a921f0c4184995cbed82394931011d12bd4).
jax 0.4.14 (July 27, 2023)#
Changes
jax.jit
takesdonate_argnames
as an argument. It’s semantics are similar tostatic_argnames
. If neither donate_argnums nor donate_argnames is provided, no arguments are donated. If donate_argnums is not provided but donate_argnames is, or vice versa, JAX usesinspect.signature(fun)
to find any positional arguments that correspond to donate_argnames (or vice versa). If both donate_argnums and donate_argnames are provided, inspect.signature is not used, and only actual parameters listed in either donate_argnums or donate_argnames will be donated.jax.random.gamma()
has been re-factored to a more efficient algorithm with more robust endpoint behavior (#16779). This means that the sequence of values returned for a givenkey
will change between JAX v0.4.13 and v0.4.14 forgamma
and related samplers (includingjax.random.ball()
,jax.random.beta()
,jax.random.chisquare()
,jax.random.dirichlet()
,jax.random.generalized_normal()
,jax.random.loggamma()
,jax.random.t()
).
Deletions
in_axis_resources
andout_axis_resources
have been deleted from pjit since it has been more than 3 months since their deprecation. Please usein_shardings
andout_shardings
as the replacement. This is a safe and trivial name replacement. It does not change any of the current pjit semantics and doesn’t break any code. You can still pass inPartitionSpecs
to in_shardings and out_shardings.
Deprecations
Python 3.8 support has been dropped as per https://jax.readthedocs.io/en/latest/deprecation.html
JAX now requires NumPy 1.22 or newer as per https://jax.readthedocs.io/en/latest/deprecation.html
Passing optional arguments to
jax.numpy.ndarray.at()
by position is no longer supported, after being deprecated in JAX version 0.4.7. For example, instead ofx.at[i].get(True)
, usex.at[i].get(indices_are_sorted=True)
The following
jax.Array
methods have been removed, after being deprecated in JAX v0.4.5:jax.Array.broadcast
: usejax.lax.broadcast()
instead.jax.Array.broadcast_in_dim
: usejax.lax.broadcast_in_dim()
instead.jax.Array.split
: usejax.numpy.split()
instead.
The following APIs have been removed after previous deprecation:
jax.ad
: usejax.interpreters.ad
.jax.curry
: usecurry = lambda f: partial(partial, f)
.jax.partial_eval
: usejax.interpreters.partial_eval
.jax.pxla
: usejax.interpreters.pxla
.jax.xla
: usejax.interpreters.xla
.jax.ShapedArray
: usejax.core.ShapedArray
.jax.interpreters.pxla.device_put
: usejax.device_put()
.jax.interpreters.pxla.make_sharded_device_array
: usejax.make_array_from_single_device_arrays()
.jax.interpreters.pxla.ShardedDeviceArray
: usejax.Array
.jax.numpy.DeviceArray
: usejax.Array
.jax.stages.Compiled.compiler_ir
: usejax.stages.Compiled.as_text()
.
Breaking changes
JAX now requires ml_dtypes version 0.2.0 or newer.
To fix a corner case, calls to
jax.lax.cond()
with five arguments will always resolve to the “common operands”cond
behavior (as documented) if the second and third arguments are callable, even if other operands are callable as well. See #16413.The deprecated config options
jax_array
andjax_jit_pjit_api_merge
, which did nothing, have been removed. These options have been true by default for many releases.
New features
JAX now supports a configuration flag –jax_serialization_version and a JAX_SERIALIZATION_VERSION environment variable to control the serialization version (#16746).
jax2tf in presence of shape polymorphism now generates code that checks certain shape constraints, if the serialization version is at least 7. See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism.
jaxlib 0.4.14 (July 27, 2023)#
Deprecations
Python 3.8 support has been dropped as per https://jax.readthedocs.io/en/latest/deprecation.html
jax 0.4.13 (June 22, 2023)#
Changes
jax.jit
now allowsNone
to be passed toin_shardings
andout_shardings
. The semantics are as follows:For in_shardings, JAX will mark is as replicated but this behavior can change in the future.
For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings.
jax.experimental.pjit.pjit
also allowsNone
to be passed toin_shardings
andout_shardings
. The semantics are as follows:If the mesh context manager is not provided, JAX has the freedom to choose whatever sharding it wants.
For in_shardings, JAX will mark is as replicated but this behavior can change in the future.
For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings.
If the mesh context manager is provided, None will imply that the value will be replicated on all devices of the mesh.
Executable.cost_analysis() works on Cloud TPU
Added a warning if a non-allowlisted
jaxlib
plugin is in use.Added
jax.tree_util.tree_leaves_with_path
.None
is not a valid input tojax.experimental.multihost_utils.host_local_array_to_global_array
orjax.experimental.multihost_utils.global_array_to_host_local_array
. Please usejax.sharding.PartitionSpec()
if you wanted to replicate your input.
Bug fixes
Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel is named
cudnn89
instead ofcudnn88
.
Deprecations
The
native_serialization_strict_checks
parameter tojax.experimental.jax2tf.convert()
is deprecated in favor of the newnative_serializaation_disabled_checks
(#16347).
jaxlib 0.4.13 (June 22, 2023)#
Changes
Added Windows CPU-only wheels to the
jaxlib
Pypi release.
Bug fixes
__cuda_array_interface__
was broken in previous jaxlib versions and is now fixed (#16440).Concurrent CUDA kernel tracing is now enabled by default on NVIDIA GPUs.
jax 0.4.12 (June 8, 2023)#
Changes
Deprecations
jax.abstract_arrays
and its contents are now deprecated. See related functionality in :mod:jax.core
.jax.numpy.alltrue
: usejax.numpy.all
. This follows the deprecation ofnumpy.alltrue
in NumPy version 1.25.0.jax.numpy.sometrue
: usejax.numpy.any
. This follows the deprecation ofnumpy.sometrue
in NumPy version 1.25.0.jax.numpy.product
: usejax.numpy.prod
. This follows the deprecation ofnumpy.product
in NumPy version 1.25.0.jax.numpy.cumproduct
: usejax.numpy.cumprod
. This follows the deprecation ofnumpy.cumproduct
in NumPy version 1.25.0.jax.sharding.OpShardingSharding
has been removed since it has been 3 months since it was deprecated.
jaxlib 0.4.12 (June 8, 2023)#
Changes
Includes PTX/SASS for Hopper (SM version 9.0+) GPUs. Previous versions of jaxlib should work on Hopper but would have a long JIT-compilation delay the first time a JAX operation was executed.
Bug fixes
Fixes incorrect source line information in JAX-generated Python tracebacks under Python 3.11.
Fixes crash when printing local variables of frames in JAX-generated Python tracebacks (#16027).
jax 0.4.11 (May 31, 2023)#
Deprecations
The following APIs have been removed after a 3 month deprecation period, in accordance with the API compatibility policy:
jax.experimental.PartitionSpec
: usejax.sharding.PartitionSpec
.jax.experimental.maps.Mesh
: usejax.sharding.Mesh
jax.experimental.pjit.NamedSharding
: usejax.sharding.NamedSharding
.jax.experimental.pjit.PartitionSpec
: usejax.sharding.PartitionSpec
.jax.experimental.pjit.FROM_GDA
. Instead pass shardedjax.Array
objects as input and remove the optionalin_shardings
argument topjit
.jax.interpreters.pxla.PartitionSpec
: usejax.sharding.PartitionSpec
.jax.interpreters.pxla.Mesh
: usejax.sharding.Mesh
jax.interpreters.xla.Buffer
: usejax.Array
.jax.interpreters.xla.Device
: usejax.Device
.jax.interpreters.xla.DeviceArray
: usejax.Array
.jax.interpreters.xla.device_put
: usejax.device_put
.jax.interpreters.xla.xla_call_p
: usejax.experimental.pjit.pjit_p
.axis_resources
argument ofwith_sharding_constraint
is removed. Please useshardings
instead.
jaxlib 0.4.11 (May 31, 2023)#
Changes
Added
memory_stats()
method toDevice
s. If supported, this returns a dict of string stat names with int values, e.g."bytes_in_use"
, or None if the platform doesn’t support memory statistics. The exact stats returned may vary across platforms. Currently only implemented on Cloud TPU.Readded support for the Python buffer protocol (
memoryview
) on CPU devices.
jax 0.4.10 (May 11, 2023)#
jaxlib 0.4.10 (May 11, 2023)#
Changes
Fixed
'apple-m1' is not a recognized processor for this target (ignoring processor)
issue that prevented previous release from running on Mac M1.
jax 0.4.9 (May 9, 2023)#
Changes
The flags experimental_cpp_jit, experimental_cpp_pjit and experimental_cpp_pmap have been removed. They are now always on.
Accuracy of singular value decomposition (SVD) on TPU has been improved (requires jaxlib 0.4.9).
Deprecations
jax.experimental.gda_serialization
is deprecated and has been renamed tojax.experimental.array_serialization
. Please change your imports to usejax.experimental.array_serialization
.The
in_axis_resources
andout_axis_resources
arguments of pjit have been deprecated. Please usein_shardings
andout_shardings
respectively.The function
jax.numpy.msort
has been removed. It has been deprecated since JAX v0.4.1. Usejnp.sort(a, axis=0)
instead.in_parts
andout_parts
arguments have been removed fromjax.xla_computation
since they were only used with sharded_jit and sharded_jit is long gone.instantiate_const_outputs
argument has been removed fromjax.xla_computation
since it has been unused for a very long time.
jaxlib 0.4.9 (May 9, 2023)#
jax 0.4.8 (March 29, 2023)#
Breaking changes
A major component of the Cloud TPU runtime has been upgraded. This enables the following new features on Cloud TPU:
jax.debug.print()
,jax.debug.callback()
, andjax.debug.breakpoint()
now work on Cloud TPUAutomatic TPU memory defragmentation
jax.experimental.host_callback()
is no longer supported on Cloud TPU with the new runtime component. Please file an issue on the JAX issue tracker if the newjax.debug
APIs are insufficient for your use case.The old runtime component will be available for at least the next three months by setting the environment variable
JAX_USE_PJRT_C_API_ON_TPU=false
. If you find you need to disable the new runtime for any reason, please let us know on the JAX issue tracker.
Changes
The minimum jaxlib version has been bumped from 0.4.6 to 0.4.7.
Deprecations
CUDA 11.4 support has been dropped. JAX GPU wheels only support CUDA 11.8 and CUDA 12. Older CUDA versions may work if jaxlib is built from source.
global_arg_shapes
argument of pmap only worked with sharded_jit and has been removed from pmap. Please migrate to pjit and remove global_arg_shapes from pmap.
jax 0.4.7 (March 27, 2023)#
Changes
As per https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration
jax.config.jax_array
cannot be disabled anymore.jax.config.jax_jit_pjit_api_merge
cannot be disabled anymore.jax.experimental.jax2tf.convert()
now supports thenative_serialization
parameter to use JAX’s native lowering to StableHLO to obtain a StableHLO module for the entire JAX function instead of lowering each JAX primitive to a TensorFlow op. This simplifies the internals and increases the confidence that what you serialize matches the JAX native semantics. See documentation. As part of this change the config flag--jax2tf_default_experimental_native_lowering
has been renamed to--jax2tf_native_serialization
.JAX now depends on
ml_dtypes
, which contains definitions of NumPy types like bfloat16. These definitions were previously internal to JAX, but have been split into a separate package to facilitate sharing them with other projects.JAX now requires NumPy 1.21 or newer and SciPy 1.7 or newer.
Deprecations
The type
jax.numpy.DeviceArray
is deprecated. Usejax.Array
instead, for which it is an alias.The type
jax.interpreters.pxla.ShardedDeviceArray
is deprecated. Usejax.Array
instead.Passing additional arguments to
jax.numpy.ndarray.at()
by position is deprecated. For example, instead ofx.at[i].get(True)
, usex.at[i].get(indices_are_sorted=True)
jax.interpreters.xla.device_put
is deprecated. Please usejax.device_put
.jax.interpreters.pxla.device_put
is deprecated. Please usejax.device_put
.jax.experimental.pjit.FROM_GDA
is deprecated. Please pass in sharded jax.Arrays as input and remove thein_shardings
argument to pjit since it is optional.
jaxlib 0.4.7 (March 27, 2023)#
Changes:
jaxlib now depends on
ml_dtypes
, which contains definitions of NumPy types like bfloat16. These definitions were previously internal to JAX, but have been split into a separate package to facilitate sharing them with other projects.
jax 0.4.6 (Mar 9, 2023)#
Changes
jax.tree_util
now contain a set of APIs that allow user to define keys for their custom pytree node. This includes:tree_flatten_with_path
that flattens a tree and return not only each leaf but also their key paths.tree_map_with_path
that can map a function that takes the key path as an argument.register_pytree_with_keys
to register how the key path and leaves should looks like in a custom pytree node.keystr
that pretty-prints a key path.
jax2tf.call_tf()
has a new parameteroutput_shape_dtype
(defaultNone
) that can be used to declare the output shape and type of the result. This enablesjax2tf.call_tf()
to work in the presence of shape polymorphism. (#14734).
Deprecations
The old key-path APIs in
jax.tree_util
are deprecated and will be removed 3 months from Mar 10 2023:register_keypaths
: usejax.tree_util.register_pytree_with_keys()
instead.AttributeKeyPathEntry
: useGetAttrKey
instead.GetitemKeyPathEntry
: useSequenceKey
orDictKey
instead.
jaxlib 0.4.6 (Mar 9, 2023)#
jax 0.4.5 (Mar 2, 2023)#
Deprecations
jax.sharding.OpShardingSharding
has been renamed tojax.sharding.GSPMDSharding
.jax.sharding.OpShardingSharding
will be removed in 3 months from Feb 17, 2023.The following
jax.Array
methods are deprecated and will be removed 3 months from Feb 23 2023:jax.Array.broadcast
: usejax.lax.broadcast()
instead.jax.Array.broadcast_in_dim
: usejax.lax.broadcast_in_dim()
instead.jax.Array.split
: usejax.numpy.split()
instead.
jax 0.4.4 (Feb 16, 2023)#
Changes
The implementation of
jit
andpjit
has been merged. Merging jit and pjit changes the internals of JAX without affecting the public API of JAX. Before,jit
was a final style primitive. Final style means that the creation of jaxpr was delayed as much as possible and transformations were stacked on top of each other. With thejit
-pjit
implementation merge,jit
becomes an initial style primitive which means that we trace to jaxpr as early as possible. For more information see this section in autodidax. Moving to initial style should simplify JAX’s internals and make development of features like dynamic shapes, etc easier. You can disable it only via the environment variable i.e.os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'
. The merge must be disabled via an environment variable since it affects JAX at import time so it needs to be disabled before jax is imported.axis_resources
argument ofwith_sharding_constraint
is deprecated. Please useshardings
instead. There is no change needed if you were usingaxis_resources
as an arg. If you were using it as a kwarg, then please useshardings
instead.axis_resources
will be removed after 3 months from Feb 13, 2023.added the
jax.typing
module, with tools for type annotations of JAX functions.The following names have been deprecated:
jax.xla.Device
andjax.interpreters.xla.Device
: usejax.Device
.jax.experimental.maps.Mesh
. Usejax.sharding.Mesh
instead.jax.experimental.pjit.NamedSharding
: usejax.sharding.NamedSharding
.jax.experimental.pjit.PartitionSpec
: usejax.sharding.PartitionSpec
.jax.interpreters.pxla.Mesh
: usejax.sharding.Mesh
.jax.interpreters.pxla.PartitionSpec
: usejax.sharding.PartitionSpec
.
Breaking Changes
the
initial
argument to reduction functions like :func:jax.numpy.sum
is now required to be a scalar, consistent with the corresponding NumPy API. The previous behavior of broadcasting the output against non-scalarinitial
values was an unintentional implementation detail (#14446).
jaxlib 0.4.4 (Feb 16, 2023)#
Breaking changes
Support for NVIDIA Kepler series GPUs has been removed from the default
jaxlib
builds. If Kepler support is needed, it is still possible to buildjaxlib
from source with Kepler support (via the--cuda_compute_capabilities=sm_35
option tobuild.py
), however note that CUDA 12 has completely dropped support for Kepler GPUs.
jax 0.4.3 (Feb 8, 2023)#
Breaking changes
Deleted
jax.scipy.linalg.polar_unitary()
, which was a deprecated JAX extension to the scipy API. Usejax.scipy.linalg.polar()
instead.
Changes
Added
jax.scipy.stats.rankdata()
.
jaxlib 0.4.3 (Feb 8, 2023)#
jax.Array
now has the non-blockingis_ready()
method, which returnsTrue
if the array is ready (see alsojax.block_until_ready()
).
jax 0.4.2 (Jan 24, 2023)#
Breaking changes
Deleted
jax.experimental.callback
Operations with dimensions in presence of jax2tf shape polymorphism have been generalized to work in more scenarios, by converting the symbolic dimension to JAX arrays. Operations involving symbolic dimensions and
np.ndarray
now can raise errors when the result is used as a shape value (#14106).jaxpr objects now raise an error on attribute setting in order to avoid problematic mutations (#14102)
Changes
jax2tf.call_tf()
has a new parameterhas_side_effects
(defaultTrue
) that can be used to declare whether an instance can be removed or replicated by JAX optimizations such as dead-code elimination (#13980).Added more support for floordiv and mod for jax2tf shape polymorphism. Previously, certain division operations resulted in errors in presence of symbolic dimensions (#14108).
jaxlib 0.4.2 (Jan 24, 2023)#
Changes
Set JAX_USE_PJRT_C_API_ON_TPU=1 to enable new Cloud TPU runtime, featuring automatic device memory defragmentation.
jax 0.4.1 (Dec 13, 2022)#
Changes
Support for Python 3.7 has been dropped, in accordance with JAX’s Python and NumPy version support policy.
We introduce
jax.Array
which is a unified array type that subsumesDeviceArray
,ShardedDeviceArray
, andGlobalDeviceArray
types in JAX. Thejax.Array
type helps make parallelism a core feature of JAX, simplifies and unifies JAX internals, and allows us to unifyjit
andpjit
.jax.Array
has been enabled by default in JAX 0.4 and makes some breaking change to thepjit
API. The jax.Array migration guide can help you migrate your codebase tojax.Array
. You can also look at the Distributed arrays and automatic parallelization tutorial to understand the new concepts.PartitionSpec
andMesh
are now out of experimental. The new API endpoints arejax.sharding.PartitionSpec
andjax.sharding.Mesh
.jax.experimental.maps.Mesh
andjax.experimental.PartitionSpec
are deprecated and will be removed in 3 months.with_sharding_constraint
s new public endpoint isjax.lax.with_sharding_constraint
.If using ABSL flags together with
jax.config
, the ABSL flag values are no longer read or written after the JAX configuration options are initially populated from the ABSL flags. This change improves performance of readingjax.config
options, which are used pervasively in JAX.The jax2tf.call_tf function now uses for TF lowering the first TF device of the same platform as used by the embedding JAX computation. Before, it was using the 0th device for the JAX-default backend.
A number of
jax.numpy
functions now have their arguments marked as positional-only, matching NumPy.jnp.msort
is now deprecated, following the deprecation ofnp.msort
in numpy 1.24. It will be removed in a future release, in accordance with the API compatibility policy. It can be replaced withjnp.sort(a, axis=0)
.
jaxlib 0.4.1 (Dec 13, 2022)#
Changes
Support for Python 3.7 has been dropped, in accordance with JAX’s Python and NumPy version support policy.
The behavior of
XLA_PYTHON_CLIENT_MEM_FRACTION=.XX
has been changed to allocate XX% of the total GPU memory instead of the previous behavior of using currently available GPU memory to calculate preallocation. Please refer to GPU memory allocation for more details.The deprecated method
.block_host_until_ready()
has been removed. Use.block_until_ready()
instead.
jax 0.4.0 (Dec 12, 2022)#
The release was yanked.
jaxlib 0.4.0 (Dec 12, 2022)#
The release was yanked.
jax 0.3.25 (Nov 15, 2022)#
Changes
jax.numpy.linalg.pinv()
now supports thehermitian
option.jax.scipy.linalg.hessenberg()
is now supported on CPU only. Requires jaxlib > 0.3.24.New functions
jax.lax.linalg.hessenberg()
,jax.lax.linalg.tridiagonal()
, andjax.lax.linalg.householder_product()
were added. Householder reduction is currently CPU-only and tridiagonal reductions are supported on CPU and GPU only.The gradients of
svd
andjax.numpy.linalg.pinv
are now computed more economically for non-square matrices.
Breaking Changes
Deleted the
jax_experimental_name_stack
config option.Convert a string
axis_names
arguments to thejax.experimental.maps.Mesh
constructor into a singleton tuple instead of unpacking the string into a sequence of character axis names.
jaxlib 0.3.25 (Nov 15, 2022)#
Changes
Added support for tridiagonal reductions on CPU and GPU.
Added support for upper Hessenberg reductions on CPU.
Bugs
Fixed a bug that meant that frames in tracebacks captured by JAX were incorrectly mapped to source lines under Python 3.10+
jax 0.3.24 (Nov 4, 2022)#
Changes
JAX should be faster to import. We now import scipy lazily, which accounted for a significant fraction of JAX’s import time.
Setting the env var
JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=$N
can be used to limit the number of cache entries written to the persistent cache. By default, computations that take 1 second or more to compile will be cached.Added
jax.scipy.stats.mode()
.
The default device order used by
pmap
on TPU if no order is specified now matchesjax.devices()
for single-process jobs. Previously the two orderings differed, which could lead to unnecessary copies or out-of-memory errors. Requiring the orderings to agree simplifies matters.
Breaking Changes
jax.numpy.gradient()
now behaves like most other functions injax.numpy
, and forbids passing lists or tuples in place of arrays (#12958)Functions in
jax.numpy.linalg
andjax.numpy.fft
now uniformly require inputs to be array-like: i.e. lists and tuples cannot be used in place of arrays. Part of #7737.
Deprecations
jax.sharding.MeshPspecSharding
has been renamed tojax.sharding.NamedSharding
.jax.sharding.MeshPspecSharding
name will be removed in 3 months.
jaxlib 0.3.24 (Nov 4, 2022)#
Changes
Buffer donation now works on CPU. This may break code that marked buffers for donation on CPU but relied on donation not being implemented.
jax 0.3.23 (Oct 12, 2022)#
Changes
Update Colab TPU driver version for new jaxlib release.
jax 0.3.22 (Oct 11, 2022)#
Changes
Add
JAX_PLATFORMS=tpu,cpu
as default setting in TPU initialization, so JAX will raise an error if TPU cannot be initialized instead of falling back to CPU. SetJAX_PLATFORMS=''
to override this behavior and automatically choose an available backend (the original default), or setJAX_PLATFORMS=cpu
to always use CPU regardless of if the TPU is available.
Deprecations
Several test utilities deprecated in JAX v0.3.8 are now removed from
jax.test_util
.
jaxlib 0.3.22 (Oct 11, 2022)#
jax 0.3.21 (Sep 30, 2022)#
Changes
The persistent compilation cache will now warn instead of raising an exception on error (#12582), so program execution can continue if something goes wrong with the cache. Set
JAX_RAISE_PERSISTENT_CACHE_ERRORS=true
to revert this behavior.
jax 0.3.20 (Sep 28, 2022)#
jaxlib 0.3.20 (Sep 28, 2022)#
Bug fixes
Fixes support for limiting the visible CUDA devices via
jax_cuda_visible_devices
in distributed jobs. This functionality is needed for the JAX/SLURM integration on GPU (#12533).
jax 0.3.19 (Sep 27, 2022)#
Fixes required jaxlib version.
jax 0.3.18 (Sep 26, 2022)#
Changes
Ahead-of-time lowering and compilation functionality (tracked in #7733) is stable and public. See the overview and the API docs for
jax.stages
.Introduced
jax.Array
, intended to be used for bothisinstance
checks and type annotations for array types in JAX. Notice that this included some subtle changes to howisinstance
works forjax.numpy.ndarray
for jax-internal objects, asjax.numpy.ndarray
is now a simple alias ofjax.Array
.
Breaking changes
jax._src
is no longer imported into the publicjax
namespace. This may break users that were using JAX internals.jax.soft_pmap
has been deleted. Please usepjit
orxmap
instead.jax.soft_pmap
is undocumented. If it were documented, a deprecation period would have been provided.
jax 0.3.17 (Aug 31, 2022)#
Bugs
Fix corner case issue in gradient of
lax.pow
with an exponent of zero (#12041)
Breaking changes
jax.checkpoint()
, also known asjax.remat()
, no longer supports theconcrete
option, following the previous version’s deprecation; see JEP 11830.
Changes
Added
jax.pure_callback()
that enables calling back to pure Python functions from compiled functions (e.g. functions decorated withjax.jit
orjax.pmap
).
Deprecations:
The deprecated
DeviceArray.tile()
method has been removed. Usejax.numpy.tile()
(#11944).DeviceArray.to_py()
has been deprecated. Usenp.asarray(x)
instead.
jax 0.3.16#
Breaking changes
Support for NumPy 1.19 has been dropped, per the deprecation policy. Please upgrade to NumPy 1.20 or newer.
Changes
Added
jax.debug
that includes utilities for runtime value debugging such atjax.debug.print()
andjax.debug.breakpoint()
.Added new documentation for runtime value debugging
Deprecations
jax.mask()
jax.shapecheck()
APIs have been removed. See #11557.jax.experimental.loops
has been removed. See #10278 for an alternative API.jax.tree_util.tree_multimap()
has been removed. It has been deprecated since JAX release 0.3.5, andjax.tree_util.tree_map()
is a direct replacement.Removed
jax.experimental.stax
; it has long been a deprecated alias ofjax.example_libraries.stax
.Removed
jax.experimental.optimizers
; it has long been a deprecated alias ofjax.example_libraries.optimizers
.jax.checkpoint()
, also known asjax.remat()
, has a new implementation switched on by default, meaning the old implementation is deprecated; see JEP 11830.
jax 0.3.15 (July 22, 2022)#
Changes
JaxTestCase
andJaxTestLoader
have been removed fromjax.test_util
. These classes have been deprecated since v0.3.1 (#11248).Added
jax.scipy.gaussian_kde
(#11237).Binary operations between JAX arrays and built-in collections (
dict
,list
,set
,tuple
) now raise aTypeError
in all cases. Previously some cases (particularly equality and inequality) would return boolean scalars inconsistent with similar operations in NumPy (#11234).Several
jax.tree_util
routines accessed as top-level JAX package imports are now deprecated, and will be removed in a future JAX release in accordance with the API compatibility policy:jax.treedef_is_leaf()
is deprecated in favor ofjax.tree_util.treedef_is_leaf()
jax.tree_flatten()
is deprecated in favor ofjax.tree_util.tree_flatten()
jax.tree_leaves()
is deprecated in favor ofjax.tree_util.tree_leaves()
jax.tree_structure()
is deprecated in favor ofjax.tree_util.tree_structure()
jax.tree_transpose()
is deprecated in favor ofjax.tree_util.tree_transpose()
jax.tree_unflatten()
is deprecated in favor ofjax.tree_util.tree_unflatten()
The
sym_pos
argument ofjax.scipy.linalg.solve()
is deprecated in favor ofassume_a='pos'
, following a similar deprecation inscipy.linalg.solve()
.
jaxlib 0.3.15 (July 22, 2022)#
jax 0.3.14 (June 27, 2022)#
Breaking changes
jax.experimental.compilation_cache.initialize_cache()
does not supportmax_cache_size_ bytes
anymore and will not get that as an input.JAX_PLATFORMS
now raises an exception when platform initialization fails.
Changes
Fixed compatibility problems with NumPy 1.23.
jax.numpy.linalg.slogdet()
now accepts an optionalmethod
argument that allows selection between an LU-decomposition based implementation and an implementation based on QR decomposition.jax.numpy.linalg.qr()
now supportsmode="raw"
.pickle
,copy.copy
, andcopy.deepcopy
now have more complete support when used on jax arrays (#10659). In particular:pickle
anddeepcopy
previously returnednp.ndarray
objects when used on aDeviceArray
; nowDeviceArray
objects are returned. Fordeepcopy
, the copied array is on the same device as the original. Forpickle
the deserialized array will be on the default device.Within function transformations (i.e. traced code),
deepcopy
andcopy
previously were no-ops. Now they use the same mechanism asDeviceArray.copy()
.Calling
pickle
on a traced array now results in an explicitConcretizationTypeError
.
The implementation of singular value decomposition (SVD) and symmetric/Hermitian eigendecomposition should be significantly faster on TPU, especially for matrices above 1000x1000 or so. Both now use a spectral divide-and-conquer algorithm for eigendecomposition (QDWH-eig).
jax.numpy.ldexp()
no longer silently promotes all inputs to float64, instead it promotes to float32 for integer inputs of size int32 or smaller (#10921).Add a
create_perfetto_link
option tojax.profiler.start_trace()
andjax.profiler.start_trace()
. When used, the profiler will generate a link to the Perfetto UI to view the trace.Changed the semantics of
jax.profiler.start_server(...)()
to store the keepalive globally, rather than requiring the user to keep a reference to it.Added
jax.random.ball()
.Added
jax.default_device()
.Added a
python -m jax.collect_profile
script to manually capture program traces as an alternative to the Tensorboard UI.Added a
jax.named_scope
context manager that adds profiler metadata to Python programs (similar tojax.named_call
).In scatter-update operations (i.e. :attr:
jax.numpy.ndarray.at
), unsafe implicit dtype casts are deprecated, and now result in aFutureWarning
. In a future release, this will become an error. An example of an unsafe implicit cast isjnp.zeros(4, dtype=int).at[0].set(1.5)
, in which1.5
previously was silently truncated to1
.jax.experimental.compilation_cache.initialize_cache()
now supports gcs bucket path as input.Added
jax.scipy.stats.gennorm()
.jax.numpy.roots()
is now better behaved whenstrip_zeros=False
when coefficients have leading zeros (#11215).
jaxlib 0.3.14 (June 27, 2022)#
-
x86-64 Mac wheels now require Mac OS 10.14 (Mojave) or newer. Mac OS 10.14 was released in 2018, so this should not be a very onerous requirement.
The bundled version of NCCL was updated to 2.12.12, fixing some deadlocks.
The Python flatbuffers package is no longer a dependency of jaxlib.
jax 0.3.13 (May 16, 2022)#
jax 0.3.12 (May 15, 2022)#
Changes
Fixes #10717.
jax 0.3.11 (May 15, 2022)#
Changes
jax.lax.eigh()
now accepts an optionalsort_eigenvalues
argument that allows users to opt out of eigenvalue sorting on TPU.
Deprecations
Non-array arguments to functions in
jax.lax.linalg
are now marked keyword-only. As a backward-compatibility step passing keyword-only arguments positionally yields a warning, but in a future JAX release passing keyword-only arguments positionally will fail. However, most users should prefer to usejax.numpy.linalg
instead.jax.scipy.linalg.polar_unitary()
, which was a JAX extension to the scipy API, is deprecated. Usejax.scipy.linalg.polar()
instead.
jax 0.3.10 (May 3, 2022)#
jaxlib 0.3.10 (May 3, 2022)#
Changes
TF commit fixes an issue in the MHLO canonicalizer that caused constant folding to take a long time or crash for certain programs.
jax 0.3.9 (May 2, 2022)#
Changes
Added support for fully asynchronous checkpointing for GlobalDeviceArray.
jax 0.3.8 (April 29 2022)#
Changes
jax.numpy.linalg.svd()
on TPUs uses a qdwh-svd solver.jax.numpy.linalg.cond()
on TPUs now accepts complex input.jax.numpy.linalg.pinv()
on TPUs now accepts complex input.jax.numpy.linalg.matrix_rank()
on TPUs now accepts complex input.jax.scipy.cluster.vq.vq()
has been added.jax.experimental.maps.mesh
has been deleted. Please usejax.experimental.maps.Mesh
. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh for more information.jax.scipy.linalg.qr()
now returns a length-1 tuple rather than the raw array whenmode='r'
, in order to match the behavior ofscipy.linalg.qr
(#10452)jax.numpy.take_along_axis()
now takes an optionalmode
parameter that specifies the behavior of out-of-bounds indexing. By default, invalid values (e.g., NaN) will be returned for out-of-bounds indices. In previous versions of JAX, invalid indices were clamped into range. The previous behavior can be restored by passingmode="clip"
.jax.numpy.take()
now defaults tomode="fill"
, which returns invalid values (e.g., NaN) for out-of-bounds indices.Scatter operations, such as
x.at[...].set(...)
, now have"drop"
semantics. This has no effect on the scatter operation itself, but it means that when differentiated the gradient of a scatter will yield zero cotangents for out-of-bounds indices. Previously out-of-bounds indices were clamped into range for the gradient, which was not mathematically correct.jax.numpy.take_along_axis()
now raises aTypeError
if its indices are not of an integer type, matching the behavior ofnumpy.take_along_axis()
. Previously non-integer indices were silently cast to integers.jax.numpy.ravel_multi_index()
now raises aTypeError
if itsdims
argument is not of an integer type, matching the behavior ofnumpy.ravel_multi_index()
. Previously non-integerdims
was silently cast to integers.jax.numpy.split()
now raises aTypeError
if itsaxis
argument is not of an integer type, matching the behavior ofnumpy.split()
. Previously non-integeraxis
was silently cast to integers.jax.numpy.indices()
now raises aTypeError
if its dimensions are not of an integer type, matching the behavior ofnumpy.indices()
. Previously non-integer dimensions were silently cast to integers.jax.numpy.diag()
now raises aTypeError
if itsk
argument is not of an integer type, matching the behavior ofnumpy.diag()
. Previously non-integerk
was silently cast to integers.Added
jax.random.orthogonal()
.
Deprecations
Many functions and objects available in
jax.test_util
are now deprecated and will raise a warning on import. This includescases_from_list
,check_close
,check_eq
,device_under_test
,format_shape_dtype_string
,rand_uniform
,skip_on_devices
,with_config
,xla_bridge
, and_default_tolerance
(#10389). These, along with previously-deprecatedJaxTestCase
,JaxTestLoader
, andBufferDonationTestCase
, will be removed in a future JAX release. Most of these utilities can be replaced by calls to standard python & numpy testing utilities found in e.g.unittest
,absl.testing
,numpy.testing
, etc. JAX-specific functionality such as device checking can be replaced through the use of public APIs such asjax.devices()
. Many of the deprecated utilities will still exist injax._src.test_util
, but these are not public APIs and as such may be changed or removed without notice in future releases.
jax 0.3.7 (April 15, 2022)#
Changes:
Fixed a performance problem if the indices passed to
jax.numpy.take_along_axis()
were broadcasted (#10281).jax.scipy.special.expit()
andjax.scipy.special.logit()
now require their arguments to be scalars or JAX arrays. They also now promote integer arguments to floating point.The
DeviceArray.tile()
method is deprecated, because numpy arrays do not have atile()
method. As a replacement for this, usejax.numpy.tile()
(#10266).
jaxlib 0.3.7 (April 15, 2022)#
Changes:
Linux wheels are now built conforming to the
manylinux2014
standard, instead ofmanylinux2010
.
jax 0.3.6 (April 12, 2022)#
jax 0.3.5 (April 7, 2022)#
Changes:
added
jax.random.loggamma()
& improved behavior ofjax.random.beta()
andjax.random.dirichlet()
for small parameter values (#9906).the private
lax_numpy
submodule is no longer exposed in thejax.numpy
namespace (#10029).added array creation routines
jax.numpy.frombuffer()
,jax.numpy.fromfunction()
, andjax.numpy.fromstring()
(#10049).DeviceArray.copy()
now returns aDeviceArray
rather than anp.ndarray
(#10069)jax.experimental.sharded_jit
has been deprecated and will be removed soon.
Deprecations:
jax.nn.normalize()
is being deprecated. Usejax.nn.standardize()
instead (#9899).jax.tree_util.tree_multimap()
is deprecated. Usejax.tree_util.tree_map()
instead (#5746).jax.experimental.sharded_jit
is deprecated. Usepjit
instead.
jaxlib 0.3.5 (April 7, 2022)#
jax 0.3.4 (March 18, 2022)#
jax 0.3.3 (March 17, 2022)#
jax 0.3.2 (March 16, 2022)#
Changes:
The functions
jax.ops.index_update
,jax.ops.index_add
, which were deprecated in 0.2.22, have been removed. Please use the.at
property on JAX arrays instead, e.g.,x.at[idx].set(y)
.Moved
jax.experimental.ann.approx_*_k
intojax.lax
. These functions are optimized alternatives tojax.lax.top_k
.jax.numpy.broadcast_arrays()
andjax.numpy.broadcast_to()
now require scalar or array-like inputs, and will fail if they are passed lists (part of #7737).The standard jax[tpu] install can now be used with Cloud TPU v4 VMs.
pjit
now works on CPU (in addition to previous TPU and GPU support).
jaxlib 0.3.2 (March 16, 2022)#
Changes
XlaComputation.as_hlo_text()
now supports printing large constants by passing boolean flagprint_large_constants=True
.
Deprecations:
The
.block_host_until_ready()
method on JAX arrays has been deprecated. Use.block_until_ready()
instead.
jax 0.3.1 (Feb 18, 2022)#
Changes:
jax.test_util.JaxTestCase
andjax.test_util.JaxTestLoader
are now deprecated. The suggested replacement is to useparametrized.TestCase
directly. For tests that rely on custom asserts such asJaxTestCase.assertAllClose()
, the suggested replacement is to use standard numpy testing utilities such asnumpy.testing.assert_allclose()
, which work directly with JAX arrays (#9620).jax.test_util.JaxTestCase
now setsjax_numpy_rank_promotion='raise'
by default (#9562). To recover the previous behavior, use the newjax.test_util.with_config
decorator:@jtu.with_config(jax_numpy_rank_promotion='allow') class MyTestCase(jtu.JaxTestCase): ...
Added
jax.scipy.linalg.schur()
,jax.scipy.linalg.sqrtm()
,jax.scipy.signal.csd()
,jax.scipy.signal.stft()
,jax.scipy.signal.welch()
.
jax 0.3.0 (Feb 10, 2022)#
Changes
jax version has been bumped to 0.3.0. Please see the design doc for the explanation.
jaxlib 0.3.0 (Feb 10, 2022)#
Changes
Bazel 5.0.0 is now required to build jaxlib.
jaxlib version has been bumped to 0.3.0. Please see the design doc for the explanation.
jax 0.2.28 (Feb 1, 2022)#
-
jax.jit(f).lower(...).compiler_ir()
now defaults to the MHLO dialect if nodialect=
is passed.The
jax.jit(f).lower(...).compiler_ir(dialect='mhlo')
now returns an MLIRir.Module
object instead of its string representation.
jaxlib 0.1.76 (Jan 27, 2022)#
New features
Includes precompiled SASS for NVidia compute capability 8.0 GPUS (e.g. A100). Removes precompiled SASS for compute capability 6.1 so as not to increase the number of compute capabilities: GPUs with compute capability 6.1 can use the 6.0 SASS.
With jaxlib 0.1.76, JAX uses the MHLO MLIR dialect as its primary target compiler IR by default.
Breaking changes
Support for NumPy 1.18 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.
Bug fixes
Fixed a bug where apparently identical pytreedef objects constructed by different routes do not compare as equal (#9066).
The JAX jit cache requires two static arguments to have identical types for a cache hit (#9311).
jax 0.2.27 (Jan 18 2022)#
Breaking changes:
Support for NumPy 1.18 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.
The host_callback primitives have been simplified to drop the special autodiff handling for hcb.id_tap and id_print. From now on, only the primals are tapped. The old behavior can be obtained (for a limited time) by setting the
JAX_HOST_CALLBACK_AD_TRANSFORMS
environment variable, or the--jax_host_callback_ad_transforms
flag. Additionally, added documentation for how to implement the old behavior using JAX custom AD APIs (#8678).Sorting now matches the behavior of NumPy for
0.0
andNaN
regardless of the bit representation. In particular,0.0
and-0.0
are now treated as equivalent, where previously-0.0
was treated as less than0.0
. Additionally allNaN
representations are now treated as equivalent and sorted to the end of the array. Previously negativeNaN
values were sorted to the front of the array, andNaN
values with different internal bit representations were not treated as equivalent, and were sorted according to those bit patterns (#9178).jax.numpy.unique()
now treatsNaN
values in the same way asnp.unique
in NumPy versions 1.21 and newer: at most oneNaN
value will appear in the uniquified output (#9184).
Bug fixes:
host_callback now supports ad_checkpoint.checkpoint (#8907).
New features:
add
jax.block_until_ready
({jax-issue}`#8941)Added a new debugging flag/environment variable
JAX_DUMP_IR_TO=/path
. If set, JAX dumps the MHLO/HLO IR it generates for each computation to a file under the given path.Added
jax.ensure_compile_time_eval
to the public api (#7987).jax2tf now supports a flag jax2tf_associative_scan_reductions to change the lowering for associative reductions, e.g., jnp.cumsum, to behave like JAX on CPU and GPU (to use an associative scan). See the jax2tf README for more details (#9189).
jaxlib 0.1.75 (Dec 8, 2021)#
New features:
Support for python 3.10.
jax 0.2.26 (Dec 8, 2021)#
Bug fixes:
Out-of-bounds indices to
jax.ops.segment_sum
will now be handled withFILL_OR_DROP
semantics, as documented. This primarily affects the reverse-mode derivative, where gradients corresponding to out-of-bounds indices will now be returned as 0. (#8634).jax2tf will force the converted code to use XLA for the code fragments under jax.jit, e.g., most jax.numpy functions (#7839).
jaxlib 0.1.74 (Nov 17, 2021)#
Enabled peer-to-peer copies between GPUs. Previously, GPU copies were bounced via the host, which is usually slower.
Added experimental MLIR Python bindings for use by JAX.
jax 0.2.25 (Nov 10, 2021)#
New features:
(Experimental)
jax.distributed.initialize
exposes multi-host GPU backend.jax.random.permutation
supports newindependent
keyword argument (#8430)
Breaking changes
Moved
jax.experimental.stax
tojax.example_libraries.stax
Moved
jax.experimental.optimizers
tojax.example_libraries.optimizers
New features:
Added
jax.lax.linalg.qdwh
.
jax 0.2.24 (Oct 19, 2021)#
jaxlib 0.1.73 (Oct 18, 2021)#
Multiple cuDNN versions are now supported for jaxlib GPU
cuda11
wheels.cuDNN 8.2 or newer. We recommend using the cuDNN 8.2 wheel if your cuDNN installation is new enough, since it supports additional functionality.
cuDNN 8.0.5 or newer.
Breaking changes:
The install commands for GPU jaxlib are as follows:
pip install --upgrade pip # Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer. pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html # Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer. pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html # Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer. pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
jax 0.2.22 (Oct 12, 2021)#
Breaking Changes
Static arguments to
jax.pmap
must now be hashable.Unhashable static arguments have long been disallowed on
jax.jit
, but they were still permitted onjax.pmap
;jax.pmap
compared unhashable static arguments using object identity.This behavior is a footgun, since comparing arguments using object identity leads to recompilation each time the object identity changes. Instead, we now ban unhashable arguments: if a user of
jax.pmap
wants to compare static arguments by object identity, they can define__hash__
and__eq__
methods on their objects that do that, or wrap their objects in an object that has those operations with object identity semantics. Another option is to usefunctools.partial
to encapsulate the unhashable static arguments into the function object.jax.util.partial
was an accidental export that has now been removed. Usefunctools.partial
from the Python standard library instead.
Deprecations
The functions
jax.ops.index_update
,jax.ops.index_add
etc. are deprecated and will be removed in a future JAX release. Please use the.at
property on JAX arrays instead, e.g.,x.at[idx].set(y)
. For now, these functions produce aDeprecationWarning
.
New features:
An optimized C++ code-path improving the dispatch time for
pmap
is now the default when using jaxlib 0.1.72 or newer. The feature can be disabled using the--experimental_cpp_pmap
flag (orJAX_CPP_PMAP
environment variable).jax.numpy.unique
now supports an optionalfill_value
argument (#8121)
jaxlib 0.1.72 (Oct 12, 2021)#
Breaking changes:
Support for CUDA 10.2 and CUDA 10.1 has been dropped. Jaxlib now supports CUDA 11.1+.
Bug fixes:
Fixes https://github.com/google/jax/issues/7461, which caused wrong outputs on all platforms due to incorrect buffer aliasing inside the XLA compiler.
jax 0.2.21 (Sept 23, 2021)#
Breaking Changes
jax.api
has been removed. Functions that were available asjax.api.*
were aliases for functions injax.*
; please use the functions injax.*
instead.jax.partial
, andjax.lax.partial
were accidental exports that have now been removed. Usefunctools.partial
from the Python standard library instead.Boolean scalar indices now raise a
TypeError
; previously this silently returned wrong results (#7925).Many more
jax.numpy
functions now require array-like inputs, and will error if passed a list (#7747 #7802 #7907). See #7737 for a discussion of the rationale behind this change.When inside a transformation such as
jax.jit
,jax.numpy.array
always stages the array it produces into the traced computation. Previouslyjax.numpy.array
would sometimes produce a on-device array, even under ajax.jit
decorator. This change may break code that used JAX arrays to perform shape or index computations that must be known statically; the workaround is to perform such computations using classic NumPy arrays instead.jnp.ndarray
is now a true base-class for JAX arrays. In particular, this means that for a standard numpy arrayx
,isinstance(x, jnp.ndarray)
will now returnFalse
(#7927).
New features:
Added
jax.numpy.insert()
implementation (#7936).
jax 0.2.20 (Sept 2, 2021)#
Breaking Changes
jaxlib 0.1.71 (Sep 1, 2021)#
Breaking changes:
Support for CUDA 11.0 and CUDA 10.1 has been dropped. Jaxlib now supports CUDA 10.2 and CUDA 11.1+.
jax 0.2.19 (Aug 12, 2021)#
Breaking changes:
Support for NumPy 1.17 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.
The
jit
decorator has been added around the implementation of a number of operators on JAX arrays. This speeds up dispatch times for common operators such as+
.This change should largely be transparent to most users. However, there is one known behavioral change, which is that large integer constants may now produce an error when passed directly to a JAX operator (e.g.,
x + 2**40
). The workaround is to cast the constant to an explicit type (e.g.,np.float64(2**40)
).
New features:
Improved the support for shape polymorphism in jax2tf for operations that need to use a dimension size in array computation, e.g.,
jnp.mean
. (#7317)
Bug fixes:
Some leaked trace errors from the previous release (#7613)
jaxlib 0.1.70 (Aug 9, 2021)#
Breaking changes:
Support for Python 3.6 has been dropped, per the deprecation policy. Please upgrade to a supported Python version.
Support for NumPy 1.17 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.
The host_callback mechanism now uses one thread per local device for making the calls to the Python callbacks. Previously there was a single thread for all devices. This means that the callbacks may now be called interleaved. The callbacks corresponding to one device will still be called in sequence.
jax 0.2.18 (July 21 2021)#
Breaking changes:
Support for Python 3.6 has been dropped, per the deprecation policy. Please upgrade to a supported Python version.
The minimum jaxlib version is now 0.1.69.
The
backend
argument tojax.dlpack.from_dlpack()
has been removed.
New features:
Added a polar decomposition (
jax.scipy.linalg.polar()
).
Bug fixes:
Tightened the checks for lax.argmin and lax.argmax to ensure they are not used with an invalid
axis
value, or with an empty reduction dimension. (#7196)
jaxlib 0.1.69 (July 9 2021)#
Fix bugs in TFRT CPU backend that results in incorrect results.
jax 0.2.17 (July 9 2021)#
Bug fixes:
Default to the older “stream_executor” CPU runtime for jaxlib <= 0.1.68 to work around #7229, which caused wrong outputs on CPU due to a concurrency problem.
New features:
New SciPy function
jax.scipy.special.sph_harm()
.Reverse-mode autodiff functions (
jax.grad()
,jax.value_and_grad()
,jax.vjp()
, andjax.linear_transpose()
) support a parameter that indicates which named axes should be summed over in the backward pass if they were broadcasted over in the forward pass. This enables use of these APIs in a non-per-example way inside maps (initially onlyjax.experimental.maps.xmap()
) (#6950).
jax 0.2.16 (June 23 2021)#
jax 0.2.15 (June 23 2021)#
New features:
#7042 Turned on TFRT CPU backend with significant dispatch performance improvements on CPU.
The
jax2tf.convert()
supports inequalities and min/max for booleans (#6956).New SciPy function
jax.scipy.special.lpmn_values()
.
Breaking changes:
Support for NumPy 1.16 has been dropped, per the deprecation policy.
Bug fixes:
Fixed bug that prevented round-tripping from JAX to TF and back:
jax2tf.call_tf(jax2tf.convert)
(#6947).
jaxlib 0.1.68 (June 23 2021)#
Bug fixes:
Fixed bug in TFRT CPU backend that gets nans when transfer TPU buffer to CPU.
jax 0.2.14 (June 10 2021)#
New features:
The
jax2tf.convert()
now has support forpjit
andsharded_jit
.A new configuration option JAX_TRACEBACK_FILTERING controls how JAX filters tracebacks.
A new traceback filtering mode using
__tracebackhide__
is now enabled by default in sufficiently recent versions of IPython.The
jax2tf.convert()
supports shape polymorphism even when the unknown dimensions are used in arithmetic operations, e.g.,jnp.reshape(-1)
(#6827).The
jax2tf.convert()
generates custom attributes with location information in TF ops. The code that XLA generates after jax2tf has the same location information as JAX/XLA.New SciPy function
jax.scipy.special.lpmn()
.
Bug fixes:
The
jax2tf.convert()
now ensures that it uses the same typing rules for Python scalars and for choosing 32-bit vs. 64-bit computations as JAX (#6883).The
jax2tf.convert()
now scopes theenable_xla
conversion parameter properly to apply only during the just-in-time conversion (#6720).The
jax2tf.convert()
now convertslax.dot_general
using theXlaDot
TensorFlow op, for better fidelity w.r.t. JAX numerical precision (#6717).The
jax2tf.convert()
now has support for inequality comparisons and min/max for complex numbers (#6892).
jaxlib 0.1.67 (May 17 2021)#
jaxlib 0.1.66 (May 11 2021)#
New features:
CUDA 11.1 wheels are now supported on all CUDA 11 versions 11.1 or higher.
NVidia now promises compatibility between CUDA minor releases starting with CUDA 11.1. This means that JAX can release a single CUDA 11.1 wheel that is compatible with CUDA 11.2 and 11.3.
There is no longer a separate jaxlib release for CUDA 11.2 (or higher); use the CUDA 11.1 wheel for those versions (cuda111).
Jaxlib now bundles
libdevice.10.bc
in CUDA wheels. There should be no need to point JAX to a CUDA installation to find this file.Added automatic support for static keyword arguments to the
jit()
implementation.Added support for pretransformation exception traces.
Initial support for pruning unused arguments from
jit()
-transformed computations. Pruning is still a work in progress.Improved the string representation of
PyTreeDef
objects.Added support for XLA’s variadic ReduceWindow.
Bug fixes:
Fixed a bug in the remote cloud TPU support when large numbers of arguments are passed to a computation.
Fix a bug that meant that JAX garbage collection was not triggered by
jit()
transformed functions.
jax 0.2.13 (May 3 2021)#
New features:
When combined with jaxlib 0.1.66,
jax.jit()
now supports static keyword arguments. A newstatic_argnames
option has been added to specify keyword arguments as static.jax.nonzero()
has a new optionalsize
argument that allows it to be used withinjit
(#6501)jax.numpy.unique()
now supports theaxis
argument (#6532).jax.experimental.host_callback.call()
now supportspjit.pjit
(#6569).Added
jax.scipy.linalg.eigh_tridiagonal()
that computes the eigenvalues of a tridiagonal matrix. Only eigenvalues are supported at present.The order of the filtered and unfiltered stack traces in exceptions has been changed. The traceback attached to an exception thrown from JAX-transformed code is now filtered, with an
UnfilteredStackTrace
exception containing the original trace as the__cause__
of the filtered exception. Filtered stack traces now also work with Python 3.6.If an exception is thrown by code that has been transformed by reverse-mode automatic differentiation, JAX now attempts to attach as a
__cause__
of the exception aJaxStackTraceBeforeTransformation
object that contains the stack trace that created the original operation in the forward pass. Requires jaxlib 0.1.66.
Breaking changes:
The following function names have changed. There are still aliases, so this should not break existing code, but the aliases will eventually be removed so please change your code.
host_id
–>process_index()
host_count
–>process_count()
host_ids
–>range(jax.process_count())
Similarly, the argument to
local_devices()
has been renamed fromhost_id
toprocess_index
.Arguments to
jax.jit()
other than the function are now marked as keyword-only. This change is to prevent accidental breakage when arguments are added tojit
.
Bug fixes:
jaxlib 0.1.65 (April 7 2021)#
jax 0.2.12 (April 1 2021)#
New features
New profiling APIs:
jax.profiler.start_trace()
,jax.profiler.stop_trace()
, andjax.profiler.trace()
jax.lax.reduce()
is now differentiable.
Breaking changes:
The minimum jaxlib version is now 0.1.64.
Some profiler APIs names have been changed. There are still aliases, so this should not break existing code, but the aliases will eventually be removed so please change your code.
TraceContext
–>TraceAnnotation()
StepTraceContext
–>StepTraceAnnotation()
trace_function
–>annotate_function()
Omnistaging can no longer be disabled. See omnistaging for more information.
Python integers larger than the maximum
int64
value will now lead to an overflow in all cases, rather than being silently converted touint64
in some cases (#6047).Outside X64 mode, Python integers outside the range representable by
int32
will now lead to anOverflowError
rather than having their value silently truncated.
Bug fixes:
host_callback
now supports empty arrays in arguments and results (#6262).jax.random.randint()
clips rather than wraps of out-of-bounds limits, and can now generate integers in the full range of the specified dtype (#5868)
jax 0.2.11 (March 23 2021)#
New features:
Bug fixes:
#6136 generalized
jax.flatten_util.ravel_pytree
to handle integer dtypes.#6129 fixed a bug with handling some constants like
enum.IntEnums
#6145 fixed batching issues with incomplete beta functions
#6014 fixed H2D transfers during tracing
#6165 avoids OverflowErrors when converting some large Python integers to floats
Breaking changes:
The minimum jaxlib version is now 0.1.62.
jaxlib 0.1.64 (March 18 2021)#
jaxlib 0.1.63 (March 17 2021)#
jax 0.2.10 (March 5 2021)#
New features:
jax.scipy.stats.chi2()
is now available as a distribution with logpdf and pdf methods.jax.scipy.stats.betabinom()
is now available as a distribution with logpmf and pmf methods.Added
jax.experimental.jax2tf.call_tf()
to call TensorFlow functions from JAX (#5627) and README).Extended the batching rule for
lax.pad
to support batching of the padding values.
Bug fixes:
jax.numpy.take()
properly handles negative indices (#5768)
Breaking changes:
JAX’s promotion rules were adjusted to make promotion more consistent and invariant to JIT. In particular, binary operations can now result in weakly-typed values when appropriate. The main user-visible effect of the change is that some operations result in outputs of different precision than before; for example the expression
jnp.bfloat16(1) + 0.1 * jnp.arange(10)
previously returned afloat64
array, and now returns abfloat16
array. JAX’s type promotion behavior is described at Type promotion semantics.jax.numpy.linspace()
now computes the floor of integer values, i.e., rounding towards -inf rather than 0. This change was made to match NumPy 1.20.0.jax.numpy.i0()
no longer accepts complex numbers. Previously the function computed the absolute value of complex arguments. This change was made to match the semantics of NumPy 1.20.0.Several
jax.numpy
functions no longer accept tuples or lists in place of array arguments:jax.numpy.pad()
, :funcjax.numpy.ravel
,jax.numpy.repeat()
,jax.numpy.reshape()
. In general,jax.numpy
functions should be used with scalars or array arguments.
jaxlib 0.1.62 (March 9 2021)#
New features:
jaxlib wheels are now built to require AVX instructions on x86-64 machines by default. If you want to use JAX on a machine that doesn’t support AVX, you can build a jaxlib from source using the
--target_cpu_features
flag tobuild.py
.--target_cpu_features
also replaces--enable_march_native
.
jaxlib 0.1.61 (February 12 2021)#
jaxlib 0.1.60 (February 3 2021)#
Bug fixes:
Fixed a memory leak when converting CPU DeviceArrays to NumPy arrays. The memory leak was present in jaxlib releases 0.1.58 and 0.1.59.
bool
,int8
, anduint8
are now considered safe to cast tobfloat16
NumPy extension type.
jax 0.2.9 (January 26 2021)#
New features:
Extend the
jax.experimental.loops
module with support for pytrees. Improved error checking and error messages.Add
jax.experimental.enable_x64()
andjax.experimental.disable_x64()
. These are context managers which allow X64 mode to be temporarily enabled/disabled within a session.
Breaking changes:
jax.ops.segment_sum()
now drops segment IDs that are out of range rather than wrapping them into the segment ID space. This was done for performance reasons.
jaxlib 0.1.59 (January 15 2021)#
jax 0.2.8 (January 12 2021)#
New features:
Add
jax.closure_convert()
for use with higher-order custom derivative functions. (#5244)Add
jax.experimental.host_callback.call()
to call a custom Python function on the host and return a result to the device computation. (#5243)
Bug fixes:
jax.numpy.arccosh
now returns the same branch asnumpy.arccosh
for complex inputs (#5156)host_callback.id_tap
now works forjax.pmap
also. There is an optional parameter forid_tap
andid_print
to request that the device from which the value is tapped be passed as a keyword argument to the tap function (#5182).
Breaking changes:
jax.numpy.pad
now takes keyword arguments. Positional argumentconstant_values
has been removed. In addition, passing unsupported keyword arguments raises an error.Changes for
jax.experimental.host_callback.id_tap()
(#5243):Removed support for
kwargs
forjax.experimental.host_callback.id_tap()
. (This support has been deprecated for a few months.)Changed the printing of tuples for
jax.experimental.host_callback.id_print()
to use ‘(’ instead of ‘[‘.Changed the
jax.experimental.host_callback.id_print()
in presence of JVP to print a pair of primal and tangent. Previously, there were two separate print operations for the primals and the tangent.host_callback.outfeed_receiver
has been removed (it is not necessary, and was deprecated a few months ago).
New features:
New flag for debugging
inf
, analogous to that forNaN
(#5224).
jax 0.2.7 (Dec 4 2020)#
New features:
Add
jax.device_put_replicated
Add multi-host support to
jax.experimental.sharded_jit
Add support for differentiating eigenvalues computed by
jax.numpy.linalg.eig
Add support for building on Windows platforms
Add support for general in_axes and out_axes in
jax.pmap
Add complex support for
jax.numpy.linalg.slogdet
Bug fixes:
Fix higher-than-second order derivatives of
jax.numpy.sinc
at zeroFix some hard-to-hit bugs around symbolic zeros in transpose rules
Breaking changes:
jax.experimental.optix
has been deleted, in favor of the standaloneoptax
Python package.indexing of JAX arrays with non-tuple sequences now raises a
TypeError
. This type of indexing has been deprecated in Numpy since v1.16, and in JAX since v0.2.4. See #4564.
jax 0.2.6 (Nov 18 2020)#
New Features:
Add support for shape-polymorphic tracing for the jax.experimental.jax2tf converter. See README.md.
Breaking change cleanup
Raise an error on non-hashable static arguments for jax.jit and xla_computation. See cb48f42.
Improve consistency of type promotion behavior (#4744):
Adding a complex Python scalar to a JAX floating point number respects the precision of the JAX float. For example,
jnp.float32(1) + 1j
now returnscomplex64
, where previously it returnedcomplex128
.Results of type promotion with 3 or more terms involving uint64, a signed int, and a third type are now independent of the order of arguments. For example:
jnp.result_type(jnp.uint64, jnp.int64, jnp.float16)
andjnp.result_type(jnp.float16, jnp.uint64, jnp.int64)
both returnfloat16
, where previously the first returnedfloat64
and the second returnedfloat16
.
The contents of the (undocumented)
jax.lax_linalg
linear algebra module are now exposed publicly asjax.lax.linalg
.jax.random.PRNGKey
now produces the same results in and out of JIT compilation (#4877). This required changing the result for a given seed in a few particular cases:With
jax_enable_x64=False
, negative seeds passed as Python integers now return a different result outside JIT mode. For example,jax.random.PRNGKey(-1)
previously returned[4294967295, 4294967295]
, and now returns[0, 4294967295]
. This matches the behavior in JIT.Seeds outside the range representable by
int64
outside JIT now result in anOverflowError
rather than aTypeError
. This matches the behavior in JIT.
To recover the keys returned previously for negative integers with
jax_enable_x64=False
outside JIT, you can use:key = random.PRNGKey(-1).at[0].set(0xFFFFFFFF)
DeviceArray now raises
RuntimeError
instead ofValueError
when trying to access its value while it has been deleted.
jaxlib 0.1.58 (January 12ish 2021)#
Fixed a bug that meant JAX sometimes return platform-specific types (e.g.,
np.cint
) instead of standard types (e.g.,np.int32
). (#4903)Fixed a crash when constant-folding certain int16 operations. (#4971)
Added an
is_leaf
predicate topytree.flatten()
.
jaxlib 0.1.57 (November 12 2020)#
Fixed manylinux2010 compliance issues in GPU wheels.
Switched the CPU FFT implementation from Eigen to PocketFFT.
Fixed a bug where the hash of bfloat16 values was not correctly initialized and could change (#4651).
Add support for retaining ownership when passing arrays to DLPack (#4636).
Fixed a bug for batched triangular solves with sizes greater than 128 but not a multiple of 128.
Fixed a bug when performing concurrent FFTs on multiple GPUs (#3518).
Fixed a bug in profiler where tools are missing (#4427).
Dropped support for CUDA 10.0.
jax 0.2.5 (October 27 2020)#
Improvements:
Ensure that
check_jaxpr
does not perform FLOPS. See #4650.Expanded the set of JAX primitives converted by jax2tf. See primitives_with_limited_support.md.
jax 0.2.4 (October 19 2020)#
jaxlib 0.1.56 (October 14, 2020)#
jax 0.2.3 (October 14 2020)#
The reason for another release so soon is we need to temporarily roll back a new jit fastpath while we look into a performance degradation
jax 0.2.2 (October 13 2020)#
jax 0.2.1 (October 6 2020)#
Improvements:
As a benefit of omnistaging, the host_callback functions are executed (in program order) even if the result of the
jax.experimental.host_callback.id_print()
/jax.experimental.host_callback.id_tap()
is not used in the computation.
jax (0.2.0) (September 23 2020)#
Improvements:
Omnistaging on by default. See #3370 and omnistaging
jax (0.1.77) (September 15 2020)#
Breaking changes:
New simplified interface for
jax.experimental.host_callback.id_tap()
(#4101)
jaxlib 0.1.55 (September 8, 2020)#
Update XLA:
Fix bug in DLPackManagedTensorToBuffer (#4196)
jax 0.1.76 (September 8, 2020)#
jax 0.1.75 (July 30, 2020)#
Bug Fixes:
make jnp.abs() work for unsigned inputs (#3914)
Improvements:
“Omnistaging” behavior added behind a flag, disabled by default (#3370)
jax 0.1.74 (July 29, 2020)#
New Features:
BFGS (#3101)
TPU support for half-precision arithmetic (#3878)
Bug Fixes:
Prevent some accidental dtype warnings (#3874)
Fix a multi-threading bug in custom derivatives (#3845, #3869)
Improvements:
Faster searchsorted implementation (#3873)
Better test coverage for jax.numpy sorting algorithms (#3836)
jaxlib 0.1.52 (July 22, 2020)#
Update XLA.
jax 0.1.73 (July 22, 2020)#
The minimum jaxlib version is now 0.1.51.
New Features:
jax.image.resize. (#3703)
hfft and ihfft (#3664)
jax.numpy.intersect1d (#3726)
jax.numpy.lexsort (#3812)
lax.scan
and thescan
primitive support anunroll
parameter for loop unrolling when lowering to XLA (#3738).
Bug Fixes:
Fix reduction repeated axis error (#3618)
Fix shape rule for lax.pad for input dimensions of size 0. (#3608)
make psum transpose handle zero cotangents (#3653)
Fix shape error when taking JVP of reduce-prod over size 0 axis. (#3729)
Support differentiation through jax.lax.all_to_all (#3733)
address nan issue in jax.scipy.special.zeta (#3777)
Improvements:
Many improvements to jax2tf
Reimplement argmin/argmax using a single pass variadic reduction. (#3611)
Enable XLA SPMD partitioning by default. (#3151)
Add support for 0d transpose convolution (#3643)
Make LU gradient work for low-rank matrices (#3610)
support multiple_results and custom JVPs in jet (#3657)
Generalize reduce-window padding to support (lo, hi) pairs. (#3728)
Implement complex convolutions on CPU and GPU. (#3735)
Make jnp.take work for empty slices of empty arrays. (#3751)
Relax dimension ordering rules for dot_general. (#3778)
Enable buffer donation for GPU. (#3800)
Add support for base dilation and window dilation to reduce window op… (#3803)
jaxlib 0.1.51 (July 2, 2020)#
Update XLA.
Add new runtime support for host_callback.
jax 0.1.72 (June 28, 2020)#
Bug fixes:
Fix an odeint bug introduced in the previous release, see #3587.
jax 0.1.71 (June 25, 2020)#
The minimum jaxlib version is now 0.1.48.
Bug fixes:
Allow
jax.experimental.ode.odeint
dynamics functions to close over values with respect to which we’re differentiating #3562.
jaxlib 0.1.50 (June 25, 2020)#
Add support for CUDA 11.0.
Drop support for CUDA 9.2 (we only maintain support for the last four CUDA versions.)
Update XLA.
jaxlib 0.1.49 (June 19, 2020)#
Bug fixes:
Fix build issue that could result in slow compiles (tensorflow/tensorflow)
jaxlib 0.1.48 (June 12, 2020)#
New features:
Adds support for fast traceback collection.
Adds preliminary support for on-device heap profiling.
Implements
np.nextafter
forbfloat16
types.Complex128 support for FFTs on CPU and GPU.
Bug fixes:
Improved float64
tanh
accuracy on GPU.float64 scatters on GPU are much faster.
Complex matrix multiplication on CPU should be much faster.
Stable sorts on CPU should actually be stable now.
Concurrency bug fix in CPU backend.
jax 0.1.70 (June 8, 2020)#
New features:
lax.switch
introduces indexed conditionals with multiple branches, together with a generalization of thecond
primitive #3318.
jax 0.1.69 (June 3, 2020)#
jax 0.1.68 (May 21, 2020)#
New features:
lax.cond()
supports a single-operand form, taken as the argument to both branches #2993.
Notable changes:
The format of the
transforms
keyword for thejax.experimental.host_callback.id_tap()
primitive has changed #3132.
jax 0.1.67 (May 12, 2020)#
New features:
Support for reduction over subsets of a pmapped axis using
axis_index_groups
#2382.Experimental support for printing and calling host-side Python function from compiled code. See id_print and id_tap (#3006).
Notable changes:
The visibility of names exported from
jax.numpy
has been tightened. This may break code that was making use of names that were previously exported accidentally.
jaxlib 0.1.47 (May 8, 2020)#
Fixes crash for outfeed.
jax 0.1.66 (May 5, 2020)#
New features:
Support for
in_axes=None
onpmap()
#2896.
jaxlib 0.1.46 (May 5, 2020)#
Fixes crash for linear algebra functions on Mac OS X (#432).
Fixes an illegal instruction crash caused by using AVX512 instructions when an operating system or hypervisor disabled them (#2906).
jax 0.1.65 (April 30, 2020)#
New features:
Differentiation of determinants of singular matrices #2809.
Bug fixes:
jaxlib 0.1.45 (April 21, 2020)#
Fixes segfault: #2755
Plumb is_stable option on Sort HLO through to Python.
jax 0.1.64 (April 21, 2020)#
New features:
Add syntactic sugar for functional indexed updates #2684.
Add
jax.numpy.unique()
#2760.Add
jax.numpy.rint()
#2724.Add
jax.numpy.rint()
#2724.Add more primitive rules for
jax.experimental.jet()
.
Bug fixes:
Better errors:
Improves error message for reverse-mode differentiation of
lax.while_loop()
#2129.
jaxlib 0.1.44 (April 16, 2020)#
Fixes a bug where if multiple GPUs of different models were present, JAX would only compile programs suitable for the first GPU.
Bugfix for
batch_group_count
convolutions.Added precompiled SASS for more GPU versions to avoid startup PTX compilation hang.
jax 0.1.63 (April 12, 2020)#
Added
jax.custom_jvp
andjax.custom_vjp
from #2026, see the tutorial notebook. Deprecatedjax.custom_transforms
and removed it from the docs (though it still works).Add
scipy.sparse.linalg.cg
#2566.Changed how Tracers are printed to show more useful information for debugging #2591.
Made
jax.numpy.isclose
handlenan
andinf
correctly #2501.Added several new rules for
jax.experimental.jet
#2537.Fixed
jax.experimental.stax.BatchNorm
whenscale
/center
isn’t provided.Fix some missing cases of broadcasting in
jax.numpy.einsum
#2512.Implement
jax.numpy.cumsum
andjax.numpy.cumprod
in terms of a parallel prefix scan #2596 and makereduce_prod
differentiable to arbitray order #2597.Add
batch_group_count
toconv_general_dilated
#2635.Add docstring for
test_util.check_grads
#2656.Add
callback_transform
#2665.Implement
rollaxis
,convolve
/correlate
1d & 2d,copysign
,trunc
,roots
, andquantile
/percentile
interpolation options.
jaxlib 0.1.43 (March 31, 2020)#
Fixed a performance regression for Resnet-50 on GPU.
jax 0.1.62 (March 21, 2020)#
JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
Removed the internal function
lax._safe_mul
, which implemented the convention0. * nan == 0.
. This change means some programs when differentiated will produce nans when they previously produced correct values, though it ensures nans rather than silently incorrect results are produced for other programs. See #2447 and #1052 for details.Added an
all_gather
parallel convenience function.More type annotations in core code.
jaxlib 0.1.42 (March 19, 2020)#
jaxlib 0.1.41 broke cloud TPU support due to an API incompatibility. This release fixes it again.
JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
jax 0.1.61 (March 17, 2020)#
Fixes Python 3.5 support. This will be the last JAX or jaxlib release that supports Python 3.5.
jax 0.1.60 (March 17, 2020)#
New features:
jax.pmap()
hasstatic_broadcast_argnums
argument which allows the user to specify arguments that should be treated as compile-time constants and should be broadcasted to all devices. It works analogously tostatic_argnums
injax.jit()
.Improved error messages for when tracers are mistakenly saved in global state.
Added
jax.nn.one_hot()
utility function.Added
jax.experimental.jet
for exponentially faster higher-order automatic differentiation.Added more correctness checking to arguments of
jax.lax.broadcast_in_dim()
.
The minimum jaxlib version is now 0.1.41.
jaxlib 0.1.40 (March 4, 2020)#
Adds experimental support in Jaxlib for TensorFlow profiler, which allows tracing of CPU and GPU computations from TensorBoard.
Includes prototype support for multihost GPU computations that communicate via NCCL.
Improves performance of NCCL collectives on GPU.
Adds TopK, CustomCallWithoutLayout, CustomCallWithLayout, IGammaGradA and RandomGamma implementations.
Supports device assignments known at XLA compilation time.
jax 0.1.59 (February 11, 2020)#
Breaking changes
The minimum jaxlib version is now 0.1.38.
Simplified
Jaxpr
by removing theJaxpr.freevars
andJaxpr.bound_subjaxprs
. The call primitives (xla_call
,xla_pmap
,sharded_call
, andremat_call
) get a new parametercall_jaxpr
with a fully-closed (noconstvars
) jaxpr. Also, added a new fieldcall_primitive
to primitives.
New features:
Reverse-mode automatic differentiation (e.g.
grad
) oflax.cond
, making it now differentiable in both modes (#2091)JAX now supports DLPack, which allows sharing CPU and GPU arrays in a zero-copy way with other libraries, such as PyTorch.
JAX GPU DeviceArrays now support
__cuda_array_interface__
, which is another zero-copy protocol for sharing GPU arrays with other libraries such as CuPy and Numba.JAX CPU device buffers now implement the Python buffer protocol, which allows zero-copy buffer sharing between JAX and NumPy.
Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known as slow.
jaxlib 0.1.39 (February 11, 2020)#
Updates XLA.
jaxlib 0.1.38 (January 29, 2020)#
CUDA 9.0 is no longer supported.
CUDA 10.2 wheels are now built by default.
jax 0.1.58 (January 28, 2020)#
Breaking changes
JAX has dropped Python 2 support, because Python 2 reached its end of life on January 1, 2020. Please update to Python 3.5 or newer.
New features
Forward-mode automatic differentiation (
jvp
) of while loop (#1980)
New NumPy and SciPy functions:
Batched Cholesky decomposition on GPU now uses a more efficient batched kernel.
Notable bug fixes#
With the Python 3 upgrade, JAX no longer depends on
fastcache
, which should help with installation.
JAX Glossary of Terms#
- Array#
JAX’s analog of
numpy.ndarray
. Seejax.Array
.- CPU#
Short for Central Processing Unit, CPUs are the standard computational architecture available in most computers. JAX can run computations on CPUs, but often can achieve much better performance on GPU and TPU.
- Device#
The generic name used to refer to the CPU, GPU, or TPU used by JAX to perform computations.
- forward-mode autodiff#
See JVP
- functional programming#
A programming paradigm in which programs are defined by applying and composing pure functions. JAX is designed for use with functional programs.
- GPU#
Short for Graphical Processing Unit, GPUs were originally specialized for operations related to rendering of images on screen, but now are much more general-purpose. JAX is able to target GPUs for fast operations on arrays (see also CPU and TPU).
- jaxpr#
Short for JAX Expression, a jaxpr is an intermediate representation of a computation that is generated by JAX, and is forwarded to XLA for compilation and execution. See Understanding Jaxprs for more discussion and examples.
- JIT#
Short for Just In Time compilation, JIT in JAX generally refers to the compilation of array operations to XLA, most often accomplished using
jax.jit()
.- JVP#
Short for Jacobian Vector Product, also sometimes known as forward-mode automatic differentiation. For more details, see Jacobian-Vector products (JVPs, aka forward-mode autodiff). In JAX, JVP is a transformation that is implemented via
jax.jvp()
. See also VJP.- primitive#
A primitive is a fundamental unit of computation used in JAX programs. Most functions in
jax.lax
represent individual primitives. When representing a computation in a jaxpr, each operation in the jaxpr is a primitive.- pure function#
A pure function is a function whose outputs are based only on its inputs, and which has no side-effects. JAX’s transformation model is designed to work with pure functions. See also functional programming.
- pytree#
A pytree is an abstraction that lets JAX handle tuples, lists, dicts, and other more general containers of array values in a uniform way. Refer to Working with pytrees for a more detailed discussion.
- reverse-mode autodiff#
See VJP.
- SPMD#
Short for Single Program Multi Data, it refers to a parallel computation technique in which the same computation (e.g., the forward pass of a neural net) is run on different input data (e.g., different inputs in a batch) in parallel on different devices (e.g., several TPUs).
jax.pmap()
is a JAX transformation that implements SPMD parallelism.- static#
In a JIT compilation, a value that is not traced (see Tracer). Also sometimes refers to compile-time computations on static values.
- TPU#
Short for Tensor Processing Unit, TPUs are chips specifically engineered for fast operations on N-dimensional tensors used in deep learning applications. JAX is able to target TPUs for fast operations on arrays (see also CPU and GPU).
- Tracer#
An object used as a standin for a JAX Array in order to determine the sequence of operations performed by a Python function. Internally, JAX implements this via the
jax.core.Tracer
class.- transformation#
A higher-order function: that is, a function that takes a function as input and outputs a transformed function. Examples in JAX include
jax.jit()
,jax.vmap()
, andjax.grad()
.- VJP#
Short for Vector Jacobian Product, also sometimes known as reverse-mode automatic differentiation. For more details, see Vector-Jacobian products (VJPs, aka reverse-mode autodiff). In JAX, VJP is a transformation that is implemented via
jax.vjp()
. See also JVP.- XLA#
Short for Accelerated Linear Algebra, XLA is a domain-specific compiler for linear algebra operations that is the primary backend for JIT-compiled JAX code. See https://www.tensorflow.org/xla/.
- weak type#
A JAX data type that has the same type promotion semantics as Python scalars; see Weakly-typed values in JAX.