JAX: High-Performance Array Computing

Contents

JAX: High-Performance Array Computing#

JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning.

If you’re looking to train neural networks, use Flax and start with its documentation. Some associated tools are Optax and Orbax. For an end-to-end transformer library built on JAX, see MaxText.

Familiar API

JAX provides a familiar NumPy-style API for ease of adoption by researchers and engineers.

Transformations

JAX includes composable function transformations for compilation, batching, automatic differentiation, and parallelization.

Run Anywhere

The same code executes on multiple backends, including CPU, GPU, & TPU

Getting Started
User Guides
Developer Docs

Installing JAX#

Using JAX requires installing two packages: jax, which is pure Python and cross-platform, and jaxlib which contains compiled binaries, and requires different builds for different operating systems and accelerators.

TL;DR For most users, a typical JAX installation may look something like this:

  • CPU-only (Linux/macOS/Windows)

    pip install -U "jax[cpu]"
    
  • GPU (NVIDIA, CUDA 12, x86_64)

    pip install -U "jax[cuda12]"
    
  • GPU (NVIDIA, CUDA 12, x86_64) legacy

You should prefer jax[cuda12], which uses the common CPU jaxlib and adds GPU support as a plugin. The monolithic jax[cuda12_pip] option will be removed in a future JAX release.

pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Supported platforms#

The table below shows all supported platforms and installation options. Check if your setup is supported; and if it says “yes” or “experimental”, then click on the corresponding link to learn how to install JAX in greater detail.

Linux, x86_64

Linux, aarch64

macOS, Intel x86_64, AMD GPU

macOS, Apple Silicon, ARM-based

Windows, x86_64

Windows WSL2, x86_64

CPU

yes

yes

yes

yes

yes

yes

NVIDIA GPU

yes

yes

no

n/a

no

experimental

Google Cloud TPU

yes

n/a

n/a

n/a

n/a

n/a

AMD GPU

experimental

no

no

n/a

no

no

Apple GPU

n/a

no

experimental

experimental

n/a

n/a

CPU#

pip installation: CPU#

Currently, the JAX team releases jaxlib wheels for the following operating systems and architectures:

  • Linux, x86_64

  • macOS, Intel

  • macOS, Apple ARM-based

  • Windows, x86_64 (experimental)

To install a CPU-only version of JAX, which might be useful for doing local development on a laptop, you can run:

pip install --upgrade pip
pip install --upgrade "jax[cpu]"

On Windows, you may also need to install the Microsoft Visual Studio 2019 Redistributable if it is not already installed on your machine.

Other operating systems and architectures require building from source. Trying to pip install on other operating systems and architectures may lead to jaxlib not being installed alongside jax, although jax may successfully install (but fail at runtime).

NVIDIA GPU#

JAX supports NVIDIA GPUs that have SM version 5.2 (Maxwell) or newer. Note that Kepler-series GPUs are no longer supported by JAX since NVIDIA has dropped support for Kepler GPUs in its software.

You must first install the NVIDIA driver. You’re recommended to install the newest driver available from NVIDIA, but the driver version must be >= 525.60.13 for CUDA 12 on Linux.

If you need to use a newer CUDA toolkit with an older driver, for example on a cluster where you cannot update the NVIDIA driver easily, you may be able to use the CUDA forward compatibility packages that NVIDIA provides for this purpose.

pip installation: NVIDIA GPU (CUDA, installed via pip, easier)#

There are two ways to install JAX with NVIDIA GPU support:

  • Using NVIDIA CUDA and cuDNN installed from pip wheels

  • Using a self-installed CUDA/cuDNN

The JAX team strongly recommends installing CUDA and cuDNN using the pip wheels, since it is much easier!

This method is only supported on x86_64, because NVIDIA has not released aarch64 CUDA pip packages.

pip install --upgrade pip

# NVIDIA CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12]"

# Legacy way of NVIDIA CUDA 12 installation. You should prefer `jax[cuda12]`,
# which uses the common CPU jaxlib and adds GPU support as a plugin. The
# monolithic `jax[cuda12_pip]` option will be removed in a future JAX release.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

If JAX detects the wrong version of the NVIDIA CUDA libraries, there are several things you need to check:

  • Make sure that LD_LIBRARY_PATH is not set, since LD_LIBRARY_PATH can override the NVIDIA CUDA libraries.

  • Make sure that the NVIDIA CUDA libraries installed are those requested by JAX. Rerunning the installation command above should work.

pip installation: NVIDIA GPU (CUDA, installed locally, harder)#

If you prefer to use a preinstalled copy of NVIDIA CUDA, you must first install NVIDIA CUDA and cuDNN.

JAX provides pre-built CUDA-compatible wheels for Linux x86_64 only. Other combinations of operating system and architecture are possible, but require building from source (refer to Building from source to learn more}.

You should use an NVIDIA driver version that is at least as new as your NVIDIA CUDA toolkit’s corresponding driver version. If you need to use a newer CUDA toolkit with an older driver, for example on a cluster where you cannot update the NVIDIA driver easily, you may be able to use the CUDA forward compatibility packages that NVIDIA provides for this purpose.

JAX currently ships one CUDA wheel variant:

Built with

Compatible with

CUDA 12.3

CUDA >=12.1

CUDNN 8.9

CUDNN >=8.9, <9.0

NCCL 2.19

NCCL >=2.18

JAX checks the versions of your libraries, and will report an error if they are not sufficiently new. Setting the JAX_SKIP_CUDA_CONSTRAINTS_CHECK environment variable will disable the check, but using older versions of CUDA may lead to errors, or incorrect results.

NCCL is an optional dependency, required only if you are performing multi-GPU computations.

To install, run:

pip install --upgrade pip

# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 8.9 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

These pip installations do not work with Windows, and may fail silently; refer to the table above.

You can find your CUDA version with the command:

nvcc --version

JAX uses LD_LIBRARY_PATH to find CUDA libraries and PATH to find binaries (ptxas, nvlink). Please make sure that these paths point to the correct CUDA installation.

Please let the JAX team know on the GitHub issue tracker if you run into any errors or problems with the pre-built wheels.

NVIDIA GPU Docker containers#

NVIDIA provides the JAX Toolbox containers, which are bleeding edge containers containing nightly releases of jax and some models/frameworks.

JAX nightly installation#

Nightly releases reflect the state of the main JAX repository at the time they are built, and may not pass the full test suite.

  • jax:

pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
  • jaxlib CPU:

pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
  • jaxlib Google Cloud TPU:

pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  • jaxlib NVIDIA GPU (CUDA 12):

pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install -U --pre jax-cuda12-pjrt jax-cuda12-plugin -f https://storage.googleapis.com/jax-releases/jax_cuda_plugin_nightly_releases.html
  • jaxlib NVIDIA GPU (CUDA 12) legacy:

pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html

Google Cloud TPU#

pip installation: Google Cloud TPU#

JAX provides pre-built wheels for Google Cloud TPU. To install JAX along with appropriate versions of jaxlib and libtpu, you can run the following in your cloud TPU VM:

pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

For users of Colab (https://colab.research.google.com/), be sure you are using TPU v2 and not the older, deprecated TPU runtime.

Apple Silicon GPU (ARM-based)#

pip installation: Apple ARM-based Silicon GPUs#

Apple provides an experimental Metal plugin for Apple ARM-based GPU hardware. For details, refer to Apple’s JAX on Metal documentation.

Note: There are several caveats with the Metal plugin:

  • The Metal plugin is new and experimental and has a number of known issues. Please report any issues on the JAX issue tracker.

  • The Metal plugin currently requires very specific versions of jax and jaxlib. This restriction will be relaxed over time as the plugin API matures.

AMD GPU#

JAX has experimental ROCm support. There are two ways to install JAX:

Conda (community-supported)#

Conda installation#

There is a community-supported Conda build of jax. To install it using conda, simply run:

conda install jax -c conda-forge

To install it on a machine with an NVIDIA GPU, run:

conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia

Note the cudatoolkit distributed by conda-forge is missing ptxas, which JAX requires. You must therefore either install the cuda-nvcc package from the nvidia channel, or install CUDA on your machine separately so that ptxas is in your path. The channel order above is important (conda-forge before nvidia).

If you would like to override which release of CUDA is used by JAX, or to install the CUDA build on a machine without GPUs, follow the instructions in the Tips & tricks section of the conda-forge website.

Go to the conda-forge jaxlib and jax repositories for more details.

Building JAX from source#

Refer to Building from source.

Installing older jaxlib wheels#

Due to storage limitations on the Python package index, the JAX team periodically removes older jaxlib wheels from the releases on http://pypi.org/project/jax. These can still be installed directly via the URLs here. For example:

# Install jaxlib on CPU via the wheel archive
pip install jax[cpu]==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_releases.html

# Install the jaxlib 0.3.25 CPU wheel directly
pip install jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_releases.html

For specific older GPU wheels, be sure to use the jax_cuda_releases.html URL; for example

pip install jaxlib==0.3.25+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Quickstart#

JAX a library for array-oriented numerical computation (à la NumPy), with automatic differentiation and JIT compilation to enable high-performance machine learning research.

This document provides a quick overview of essential JAX features, so you can get started with JAX quickly:

  • JAX provides a unified NumPy-like interface to computations that run on CPU, GPU, or TPU, in local or distributed settings.

  • JAX features built-in Just-In-Time (JIT) compilation via Open XLA, an open-source machine learning compiler ecosystem.

  • JAX functions support efficient evaluation of gradients via its automatic differentiation transformations.

  • JAX functions can be automatically vectorized to efficiently map them over arrays representing batches of inputs.

Installation#

JAX can be installed for CPU on Linux, Windows, and macOS directly from the Python Package Index:

pip install "jax[cpu]"

or, for NVIDIA GPU:

pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

For more detailed platform-specific installation information, check out Installing JAX.

JAX as NumPy#

Most JAX usage is through the familiar jax.numpy API, which is typically imported under the jnp alias:

import jax.numpy as jnp

With this import, you can immediately use JAX in a similar manner to typical NumPy programs, including using NumPy-style array creation functions, Python functions and operators, and array attributes and methods:

def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(5.0)
print(selu(x))
[0.        1.05      2.1       3.1499999 4.2      ]

You’ll find a few differences between JAX arrays and NumPy arrays once you begin digging-in; these are explored in 🔪 JAX - The Sharp Bits 🔪.

Just-in-time compilation with jax.jit()#

JAX runs transparently on the GPU or TPU (falling back to CPU if you don’t have one). However, in the above example, JAX is dispatching kernels to the chip one operation at a time. If we have a sequence of operations, we can use the jax.jit() function to compile this sequence of operations together using XLA.

We can use IPython’s %timeit to quickly benchmark our selu function, using block_until_ready() to account for JAX’s dynamic dispatch (See Asynchronous dispatch):

from jax import random

key = random.key(1701)
x = random.normal(key, (1_000_000,))
%timeit selu(x).block_until_ready()
2.94 ms ± 18.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

(notice we’ve used jax.random to generate some random numbers; for details on how to generate random numbers in JAX, check out Pseudorandom numbers).

We can speed the execution of this function with the jax.jit() transformation, which will jit-compile the first time selu is called and will be cached thereafter.

from jax import jit

selu_jit = jit(selu)
_ = selu_jit(x)  # compiles on first call
%timeit selu_jit(x).block_until_ready()
859 µs ± 7.13 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

The above timing represent execution on CPU, but the same code can be run on GPU or TPU, typically for an even greater speedup.

For more on JIT compilation in JAX, check out Just-in-time compilation.

Taking derivatives with jax.grad()#

In addition to transforming functions via JIT compilation, JAX also provides other transformations. One such transformation is jax.grad(), which performs automatic differentiation (autodiff):

from jax import grad

def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))
[0.25       0.19661197 0.10499357]

Let’s verify with finite differences that our result is correct.

def first_finite_differences(f, x, eps=1E-3):
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])

print(first_finite_differences(sum_logistic, x_small))
[0.24998187 0.1965761  0.10502338]

The grad() and jit() transformations compose and can be mixed arbitrarily. In the above example we jitted sum_logistic and then took its derivative. We can go further:

print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))
-0.0353256

Beyond scalar-valued functions, the jax.jacobian() transformation can be used to compute the full Jacobian matrix for vector-valued functions:

from jax import jacobian
print(jacobian(jnp.exp)(x_small))
[[1.        0.        0.       ]
 [0.        2.7182817 0.       ]
 [0.        0.        7.389056 ]]

For more advanced autodiff operations, you can use jax.vjp() for reverse-mode vector-Jacobian products, and jax.jvp() and jax.linearize() for forward-mode Jacobian-vector products. The two can be composed arbitrarily with one another, and with other JAX transformations. For example, jax.jvp() and jax.vjp() are used to define the forward-mode jax.jacfwd() and reverse-mode jax.jacrev() for computing Jacobians in forward- and reverse-mode, respectively. Here’s one way to compose them to make a function that efficiently computes full Hessian matrices:

from jax import jacfwd, jacrev
def hessian(fun):
  return jit(jacfwd(jacrev(fun)))
print(hessian(sum_logistic)(x_small))
[[-0.         -0.         -0.        ]
 [-0.         -0.09085776 -0.        ]
 [-0.         -0.         -0.07996249]]

This kind of composition produces efficient code in practice; this is more-or-less how JAX’s built-in jax.hessian() function is implemented.

For more on automatic differentiation in JAX, check out Automatic differentiation.

Auto-vectorization with jax.vmap()#

Another useful transformation is vmap(), the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of explicitly looping over function calls, it transforms the function into a natively vectorized version for better performance. When composed with jit(), it can be just as performant as manually rewriting your function operate over an extra batch dimension.

We’re going to work with a simple example, and promote matrix-vector products into matrix-matrix products using vmap(). Although this is easy to do by hand in this specific case, the same technique can apply to more complicated functions.

key1, key2 = random.split(key)
mat = random.normal(key1, (150, 100))
batched_x = random.normal(key2, (10, 100))

def apply_matrix(x):
  return jnp.dot(mat, x)

The apply_matrix function maps a vector to a vector, but we may want to apply it row-wise across a matrix. We could do this by looping over the batch dimension in Python, but this usually results in poor performance.

def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
Naively batched
1.07 ms ± 2.76 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

A programmer familiar with the the jnp.dot function might recognize that apply_matrix can be rewritten to avoid explicit looping, using the built-in batching semantics of jnp.dot:

import numpy as np

@jit
def batched_apply_matrix(batched_x):
  return jnp.dot(batched_x, mat.T)

np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
                           batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()
Manually batched
18.5 µs ± 11.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

However, as functions become more complicated, this kind of manual batching becomes more difficult and error-prone. The vmap() transformation is designed to automatically transform a function into a batch-aware version:

from jax import vmap

@jit
def vmap_batched_apply_matrix(batched_x):
  return vmap(apply_matrix)(batched_x)

np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
                           vmap_batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
Auto-vectorized with vmap
26.2 µs ± 130 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

As you would expect, vmap() can be arbitrarily composed with jit(), grad(), and any other JAX transformation.

For more on automatic vectorization in JAX, check out Automatic vectorization.

This is just a taste of what JAX can do. We’re really excited to see what you do with it!

🔪 JAX - The Sharp Bits 🔪#

Open in Colab Open in Kaggle

levskaya@ mattjj@

When walking about the countryside of Italy, the people will not hesitate to tell you that JAX has “una anima di pura programmazione funzionale”.

JAX is a language for expressing and composing transformations of numerical programs. JAX is also able to compile numerical programs for CPU or accelerators (GPU/TPU). JAX works great for many numerical and scientific programs, but only if they are written with certain constraints that we describe below.

import numpy as np
from jax import grad, jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp

🔪 Pure functions#

JAX transformation and compilation are designed to work only on Python functions that are functionally pure: all the input data is passed through the function parameters, all the results are output through the function results. A pure function will always return the same result if invoked with the same inputs.

Here are some examples of functions that are not functionally pure for which JAX behaves differently than the Python interpreter. Note that these behaviors are not guaranteed by the JAX system; the proper way to use JAX is to use it only on functionally pure Python functions.

def impure_print_side_effect(x):
  print("Executing function")  # This is a side-effect
  return x

# The side-effects appear during the first run
print ("First call: ", jit(impure_print_side_effect)(4.))

# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print ("Second call: ", jit(impure_print_side_effect)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.])))
Executing function
First call:  4.0
Second call:  5.0
Executing function
Third call, different type:  [5.]
g = 0.
def impure_uses_globals(x):
  return x + g

# JAX captures the value of the global during the first run
print ("First call: ", jit(impure_uses_globals)(4.))
g = 10.  # Update the global

# Subsequent runs may silently use the cached value of the globals
print ("Second call: ", jit(impure_uses_globals)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))
First call:  4.0
Second call:  5.0
Third call, different type:  [14.]
g = 0.
def impure_saves_global(x):
  global g
  g = x
  return x

# JAX runs once the transformed function with special Traced values for arguments
print ("First call: ", jit(impure_saves_global)(4.))
print ("Saved global: ", g)  # Saved global has an internal JAX value
First call:  4.0
Saved global:  Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>

A Python function can be functionally pure even if it actually uses stateful objects internally, as long as it does not read or write external state:

def pure_uses_internal_state(x):
  state = dict(even=0, odd=0)
  for i in range(10):
    state['even' if i % 2 == 0 else 'odd'] += x
  return state['even'] + state['odd']

print(jit(pure_uses_internal_state)(5.))
50.0

It is not recommended to use iterators in any JAX function you want to jit or in any control-flow primitive. The reason is that an iterator is a python object which introduces state to retrieve the next element. Therefore, it is incompatible with JAX functional programming model. In the code below, there are some examples of incorrect attempts to use iterators with JAX. Most of them return an error, but some give unexpected results.

import jax.numpy as jnp
import jax.lax as lax
from jax import make_jaxpr

# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0

# lax.scan
def func11(arr, extra):
    ones = jnp.ones(arr.shape)
    def body(carry, aelems):
        ae1, ae2 = aelems
        return (carry + ae1 * ae2 + extra, carry)
    return lax.scan(body, 0., (arr, ones))
make_jaxpr(func11)(jnp.arange(16), 5.)
# make_jaxpr(func11)(iter(range(16)), 5.) # throws error

# lax.cond
array_operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
iter_operand = iter(range(10))
# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error
45
0

🔪 In-Place Updates#

In Numpy you’re used to doing this:

numpy_array = np.zeros((3,3), dtype=np.float32)
print("original array:")
print(numpy_array)

# In place, mutating update
numpy_array[1, :] = 1.0
print("updated array:")
print(numpy_array)
original array:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
updated array:
[[0. 0. 0.]
 [1. 1. 1.]
 [0. 0. 0.]]

If we try to update a JAX device array in-place, however, we get an error! (☉_☉)

%xmode Minimal
Exception reporting mode: Minimal
jax_array = jnp.zeros((3,3), dtype=jnp.float32)

# In place update of JAX's array will yield an error!
jax_array[1, :] = 1.0
TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

Allowing mutation of variables in-place makes program analysis and transformation difficult. JAX requires that programs are pure functions.

Instead, JAX offers a functional array update using the .at property on JAX arrays.

️⚠️ inside jit’d code and lax.while_loop or lax.fori_loop the size of slices can’t be functions of argument values but only functions of argument shapes – the slice start indices have no such restriction. See the below Control Flow Section for more information on this limitation.

Array updates: x.at[idx].set(y)#

For example, the update above can be written as:

updated_array = jax_array.at[1, :].set(1.0)
print("updated array:\n", updated_array)
updated array:
 [[0. 0. 0.]
 [1. 1. 1.]
 [0. 0. 0.]]

JAX’s array update functions, unlike their NumPy versions, operate out-of-place. That is, the updated array is returned as a new array and the original array is not modified by the update.

print("original array unchanged:\n", jax_array)
original array unchanged:
 [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]

However, inside jit-compiled code, if the input value x of x.at[idx].set(y) is not reused, the compiler will optimize the array update to occur in-place.

Array updates with other operations#

Indexed array updates are not limited simply to overwriting values. For example, we can perform indexed addition as follows:

print("original array:")
jax_array = jnp.ones((5, 6))
print(jax_array)

new_jax_array = jax_array.at[::2, 3:].add(7.)
print("new array post-addition:")
print(new_jax_array)
original array:
[[1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]]
new array post-addition:
[[1. 1. 1. 8. 8. 8.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 8. 8. 8.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 8. 8. 8.]]

For more details on indexed array updates, see the documentation for the .at property.

🔪 Out-of-Bounds Indexing#

In Numpy, you are used to errors being thrown when you index an array outside of its bounds, like this:

np.arange(10)[11]
IndexError: index 11 is out of bounds for axis 0 with size 10

However, raising an error from code running on an accelerator can be difficult or impossible. Therefore, JAX must choose some non-error behavior for out of bounds indexing (akin to how invalid floating point arithmetic results in NaN). When the indexing operation is an array index update (e.g. index_add or scatter-like primitives), updates at out-of-bounds indices will be skipped; when the operation is an array index retrieval (e.g. NumPy indexing or gather-like primitives) the index is clamped to the bounds of the array since something must be returned. For example, the last value of the array will be returned from this indexing operation:

jnp.arange(10)[11]
Array(9, dtype=int32)

If you would like finer-grained control over the behavior for out-of-bound indices, you can use the optional parameters of ndarray.at; for example:

jnp.arange(10.0).at[11].get()
Array(9., dtype=float32)
jnp.arange(10.0).at[11].get(mode='fill', fill_value=jnp.nan)
Array(nan, dtype=float32)

Note that due to this behavior for index retrieval, functions like jnp.nanargmin and jnp.nanargmax return -1 for slices consisting of NaNs whereas Numpy would throw an error.

Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) will not preserve the semantics of out of bounds indexing. Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of undefined behavior.

🔪 Non-array inputs: NumPy vs. JAX#

NumPy is generally happy accepting Python lists or tuples as inputs to its API functions:

np.sum([1, 2, 3])
6

JAX departs from this, generally returning a helpful error:

jnp.sum([1, 2, 3])
TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.

This is a deliberate design choice, because passing lists or tuples to traced functions can lead to silent performance degradation that might otherwise be difficult to detect.

For example, consider the following permissive version of jnp.sum that allows list inputs:

def permissive_sum(x):
  return jnp.sum(jnp.array(x))

x = list(range(10))
permissive_sum(x)
Array(45, dtype=int32)

The output is what we would expect, but this hides potential performance issues under the hood. In JAX’s tracing and JIT compilation model, each element in a Python list or tuple is treated as a separate JAX variable, and individually processed and pushed to device. This can be seen in the jaxpr for the permissive_sum function above:

make_jaxpr(permissive_sum)(x)
{ lambda ; a:i32[] b:i32[] c:i32[] d:i32[] e:i32[] f:i32[] g:i32[] h:i32[] i:i32[]
    j:i32[]. let
    k:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
    l:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    m:i32[] = convert_element_type[new_dtype=int32 weak_type=False] c
    n:i32[] = convert_element_type[new_dtype=int32 weak_type=False] d
    o:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
    p:i32[] = convert_element_type[new_dtype=int32 weak_type=False] f
    q:i32[] = convert_element_type[new_dtype=int32 weak_type=False] g
    r:i32[] = convert_element_type[new_dtype=int32 weak_type=False] h
    s:i32[] = convert_element_type[new_dtype=int32 weak_type=False] i
    t:i32[] = convert_element_type[new_dtype=int32 weak_type=False] j
    u:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] k
    v:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] l
    w:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] m
    x:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] n
    y:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] o
    z:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] p
    ba:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] q
    bb:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] r
    bc:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] s
    bd:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] t
    be:i32[10] = concatenate[dimension=0] u v w x y z ba bb bc bd
    bf:i32[] = reduce_sum[axes=(0,)] be
  in (bf,) }

Each entry of the list is handled as a separate input, resulting in a tracing & compilation overhead that grows linearly with the size of the list. To prevent surprises like this, JAX avoids implicit conversions of lists and tuples to arrays.

If you would like to pass a tuple or list to a JAX function, you can do so by first explicitly converting it to an array:

jnp.sum(jnp.array(x))
Array(45, dtype=int32)

🔪 Random Numbers#

If all scientific papers whose results are in doubt because of bad rand()s were to disappear from library shelves, there would be a gap on each shelf about as big as your fist. - Numerical Recipes

RNGs and State#

You’re used to stateful pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness:

print(np.random.random())
print(np.random.random())
print(np.random.random())
0.4157599213412506
0.09201806380722433
0.44769303473910294

Underneath the hood, numpy uses the Mersenne Twister PRNG to power its pseudorandom functions. The PRNG has a period of \(2^{19937}-1\) and at any point can be described by 624 32-bit unsigned ints and a position indicating how much of this “entropy” has been used up.

np.random.seed(0)
rng_state = np.random.get_state()
# print(rng_state)
# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,
#       2481403966, 4042607538,  337614300, ... 614 more numbers...,
#       3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)

This pseudorandom state vector is automagically updated behind the scenes every time a random number is needed, “consuming” 2 of the uint32s in the Mersenne twister state vector:

_ = np.random.uniform()
rng_state = np.random.get_state()
#print(rng_state)
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
#       ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)

# Let's exhaust the entropy in this PRNG statevector
for i in range(311):
  _ = np.random.uniform()
rng_state = np.random.get_state()
#print(rng_state)
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
#       ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)

# Next call iterates the RNG state for a new batch of fake "entropy".
_ = np.random.uniform()
rng_state = np.random.get_state()
# print(rng_state)
# --> ('MT19937', array([1499117434, 2949980591, 2242547484,
#      4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)

The problem with magic PRNG state is that it’s hard to reason about how it’s being used and updated across different threads, processes, and devices, and it’s very easy to screw up when the details of entropy production and consumption are hidden from the end user.

The Mersenne Twister PRNG is also known to have a number of problems, it has a large 2.5kB state size, which leads to problematic initialization issues. It fails modern BigCrush tests, and is generally slow.

JAX PRNG#

JAX instead implements an explicit PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern Threefry counter-based PRNG that’s splittable. That is, its design allows us to fork the PRNG state into new PRNGs for use with parallel stochastic generation.

The random state is described by a special array element that we call a key:

from jax import random
key = random.key(0)
key
Array((), dtype=key<fry>) overlaying:
[0 0]

JAX’s random functions produce pseudorandom numbers from the PRNG state, but do not change the state!

Reusing the same state will cause sadness and monotony, depriving the end user of lifegiving chaos:

print(random.normal(key, shape=(1,)))
print(key)
# No no no!
print(random.normal(key, shape=(1,)))
print(key)
[-0.20584226]
Array((), dtype=key<fry>) overlaying:
[0 0]
[-0.20584226]
Array((), dtype=key<fry>) overlaying:
[0 0]

Instead, we split the PRNG to get usable subkeys every time we need a new pseudorandom number:

print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print("    \---SPLIT --> new key   ", key)
print("             \--> new subkey", subkey, "--> normal", normal_pseudorandom)
old key Array((), dtype=key<fry>) overlaying:
[0 0]
    \---SPLIT --> new key    Array((), dtype=key<fry>) overlaying:
[4146024105  967050713]
             \--> new subkey Array((), dtype=key<fry>) overlaying:
[2718843009 1272950319] --> normal [-1.2515389]

We propagate the key and make new subkeys whenever we need a new random number:

print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print("    \---SPLIT --> new key   ", key)
print("             \--> new subkey", subkey, "--> normal", normal_pseudorandom)
old key Array((), dtype=key<fry>) overlaying:
[4146024105  967050713]
    \---SPLIT --> new key    Array((), dtype=key<fry>) overlaying:
[2384771982 3928867769]
             \--> new subkey Array((), dtype=key<fry>) overlaying:
[1278412471 2182328957] --> normal [-0.58665055]

We can generate more than one subkey at a time:

key, *subkeys = random.split(key, 4)
for subkey in subkeys:
  print(random.normal(subkey, shape=(1,)))
[-0.37533438]
[0.98645043]
[0.14553197]

🔪 Control Flow#

✔ python control_flow + autodiff ✔#

If you just want to apply grad to your python functions, you can use regular python control-flow constructs with no problems, as if you were using Autograd (or Pytorch or TF Eager).

def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

print(grad(f)(2.))  # ok!
print(grad(f)(4.))  # ok!
12.0
-4.0

python control flow + JIT#

Using control flow with jit is more complicated, and by default it has more constraints.

This works:

@jit
def f(x):
  for i in range(3):
    x = 2 * x
  return x

print(f(3))
24

So does this:

@jit
def g(x):
  y = 0.
  for i in range(x.shape[0]):
    y = y + x[i]
  return y

print(g(jnp.array([1., 2., 3.])))
6.0

But this doesn’t, at least by default:

@jit
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

# This will fail!
f(2)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function f at /tmp/ipykernel_3575/3402096563.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

What gives!?

When we jit-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don’t have to re-compile on each function evaluation.

For example, if we evaluate an @jit function on the array jnp.array([1., 2., 3.], jnp.float32), we might want to compile code that we can reuse to evaluate the function on jnp.array([4., 5., 6.], jnp.float32) to save on compile time.

To get a view of your Python code that is valid for many different argument values, JAX traces it on abstract values that represent sets of possible inputs. There are multiple different levels of abstraction, and different transformations use different abstraction levels.

By default, jit traces your code on the ShapedArray abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value ShapedArray((3,), jnp.float32), we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.

But there’s a tradeoff here: if we trace a Python function on a ShapedArray((), jnp.float32) that isn’t committed to a specific concrete value, when we hit a line like if x < 3, the expression x < 3 evaluates to an abstract ShapedArray((), jnp.bool_) that represents the set {True, False}. When Python attempts to coerce that to a concrete True or False, we get an error: we don’t know which branch to take, and can’t continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace.

The good news is that you can control this tradeoff yourself. By having jit trace on more refined abstract values, you can relax the traceability constraints. For example, using the static_argnums argument to jit, we can specify to trace on concrete values of some arguments. Here’s that example function again:

def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

f = jit(f, static_argnums=(0,))

print(f(2.))
12.0

Here’s another example, this time involving a loop:

def f(x, n):
  y = 0.
  for i in range(n):
    y = y + x[i]
  return y

f = jit(f, static_argnums=(1,))

f(jnp.array([2., 3., 4.]), 2)
Array(5., dtype=float32)

In effect, the loop gets statically unrolled. JAX can also trace at higher levels of abstraction, like Unshaped, but that’s not currently the default for any transformation

️⚠️ functions with argument-value dependent shapes

These control-flow issues also come up in a more subtle way: numerical functions we want to jit can’t specialize the shapes of internal arrays on argument values (specializing on argument shapes is ok). As a trivial example, let’s make a function whose output happens to depend on the input variable length.

def example_fun(length, val):
  return jnp.ones((length,)) * val
# un-jit'd works fine
print(example_fun(5, 4))
[4. 4. 4. 4. 4.]
bad_example_jit = jit(example_fun)
# this will fail:
bad_example_jit(10, 4)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function example_fun at /tmp/ipykernel_3575/1210496444.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument length.
# static_argnums tells JAX to recompile on changes at these argument positions:
good_example_jit = jit(example_fun, static_argnums=(0,))
# first compile
print(good_example_jit(10, 4))
# recompiles
print(good_example_jit(5, 4))
[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]
[4. 4. 4. 4. 4.]

static_argnums can be handy if length in our example rarely changes, but it would be disastrous if it changed a lot!

Lastly, if your function has global side-effects, JAX’s tracer can cause weird things to happen. A common gotcha is trying to print arrays inside jit’d functions:

@jit
def f(x):
  print(x)
  y = 2 * x
  print(y)
  return y
f(2)
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
Array(4, dtype=int32, weak_type=True)

Structured control flow primitives#

There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that’s traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives:

  • lax.cond differentiable

  • lax.while_loop fwd-mode-differentiable

  • lax.fori_loop fwd-mode-differentiable in general; fwd and rev-mode differentiable if endpoints are static.

  • lax.scan differentiable

cond#

python equivalent:

def cond(pred, true_fun, false_fun, operand):
  if pred:
    return true_fun(operand)
  else:
    return false_fun(operand)
from jax import lax

operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, operand)
# --> array([1.], dtype=float32)
lax.cond(False, lambda x: x+1, lambda x: x-1, operand)
# --> array([-1.], dtype=float32)
Array([-1.], dtype=float32)

jax.lax provides two other functions that allow branching on dynamic predicates:

  • lax.select is like a batched version of lax.cond, with the choices expressed as pre-computed arrays rather than as functions.

  • lax.switch is like lax.cond, but allows switching between any number of callable choices.

In addition, jax.numpy provides several numpy-style interfaces to these functions:

  • jnp.where with three arguments is the numpy-style wrapper of lax.select.

  • jnp.piecewise is a numpy-style wrapper of lax.switch, but switches on a list of boolean conditions rather than a single scalar index.

  • jnp.select has an API similar to jnp.piecewise, but the choices are given as pre-computed arrays rather than as functions. It is implemented in terms of multiple calls to lax.select.

while_loop#

python equivalent:

def while_loop(cond_fun, body_fun, init_val):
  val = init_val
  while cond_fun(val):
    val = body_fun(val)
  return val
init_val = 0
cond_fun = lambda x: x<10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)
# --> array(10, dtype=int32)
Array(10, dtype=int32, weak_type=True)
fori_loop#

python equivalent:

def fori_loop(start, stop, body_fun, init_val):
  val = init_val
  for i in range(start, stop):
    val = body_fun(i, val)
  return val
init_val = 0
start = 0
stop = 10
body_fun = lambda i,x: x+i
lax.fori_loop(start, stop, body_fun, init_val)
# --> array(45, dtype=int32)
Array(45, dtype=int32, weak_type=True)
Summary#
\[\begin{split} \begin{array} {r|rr} \hline \ \textrm{construct} & \textrm{jit} & \textrm{grad} \\ \hline \ \textrm{if} & ❌ & ✔ \\ \textrm{for} & ✔* & ✔\\ \textrm{while} & ✔* & ✔\\ \textrm{lax.cond} & ✔ & ✔\\ \textrm{lax.while_loop} & ✔ & \textrm{fwd}\\ \textrm{lax.fori_loop} & ✔ & \textrm{fwd}\\ \textrm{lax.scan} & ✔ & ✔\\ \hline \end{array} \end{split}\]

\(\ast\) = argument-value-independent loop condition - unrolls the loop

🔪 Dynamic Shapes#

JAX code used within transforms like jax.jit, jax.vmap, jax.grad, etc. requires all output arrays and intermediate arrays to have static shape: that is, the shape cannot depend on values within other arrays.

For example, if you were implementing your own version of jnp.nansum, you might start with something like this:

def nansum(x):
  mask = ~jnp.isnan(x)  # boolean mask selecting non-nan values
  x_without_nans = x[mask]
  return x_without_nans.sum()

Outside JIT and other transforms, this works as expected:

x = jnp.array([1, 2, jnp.nan, 3, 4])
print(nansum(x))
10.0

If you attempt to apply jax.jit or another transform to this function, it will error:

jax.jit(nansum)(x)
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[5])

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError

The problem is that the size of x_without_nans is dependent on the values within x, which is another way of saying its size is dynamic. Often in JAX it is possible to work-around the need for dynamically-sized arrays via other means. For example, here it is possible to use the three-argument form of jnp.where to replace the NaN values with zeros, thus computing the same result while avoiding dynamic shapes:

@jax.jit
def nansum_2(x):
  mask = ~jnp.isnan(x)  # boolean mask selecting non-nan values
  return jnp.where(mask, x, 0).sum()

print(nansum_2(x))
10.0

Similar tricks can be played in other situations where dynamically-shaped arrays occur.

🔪 NaNs#

Debugging NaNs#

If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by:

  • setting the JAX_DEBUG_NANS=True environment variable;

  • adding jax.config.update("jax_debug_nans", True) near the top of your main file;

  • adding jax.config.parse_flags_with_absl() to your main file, then set the option using a command-line flag like --jax_debug_nans=True;

This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an @jit. For code under an @jit, the output of every @jit function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of @jit at a time.

There could be tricky situations that arise, like nans that only occur under a @jit but don’t get produced in de-optimized mode. In that case you’ll see a warning message print out but your code will continue to execute.

If the nans are being produced in the backward pass of a gradient evaluation, when an exception is raised several frames up in the stack trace you will be in the backward_pass function, which is essentially a simple jaxpr interpreter that walks the sequence of primitive operations in reverse. In the example below, we started an ipython repl with the command line env JAX_DEBUG_NANS=True ipython, then ran this:

In [1]: import jax.numpy as jnp

In [2]: jnp.divide(0., 0.)
---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
<ipython-input-2-f2e2c413b437> in <module>()
----> 1 jnp.divide(0., 0.)

.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
    343     return floor_divide(x1, x2)
    344   else:
--> 345     return true_divide(x1, x2)
    346
    347

.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
    332   x1, x2 = _promote_shapes(x1, x2)
    333   return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334                  lax.convert_element_type(x2, result_dtype))
    335
    336

.../jax/jax/lax.pyc in div(x, y)
    244 def div(x, y):
    245   r"""Elementwise division: :math:`x \over y`."""
--> 246   return div_p.bind(x, y)
    247
    248 def rem(x, y):

... stack trace ...

.../jax/jax/interpreters/xla.pyc in handle_result(device_buffer)
    103         py_val = device_buffer.to_py()
    104         if np.any(np.isnan(py_val)):
--> 105           raise FloatingPointError("invalid value")
    106         else:
    107           return Array(device_buffer, *result_shape)

FloatingPointError: invalid value

The nan generated was caught. By running %debug, we can get a post-mortem debugger. This also works with functions under @jit, as the example below shows.

In [4]: from jax import jit

In [5]: @jit
   ...: def f(x, y):
   ...:     a = x * y
   ...:     b = (x + y) / (x - y)
   ...:     c = a + 2
   ...:     return a + b * c
   ...:

In [6]: x = jnp.array([2., 0.])

In [7]: y = jnp.array([3., 0.])

In [8]: f(x, y)
Invalid value encountered in the output of a jit function. Calling the de-optimized version.
---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
<ipython-input-8-811b7ddb3300> in <module>()
----> 1 f(x, y)

 ... stack trace ...

<ipython-input-5-619b39acbaac> in f(x, y)
      2 def f(x, y):
      3     a = x * y
----> 4     b = (x + y) / (x - y)
      5     c = a + 2
      6     return a + b * c

.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
    343     return floor_divide(x1, x2)
    344   else:
--> 345     return true_divide(x1, x2)
    346
    347

.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
    332   x1, x2 = _promote_shapes(x1, x2)
    333   return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334                  lax.convert_element_type(x2, result_dtype))
    335
    336

.../jax/jax/lax.pyc in div(x, y)
    244 def div(x, y):
    245   r"""Elementwise division: :math:`x \over y`."""
--> 246   return div_p.bind(x, y)
    247
    248 def rem(x, y):

 ... stack trace ...

When this code sees a nan in the output of an @jit function, it calls into the de-optimized code, so we still get a clear stack trace. And we can run a post-mortem debugger with %debug to inspect all the values to figure out the error.

⚠️ You shouldn’t have the NaN-checker on if you’re not debugging, as it can introduce lots of device-host round-trips and performance regressions!

⚠️ The NaN-checker doesn’t work with pmap. To debug nans in pmap code, one thing to try is replacing pmap with vmap.

🔪 Double (64bit) precision#

At the moment, JAX by default enforces single-precision numbers to mitigate the Numpy API’s tendency to aggressively promote operands to double. This is the desired behavior for many machine-learning applications, but it may catch you by surprise!

x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype
/tmp/ipykernel_3575/1258726447.py:1: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'>  is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
dtype('float32')

To use double-precision numbers, you need to set the jax_enable_x64 configuration variable at startup.

There are a few ways to do this:

  1. You can enable 64-bit mode by setting the environment variable JAX_ENABLE_X64=True.

  2. You can manually set the jax_enable_x64 configuration flag at startup:

    # again, this only works on startup!
    import jax
    jax.config.update("jax_enable_x64", True)
    
  3. You can parse command-line flags with absl.app.run(main)

    import jax
    jax.config.config_with_absl()
    
  4. If you want JAX to run absl parsing for you, i.e. you don’t want to do absl.app.run(main), you can instead use

    import jax
    if __name__ == '__main__':
      # calls jax.config.config_with_absl() *and* runs absl parsing
      jax.config.parse_flags_with_absl()
    

Note that #2-#4 work for any of JAX’s configuration options.

We can then confirm that x64 mode is enabled:

import jax.numpy as jnp
from jax import random
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype # --> dtype('float64')
/tmp/ipykernel_3575/2819792939.py:3: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'>  is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
dtype('float32')

Caveats#

⚠️ XLA doesn’t support 64-bit convolutions on all backends!

🔪 Miscellaneous Divergences from NumPy#

While jax.numpy makes every attempt to replicate the behavior of numpy’s API, there do exist corner cases where the behaviors differ. Many such cases are discussed in detail in the sections above; here we list several other known places where the APIs diverge.

  • For binary operations, JAX’s type promotion rules differ somewhat from those used by NumPy. See Type Promotion Semantics for more details.

  • When performing unsafe type casts (i.e. casts in which the target dtype cannot represent the input value), JAX’s behavior may be backend dependent, and in general may diverge from NumPy’s behavior. Numpy allows control over the result in these scenarios via the casting argument (see np.ndarray.astype); JAX does not provide any such configuration, instead directly inheriting the behavior of XLA:ConvertElementType.

    Here is an example of an unsafe cast with differing results between NumPy and JAX:

    >>> np.arange(254.0, 258.0).astype('uint8')
    array([254, 255,   0,   1], dtype=uint8)
    
    >>> jnp.arange(254.0, 258.0).astype('uint8')
    Array([254, 255, 255, 255], dtype=uint8)
    

    This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.

Fin.#

If something’s not covered here that has caused you weeping and gnashing of teeth, please let us know and we’ll extend these introductory advisos!

JAX Frequently Asked Questions (FAQ)#

We are collecting answers to frequently asked questions here. Contributions welcome!

jit changes the behavior of my function#

If you have a Python function that changes behavior after using jax.jit(), perhaps your function uses global state, or has side-effects. In the following code, the impure_func uses the global y and has a side-effect due to print:

y = 0

# @jit   # Different behavior with jit
def impure_func(x):
  print("Inside:", y)
  return x + y

for y in range(3):
  print("Result:", impure_func(y))

Without jit the output is:

Inside: 0
Result: 0
Inside: 1
Result: 2
Inside: 2
Result: 4

and with jit it is:

Inside: 0
Result: 0
Result: 1
Result: 2

For jax.jit(), the function is executed once using the Python interpreter, at which time the Inside printing happens, and the first value of y is observed. Then, the function is compiled and cached, and executed multiple times with different values of x, but with the same first value of y.

Additional reading:

jit changes the exact numerics of outputs#

Sometimes users are surprised by the fact that wrapping a function with jit() can change the function’s outputs. For example:

>>> from jax import jit
>>> import jax.numpy as jnp
>>> def f(x):
...   return jnp.log(jnp.sqrt(x))
>>> x = jnp.pi
>>> print(f(x))
0.572365
>>> print(jit(f)(x))
0.5723649

This slight difference in output comes from optimizations within the XLA compiler: during compilation, XLA will sometimes rearrange or elide certain operations to make the overall computation more efficient.

In this case, XLA utilizes the properties of the logarithm to replace log(sqrt(x)) with 0.5 * log(x), which is a mathematically identical expression that can be computed more efficiently than the original. The difference in output comes from the fact that floating point arithmetic is only a close approximation of real math, so different ways of computing the same expression may have subtly different results.

Other times, XLA’s optimizations may lead to even more drastic differences. Consider the following example:

>>> def f(x):
...   return jnp.log(jnp.exp(x))
>>> x = 100.0
>>> print(f(x))
inf
>>> print(jit(f)(x))
100.0

In non-JIT-compiled op-by-op mode, the result is inf because jnp.exp(x) overflows and returns inf. Under JIT, however, XLA recognizes that log is the inverse of exp, and removes the operations from the compiled function, simply returning the input. In this case, JIT compilation produces a more accurate floating point approximation of the real result.

Unfortunately the full list of XLA’s algebraic simplifications is not well documented, but if you’re familiar with C++ and curious about what types of optimizations the XLA compiler makes, you can see them in the source code: algebraic_simplifier.cc.

jit decorated function is very slow to compile#

If your jit decorated function takes tens of seconds (or more!) to run the first time you call it, but executes quickly when called again, JAX is taking a long time to trace or compile your code.

This is usually a sign that calling your function generates a large amount of code in JAX’s internal representation, typically because it makes heavy use of Python control flow such as for loops. For a handful of loop iterations, Python is OK, but if you need many loop iterations, you should rewrite your code to make use of JAX’s structured control flow primitives (such as lax.scan()) or avoid wrapping the loop with jit (you can still use jit decorated functions inside the loop).

If you’re not sure if this is the problem, you can try running jax.make_jaxpr() on your function. You can expect slow compilation if the output is many hundreds or thousands of lines long.

Sometimes it isn’t obvious how to rewrite your code to avoid Python loops because your code makes use of many arrays with different shapes. The recommended solution in this case is to make use of functions like jax.numpy.where() to do your computation on padded arrays with fixed shape.

If your functions are slow to compile for another reason, please open an issue on GitHub.

How to use jit with methods?#

Most examples of jax.jit() concern decorating stand-alone Python functions, but decorating a method within a class introduces some complication. For example, consider the following simple class, where we’ve used a standard jit() annotation on a method:

>>> import jax.numpy as jnp
>>> from jax import jit

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   @jit  # <---- How to do this correctly?
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y

However, this approach will result in an error when you attempt to call this method:

>>> c = CustomClass(2, True)
>>> c.calc(3)  
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
  File "<stdin>", line 1, in <module
TypeError: Argument '<CustomClass object at 0x7f7dd4125890>' of type <class 'CustomClass'> is not a valid JAX type.

The problem is that the first argument to the function is self, which has type CustomClass, and JAX does not know how to handle this type. There are three basic strategies we might use in this case, and we’ll discuss them below.

Strategy 1: JIT-compiled helper function#

The most straightforward approach is to create a helper function external to the class that can be JIT-decorated in the normal way. For example:

>>> from functools import partial

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   def calc(self, y):
...     return _calc(self.mul, self.x, y)

>>> @partial(jit, static_argnums=0)
... def _calc(mul, x, y):
...   if mul:
...     return x * y
...   return y

The result will work as expected:

>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6

The benefit of such an approach is that it is simple, explicit, and it avoids the need to teach JAX how to handle objects of type CustomClass. However, you may wish to keep all the method logic in the same place.

Strategy 2: Marking self as static#

Another common pattern is to use static_argnums to mark the self argument as static. But this must be done with care to avoid unexpected results. You may be tempted to simply do this:

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   # WARNING: this example is broken, as we'll see below. Don't copy & paste!
...   @partial(jit, static_argnums=0)
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y

If you call the method, it will no longer raise an error:

>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6

However, there is a catch: if you mutate the object after the first method call, the subsequent method call may return an incorrect result:

>>> c.mul = False
>>> print(c.calc(3))  # Should print 3
6

Why is this? When you mark an object as static, it will effectively be used as a dictionary key in JIT’s internal compilation cache, meaning its hash (i.e. hash(obj)) equality (i.e. obj1 == obj2) and object identity (i.e. obj1 is obj2) will be assumed to have consistent behavior. The default __hash__ for a custom object is its object ID, and so JAX has no way of knowing that a mutated object should trigger a re-compilation.

You can partially address this by defining an appropriate __hash__ and __eq__ methods for your object; for example:

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   @partial(jit, static_argnums=0)
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y
...
...   def __hash__(self):
...     return hash((self.x, self.mul))
...
...   def __eq__(self, other):
...     return (isinstance(other, CustomClass) and
...             (self.x, self.mul) == (other.x, other.mul))

(see the object.__hash__() documentation for more discussion of the requirements when overriding __hash__).

This should work correctly with JIT and other transforms so long as you never mutate your object. Mutations of objects used as hash keys lead to several subtle problems, which is why for example mutable Python containers (e.g. dict, list) don’t define __hash__, while their immutable counterparts (e.g. tuple) do.

If your class relies on in-place mutations (such as setting self.attr = ... within its methods), then your object is not really “static” and marking it as such may lead to problems. Fortunately, there’s another option for this case.

Strategy 3: Making CustomClass a PyTree#

The most flexible approach to correctly JIT-compiling a class method is to register the type as a custom PyTree object; see Extending pytrees. This lets you specify exactly which components of the class should be treated as static and which should be treated as dynamic. Here’s how it might look:

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   @jit
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y
...
...   def _tree_flatten(self):
...     children = (self.x,)  # arrays / dynamic values
...     aux_data = {'mul': self.mul}  # static values
...     return (children, aux_data)
...
...   @classmethod
...   def _tree_unflatten(cls, aux_data, children):
...     return cls(*children, **aux_data)

>>> from jax import tree_util
>>> tree_util.register_pytree_node(CustomClass,
...                                CustomClass._tree_flatten,
...                                CustomClass._tree_unflatten)

This is certainly more involved, but it solves all the issues associated with the simpler approaches used above:

>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6

>>> c.mul = False  # mutation is detected
>>> print(c.calc(3))
3

>>> c = CustomClass(jnp.array(2), True)  # non-hashable x is supported
>>> print(c.calc(3))
6

So long as your tree_flatten and tree_unflatten functions correctly handle all relevant attributes in the class, you should be able to use objects of this type directly as arguments to JIT-compiled functions, without any special annotations.

Controlling data and computation placement on devices#

Let’s first look at the principles of data and computation placement in JAX.

In JAX, the computation follows data placement. JAX arrays have two placement properties: 1) the device where the data resides; and 2) whether it is committed to the device or not (the data is sometimes referred to as being sticky to the device).

By default, JAX arrays are placed uncommitted on the default device (jax.devices()[0]), which is the first GPU or TPU by default. If no GPU or TPU is present, jax.devices()[0] is the CPU. The default device can be temporarily overridden with the jax.default_device() context manager, or set for the whole process by setting the environment variable JAX_PLATFORMS or the absl flag --jax_platforms to “cpu”, “gpu”, or “tpu” (JAX_PLATFORMS can also be a list of platforms, which determines which platforms are available in priority order).

>>> from jax import numpy as jnp
>>> print(jnp.ones(3).devices())  
{CudaDevice(id=0)}

Computations involving uncommitted data are performed on the default device and the results are uncommitted on the default device.

Data can also be placed explicitly on a device using jax.device_put() with a device parameter, in which case the data becomes committed to the device:

>>> import jax
>>> from jax import device_put
>>> arr = device_put(1, jax.devices()[2])  
>>> print(arr.devices())  
{CudaDevice(id=2)}

Computations involving some committed inputs will happen on the committed device and the result will be committed on the same device. Invoking an operation on arguments that are committed to more than one device will raise an error.

You can also use jax.device_put() without a device parameter. If the data is already on a device (committed or not), it’s left as-is. If the data isn’t on any device—that is, it’s a regular Python or NumPy value—it’s placed uncommitted on the default device.

Jitted functions behave like any other primitive operations—they will follow the data and will show errors if invoked on data committed on more than one device.

(Before PR #6002 in March 2021 there was some laziness in creation of array constants, so that jax.device_put(jnp.zeros(...), jax.devices()[1]) or similar would actually create the array of zeros on jax.devices()[1], instead of creating the array on the default device then moving it. But this optimization was removed so as to simplify the implementation.)

(As of April 2020, jax.jit() has a device parameter that affects the device placement. That parameter is experimental, is likely to be removed or changed, and its use is not recommended.)

For a worked-out example, we recommend reading through test_computation_follows_data in multi_device_test.py.

Benchmarking JAX code#

You just ported a tricky function from NumPy/SciPy to JAX. Did that actually speed things up?

Keep in mind these important differences from NumPy when measuring the speed of code using JAX:

  1. JAX code is Just-In-Time (JIT) compiled. Most code written in JAX can be written in such a way that it supports JIT compilation, which can make it run much faster (see To JIT or not to JIT). To get maximum performance from JAX, you should apply jax.jit() on your outer-most function calls.

    Keep in mind that the first time you run JAX code, it will be slower because it is being compiled. This is true even if you don’t use jit in your own code, because JAX’s builtin functions are also JIT compiled.

  2. JAX has asynchronous dispatch. This means that you need to call .block_until_ready() to ensure that computation has actually happened (see Asynchronous dispatch).

  3. JAX by default only uses 32-bit dtypes. You may want to either explicitly use 32-bit dtypes in NumPy or enable 64-bit dtypes in JAX (see Double (64 bit) precision) for a fair comparison.

  4. Transferring data between CPUs and accelerators takes time. If you only want to measure how long it takes to evaluate a function, you may want to transfer data to the device on which you want to run it first (see Controlling data and computation placement on devices).

Here’s an example of how to put together all these tricks into a microbenchmark for comparing JAX versus NumPy, making using of IPython’s convenient %time and %timeit magics:

import numpy as np
import jax.numpy as jnp
import jax

def f(x):  # function we're benchmarking (works in both NumPy & JAX)
  return x.T @ (x - x.mean(axis=0))

x_np = np.ones((1000, 1000), dtype=np.float32)  # same as JAX default dtype
%timeit f(x_np)  # measure NumPy runtime

%time x_jax = jax.device_put(x_np)  # measure JAX device transfer time
f_jit = jax.jit(f)
%time f_jit(x_jax).block_until_ready()  # measure JAX compilation time
%timeit f_jit(x_jax).block_until_ready()  # measure JAX runtime

When run with a GPU in Colab, we see:

  • NumPy takes 16.2 ms per evaluation on the CPU

  • JAX takes 1.26 ms to copy the NumPy arrays onto the GPU

  • JAX takes 193 ms to compile the function

  • JAX takes 485 µs per evaluation on the GPU

In this case, we see that once the data is transferred and the function is compiled, JAX on the GPU is about 30x faster for repeated evaluations.

Is this a fair comparison? Maybe. The performance that ultimately matters is for running full applications, which inevitably include some amount of both data transfer and compilation. Also, we were careful to pick large enough arrays (1000x1000) and an intensive enough computation (the @ operator is performing matrix-matrix multiplication) to amortize the increased overhead of JAX/accelerators vs NumPy/CPU. For example, if we switch this example to use 10x10 input instead, JAX/GPU runs 10x slower than NumPy/CPU (100 µs vs 10 µs).

Is JAX faster than NumPy?#

One question users frequently attempt to answer with such benchmarks is whether JAX is faster than NumPy; due to the difference in the two packages, there is not a simple answer.

Broadly speaking:

  • NumPy operations are executed eagerly, synchronously, and only on CPU.

  • JAX operations may be executed eagerly or after compilation (if inside jit()); they are dispatched asynchronously (see Asynchronous dispatch); and they can be executed on CPU, GPU, or TPU, each of which have vastly different and continuously evolving performance characteristics.

These architectural differences make meaningful direct benchmark comparisons between NumPy and JAX difficult.

Additionally, these differences have led to different engineering focus between the packages: for example, NumPy has put significant effort into decreasing the per-call dispatch overhead for individual array operations, because in NumPy’s computational model that overhead cannot be avoided. JAX, on the other hand, has several ways to avoid dispatch overhead (e.g. JIT compilation, asynchronous dispatch, batching transforms, etc.), and so reducing per-call overhead has been less of a priority.

Keeping all that in mind, in summary: if you’re doing microbenchmarks of individual array operations on CPU, you can generally expect NumPy to outperform JAX due to its lower per-operation dispatch overhead. If you’re running your code on GPU or TPU, or are benchmarking more complicated JIT-compiled sequences of operations on CPU, you can generally expect JAX to outperform NumPy.

Different kinds of JAX values#

In the process of transforming functions, JAX replaces some function arguments with special tracer values.

You could see this if you use a print statement:

def func(x):
  print(x)
  return jnp.cos(x)

res = jax.jit(func)(0.)

The above code does return the correct value 1. but it also prints Traced<ShapedArray(float32[])> for the value of x. Normally, JAX handles these tracer values internally in a transparent way, e.g., in the numeric JAX primitives that are used to implement the jax.numpy functions. This is why jnp.cos works in the example above.

More precisely, a tracer value is introduced for the argument of a JAX-transformed function, except the arguments identified by special parameters such as static_argnums for jax.jit() or static_broadcasted_argnums for jax.pmap(). Typically, computations that involve at least a tracer value will produce a tracer value. Besides tracer values, there are regular Python values: values that are computed outside JAX transformations, or arise from above-mentioned static arguments of certain JAX transformations, or computed solely from other regular Python values. These are the values that are used everywhere in absence of JAX transformations.

A tracer value carries an abstract value, e.g., ShapedArray with information about the shape and dtype of an array. We will refer here to such tracers as abstract tracers. Some tracers, e.g., those that are introduced for arguments of autodiff transformations, carry ConcreteArray abstract values that actually include the regular array data, and are used, e.g., for resolving conditionals. We will refer here to such tracers as concrete tracers. Tracer values computed from these concrete tracers, perhaps in combination with regular values, result in concrete tracers. A concrete value is either a regular value or a concrete tracer.

Most often values computed from tracer values are themselves tracer values. There are very few exceptions, when a computation can be entirely done using the abstract value carried by a tracer, in which case the result can be a regular value. For example, getting the shape of a tracer with ShapedArray abstract value. Another example is when explicitly casting a concrete tracer value to a regular type, e.g., int(x) or x.astype(float). Another such situation is for bool(x), which produces a Python bool when concreteness makes it possible. That case is especially salient because of how often it arises in control flow.

Here is how the transformations introduce abstract or concrete tracers:

  • jax.jit(): introduces abstract tracers for all positional arguments except those denoted by static_argnums, which remain regular values.

  • jax.pmap(): introduces abstract tracers for all positional arguments except those denoted by static_broadcasted_argnums.

  • jax.vmap(), jax.make_jaxpr(), xla_computation(): introduce abstract tracers for all positional arguments.

  • jax.jvp() and jax.grad() introduce concrete tracers for all positional arguments. An exception is when these transformations are within an outer transformation and the actual arguments are themselves abstract tracers; in that case, the tracers introduced by the autodiff transformations are also abstract tracers.

  • All higher-order control-flow primitives (lax.cond(), lax.while_loop(), lax.fori_loop(), lax.scan()) when they process the functionals introduce abstract tracers, whether or not there is a JAX transformation in progress.

All of this is relevant when you have code that can operate only on regular Python values, such as code that has conditional control-flow based on data:

def divide(x, y):
  return x / y if y >= 1. else 0.

If we want to apply jax.jit(), we must ensure to specify static_argnums=1 to ensure y stays a regular value. This is due to the boolean expression y >= 1., which requires concrete values (regular or tracers). The same would happen if we write explicitly bool(y >= 1.), or int(y), or float(y).

Interestingly, jax.grad(divide)(3., 2.), works because jax.grad() uses concrete tracers, and resolves the conditional using the concrete value of y.

Buffer donation#

When JAX executes a computation it uses buffers on the device for all inputs and outputs. If you know that one of the inputs is not needed after the computation, and if it matches the shape and element type of one of the outputs, you can specify that you want the corresponding input buffer to be donated to hold an output. This will reduce the memory required for the execution by the size of the donated buffer.

If you have something like the following pattern, you can use buffer donation:

params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params, state)

You can think of this as a way to do a memory-efficient functional update on your immutable JAX arrays. Within the boundaries of a computation XLA can make this optimization for you, but at the jit/pmap boundary you need to guarantee to XLA that you will not use the donated input buffer after calling the donating function.

You achieve this by using the donate_argnums parameter to the functions jax.jit(), jax.pjit(), and jax.pmap(). This parameter is a sequence of indices (0 based) into the positional argument list:

def add(x, y):
  return x + y

x = jax.device_put(np.ones((2, 3)))
y = jax.device_put(np.ones((2, 3)))
# Execute `add` with donation of the buffer for `y`. The result has
# the same shape and type as `y`, so it will share its buffer.
z = jax.jit(add, donate_argnums=(1,))(x, y)

Note that this currently does not work when calling your function with key-word arguments! The following code will not donate any buffers:

params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params=params, state=state)

If an argument whose buffer is donated is a pytree, then all the buffers for its components are donated:

def add_ones(xs: List[Array]):
  return [x + 1 for x in xs]

xs = [jax.device_put(np.ones((2, 3))), jax.device_put(np.ones((3, 4)))]
# Execute `add_ones` with donation of all the buffers for `xs`.
# The outputs have the same shape and type as the elements of `xs`,
# so they will share those buffers.
z = jax.jit(add_ones, donate_argnums=0)(xs)

It is not allowed to donate a buffer that is used subsequently in the computation, and JAX will give an error because the buffer for y has become invalid after it was donated:

# Donate the buffer for `y`
z = jax.jit(add, donate_argnums=(1,))(x, y)
w = y + 1  # Reuses `y` whose buffer was donated above
# >> RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer

You will get a warning if the donated buffer is not used, e.g., because there are more donated buffers than can be used for the outputs:

# Execute `add` with donation of the buffers for both `x` and `y`.
# One of those buffers will be used for the result, but the other will
# not be used.
z = jax.jit(add, donate_argnums=(0, 1))(x, y)
# >> UserWarning: Some donated buffers were not usable: f32[2,3]{1,0}

The donation may also be unused if there is no output whose shape matches the donation:

y = jax.device_put(np.ones((1, 3)))  # `y` has different shape than the output
# Execute `add` with donation of the buffer for `y`.
z = jax.jit(add, donate_argnums=(1,))(x, y)
# >> UserWarning: Some donated buffers were not usable: f32[1,3]{1,0}

Gradients contain NaN where using where#

If you define a function using where to avoid an undefined value, if you are not careful you may obtain a NaN for reverse differentiation:

def my_log(x):
  return jnp.where(x > 0., jnp.log(x), 0.)

my_log(0.) ==> 0.  # Ok
jax.grad(my_log)(0.)  ==> NaN

A short explanation is that during grad computation the adjoint corresponding to the undefined jnp.log(x) is a NaN and it gets accumulated to the adjoint of the jnp.where. The correct way to write such functions is to ensure that there is a jnp.where inside the partially-defined function, to ensure that the adjoint is always finite:

def safe_for_grad_log(x):
  return jnp.log(jnp.where(x > 0., x, 1.))

safe_for_grad_log(0.) ==> 0.  # Ok
jax.grad(safe_for_grad_log)(0.)  ==> 0.  # Ok

The inner jnp.where may be needed in addition to the original one, e.g.:

def my_log_or_y(x, y):
  """Return log(x) if x > 0 or y"""
  return jnp.where(x > 0., jnp.log(jnp.where(x > 0., x, 1.), y)

Additional reading:

Why are gradients zero for functions based on sort order?#

If you define a function that processes the input using operations that depend on the relative ordering of inputs (e.g. max, greater, argsort, etc.) then you may be surprised to find that the gradient is everywhere zero. Here is an example, where we define f(x) to be a step function that returns 0 when x is negative, and 1 when x is positive:

import jax
import numpy as np
import jax.numpy as jnp

def f(x):
  return (x > 0).astype(float)

df = jax.vmap(jax.grad(f))

x = jnp.array([-1.0, -0.5, 0.0, 0.5, 1.0])

print(f"f(x)  = {f(x)}")
# f(x)  = [0. 0. 0. 1. 1.]

print(f"df(x) = {df(x)}")
# df(x) = [0. 0. 0. 0. 0.]

The fact that the gradient is everywhere zero may be confusing at first glance: after all, the output does change in response to the input, so how can the gradient be zero? However, zero turns out to be the correct result in this case.

Why is this? Remember that what differentiation is measuring the change in f given an infinitesimal change in x. For x=1.0, f returns 1.0. If we perturb x to make it slightly larger or smaller, this does not change the output, so by definition, grad(f)(1.0) should be zero. This same logic holds for all values of f greater than zero: infinitesimally perturbing the input does not change the output, so the gradient is zero. Similarly, for all values of x less than zero, the output is zero. Perturbing x does not change this output, so the gradient is zero. That leaves us with the tricky case of x=0. Surely, if you perturb x upward, it will change the output, but this is problematic: an infinitesimal change in x produces a finite change in the function value, which implies the gradient is undefined. Fortunately, there’s another way for us to measure the gradient in this case: we perturb the function downward, in which case the output does not change, and so the gradient is zero. JAX and other autodiff systems tend to handle discontinuities in this way: if the positive gradient and negative gradient disagree, but one is defined and the other is not, we use the one that is defined. Under this definition of the gradient, mathematically and numerically the gradient of this function is everywhere zero.

The problem stems from the fact that our function has a discontinuity at x = 0. Our f here is essentially a Heaviside Step Function, and we can use a Sigmoid Function as a smoothed replacement. The sigmoid is approximately equal to the heaviside function when x is far from zero, but replaces the discontinuity at x = 0 with a smooth, differentiable curve. As a result of using jax.nn.sigmoid(), we get a similar computation with well-defined gradients:

def g(x):
  return jax.nn.sigmoid(x)

dg = jax.vmap(jax.grad(g))

x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])

with np.printoptions(suppress=True, precision=2):
  print(f"g(x)  = {g(x)}")
  # g(x)  = [0.   0.27 0.5  0.73 1.  ]

  print(f"dg(x) = {dg(x)}")
  # dg(x) = [0.   0.2  0.25 0.2  0.  ]

The jax.nn submodule also has smooth versions of other common rank-based functions, for example jax.nn.softmax() can replace uses of jax.numpy.argmax(), jax.nn.soft_sign() can replace uses of jax.numpy.sign(), jax.nn.softplus() or jax.nn.squareplus() can replace uses of jax.nn.relu(), etc.

How can I convert a JAX Tracer to a NumPy array?#

When inspecting a transformed JAX function at runtime, you’ll find that array values are replaced by Tracer objects:

@jax.jit
def f(x):
  print(type(x))
  return x

f(jnp.arange(5))

This prints the following:

<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>

A frequent question is how such a tracer can be converted back to a normal NumPy array. In short, it is impossible to convert a Tracer to a NumPy array, because a tracer is an abstract representation of every possible value with a given shape and dtype, while a numpy array is a concrete member of that abstract class. For more discussion of how tracers work within the context of JAX transformations, see JIT mechanics.

The question of converting Tracers back to arrays usually comes up within the context of another goal, related to accessing intermediate values in a computation at runtime. For example:

  • If you wish to print a traced value at runtime for debugging purposes, you might consider using jax.debug.print().

  • If you wish to call non-JAX code within a transformed JAX function, you might consider using jax.pure_callback(), an example of which is available at Pure callback example.

  • If you wish to input or output array buffers at runtime (for example, load data from file, or log the contents of the array to disk), you might consider using jax.experimental.io_callback(), an example of which can be found at IO callback example.

For more information on runtime callbacks and examples of their use, see External callbacks in JAX.

Why do some CUDA libraries fail to load/initialize?#

When resolving dynamic libraries, JAX uses the usual dynamic linker search pattern. JAX sets RPATH to point to the JAX-relative location of the pip-installed NVIDIA CUDA packages, preferring them if installed. If ld.so cannot find your CUDA runtime libraries along its usual search path, then you must include the paths to those libraries explicitly in LD_LIBRARY_PATH. The easiest way to ensure your CUDA files are discoverable is to simply install the nvidia-*-cu12 pip packages, which are included in the standard jax[cuda_12] install option.

Occasionally, even when you have ensured that your runtime libraries are discoverable, there may still be some issues with loading or initializing them. A common cause of such issues is simply having insufficient memory for CUDA library initialization at runtime. This sometimes occurs because JAX will pre-allocate too large of a chunk of currently available device memory for faster execution, occasionally resulting in insufficient memory being left available for runtime CUDA library initialization.

This is especially likely when running multiple JAX instances, running JAX in tandem with TensorFlow which performs its own pre-allocation, or when running JAX on a system where the GPU is being heavily utilized by other processes. When in doubt, try running the program again with reduced pre-allocation, either by reducing XLA_PYTHON_CLIENT_MEM_FRACTION from the default of .75, or setting XLA_PYTHON_CLIENT_PREALLOCATE=false. For more details, please see the page on JAX GPU memory allocation.

JAX tutorials#

Key Concepts#

This section briefly introduces some key concepts of the JAX package.

JAX arrays (jax.Array)#

The default array implementation in JAX is jax.Array. In many ways it is similar to the numpy.ndarray type that you may be familar with from the NumPy package, but it has some important differences.

Array creation#

We typically don’t call the jax.Array constructor directly, but rather create arrays via JAX API functions. For example, jax.numpy provides familar NumPy-style array construction functionality such as jax.numpy.zeros(), jax.numpy.linspace(), jax.numpy.arange(), etc.

import jax
import jax.numpy as jnp

x = jnp.arange(5)
isinstance(x, jax.Array)
True

If you use Python type annotations in your code, jax.Array is the appropriate annotation for jax array objects (see jax.typing for more discussion).

Array devices and sharding#

JAX Array objects have a devices method that lets you inspect where the contents of the array are stored. In the simplest cases, this will be a single CPU device:

x.devices()
{CpuDevice(id=0)}

In general, an array may be sharded across multiple devices, in a manner that can be inspected via the sharding attribute:

x.sharding
SingleDeviceSharding(device=CpuDevice(id=0))

Here the array is on a single device, but in general a JAX array can be sharded across multiple devices, or even multiple hosts. To read more about sharded arrays and parallel computation, refer to Introduction to sharded computation

Transformations#

Along with functions to operate on arrays, JAX includes a number of transformations which operate on JAX functions. These include

as well as several others. Transformations accept a function as an argument, and return a new transformed function. For example, here’s how you might JIT-compile a simple SELU function:

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

selu_jit = jax.jit(selu)
print(selu_jit(1.0))
1.05

Often you’ll see transformations applied using Python’s decorator syntax for convenience:

@jax.jit
def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

Transformations like jit(), vmap(), grad(), and others are key to using JAX effectively, and we’ll cover them in detail in later sections.

Tracing#

The magic behind transformations is the notion of a Tracer. Tracers are abstract stand-ins for array objects, and are passed to JAX functions in order to extract the sequence of operations that the function encodes.

You can see this by printing any array value within transformed JAX code; for example:

@jax.jit
def f(x):
  print(x)
  return x + 1

x = jnp.arange(5)
result = f(x)
Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace(level=1/0)>

The value printed is not the array x, but a Tracer instance that represents essential attributes of x, such as its shape and dtype. By executing the function with traced values, JAX can determine the sequence of operations encoded by the function before those operations are actually executed: transformations like jit(), vmap(), and grad() can then map this sequence of input operations to a transformed sequence of operations.

Jaxprs#

JAX has its own intermediate representation for sequences of operations, known as a jaxpr. A jaxpr (short for JAX exPRession) is a simple representation of a functional program, comprising a sequence of primitive operations.

For example, consider the selu function we defined above:

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

We can use the jax.make_jaxpr() utility to convert this function into a jaxpr given a particular input:

x = jnp.arange(5.0)
jax.make_jaxpr(selu)(x)
{ lambda ; a:f32[5]. let
    b:bool[5] = gt a 0.0
    c:f32[5] = exp a
    d:f32[5] = mul 1.6699999570846558 c
    e:f32[5] = sub d 1.6699999570846558
    f:f32[5] = pjit[
      name=_where
      jaxpr={ lambda ; g:bool[5] h:f32[5] i:f32[5]. let
          j:f32[5] = select_n g i h
        in (j,) }
    ] b a e
    k:f32[5] = mul 1.0499999523162842 f
  in (k,) }

Comparing this to the Python function definition, we see that it encodes the precise sequence of operations that the function represents. We’ll go into more depth about jaxprs later in JAX internals: The jaxpr language.

Pytrees#

JAX functions and transformations fundamentally operate on arrays, but in practice it is convenient to write code that work with collections of arrays: for example, a neural network might organize its parameters in a dictionary of arrays with meaningful keys. Rather than handle such structures on a case-by-case basis, JAX relies on the pytree abstraction to treat such collections in a uniform matter.

Here are some examples of objects that can be treated as pytrees:

# (nested) list of parameters
params = [1, 2, (jnp.arange(3), jnp.ones(2))]

print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef([*, *, (*, *)])
[1, 2, Array([0, 1, 2], dtype=int32), Array([1., 1.], dtype=float32)]
# Dictionary of parameters
params = {'n': 5, 'W': jnp.ones((2, 2)), 'b': jnp.zeros(2)}

print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef({'W': *, 'b': *, 'n': *})
[Array([[1., 1.],
       [1., 1.]], dtype=float32), Array([0., 0.], dtype=float32), 5]
# Named tuple of parameters
from typing import NamedTuple

class Params(NamedTuple):
  a: int
  b: float

params = Params(1, 5.0)
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef(CustomNode(namedtuple[Params], [*, *]))
[1, 5.0]

JAX has a number of general-purpose utilities for working with PyTrees; for example the functions jax.tree.map() can be used to map a function to every leaf in a tree, and jax.tree.reduce() can be used to apply a reduction across the leaves in a tree.

You can learn more in the Working with pytrees tutorial.

Just-in-time compilation#

In this section, we will further explore how JAX works, and how we can make it performant. We will discuss the jax.jit() transformation, which will perform Just In Time (JIT) compilation of a JAX Python function so it can be executed efficiently in XLA.

How JAX transformations work#

In the previous section, we discussed that JAX allows us to transform Python functions. JAX accomplishes this by reducing each function into a sequence of primitive operations, each representing one fundamental unit of computation.

One way to see the sequence of primitives behind a function is using jax.make_jaxpr():

import jax
import jax.numpy as jnp

global_list = []

def log2(x):
  global_list.append(x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2)(3.0))
{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }

The Understanding Jaxprs section of the documentation provides more information on the meaning of the above output.

Importantly, notice that the jaxpr does not capture the side-effect present in the function: there is nothing in it corresponding to global_list.append(x). This is a feature, not a bug: JAX transformations are designed to understand side-effect-free (a.k.a. functionally pure) code. If pure function and side-effect are unfamiliar terms, this is explained in a little more detail in 🔪 JAX - The Sharp Bits 🔪: Pure Functions.

Impure functions are dangerous because under JAX transformations they are likely not to behave as intended; they might fail silently, or produce surprising downstream errors like leaked Tracers. Moreover, JAX often can’t detect when side effects are present. (If you want debug printing, use jax.debug.print(). To express general side-effects at the cost of performance, see jax.experimental.io_callback(). To check for tracer leaks at the cost of performance, use with jax.check_tracer_leaks()).

When tracing, JAX wraps each argument by a tracer object. These tracers then record all JAX operations performed on them during the function call (which happens in regular Python). Then, JAX uses the tracer records to reconstruct the entire function. The output of that reconstruction is the jaxpr. Since the tracers do not record the Python side-effects, they do not appear in the jaxpr. However, the side-effects still happen during the trace itself.

Note: the Python print() function is not pure: the text output is a side-effect of the function. Therefore, any print() calls will only happen during tracing, and will not appear in the jaxpr:

def log2_with_print(x):
  print("printed x:", x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2_with_print)(3.))
printed x: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }

See how the printed x is a Traced object? That’s the JAX internals at work.

The fact that the Python code runs at least once is strictly an implementation detail, and so shouldn’t be relied upon. However, it’s useful to understand as you can use it when debugging to print out intermediate values of a computation.

A key thing to understand is that a jaxpr captures the function as executed on the parameters given to it. For example, if we have a Python conditional, the jaxpr will only know about the branch we take:

def log2_if_rank_2(x):
  if x.ndim == 2:
    ln_x = jnp.log(x)
    ln_2 = jnp.log(2.0)
    return ln_x / ln_2
  else:
    return x

print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1, 2, 3])))
{ lambda ; a:i32[3]. let  in (a,) }

JIT compiling a function#

As explained before, JAX enables operations to execute on CPU/GPU/TPU using the same code. Let’s look at an example of computing a Scaled Exponential Linear Unit (SELU), an operation commonly used in deep learning:

import jax
import jax.numpy as jnp

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(1000000)
%timeit selu(x).block_until_ready()
2.88 ms ± 15.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

The code above is sending one operation at a time to the accelerator. This limits the ability of the XLA compiler to optimize our functions.

Naturally, what we want to do is give the XLA compiler as much code as possible, so it can fully optimize it. For this purpose, JAX provides the jax.jit() transformation, which will JIT compile a JAX-compatible function. The example below shows how to use JIT to speed up the previous function.

selu_jit = jax.jit(selu)

# Pre-compile the function before timing...
selu_jit(x).block_until_ready()

%timeit selu_jit(x).block_until_ready()
1.02 ms ± 3.34 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Here’s what just happened:

  1. We defined selu_jit as the compiled version of selu.

  2. We called selu_jit once on x. This is where JAX does its tracing – it needs to have some inputs to wrap in tracers, after all. The jaxpr is then compiled using XLA into very efficient code optimized for your GPU or TPU. Finally, the compiled code is executed to satisfy the call. Subsequent calls to selu_jit will use the compiled code directly, skipping the python implementation entirely. (If we didn’t include the warm-up call separately, everything would still work, but then the compilation time would be included in the benchmark. It would still be faster, because we run many loops in the benchmark, but it wouldn’t be a fair comparison.)

  3. We timed the execution speed of the compiled version. (Note the use of block_until_ready(), which is required due to JAX’s Asynchronous dispatch).

Why can’t we just JIT everything?#

After going through the example above, you might be wondering whether we should simply apply jax.jit() to every function. To understand why this is not the case, and when we should/shouldn’t apply jit, let’s first check some cases where JIT doesn’t work.

# Condition on value of x.

def f(x):
  if x > 0:
    return x
  else:
    return 2 * x

jax.jit(f)(10)  # Raises an error
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function f at /tmp/ipykernel_2831/2956679937.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
# While loop conditioned on x and n.

def g(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i

jax.jit(g)(10, 20)  # Raises an error
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function g at /tmp/ipykernel_2831/722961019.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument n.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

The problem in both cases is that we tried to condition the trace-time flow of the program using runtime values. Traced values within JIT, like x and n here, can only affect control flow via their static attributes: such as shape or dtype, and not via their values. For more detail on the interaction between Python control flow and JAX, see 🔪 JAX - The Sharp Bits 🔪: Control Flow.

One way to deal with this problem is to rewrite the code to avoid conditionals on value. Another is to use special Control flow operators like jax.lax.cond(). However, sometimes that is not possible or practical. In that case, you can consider JIT-compiling only part of the function. For example, if the most computationally expensive part of the function is inside the loop, we can JIT-compile just that inner part (though make sure to check the next section on caching to avoid shooting yourself in the foot):

# While loop conditioned on x and n with a jitted body.

@jax.jit
def loop_body(prev_i):
  return prev_i + 1

def g_inner_jitted(x, n):
  i = 0
  while i < n:
    i = loop_body(i)
  return x + i

g_inner_jitted(10, 20)
Array(30, dtype=int32, weak_type=True)

Marking arguments as static#

If we really need to JIT-compile a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifying static_argnums or static_argnames. The cost of this is that the resulting jaxpr and compiled artifact depends on the particular value passed, and so JAX will have to re-compile the function for every new value of the specified static input. It is only a good strategy if the function is guaranteed to see a limited set of static values.

f_jit_correct = jax.jit(f, static_argnums=0)
print(f_jit_correct(10))
10
g_jit_correct = jax.jit(g, static_argnames=['n'])
print(g_jit_correct(10, 20))
30

To specify such arguments when using jit as a decorator, a common pattern is to use python’s functools.partial():

from functools import partial

@partial(jax.jit, static_argnames=['n'])
def g_jit_decorated(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i

print(g_jit_decorated(10, 20))
30

JIT and caching#

With the compilation overhead of the first JIT call, understanding how and when jax.jit() caches previous compilations is key to using it effectively.

Suppose we define f = jax.jit(g). When we first invoke f, it will get compiled, and the resulting XLA code will get cached. Subsequent calls of f will reuse the cached code. This is how jax.jit makes up for the up-front cost of compilation.

If we specify static_argnums, then the cached code will be used only for the same values of arguments labelled as static. If any of them change, recompilation occurs. If there are many values, then your program might spend more time compiling than it would have executing ops one-by-one.

Avoid calling jax.jit() on temporary functions defined inside loops or other Python scopes. For most cases, JAX will be able to use the compiled, cached function in subsequent calls to jax.jit(). However, because the cache relies on the hash of the function, it becomes problematic when equivalent functions are redefined. This will cause unnecessary compilation each time in the loop:

from functools import partial

def unjitted_loop_body(prev_i):
  return prev_i + 1

def g_inner_jitted_partial(x, n):
  i = 0
  while i < n:
    # Don't do this! each time the partial returns
    # a function with different hash
    i = jax.jit(partial(unjitted_loop_body))(i)
  return x + i

def g_inner_jitted_lambda(x, n):
  i = 0
  while i < n:
    # Don't do this!, lambda will also return
    # a function with a different hash
    i = jax.jit(lambda x: unjitted_loop_body(x))(i)
  return x + i

def g_inner_jitted_normal(x, n):
  i = 0
  while i < n:
    # this is OK, since JAX can find the
    # cached, compiled function
    i = jax.jit(unjitted_loop_body)(i)
  return x + i

print("jit called in a loop with partials:")
%timeit g_inner_jitted_partial(10, 20).block_until_ready()

print("jit called in a loop with lambdas:")
%timeit g_inner_jitted_lambda(10, 20).block_until_ready()

print("jit called in a loop with caching:")
%timeit g_inner_jitted_normal(10, 20).block_until_ready()
jit called in a loop with partials:
222 ms ± 3.29 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with lambdas:
224 ms ± 6.07 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with caching:
2.61 ms ± 21 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Automatic vectorization#

In the previous section we discussed JIT compilation via the jax.jit() function. This notebook discusses another of JAX’s transforms: vectorization via jax.vmap().

Manual vectorization#

Consider the following simple code that computes the convolution of two one-dimensional vectors:

import jax
import jax.numpy as jnp

x = jnp.arange(5)
w = jnp.array([2., 3., 4.])

def convolve(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  return jnp.array(output)

convolve(x, w)
Array([11., 20., 29.], dtype=float32)

Suppose we would like to apply this function to a batch of weights w to a batch of vectors x.

xs = jnp.stack([x, x])
ws = jnp.stack([w, w])

The most naive option would be to simply loop over the batch in Python:

def manually_batched_convolve(xs, ws):
  output = []
  for i in range(xs.shape[0]):
    output.append(convolve(xs[i], ws[i]))
  return jnp.stack(output)

manually_batched_convolve(xs, ws)
Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

This produces the correct result, however it is not very efficient.

In order to batch the computation efficiently, you would normally have to rewrite the function manually to ensure it is done in vectorized form. This is not particularly difficult to implement, but does involve changing how the function treats indices, axes, and other parts of the input.

For example, we could manually rewrite convolve() to support vectorized computation across the batch dimension as follows:

def manually_vectorized_convolve(xs, ws):
  output = []
  for i in range(1, xs.shape[-1] -1):
    output.append(jnp.sum(xs[:, i-1:i+2] * ws, axis=1))
  return jnp.stack(output, axis=1)

manually_vectorized_convolve(xs, ws)
Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

Such re-implementation can be messy and error-prone as the complexity of a function increases; fortunately JAX provides another way.

Automatic vectorization#

In JAX, the jax.vmap() transformation is designed to generate such a vectorized implementation of a function automatically:

auto_batch_convolve = jax.vmap(convolve)

auto_batch_convolve(xs, ws)
Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

It does this by tracing the function similarly to jax.jit(), and automatically adding batch axes at the beginning of each input.

If the batch dimension is not the first, you may use the in_axes and out_axes arguments to specify the location of the batch dimension in inputs and outputs. These may be an integer if the batch axis is the same for all inputs and outputs, or lists, otherwise.

auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1)

xst = jnp.transpose(xs)
wst = jnp.transpose(ws)

auto_batch_convolve_v2(xst, wst)
Array([[11., 11.],
       [20., 20.],
       [29., 29.]], dtype=float32)

jax.vmap() also supports the case where only one of the arguments is batched: for example, if you would like to convolve to a single set of weights w with a batch of vectors x; in this case the in_axes argument can be set to None:

batch_convolve_v3 = jax.vmap(convolve, in_axes=[0, None])

batch_convolve_v3(xs, w)
Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

Combining transformations#

As with all JAX transformations, jax.jit() and jax.vmap() are designed to be composable, which means you can wrap a vmapped function with jit, or a jitted function with vmap, and everything will work correctly:

jitted_batch_convolve = jax.jit(auto_batch_convolve)

jitted_batch_convolve(xs, ws)
Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

Automatic differentiation#

In this section, you will learn about fundamental applications of automatic differentiation (autodiff) in JAX. JAX has a pretty general autodiff system. Computing gradients is a critical part of modern machine learning methods, and this tutorial will walk you through a few introductory autodiff topics, such as:

Make sure to also check out the Advanced automatic differentiation tutorial for more advanced topics.

While understanding how automatic differentiation works “under the hood” isn’t crucial for using JAX in most contexts, you are encouraged to check out this quite accessible video to get a deeper sense of what’s going on.

1. Taking gradients with jax.grad#

In JAX, you can differentiate a scalar-valued function with the jax.grad() transformation:

import jax
import jax.numpy as jnp
from jax import grad

grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))
0.070650816

jax.grad() takes a function and returns a function. If you have a Python function f that evaluates the mathematical function \(f\), then jax.grad(f) is a Python function that evaluates the mathematical function \(\nabla f\). That means grad(f)(x) represents the value \(\nabla f(x)\).

Since jax.grad() operates on functions, you can apply it to its own output to differentiate as many times as you like:

print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0))
-0.13621868
0.25265405

JAX’s autodiff makes it easy to compute higher-order derivatives, because the functions that compute derivatives are themselves differentiable. Thus, higher-order derivatives are as easy as stacking transformations. This can be illustrated in the single-variable case:

The derivative of \(f(x) = x^3 + 2x^2 - 3x + 1\) can be computed as:

f = lambda x: x**3 + 2*x**2 - 3*x + 1

dfdx = jax.grad(f)

The higher-order derivatives of \(f\) are:

\[\begin{split} \begin{array}{l} f'(x) = 3x^2 + 4x -3\\ f''(x) = 6x + 4\\ f'''(x) = 6\\ f^{iv}(x) = 0 \end{array} \end{split}\]

Computing any of these in JAX is as easy as chaining the jax.grad() function:

d2fdx = jax.grad(dfdx)
d3fdx = jax.grad(d2fdx)
d4fdx = jax.grad(d3fdx)

Evaluating the above in \(x=1\) would give you:

\[\begin{split} \begin{array}{l} f'(1) = 4\\ f''(1) = 10\\ f'''(1) = 6\\ f^{iv}(1) = 0 \end{array} \end{split}\]

Using JAX:

print(dfdx(1.))
print(d2fdx(1.))
print(d3fdx(1.))
print(d4fdx(1.))
4.0
10.0
6.0
0.0

2. Computing gradients in a linear logistic regression#

The next example shows how to compute gradients with jax.grad() in a linear logistic regression model. First, the setup:

key = jax.random.key(0)

def sigmoid(x):
  return 0.5 * (jnp.tanh(x / 2) + 1)

# Outputs probability of a label being true.
def predict(W, b, inputs):
  return sigmoid(jnp.dot(inputs, W) + b)

# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12,  0.77],
                    [0.88, -1.08, 0.15],
                    [0.52, 0.06, -1.30],
                    [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])

# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
  preds = predict(W, b, inputs)
  label_probs = preds * targets + (1 - preds) * (1 - targets)
  return -jnp.sum(jnp.log(label_probs))

# Initialize random model coefficients
key, W_key, b_key = jax.random.split(key, 3)
W = jax.random.normal(W_key, (3,))
b = jax.random.normal(b_key, ())

Use the jax.grad() function with its argnums argument to differentiate a function with respect to positional arguments.

# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print(f'{W_grad=}')

# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print(f'{W_grad=}')

# But you can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print(f'{b_grad=}')

# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print(f'{W_grad=}')
print(f'{b_grad=}')
W_grad=Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32)
W_grad=Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32)
b_grad=Array(-0.29227245, dtype=float32)
W_grad=Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32)
b_grad=Array(-0.29227245, dtype=float32)

The jax.grad() API has a direct correspondence to the excellent notation in Spivak’s classic Calculus on Manifolds (1965), also used in Sussman and Wisdom’s Structure and Interpretation of Classical Mechanics (2015) and their Functional Differential Geometry (2013). Both books are open-access. See in particular the “Prologue” section of Functional Differential Geometry for a defense of this notation.

Essentially, when using the argnums argument, if f is a Python function for evaluating the mathematical function \(f\), then the Python expression jax.grad(f, i) evaluates to a Python function for evaluating \(\partial_i f\).

3. Differentiating with respect to nested lists, tuples, and dicts#

Due to JAX’s PyTree abstraction (see Working with pytrees), differentiating with respect to standard Python containers just works, so use tuples, lists, and dicts (and arbitrary nesting) however you like.

Continuing the previous example:

def loss2(params_dict):
    preds = predict(params_dict['W'], params_dict['b'], inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

print(grad(loss2)({'W': W, 'b': b}))
{'W': Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32), 'b': Array(-0.29227245, dtype=float32)}

You can create Custom pytree nodes to work with not just jax.grad() but other JAX transformations (jax.jit(), jax.vmap(), and so on).

4. Evaluating a function and its gradient using jax.value_and_grad#

Another convenient function is jax.value_and_grad() for efficiently computing both a function’s value as well as its gradient’s value in one pass.

Continuing the previous examples:

loss_value, Wb_grad = jax.value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b))
loss value 3.0519385
loss value 3.0519385

5. Checking against numerical differences#

A great thing about derivatives is that they’re straightforward to check with finite differences.

Continuing the previous examples:

# Set a step size for finite differences calculations
eps = 1e-4

# Check b_grad with scalar finite differences
b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', grad(loss, 1)(W, b))

# Check W_grad with finite differences in a random direction
key, subkey = jax.random.split(key)
vec = jax.random.normal(subkey, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(W, b), unitvec))
b_grad_numerical -0.29325485
b_grad_autodiff -0.29227245
W_dirderiv_numerical -0.2002716
W_dirderiv_autodiff -0.19909117

JAX provides a simple convenience function that does essentially the same thing, but checks up to any order of differentiation that you like:

from jax.test_util import check_grads

check_grads(loss, (W, b), order=2)  # check up to 2nd order derivatives

Next steps#

The Advanced automatic differentiation tutorial provides more advanced and detailed explanations of how the ideas covered in this document are implemented in the JAX backend. Some features, such as Custom derivative rules for JAX-transformable Python functions, depend on understanding advanced automatic differentiation, so do check out that section in the Advanced automatic differentiation tutorial if you are interested.

Introduction to debugging#

This section introduces you to a set of built-in JAX debugging methods — jax.debug.print(), jax.debug.breakpoint(), and jax.debug.callback() — that you can use with various JAX transformations.

Let’s begin with jax.debug.print().

JAX debug.print for high-level#

TL;DR Here is a rule of thumb:

Recall from Just-in-time compilation that when transforming a function with jax.jit(), the Python code is executed with abstract tracers in place of your arrays. Because of this, the Python print() function will only print this tracer value:

import jax
import jax.numpy as jnp

@jax.jit
def f(x):
  print("print(x) ->", x)
  y = jnp.sin(x)
  print("print(y) ->", y)
  return y

result = f(2.)
print(x) -> Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
print(y) -> Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>

Python’s print executes at trace-time, before the runtime values exist. If you want to print the actual runtime values, you can use jax.debug.print():

@jax.jit
def f(x):
  jax.debug.print("jax.debug.print(x) -> {x}", x=x)
  y = jnp.sin(x)
  jax.debug.print("jax.debug.print(y) -> {y}", y=y)
  return y

result = f(2.)
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.9092974066734314

Similarly, within jax.vmap(), using Python’s print will only print the tracer; to print the values being mapped over, use jax.debug.print():

def f(x):
  jax.debug.print("jax.debug.print(x) -> {}", x)
  y = jnp.sin(x)
  jax.debug.print("jax.debug.print(y) -> {}", y)
  return y

xs = jnp.arange(3.)

result = jax.vmap(f)(xs)
jax.debug.print(x) -> 0.0
jax.debug.print(x) -> 1.0
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.0
jax.debug.print(y) -> 0.8414709568023682
jax.debug.print(y) -> 0.9092974066734314

Here’s the result with jax.lax.map(), which is a sequential map rather than a vectorization:

result = jax.lax.map(f, xs)
jax.debug.print(x) -> 0.0
jax.debug.print(y) -> 0.0
jax.debug.print(x) -> 1.0
jax.debug.print(y) -> 0.8414709568023682
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.9092974066734314

Notice the order is different, as jax.vmap() and jax.lax.map() compute the same results in different ways. When debugging, the evaluation order details are exactly what you may need to inspect.

Below is an example with jax.grad(), where jax.debug.print() only prints the forward pass. In this case, the behavior is similar to Python’s print(), but it’s consistent if you apply jax.jit() during the call.

def f(x):
  jax.debug.print("jax.debug.print(x) -> {}", x)
  return x ** 2

result = jax.grad(f)(1.)
jax.debug.print(x) -> 1.0

Sometimes, when the arguments don’t depend on one another, calls to jax.debug.print() may print them in a different order when staged out with a JAX transformation. If you need the original order, such as x: ... first and then y: ... second, add the ordered=True parameter.

For example:

@jax.jit
def f(x, y):
  jax.debug.print("jax.debug.print(x) -> {}", x, ordered=True)
  jax.debug.print("jax.debug.print(y) -> {}", y, ordered=True)
  return x + y

f(1, 2)
jax.debug.print(x) -> 1
jax.debug.print(y) -> 2
Array(3, dtype=int32, weak_type=True)

To learn more about jax.debug.print() and its Sharp Bits, refer to Advanced debugging.

JAX debug.breakpoint for pdb-like debugging#

TL;DR Use jax.debug.breakpoint() to pause the execution of your JAX program to inspect values.

To pause your compiled JAX program during certain points during debugging, you can use jax.debug.breakpoint(). The prompt is similar to Python pdb, and it allows you to inspect the values in the call stack. In fact, jax.debug.breakpoint() is an application of jax.debug.callback() that captures information about the call stack.

To print all available commands during a breakpoint debugging session, use the help command. (Full debugger commands, the Sharp Bits, its strengths and limitations are covered in Advanced debugging.)

Here is an example of what a debugger session might look like:

@jax.jit
def f(x):
  y, z = jnp.sin(x, jnp.cos(x))
  jax.debug.breakpoint()
  return y * z
f(2.) # ==> Pauses during execution

JAX debugger

For value-dependent breakpointing, you can use runtime conditionals like jax.lax.cond():

def breakpoint_if_nonfinite(x):
  is_finite = jnp.isfinite(x).all()
  def true_fn(x):
    pass
  def false_fn(x):
    jax.debug.breakpoint()
  jax.lax.cond(is_finite, true_fn, false_fn, x)

@jax.jit
def f(x, y):
  z = x / y
  breakpoint_if_nonfinite(z)
  return z

f(2., 1.) # ==> No breakpoint
Array(2., dtype=float32, weak_type=True)
f(2., 0.) # ==> Pauses during execution

JAX debug.callback for more control during debugging#

Both jax.debug.print() and jax.debug.breakpoint() are implemented using the more flexible jax.debug.callback(), which gives greater control over the host-side logic executed via a Python callback. It is compatible with jax.jit(), jax.vmap(), jax.grad() and other transformations (refer to the Flavors of callback table in External callbacks for more information).

For example:

import logging

def log_value(x):
  logging.warning(f'Logged value: {x}')

@jax.jit
def f(x):
  jax.debug.callback(log_value, x)
  return x

f(1.0);
WARNING:root:Logged value: 1.0

This callback is compatible with other transformations, including jax.vmap() and jax.grad():

x = jnp.arange(5.0)
jax.vmap(f)(x);
WARNING:root:Logged value: 0.0
WARNING:root:Logged value: 1.0
WARNING:root:Logged value: 2.0
WARNING:root:Logged value: 3.0
WARNING:root:Logged value: 4.0
jax.grad(f)(1.0);
WARNING:root:Logged value: 1.0

This can make jax.debug.callback() useful for general-purpose debugging.

You can learn more about jax.debug.callback() and other kinds of JAX callbacks in External callbacks.

Next steps#

Check out the Advanced debugging to learn more about debugging in JAX.

Pseudorandom numbers#

In this section we focus on jax.random and pseudo random number generation (PRNG); that is, the process of algorithmically generating sequences of numbers whose properties approximate the properties of sequences of random numbers sampled from an appropriate distribution.

PRNG-generated sequences are not truly random because they are actually determined by their initial value, which is typically referred to as the seed, and each step of random sampling is a deterministic function of some state that is carried over from a sample to the next.

Pseudo random number generation is an essential component of any machine learning or scientific computing framework. Generally, JAX strives to be compatible with NumPy, but pseudo random number generation is a notable exception.

To better understand the difference between the approaches taken by JAX and NumPy when it comes to random number generation we will discuss both approaches in this section.

Random numbers in NumPy#

Pseudo random number generation is natively supported in NumPy by the numpy.random module. In NumPy, pseudo random number generation is based on a global state, which can be set to a deterministic initial condition using numpy.random.seed().

import numpy as np
np.random.seed(0)

You can inspect the content of the state using the following command.

def print_truncated_random_state():
  """To avoid spamming the outputs, print only part of the state."""
  full_random_state = np.random.get_state()
  print(str(full_random_state)[:460], '...')

print_truncated_random_state()
('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044,
       2481403966, 4042607538,  337614300, 3232553940, 1018809052,
       3202401494, 1775180719, 3192392114,  594215549,  184016991,
        829906058,  610491522, 3879932251, 3139825610,  297902587,
       4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
       2891506774, 1066338622,  135451537,  933040465, 2759011858,
       2273819758, 3545703099, 2516396728, 127 ...

The state is updated by each call to a random function:

np.random.seed(0)
print_truncated_random_state()
('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044,
       2481403966, 4042607538,  337614300, 3232553940, 1018809052,
       3202401494, 1775180719, 3192392114,  594215549,  184016991,
        829906058,  610491522, 3879932251, 3139825610,  297902587,
       4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
       2891506774, 1066338622,  135451537,  933040465, 2759011858,
       2273819758, 3545703099, 2516396728, 127 ...
_ = np.random.uniform()
print_truncated_random_state()
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
       3904844661,  676747479, 2085143622, 1056793272, 3812477442,
       2168787041,  275552121, 2696932952, 3432054210, 1657102335,
       3518946594,  962584079, 1051271004, 3806145045, 1414436097,
       2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
        696824676, 2399811678, 3992505346,  569184356, 2626558620,
        136797809, 4273176064,  296167901, 343 ...

NumPy allows you to sample both individual numbers, or entire vectors of numbers in a single function call. For instance, you may sample a vector of 3 scalars from a uniform distribution by doing:

np.random.seed(0)
print(np.random.uniform(size=3))
[0.5488135  0.71518937 0.60276338]

NumPy provides a sequential equivalent guarantee, meaning that sampling N numbers in a row individually or sampling a vector of N numbers results in the same pseudo-random sequences:

np.random.seed(0)
print("individually:", np.stack([np.random.uniform() for _ in range(3)]))

np.random.seed(0)
print("all at once: ", np.random.uniform(size=3))
individually: [0.5488135  0.71518937 0.60276338]
all at once:  [0.5488135  0.71518937 0.60276338]

Random numbers in JAX#

JAX’s random number generation differs from NumPy’s in important ways, because NumPy’s PRNG design makes it hard to simultaneously guarantee a number of desirable properties. Specifically, in JAX we want PRNG generation to be:

  1. reproducible,

  2. parallelizable,

  3. vectorisable.

We will discuss why in the following. First, we will focus on the implications of a PRNG design based on a global state. Consider the code:

import numpy as np

np.random.seed(0)

def bar(): return np.random.uniform()
def baz(): return np.random.uniform()

def foo(): return bar() + 2 * baz()

print(foo())
1.9791922366721637

The function foo sums two scalars sampled from a uniform distribution.

The output of this code can only satisfy requirement #1 if we assume a predictable order of execution for bar() and baz(). This is not a problem in NumPy, which always evaluates code in the order defined by the Python interpreter. In JAX, however, this is more problematic: for efficient execution, we want the JIT compiler to be free to reorder, elide, and fuse various operations in the function we define. Further, when executing in multi-device environments, execution efficiency would be hampered by the need for each process to synchronize a global state.

Explicit random state#

To avoid this issue, JAX avoids implicit global random state, and instead tracks state explicitly via a random key:

from jax import random

key = random.key(42)
print(key)
Array((), dtype=key<fry>) overlaying:
[ 0 42]

Note

This section uses the new-style typed PRNG keys produced by jax.random.key(), rather than the old-style raw PRNG keys produced by jax.random.PRNGKey(). For details, see JEP 9263: Typed keys & pluggable RNGs.

A key is an array with a special dtype corresponding to the particular PRNG implementation being used; in the default implementation each key is backed by a pair of uint32 values.

The key is effectively a stand-in for NumPy’s hidden state object, but we pass it explicitly to jax.random() functions. Importantly, random functions consume the key, but do not modify it: feeding the same key object to a random function will always result in the same sample being generated.

print(random.normal(key))
print(random.normal(key))
-0.18471177
-0.18471177

Re-using the same key, even with different random APIs, can result in correlated outputs, which is generally undesirable.

The rule of thumb is: never reuse keys (unless you want identical outputs).

In order to generate different and independent samples, you must split() the key explicitly before passing it to a random function:

for i in range(3):
  new_key, subkey = random.split(key)
  del key  # The old key is consumed by split() -- we must never use it again.

  val = random.normal(subkey)
  del subkey  # The subkey is consumed by normal().

  print(f"draw {i}: {val}")
  key = new_key  # new_key is safe to use in the next iteration.
draw 0: 1.369469404220581
draw 1: -0.19947023689746857
draw 2: -2.298278331756592

(Calling del here is not required, but we do so to emphasize that the key should not be reused once consumed.)

jax.random.split() is a deterministic function that converts one key into several independent (in the pseudorandomness sense) keys. We keep one of the outputs as the new_key, and can safely use the unique extra key (called subkey) as input into a random function, and then discard it forever. If you wanted to get another sample from the normal distribution, you would split key again, and so on: the crucial point is that you never use the same key twice.

It doesn’t matter which part of the output of split(key) we call key, and which we call subkey. They are all independent keys with equal status. The key/subkey naming convention is a typical usage pattern that helps track how keys are consumed: subkeys are destined for immediate consumption by random functions, while the key is retained to generate more randomness later.

Usually, the above example would be written concisely as

key, subkey = random.split(key)

which discards the old key automatically. It’s worth noting that split() can create as many keys as you need, not just 2:

key, *forty_two_subkeys = random.split(key, num=43)
Lack of sequential equivalence#

Another difference between NumPy’s and JAX’s random modules relates to the sequential equivalence guarantee mentioned above.

As in NumPy, JAX’s random module also allows sampling of vectors of numbers. However, JAX does not provide a sequential equivalence guarantee, because doing so would interfere with the vectorization on SIMD hardware (requirement #3 above).

In the example below, sampling 3 values out of a normal distribution individually using three subkeys gives a different result to using giving a single key and specifying shape=(3,):

key = random.key(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)

key = random.key(42)
print("all at once: ", random.normal(key, shape=(3,)))
individually: [-0.04838832  0.10796154 -1.2226542 ]
all at once:  [ 0.18693547 -1.2806505  -1.5593132 ]

The lack of sequential equivalence gives us freedom to write code more efficiently; for example, instead of generating sequence above via a sequential loop, we can use jax.vmap() to compute the same result in a vectorized manner:

import jax
print("vectorized:", jax.vmap(random.normal)(subkeys))
vectorized: [-0.04838832  0.10796154 -1.2226542 ]

Next Steps#

For more information on JAX random numbers, refer to the documentation of the jax.random module. If you’re interested in the details of the design of JAX’s random number generator, see JAX PRNG Design.

Working with pytrees#

JAX has built-in support for objects that look like dictionaries (dicts) of arrays, or lists of lists of dicts, or other nested structures — in JAX these are called pytrees. This section will explain how to use them, provide useful code examples, and point out common “gotchas” and patterns.

What is a pytree?#

A pytree is a container-like structure built out of container-like Python objects — “leaf” pytrees and/or more pytrees. A pytree can include lists, tuples, and dicts. A leaf is anything that’s not a pytree, such as an array, but a single leaf is also a pytree.

In the context of machine learning (ML), a pytree can contain:

  • Model parameters

  • Dataset entries

  • Reinforcement learning agent observations

When working with datasets, you can often come across pytrees (such as lists of lists of dicts).

Below is an example of a simple pytree. In JAX, you can use jax.tree.leaves(), to extract the flattened leaves from the trees, as demonstrated here:

import jax
import jax.numpy as jnp

example_trees = [
    [1, 'a', object()],
    (1, (2, 3), ()),
    [1, {'k1': 2, 'k2': (3, 4)}, 5],
    {'a': 2, 'b': (2, 3)},
    jnp.array([1, 2, 3]),
]

# Print how many leaves the pytrees have.
for pytree in example_trees:
  # This `jax.tree.leaves()` method extracts the flattened leaves from the pytrees.
  leaves = jax.tree.leaves(pytree)
  print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")
[1, 'a', <object object at 0x7f62b3463730>]   has 3 leaves: [1, 'a', <object object at 0x7f62b3463730>]
(1, (2, 3), ())                               has 3 leaves: [1, 2, 3]
[1, {'k1': 2, 'k2': (3, 4)}, 5]               has 5 leaves: [1, 2, 3, 4, 5]
{'a': 2, 'b': (2, 3)}                         has 3 leaves: [2, 2, 3]
Array([1, 2, 3], dtype=int32)                 has 1 leaves: [Array([1, 2, 3], dtype=int32)]

Any tree-like structure built out of container-like Python objects can be treated as a pytree in JAX. Classes are considered container-like if they are in the pytree registry, which by default includes lists, tuples, and dicts. Any object whose type is not in the pytree container registry will be treated as a leaf node in the tree.

The pytree registry can be extended to include user-defined container classes by registering the class with functions that specify how to flatten the tree; see Custom pytree nodes below.

Common pytree functions#

JAX provides a number of utilities to operate over pytrees. These can be found in the jax.tree_util subpackage; for convenience many of these have aliases in the jax.tree module.

Common function: jax.tree.map#

The most commonly used pytree function is jax.tree.map(). It works analogously to Python’s native map, but transparently operates over entire pytrees.

Here’s an example:

list_of_lists = [
    [1, 2, 3],
    [1, 2],
    [1, 2, 3, 4]
]

jax.tree.map(lambda x: x*2, list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

jax.tree.map() also allows mapping a N-ary function over multiple arguments. For example:

another_list_of_lists = list_of_lists
jax.tree.map(lambda x, y: x+y, list_of_lists, another_list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

When using multiple arguments with jax.tree.map(), the structure of the inputs must exactly match. That is, lists must have the same number of elements, dicts must have the same keys, etc.

Example of jax.tree.map with ML model parameters#

This example demonstrates how pytree operations can be useful when training a simple multi-layer perceptron (MLP).

Begin with defining the initial model parameters:

import numpy as np

def init_mlp_params(layer_widths):
  params = []
  for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
    params.append(
        dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),
             biases=np.ones(shape=(n_out,))
            )
    )
  return params

params = init_mlp_params([1, 128, 128, 1])

Use jax.tree.map() to check the shapes of the initial parameters:

jax.tree.map(lambda x: x.shape, params)
[{'biases': (128,), 'weights': (1, 128)},
 {'biases': (128,), 'weights': (128, 128)},
 {'biases': (1,), 'weights': (128, 1)}]

Next, define the functions for training the MLP model:

# Define the forward pass.
def forward(params, x):
  *hidden, last = params
  for layer in hidden:
    x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
  return x @ last['weights'] + last['biases']

# Define the loss function.
def loss_fn(params, x, y):
  return jnp.mean((forward(params, x) - y) ** 2)

# Set the learning rate.
LEARNING_RATE = 0.0001

# Using the stochastic gradient descent, define the parameter update function.
# Apply `@jax.jit` for JIT compilation (speed).
@jax.jit
def update(params, x, y):
  # Calculate the gradients with `jax.grad`.
  grads = jax.grad(loss_fn)(params, x, y)
  # Note that `grads` is a pytree with the same structure as `params`.
  # `jax.grad` is one of many JAX functions that has
  # built-in support for pytrees.
  # This is useful - you can apply the SGD update using JAX pytree utilities.
  return jax.tree.map(
      lambda p, g: p - LEARNING_RATE * g, params, grads
  )

Custom pytree nodes#

This section explains how in JAX you can extend the set of Python types that will be considered internal nodes in pytrees (pytree nodes) by using jax.tree_util.register_pytree_node() with jax.tree.map().

Why would you need this? In the previous examples, pytrees were shown as lists, tuples, and dicts, with everything else as pytree leaves. This is because if you define your own container class, it will be considered to be a pytree leaf unless you register it with JAX. This is also the case even if your container class has trees inside it. For example:

class Special(object):
  def __init__(self, x, y):
    self.x = x
    self.y = y

jax.tree.leaves([
    Special(0, 1),
    Special(2, 4),
])
[<__main__.Special at 0x7f62b3d55060>, <__main__.Special at 0x7f62b3d57100>]

Accordingly, if you try to use a jax.tree.map() expecting the leaves to be elements inside the container, you will get an error:

jax.tree.map(lambda x: x + 1,
  [
    Special(0, 1),
    Special(2, 4)
  ])
TypeError: unsupported operand type(s) for +: 'Special' and 'int'

As a solution, JAX allows to extend the set of types to be considered internal pytree nodes through a global registry of types. Additionally, the values of registered types are traversed recursively.

First, register a new type using jax.tree_util.register_pytree_node():

from jax.tree_util import register_pytree_node

class RegisteredSpecial(Special):
  def __repr__(self):
    return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)

def special_flatten(v):
  """Specifies a flattening recipe.

  Params:
    v: The value of the registered type to flatten.
  Returns:
    A pair of an iterable with the children to be flattened recursively,
    and some opaque auxiliary data to pass back to the unflattening recipe.
    The auxiliary data is stored in the treedef for use during unflattening.
    The auxiliary data could be used, for example, for dictionary keys.
  """
  children = (v.x, v.y)
  aux_data = None
  return (children, aux_data)

def special_unflatten(aux_data, children):
  """Specifies an unflattening recipe.

  Params:
    aux_data: The opaque data that was specified during flattening of the
      current tree definition.
    children: The unflattened children

  Returns:
    A reconstructed object of the registered type, using the specified
    children and auxiliary data.
  """
  return RegisteredSpecial(*children)

# Global registration
register_pytree_node(
    RegisteredSpecial,
    special_flatten,    # Instruct JAX what are the children nodes.
    special_unflatten   # Instruct JAX how to pack back into a `RegisteredSpecial`.
)

Now you can traverse the special container structure:

jax.tree.map(lambda x: x + 1,
  [
   RegisteredSpecial(0, 1),
   RegisteredSpecial(2, 4),
  ])
[RegisteredSpecial(x=1, y=2), RegisteredSpecial(x=3, y=5)]

Modern Python comes equipped with helpful tools to make defining containers easier. Some will work with JAX out-of-the-box, but others require more care.

For instance, a Python NamedTuple subclass doesn’t need to be registered to be considered a pytree node type:

from typing import NamedTuple, Any

class MyOtherContainer(NamedTuple):
  name: str
  a: Any
  b: Any
  c: Any

# NamedTuple subclasses are handled as pytree nodes, so
# this will work out-of-the-box.
jax.tree.leaves([
    MyOtherContainer('Alice', 1, 2, 3),
    MyOtherContainer('Bob', 4, 5, 6)
])
['Alice', 1, 2, 3, 'Bob', 4, 5, 6]

Notice that the name field now appears as a leaf, because all tuple elements are children. This is what happens when you don’t have to register the class the hard way.

Pytrees and JAX transformations#

Many JAX functions, like jax.lax.scan(), operate over pytrees of arrays. In addition, all JAX function transformations can be applied to functions that accept as input and produce as output pytrees of arrays.

Some JAX function transformations take optional parameters that specify how certain input or output values should be treated (such as the in_axes and out_axes arguments to jax.vmap()). These parameters can also be pytrees, and their structure must correspond to the pytree structure of the corresponding arguments. In particular, to be able to “match up” leaves in these parameter pytrees with values in the argument pytrees, the parameter pytrees are often constrained to be tree prefixes of the argument pytrees.

For example, if you pass the following input to jax.vmap() (note that the input arguments to a function are considered a tuple):

vmap(f, in_axes=(a1, {"k1": a2, "k2": a3}))

then you can use the following in_axes pytree to specify that only the k2 argument is mapped (axis=0), and the rest aren’t mapped over (axis=None):

vmap(f, in_axes=(None, {"k1": None, "k2": 0}))

The optional parameter pytree structure must match that of the main input pytree. However, the optional parameters can optionally be specified as a “prefix” pytree, meaning that a single leaf value can be applied to an entire sub-pytree.

For example, if you have the same jax.vmap() input as above, but wish to only map over the dictionary argument, you can use:

vmap(f, in_axes=(None, 0))  # equivalent to (None, {"k1": 0, "k2": 0})

Alternatively, if you want every argument to be mapped, you can write a single leaf value that is applied over the entire argument tuple pytree:

vmap(f, in_axes=0)  # equivalent to (0, {"k1": 0, "k2": 0})

This happens to be the default in_axes value for jax.vmap().

The same logic applies to other optional parameters that refer to specific input or output values of a transformed function, such as out_axes in jax.vmap().

Explicit key paths#

In a pytree each leaf has a key path. A key path for a leaf is a list of keys, where the length of the list is equal to the depth of the leaf in the pytree . Each key is a hashable object that represents an index into the corresponding pytree node type. The type of the key depends on the pytree node type; for example, the type of keys for dicts is different from the type of keys for tuples.

For built-in pytree node types, the set of keys for any pytree node instance is unique. For a pytree comprising nodes with this property, the key path for each leaf is unique.

JAX has the following jax.tree_util.* methods for working with key paths:

For example, one use case is to print debugging information related to a certain leaf value:

import collections

ATuple = collections.namedtuple("ATuple", ('name'))

tree = [1, {'k1': 2, 'k2': (3, 4)}, ATuple('foo')]
flattened, _ = jax.tree_util.tree_flatten_with_path(tree)

for key_path, value in flattened:
  print(f'Value of tree{jax.tree_util.keystr(key_path)}: {value}')
Value of tree[0]: 1
Value of tree[1]['k1']: 2
Value of tree[1]['k2'][0]: 3
Value of tree[1]['k2'][1]: 4
Value of tree[2].name: foo

To express key paths, JAX provides a few default key types for the built-in pytree node types, namely:

  • SequenceKey(idx: int): For lists and tuples.

  • DictKey(key: Hashable): For dictionaries.

  • GetAttrKey(name: str): For namedtuples and preferably custom pytree nodes (more in the next section)

You are free to define your own key types for your custom nodes. They will work with jax.tree_util.keystr() as long as their __str__() method is also overridden with a reader-friendly expression.

for key_path, _ in flattened:
  print(f'Key path of tree{jax.tree_util.keystr(key_path)}: {repr(key_path)}')
Key path of tree[0]: (SequenceKey(idx=0),)
Key path of tree[1]['k1']: (SequenceKey(idx=1), DictKey(key='k1'))
Key path of tree[1]['k2'][0]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=0))
Key path of tree[1]['k2'][1]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=1))
Key path of tree[2].name: (SequenceKey(idx=2), GetAttrKey(name='name'))

Common pytree gotchas#

This section covers some of the most common problems (“gotchas”) encountered when using JAX pytrees.

Mistaking pytree nodes for leaves#

A common gotcha to look out for is accidentally introducing tree nodes instead of leaves:

a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))]

# Try to make another pytree with ones instead of zeros.
shapes = jax.tree.map(lambda x: x.shape, a_tree)
jax.tree.map(jnp.ones, shapes)
[(Array([1., 1.], dtype=float32), Array([1., 1., 1.], dtype=float32)),
 (Array([1., 1., 1.], dtype=float32), Array([1., 1., 1., 1.], dtype=float32))]

What happened here is that the shape of an array is a tuple, which is a pytree node, with its elements as leaves. Thus, in the map, instead of calling jnp.ones on e.g. (2, 3), it’s called on 2 and 3.

The solution will depend on the specifics, but there are two broadly applicable options:

  • Rewrite the code to avoid the intermediate jax.tree.map().

  • Convert the tuple into a NumPy array (np.array) or a JAX NumPy array (jnp.array), which makes the entire sequence a leaf.

Handling of None by jax.tree_util#

jax.tree_util functions treat None as the absence of a pytree node, not as a leaf:

jax.tree.leaves([None, None, None])
[]

To treat None as a leaf, you can use the is_leaf argument:

jax.tree.leaves([None, None, None], is_leaf=lambda x: x is None)
[None, None, None]
Custom pytrees and initialization with unexpected values#

Another common gotcha with user-defined pytree objects is that JAX transformations occasionally initialize them with unexpected values, so that any input validation done at initialization may fail. For example:

class MyTree:
  def __init__(self, a):
    self.a = jnp.asarray(a)

register_pytree_node(MyTree, lambda tree: ((tree.a,), None),
    lambda _, args: MyTree(*args))

tree = MyTree(jnp.arange(5.0))

jax.vmap(lambda x: x)(tree)      # Error because object() is passed to `MyTree`.
TypeError: Cannot interpret '<object object at 0x7f62b3463df0>' as a data type


The above exception was the direct cause of the following exception:

TypeError: Cannot determine dtype of <object object at 0x7f62b3463df0>


During handling of the above exception, another exception occurred:

TypeError: Value '<object object at 0x7f62b3463df0>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
jax.jacobian(lambda x: x)(tree)  # Error because MyTree(...) is passed to `MyTree`.
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:2662: FutureWarning: None encountered in jnp.array(); this is currently treated as NaN. In the future this will result in an error.
  return array(a, dtype=dtype, copy=bool(copy), order=order)  # type: ignore
TypeError: Cannot interpret '<object object at 0x7f629c32c360>' as a data type


The above exception was the direct cause of the following exception:

TypeError: Cannot determine dtype of <object object at 0x7f629c32c360>


During handling of the above exception, another exception occurred:

TypeError: Value '<object object at 0x7f629c32c360>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
  • In the first case with jax.vmap(...)(tree), JAX’s internals use arrays of object() values to infer the structure of the tree

  • In the second case with jax.jacobian(...)(tree), the Jacobian of a function mapping a tree to a tree is defined as a tree of trees.

Potential solution 1:

  • The __init__ and __new__ methods of custom pytree classes should generally avoid doing any array conversion or other input validation, or else anticipate and handle these special cases. For example:

class MyTree:
  def __init__(self, a):
    if not (type(a) is object or a is None or isinstance(a, MyTree)):
      a = jnp.asarray(a)
    self.a = a

Potential solution 2:

  • Structure your custom tree_unflatten function so that it avoids calling __init__. If you choose this route, make sure that your tree_unflatten function stays in sync with __init__ if and when the code is updated. Example:

def tree_unflatten(aux_data, children):
  del aux_data  # Unused in this class.
  obj = object.__new__(MyTree)
  obj.a = a
  return obj

Common pytree patterns#

This section covers some of the most common patterns with JAX pytrees.

Transposing pytrees with jax.tree.map and jax.tree.transpose#

To transpose a pytree (turn a list of trees into a tree of lists), JAX has two functions: {func} jax.tree.map (more basic) and jax.tree.transpose() (more flexible, complex and verbose).

Option 1: Use jax.tree.map(). Here’s an example:

def tree_transpose(list_of_trees):
  """
  Converts a list of trees of identical structure into a single tree of lists.
  """
  return jax.tree.map(lambda *xs: list(xs), *list_of_trees)

# Convert a dataset from row-major to column-major.
episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)]
tree_transpose(episode_steps)
{'obs': [3, 4], 't': [1, 2]}

Option 2: For more complex transposes, use jax.tree.transpose(), which is more verbose, but allows you specify the structure of the inner and outer pytree for more flexibility. For example:

jax.tree.transpose(
  outer_treedef = jax.tree.structure([0 for e in episode_steps]),
  inner_treedef = jax.tree.structure(episode_steps[0]),
  pytree_to_transpose = episode_steps
)
{'obs': [3, 4], 't': [1, 2]}

Introduction to sharded computation#

JAX’s jax.Array object is designed with distributed data and computation in mind.

This section will cover three modes of parallel computation:

These examples will be run on Colab’s free TPU runtime, which provides eight devices to work with:

import jax
jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

Key concept: data sharding#

Key to all of the distributed computation approaches below is the concept of data sharding, which describes how data is laid out on the available devices.

Each concrete jax.Array object has a sharding attribute and a devices() method that can give you insight into how the underlying data are stored. In the simplest cases, arrays are sharded on a single device:

import jax.numpy as jnp
arr = jnp.arange(32.0).reshape(4, 8)
arr.devices()
{TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}
arr.sharding
SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0))

For a more visual representation of the storage layout, the jax.debug module provides some helpers to visualize the sharding of an array:

jax.debug.visualize_array_sharding(arr)
                                                  
                                                  
                                                  
                                                  
                                                  
                      TPU 0                       
                                                  
                                                  
                                                  
                                                  
                                                  

To create an array with a non-trivial sharding, we can define a sharding specification for the array and pass this to jax.device_put(). Here we’ll define a NamedSharding, which specifies an N-dimensional grid of devices with named axes:

# Pardon the boilerplate; constructing a sharding will become easier soon!
from jax.sharding import Mesh
from jax.sharding import PartitionSpec
from jax.sharding import NamedSharding
from jax.experimental import mesh_utils

P = jax.sharding.PartitionSpec
devices = mesh_utils.create_device_mesh((2, 4))
mesh = jax.sharding.Mesh(devices, ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
print(sharding)
NamedSharding(mesh=Mesh('x': 2, 'y': 4), spec=PartitionSpec('x', 'y'))

Passing this sharding to jax.device_put(), we obtain a sharded array:

arr_sharded = jax.device_put(arr, sharding)

print(arr_sharded)
jax.debug.visualize_array_sharding(arr_sharded)
[[ 0.  1.  2.  3.  4.  5.  6.  7.]
 [ 8.  9. 10. 11. 12. 13. 14. 15.]
 [16. 17. 18. 19. 20. 21. 22. 23.]
 [24. 25. 26. 27. 28. 29. 30. 31.]]
                                                
                                                
   TPU 0       TPU 1       TPU 2       TPU 3    
                                                
                                                
                                                
                                                
                                                
   TPU 6       TPU 7       TPU 4       TPU 5    
                                                
                                                
                                                

The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.

Automatic parallelism via jit#

Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a JIT-compiled function! The XLA compiler behind jit includes heuristics for optimizing computations across multiple devices. In the simplest of cases, those heuristics boil down to computation follows data.

For example, here’s a simple element-wise function: the computation for each shard will be performed on the device associated with that shard, and the output is sharded in the same way:

@jax.jit
def f_elementwise(x):
  return 2 * jnp.sin(x) + 1

result = f_elementwise(arr_sharded)

print("shardings match:", result.sharding == arr_sharded.sharding)
shardings match: True

As computations get more complex, the compiler makes decisions about how to best propagate the sharding of the data. Here we sum along the leading axis of x:

@jax.jit
def f_contract(x):
  return x.sum(axis=0)

result = f_contract(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)
 TPU 0,6  TPU 1,7  TPU 2,4  TPU 3,5 
                                    
[48. 52. 56. 60. 64. 68. 72. 76.]

The result is partially replicated: that is, the first two elements of the array are replicated on devices 0 and 6, the second on 1 and 7, and so on.

Semi-automated sharding with constraints#

If you’d like to have some control over the sharding used within a particular computation, JAX offers the with_sharding_constraint() function.

For example, suppose that within f_contract above, you’d prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices:

@jax.jit
def f_contract_2(x):
  out = x.sum(axis=0)
  # mesh = jax.create_mesh((8,), 'x')
  devices = mesh_utils.create_device_mesh(8)
  mesh = jax.sharding.Mesh(devices, 'x')
  sharding = jax.sharding.NamedSharding(mesh, P('x'))
  return jax.lax.with_sharding_constraint(out, sharding)

result = f_contract_2(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)
  TPU 0    TPU 1    TPU 2    TPU 3    TPU 6    TPU 7    TPU 4    TPU 5  
                                                                        
[48. 52. 56. 60. 64. 68. 72. 76.]

This gives you a function with the particular output sharding you’d like.

Manual parallelism with shard_map#

In the automatic parallelism methods explored above, you can write a function as if you’re operating on the full dataset, and jit will split that computation across multiple devices. By contrast, with shard_map you write the function that will handle a single shard of data, and shard_map will construct the full function.

shard_map works by mapping a function across a particular mesh of devices:

from jax.experimental.shard_map import shard_map
P = jax.sharding.PartitionSpec
mesh = jax.sharding.Mesh(jax.devices(), 'x')

f_elementwise_sharded = shard_map(
    f_elementwise,
    mesh=mesh,
    in_specs=P('x'),
    out_specs=P('x'))

arr = jnp.arange(32)
f_elementwise_sharded(arr)
Array([ 1.        ,  2.682942  ,  2.818595  ,  1.28224   , -0.513605  ,
       -0.9178486 ,  0.44116896,  2.3139732 ,  2.9787164 ,  1.824237  ,
       -0.08804226, -0.99998045, -0.07314599,  1.8403342 ,  2.9812148 ,
        2.3005757 ,  0.42419332, -0.92279506, -0.50197446,  1.2997544 ,
        2.8258905 ,  2.6733112 ,  0.98229736, -0.69244075, -0.81115675,
        0.7352965 ,  2.525117  ,  2.912752  ,  1.5418116 , -0.32726777,
       -0.97606325,  0.19192469], dtype=float32)

The function you write only “sees” a single batch of the data, which we can see by printing the device local shape:

x = jnp.arange(32)
print(f"global shape: {x.shape=}")

def f(x):
  print(f"device local shape: {x.shape=}")
  return x * 2

y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
global shape: x.shape=(32,)
device local shape: x.shape=(4,)

Because each of your functions only sees the device-local part of the data, it means that aggregation-like functions require some extra thought. For example, here’s what a shard_map of a sum looks like:

def f(x):
  return jnp.sum(x, keepdims=True)

shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
Array([  6,  22,  38,  54,  70,  86, 102, 118], dtype=int32)

Our function f operates separately on each shard, and the resulting summation reflects this. If we want to sum across shards, we need to explicitly request it using collective operations like jax.lax.psum():

def f(x):
  sum_in_shard = x.sum()
  return jax.lax.psum(sum_in_shard, 'x')

shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x)
Array(496, dtype=int32)

Because the output no longer has a sharded dimension, we set out_specs=P().

Comparing the three approaches#

With these concepts fresh in our mind, let’s compare the three approaches for a simple neural network layer. We’ll define our canonical function like this:

@jax.jit
def layer(x, weights, bias):
  return jax.nn.sigmoid(x @ weights + bias)
import numpy as np
rng = np.random.default_rng(0)

x = rng.normal(size=(32,))
weights = rng.normal(size=(32, 4))
bias = rng.normal(size=(4,))

layer(x, weights, bias)
Array([0.02138912, 0.893112  , 0.59892005, 0.97742504], dtype=float32)

We can automatically run this in a distributed manner using jax.jit() and passing appropriately sharded data. If we shard the leading axis of both x and weights in the same way, then the matrix multiplication will autoatically happen in parallel:

P = jax.sharding.PartitionSpec
mesh = jax.sharding.Mesh(jax.devices(), 'x')
sharding = jax.sharding.NamedSharding(mesh, P('x'))

x_sharded = jax.device_put(x, sharding)
weights_sharded = jax.device_put(weights, sharding)

layer(x_sharded, weights_sharded, bias)
Array([0.02138912, 0.893112  , 0.59892005, 0.97742504], dtype=float32)

Alternatively, we can use jax.lax.with_sharding_constraint() in the function to automatically distribute unsharded inputs:

@jax.jit
def layer_auto(x, weights, bias):
  x = jax.lax.with_sharding_constraint(x, sharding)
  weights = jax.lax.with_sharding_constraint(weights, sharding)
  return layer(x, weights, bias)

layer_auto(x, weights, bias)  # pass in unsharded inputs
Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32)

Finally, we can do the same thing with shard_map, using psum to indicate the cross-shard collective required for the matrix product:

from functools import partial

@jax.jit
@partial(shard_map, mesh=mesh,
         in_specs=(P('x'), P('x', None), P(None)),
         out_specs=P(None))
def layer_sharded(x, weights, bias):
  return jax.nn.sigmoid(jax.lax.psum(x @ weights, 'x') + bias)

layer_sharded(x, weights, bias)
Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32)

This section has been a brief introduction of sharded and parallel computation; for more discussion of shard_map, see SPMD multi-device parallelism with shard_map.

Stateful Computations#

JAX transformations like jit(), vmap(), grad(), require the functions they wrap to be pure: that is, functions whose outputs depend solely on the inputs, and which have no side effects such as updating of global state. You can find a discussion of this in JAX sharp bits: Pure functions.

This constraint can pose some challenges in the context of machine learning, where state may exist in many forms. For example:

  • model parameters,

  • optimizer state, and

  • stateful layers, such as BatchNorm.

This section offers some advice of how to properly handle state in a JAX program.

A simple example: Counter#

Let’s start by looking at a simple stateful program: a counter.

import jax
import jax.numpy as jnp

class Counter:
  """A simple counter."""

  def __init__(self):
    self.n = 0

  def count(self) -> int:
    """Increments the counter and returns the new value."""
    self.n += 1
    return self.n

  def reset(self):
    """Resets the counter to zero."""
    self.n = 0


counter = Counter()

for _ in range(3):
  print(counter.count())
1
2
3

The counter’s n attribute maintains the counter’s state between successive calls of count. It is modified as a side effect of calling count.

Let’s say we want to count fast, so we JIT-compile the count method. (In this example, this wouldn’t actually help speed anyway, for many reasons, but treat this as a toy model of JIT-compiling the update of model parameters, where jit() makes an enormous difference).

counter.reset()
fast_count = jax.jit(counter.count)

for _ in range(3):
  print(fast_count())
1
1
1

Oh no! Our counter isn’t working. This is because the line

self.n += 1

in count involves a side effect: it modifies the input counter in-place, and so this function is not supported by jit. Such side effects are executed only once when the function is first traced, and subsequent calls will not repeat the side effect. So, how do we fix it?

The solution: explicit state#

Part of the problem with our counter was that the returned value didn’t depend on the arguments, meaning a constant was “baked into” the compiled output. But it shouldn’t be a constant – it should depend on the state. Well, then why don’t we make the state into an argument?

CounterState = int

class CounterV2:

  def count(self, n: CounterState) -> tuple[int, CounterState]:
    # You could just return n+1, but here we separate its role as 
    # the output and as the counter state for didactic purposes.
    return n+1, n+1

  def reset(self) -> CounterState:
    return 0

counter = CounterV2()
state = counter.reset()

for _ in range(3):
  value, state = counter.count(state)
  print(value)
1
2
3

In this new version of Counter, we moved n to be an argument of count, and added another return value that represents the new, updated, state. To use this counter, we now need to keep track of the state explicitly. But in return, we can now safely jax.jit this counter:

state = counter.reset()
fast_count = jax.jit(counter.count)

for _ in range(3):
  value, state = fast_count(state)
  print(value)
1
2
3

A general strategy#

We can apply the same process to any stateful method to convert it into a stateless one. We took a class of the form

class StatefulClass

  state: State

  def stateful_method(*args, **kwargs) -> Output:

and turned it into a class of the form

class StatelessClass

  def stateless_method(state: State, *args, **kwargs) -> (Output, State):

This is a common functional programming pattern, and, essentially, is the way that state is handled in all JAX programs.

Notice that the need for a class becomes less clear once we have rewritten it this way. We could just keep stateless_method, since the class is no longer doing any work. This is because, like the strategy we just applied, object-oriented programming (OOP) is a way to help programmers understand program state.

In our case, the CounterV2 class is nothing more than a namespace bringing all the functions that use CounterState into one location. Exercise for the reader: do you think it makes sense to keep it as a class?

Incidentally, you’ve already seen an example of this strategy in the JAX pseudo-randomness API, jax.random, shown in the :ref:pseudorandom-numbers section. Unlike Numpy, which manages random state using implicitly updated stateful classes, JAX requires the programmer to work directly with the random generator state – the PRNG key.

Simple worked example: Linear Regression#

Let’s apply this strategy to a simple machine learning model: linear regression via gradient descent.

Here, we only deal with one kind of state: the model parameters. But generally, you’ll see many kinds of state being threaded in and out of JAX functions, like optimizer state, layer statistics for batchnorm, and others.

The function to look at carefully is update.

from typing import NamedTuple

class Params(NamedTuple):
  weight: jnp.ndarray
  bias: jnp.ndarray


def init(rng) -> Params:
  """Returns the initial model params."""
  weights_key, bias_key = jax.random.split(rng)
  weight = jax.random.normal(weights_key, ())
  bias = jax.random.normal(bias_key, ())
  return Params(weight, bias)


def loss(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
  """Computes the least squares error of the model's predictions on x against y."""
  pred = params.weight * x + params.bias
  return jnp.mean((pred - y) ** 2)


LEARNING_RATE = 0.005

@jax.jit
def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params:
  """Performs one SGD update step on params using the given data."""
  grad = jax.grad(loss)(params, x, y)

  # If we were using Adam or another stateful optimizer,
  # we would also do something like
  #
  #   updates, new_optimizer_state = optimizer(grad, optimizer_state)
  # 
  # and then use `updates` instead of `grad` to actually update the params.
  # (And we'd include `new_optimizer_state` in the output, naturally.)

  new_params = jax.tree_map(
      lambda param, g: param - g * LEARNING_RATE, params, grad)

  return new_params

Notice that we manually pipe the params in and out of the update function.

import matplotlib.pyplot as plt

rng = jax.random.key(42)

# Generate true data from y = w*x + b + noise
true_w, true_b = 2, -1
x_rng, noise_rng = jax.random.split(rng)
xs = jax.random.normal(x_rng, (128, 1))
noise = jax.random.normal(noise_rng, (128, 1)) * 0.5
ys = xs * true_w + true_b + noise

# Fit regression
params = init(rng)
for _ in range(1000):
  params = update(params, xs, ys)

plt.scatter(xs, ys)
plt.plot(xs, params.weight * xs + params.bias, c='red', label='Model Prediction')
plt.legend();
/tmp/ipykernel_8678/721844192.py:37: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
  new_params = jax.tree_map(
_images/d689921faeecf18af95e84ce7f3510fa678f51be51796d5230a05cf0d1a95092.png

Taking it further#

The strategy described above is how any JAX program must handle state when using transformations like jit, vmap, grad, etc.

Handling parameters manually seems fine if you’re dealing with two parameters, but what if it’s a neural net with dozens of layers? You might already be getting worried about two things:

  1. Are we supposed to initialize them all manually, essentially repeating what we already write in the forward pass definition?

  2. Are we supposed to pipe all these things around manually?

The details can be tricky to handle, but there are examples of libraries that take care of this for you. See JAX Neural Network Libraries for some examples.

User Guides#

User guides are deeper dives into particular topics within JAX that become relevant as your JAX project matures into larger or deployed codebases.

How to Think in JAX#

Open in Colab Open in Kaggle

JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively.

JAX vs. NumPy#

Key Concepts:

  • JAX provides a NumPy-inspired interface for convenience.

  • Through duck-typing, JAX arrays can often be used as drop-in replacements of NumPy arrays.

  • Unlike NumPy arrays, JAX arrays are always immutable.

NumPy provides a well-known, powerful API for working with numerical data. For convenience, JAX provides jax.numpy which closely mirrors the numpy API and provides easy entry into JAX. Almost anything that can be done with numpy can be done with jax.numpy:

import matplotlib.pyplot as plt
import numpy as np

x_np = np.linspace(0, 10, 1000)
y_np = 2 * np.sin(x_np) * np.cos(x_np)
plt.plot(x_np, y_np);
_images/75c6f5221479f5438bd4027fc2a073973120a0c097749d3eec8603253c3d729a.png
import jax.numpy as jnp

x_jnp = jnp.linspace(0, 10, 1000)
y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)
plt.plot(x_jnp, y_jnp);
_images/c4c8ae1471a7af895f4a3ef289a5f377b179ecd01642447d013c9d7f57d23caf.png

The code blocks are identical aside from replacing np with jnp, and the results are the same. As we can see, JAX arrays can often be used directly in place of NumPy arrays for things like plotting.

The arrays themselves are implemented as different Python types:

type(x_np)
numpy.ndarray
type(x_jnp)
jaxlib.xla_extension.ArrayImpl

Python’s duck-typing allows JAX arrays and NumPy arrays to be used interchangeably in many places.

However, there is one important difference between JAX and NumPy arrays: JAX arrays are immutable, meaning that once created their contents cannot be changed.

Here is an example of mutating an array in NumPy:

# NumPy: mutable arrays
x = np.arange(10)
x[0] = 10
print(x)
[10  1  2  3  4  5  6  7  8  9]

The equivalent in JAX results in an error, as JAX arrays are immutable:

%xmode minimal
Exception reporting mode: Minimal
# JAX: immutable arrays
x = jnp.arange(10)
x[0] = 10
TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

For updating individual elements, JAX provides an indexed update syntax that returns an updated copy:

y = x.at[0].set(10)
print(x)
print(y)
[0 1 2 3 4 5 6 7 8 9]
[10  1  2  3  4  5  6  7  8  9]

NumPy, lax & XLA: JAX API layering#

Key Concepts:

  • jax.numpy is a high-level wrapper that provides a familiar interface.

  • jax.lax is a lower-level API that is stricter and often more powerful.

  • All JAX operations are implemented in terms of operations in XLA – the Accelerated Linear Algebra compiler.

If you look at the source of jax.numpy, you’ll see that all the operations are eventually expressed in terms of functions defined in jax.lax. You can think of jax.lax as a stricter, but often more powerful, API for working with multi-dimensional arrays.

For example, while jax.numpy will implicitly promote arguments to allow operations between mixed data types, jax.lax will not:

import jax.numpy as jnp
jnp.add(1, 1.0)  # jax.numpy API implicitly promotes mixed types.
Array(2., dtype=float32, weak_type=True)
from jax import lax
lax.add(1, 1.0)  # jax.lax API requires explicit type promotion.
MLIRError: Verification failed:
error: "jit(add)/jit(main)/add"(callsite("<module>"("/tmp/ipykernel_8133/3435837498.py":2:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at callsite("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0) at callsite("_run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3130:0) at callsite("run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3075:0) at callsite("run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/zmqshell.py":549:0) at callsite("do_execute"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/ipkernel.py":449:0) at "execute_request"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/kernelbase.py":778:0))))))))))): op requires the same element type for all operands and results


The above exception was the direct cause of the following exception:

ValueError: Cannot lower jaxpr with verifier errors:
	op requires the same element type for all operands and results
		at loc("jit(add)/jit(main)/add"(callsite("<module>"("/tmp/ipykernel_8133/3435837498.py":2:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at callsite("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0) at callsite("_run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3130:0) at callsite("run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3075:0) at callsite("run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/zmqshell.py":549:0) at callsite("do_execute"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/ipkernel.py":449:0) at "execute_request"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/kernelbase.py":778:0))))))))))))Define JAX_DUMP_IR_TO to dump the module.

If using jax.lax directly, you’ll have to do type promotion explicitly in such cases:

lax.add(jnp.float32(1), 1.0)
Array(2., dtype=float32)

Along with this strictness, jax.lax also provides efficient APIs for some more general operations than are supported by NumPy.

For example, consider a 1D convolution, which can be expressed in NumPy this way:

x = jnp.array([1, 2, 1])
y = jnp.ones(10)
jnp.convolve(x, y)
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)

Under the hood, this NumPy operation is translated to a much more general convolution implemented by lax.conv_general_dilated:

from jax import lax
result = lax.conv_general_dilated(
    x.reshape(1, 1, 3).astype(float),  # note: explicit promotion
    y.reshape(1, 1, 10),
    window_strides=(1,),
    padding=[(len(y) - 1, len(y) - 1)])  # equivalent of padding='full' in NumPy
result[0, 0]
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)

This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (See Convolutions in JAX for more detail on JAX convolutions).

At their heart, all jax.lax operations are Python wrappers for operations in XLA; here, for example, the convolution implementation is provided by XLA:ConvWithGeneralPadding. Every JAX operation is eventually expressed in terms of these fundamental XLA operations, which is what enables just-in-time (JIT) compilation.

To JIT or not to JIT#

Key Concepts:

  • By default JAX executes operations one at a time, in sequence.

  • Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once.

  • Not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time.

The fact that all JAX operations are expressed in terms of XLA allows JAX to use the XLA compiler to execute blocks of code very efficiently.

For example, consider this function that normalizes the rows of a 2D matrix, expressed in terms of jax.numpy operations:

import jax.numpy as jnp

def norm(X):
  X = X - X.mean(0)
  return X / X.std(0)

A just-in-time compiled version of the function can be created using the jax.jit transform:

from jax import jit
norm_compiled = jit(norm)

This function returns the same results as the original, up to standard floating-point accuracy:

np.random.seed(1701)
X = jnp.array(np.random.rand(10000, 10))
np.allclose(norm(X), norm_compiled(X), atol=1E-6)
True

But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case (note the use of block_until_ready() to account for JAX’s asynchronous dispatch):

%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready()
347 µs ± 2.11 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
278 µs ± 1.77 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

That said, jax.jit does have limitations: in particular, it requires all arrays to have static shapes. That means that some JAX operations are incompatible with JIT compilation.

For example, this operation can be executed in op-by-op mode:

def get_negatives(x):
  return x[x < 0]

x = jnp.array(np.random.randn(10))
get_negatives(x)
Array([-0.10570311, -0.59403396, -0.8680282 , -0.23489487], dtype=float32)

But it returns an error if you attempt to execute it in jit mode:

jit(get_negatives)(x)
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[10])

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError

This is because the function generates an array whose shape is not known at compile time: the size of the output depends on the values of the input array, and so it is not compatible with JIT.

JIT mechanics: tracing and static variables#

Key Concepts:

  • JIT and other JAX transforms work by tracing a function to determine its effect on inputs of a specific shape and type.

  • Variables that you don’t want to be traced can be marked as static

To use jax.jit effectively, it is useful to understand how it works. Let’s put a few print() statements within a JIT-compiled function and then call the function:

@jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  result = jnp.dot(x + 1, y + 1)
  print(f"  result = {result}")
  return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)
Running f():
  x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=1/0)>
  y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
  result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>
Array([0.25773212, 5.3623195 , 5.403243  ], dtype=float32)

Notice that the print statements execute, but rather than printing the data we passed to the function, though, it prints tracer objects that stand-in for them.

These tracer objects are what jax.jit uses to extract the sequence of operations specified by the function. Basic tracers are stand-ins that encode the shape and dtype of the arrays, but are agnostic to the values. This recorded sequence of computations can then be efficiently applied within XLA to new inputs with the same shape and dtype, without having to re-execute the Python code.

When we call the compiled function again on matching inputs, no re-compilation is required and nothing is printed because the result is computed in compiled XLA rather than in Python:

x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
f(x2, y2)
Array([1.4344584, 4.3004413, 7.9897013], dtype=float32)

The extracted sequence of operations is encoded in a JAX expression, or jaxpr for short. You can view the jaxpr using the jax.make_jaxpr transformation:

from jax import make_jaxpr

def f(x, y):
  return jnp.dot(x + 1, y + 1)

make_jaxpr(f)(x, y)
{ lambda ; a:f32[3,4] b:f32[4]. let
    c:f32[3,4] = add a 1.0
    d:f32[4] = add b 1.0
    e:f32[3] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] c d
  in (e,) }

Note one consequence of this: because JIT compilation is done without information on the content of the array, control flow statements in the function cannot depend on traced values. For example, this fails:

@jit
def f(x, neg):
  return -x if neg else x

f(1, True)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function f at /tmp/ipykernel_8133/2422663986.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

If there are variables that you would not like to be traced, they can be marked as static for the purposes of JIT compilation:

from functools import partial

@partial(jit, static_argnums=(1,))
def f(x, neg):
  return -x if neg else x

f(1, True)
Array(-1, dtype=int32, weak_type=True)

Note that calling a JIT-compiled function with a different static argument results in re-compilation, so the function still works as expected:

f(1, False)
Array(1, dtype=int32, weak_type=True)

Understanding which values and operations will be static and which will be traced is a key part of using jax.jit effectively.

Static vs Traced Operations#

Key Concepts:

  • Just as values can be either static or traced, operations can be static or traced.

  • Static operations are evaluated at compile-time in Python; traced operations are compiled & evaluated at run-time in XLA.

  • Use numpy for operations that you want to be static; use jax.numpy for operations that you want to be traced.

This distinction between static and traced values makes it important to think about how to keep a static value static. Consider this function:

import jax.numpy as jnp
from jax import jit

@jit
def f(x):
  return x.reshape(jnp.array(x.shape).prod())

x = jnp.ones((2, 3))
f(x)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function f at /tmp/ipykernel_8133/1983583872.py:4 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i32[2] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line /tmp/ipykernel_8133/1983583872.py:6 (f)

This fails with an error specifying that a tracer was found instead of a 1D sequence of concrete values of integer type. Let’s add some print statements to the function to understand why this is happening:

@jit
def f(x):
  print(f"x = {x}")
  print(f"x.shape = {x.shape}")
  print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}")
  # comment this out to avoid the error:
  # return x.reshape(jnp.array(x.shape).prod())

f(x)
x = Traced<ShapedArray(float32[2,3])>with<DynamicJaxprTrace(level=1/0)>
x.shape = (2, 3)
jnp.array(x.shape).prod() = Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>

Notice that although x is traced, x.shape is a static value. However, when we use jnp.array and jnp.prod on this static value, it becomes a traced value, at which point it cannot be used in a function like reshape() that requires a static input (recall: array shapes must be static).

A useful pattern is to use numpy for operations that should be static (i.e. done at compile-time), and use jax.numpy for operations that should be traced (i.e. compiled and executed at run-time). For this function, it might look like this:

from jax import jit
import jax.numpy as jnp
import numpy as np

@jit
def f(x):
  return x.reshape((np.prod(x.shape),))

f(x)
Array([1., 1., 1., 1., 1., 1.], dtype=float32)

For this reason, a standard convention in JAX programs is to import numpy as np and import jax.numpy as jnp so that both interfaces are available for finer control over whether operations are performed in a static matter (with numpy, once at compile-time) or a traced manner (with jax.numpy, optimized at run-time).

Profiling JAX programs#

Viewing program traces with Perfetto#

We can use the JAX profiler to generate traces of a JAX program that can be visualized using the Perfetto visualizer. Currently, this method blocks the program until a link is clicked and the Perfetto UI loads the trace. If you wish to get profiling information without any interaction, check out the Tensorboard profiler below.

with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
  # Run the operations to be profiled
  key = jax.random.key(0)
  x = jax.random.normal(key, (5000, 5000))
  y = x @ x
  y.block_until_ready()

After this computation is done, the program will prompt you to open a link to ui.perfetto.dev. When you open the link, the Perfetto UI will load the trace file and open a visualizer.

Perfetto trace viewer

Program execution will continue after loading the link. The link is no longer valid after opening once, but it will redirect to a new URL that remains valid. You can then click the “Share” button in the Perfetto UI to create a permalink to the trace that can be shared with others.

Remote profiling#

When profiling code that is running remotely (for example on a hosted VM), you need to establish an SSH tunnel on port 9001 for the link to work. You can do that with this command:

$ ssh -L 9001:127.0.0.1:9001 <user>@<host>

or if you’re using Google Cloud:

$ gcloud compute ssh <machine-name> -- -L 9001:127.0.0.1:9001
Manual capture#

Instead of capturing traces programmatically using jax.profiler.trace, you can instead start a profiling server in the script of interest by calling jax.profiler.start_server(<port>). If you only need the profiler server to be active for a portion of your script, you can shut it down by calling jax.profiler.stop_server().

Once the script is running and after the profiler server has started, we can manually capture and trace by running:

$ python -m jax.collect_profile <port> <duration_in_ms>

By default, the resulting trace information is dumped into a temporary directory but this can be overridden by passing in --log_dir=<directory of choice>. Also, by default, the program will prompt you to open a link to ui.perfetto.dev. When you open the link, the Perfetto UI will load the trace file and open a visualizer. This feature is disabled by passing in --no_perfetto_link into the command. Alternatively, you can also point Tensorboard to the log_dir to analyze the trace (see the “Tensorboard Profiling” section below).

TensorBoard profiling#

TensorBoard’s profiler can be used to profile JAX programs. Tensorboard is a great way to acquire and visualize performance traces and profiles of your program, including activity on GPU and TPU. The end result looks something like this:

TensorBoard profiler example

Installation#

The TensorBoard profiler is only available with the version of TensorBoard bundled with TensorFlow.

pip install tensorflow tensorboard-plugin-profile

If you already have TensorFlow installed, you only need to install the tensorboard-plugin-profile pip package. Be careful to only install one version of TensorFlow or TensorBoard, otherwise you may encounter the “duplicate plugins” error described below. See https://www.tensorflow.org/guide/profiler for more information on installing TensorBoard.

Programmatic capture#

You can instrument your code to capture a profiler trace via the jax.profiler.start_trace() and jax.profiler.stop_trace() methods. Call start_trace() with the directory to write trace files to. This should be the same --logdir directory used to start TensorBoard. Then, you can use TensorBoard to view the traces.

For example, to take a profiler trace:

import jax

jax.profiler.start_trace("/tmp/tensorboard")

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

Note the block_until_ready() call. We use this to make sure on-device execution is captured by the trace. See Asynchronous dispatch for details on why this is necessary.

You can also use the jax.profiler.trace() context manager as an alternative to start_trace and stop_trace:

import jax

with jax.profiler.trace("/tmp/tensorboard"):
  key = jax.random.key(0)
  x = jax.random.normal(key, (5000, 5000))
  y = x @ x
  y.block_until_ready()

To view the trace, first start TensorBoard if you haven’t already:

$ tensorboard --logdir=/tmp/tensorboard
[...]
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.5.0 at http://localhost:6006/ (Press CTRL+C to quit)

You should be able to load TensorBoard at http://localhost:6006/ in this example. You can specify a different port with the --port flag. See Profiling on a remote machine below if running JAX on a remote server.

Then, either select “Profile” in the upper-right dropdown menu, or go directly to http://localhost:6006/#profile. Available traces appear in the “Runs” dropdown menu on the left. Select the run you’re interested in, and then under “Tools”, select trace_viewer. You should now see a timeline of the execution. You can use the WASD keys to navigate the trace, and click or drag to select events to see more details at the bottom. See these TensorFlow docs for more details on using the trace viewer.

You can also use the memory_viewer, op_profile, and graph_viewer tools.

Manual capture via TensorBoard#

The following are instructions for capturing a manually-triggered N-second trace from a running program.

  1. Start a TensorBoard server:

    tensorboard --logdir /tmp/tensorboard/
    

    You should be able to load TensorBoard at http://localhost:6006/. You can specify a different port with the --port flag. See Profiling on a remote machine below if running JAX on a remote server.

  2. In the Python program or process you’d like to profile, add the following somewhere near the beginning:

    import jax.profiler
    jax.profiler.start_server(9999)
    

    This starts the profiler server that TensorBoard connects to. The profiler server must be running before you move on to the next step. When you’re done using the server, you can call jax.profiler.stop_server() to shut it down.

    If you’d like to profile a snippet of a long-running program (e.g. a long training loop), you can put this at the beginning of the program and start your program as usual. If you’d like to profile a short program (e.g. a microbenchmark), one option is to start the profiler server in an IPython shell, and run the short program with %run after starting the capture in the next step. Another option is to start the profiler server at the beginning of the program and use time.sleep() to give you enough time to start the capture.

  3. Open http://localhost:6006/#profile, and click the “CAPTURE PROFILE” button in the upper left. Enter “localhost:9999” as the profile service URL (this is the address of the profiler server you started in the previous step). Enter the number of milliseconds you’d like to profile for, and click “CAPTURE”.

  4. If the code you’d like to profile isn’t already running (e.g. if you started the profiler server in a Python shell), run it while the capture is running.

  5. After the capture finishes, TensorBoard should automatically refresh. (Not all of the TensorBoard profiling features are hooked up with JAX, so it may initially look like nothing was captured.) On the left under “Tools”, select trace_viewer.

    You should now see a timeline of the execution. You can use the WASD keys to navigate the trace, and click or drag to select events to see more details at the bottom. See these TensorFlow docs for more details on using the trace viewer.

    You can also use the memory_viewer, op_profile, and graph_viewer tools.

Adding custom trace events#

By default, the events in the trace viewer are mostly low-level internal JAX functions. You can add your own events and functions by using jax.profiler.TraceAnnotation and jax.profiler.annotate_function() in your code.

Troubleshooting#
GPU profiling#

Programs running on GPU should produce traces for the GPU streams near the top of the trace viewer. If you’re only seeing the host traces, check your program logs and/or output for the following error messages.

If you get an error like: Could not load dynamic library 'libcupti.so.10.1'
Full error:

W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libcupti.so.10.1'; dlerror: libcupti.so.10.1: cannot open shared object file: No such file or directory
2020-06-12 13:19:59.822799: E external/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1422] function cupti_interface_->Subscribe( &subscriber_, (CUpti_CallbackFunc)ApiCallback, this)failed with error CUPTI could not be loaded or symbol could not be found.

Add the path to libcupti.so to the environment variable LD_LIBRARY_PATH. (Try locate libcupti.so to find the path.) For example:

export LD_LIBRARY_PATH=/usr/local/cuda-10.1/extras/CUPTI/lib64/:$LD_LIBRARY_PATH

If you still get the Could not load dynamic library message after doing this, check if the GPU trace shows up in the trace viewer anyway. This message sometimes occurs even when everything is working, since it looks for the libcupti library in multiple places.

If you get an error like: failed with error CUPTI_ERROR_INSUFFICIENT_PRIVILEGES
Full error:

E external/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1445] function cupti_interface_->EnableCallback( 0 , subscriber_, CUPTI_CB_DOMAIN_DRIVER_API, cbid)failed with error CUPTI_ERROR_INSUFFICIENT_PRIVILEGES
2020-06-12 14:31:54.097791: E external/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1487] function cupti_interface_->ActivityDisable(activity)failed with error CUPTI_ERROR_NOT_INITIALIZED

Run the following commands (note this requires a reboot):

echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"' | sudo tee -a /etc/modprobe.d/nvidia-kernel-common.conf
sudo update-initramfs -u
sudo reboot now

See NVIDIA’s documentation on this error for more information.

Profiling on a remote machine#

If the JAX program you’d like to profile is running on a remote machine, one option is to run all the instructions above on the remote machine (in particular, start the TensorBoard server on the remote machine), then use SSH local port forwarding to access the TensorBoard web UI from your local machine. Use the following SSH command to forward the default TensorBoard port 6006 from the local to the remote machine:

ssh -L 6006:localhost:6006 <remote server address>

or if you’re using Google Cloud:

$ gcloud compute ssh <machine-name> -- -L 6006:localhost:6006
Multiple TensorBoard installs#

If starting TensorBoard fails with an error like: ValueError: Duplicate plugins for name projector

It’s often because there are two versions of TensorBoard and/or TensorFlow installed (e.g. the tensorflow, tf-nightly, tensorboard, and tb-nightly pip packages all include TensorBoard). Uninstalling a single pip package can result in the tensorboard executable being removed which is then hard to replace, so it may be necessary to uninstall everything and reinstall a single version:

pip uninstall tensorflow tf-nightly tensorboard tb-nightly
pip install tensorflow

Nsight#

NVIDIA’s Nsight tools can be used to trace and profile JAX code on GPU. For details, see the Nsight documentation.

Device Memory Profiling#

Note

May 2023 update: we recommend using Tensorboard profiling for device memory analysis. After taking a profile, open the memory_viewer tab of the Tensorboard profiler for more detailed and understandable device memory usage.

The JAX Device Memory Profiler allows us to explore how and why JAX programs are using GPU or TPU memory. For example, it can be used to:

  • Figure out which arrays and executables are in GPU memory at a given time, or

  • Track down memory leaks.

Installation#

The JAX device memory profiler emits output that can be interpreted using pprof (google/pprof). Start by installing pprof, by following its installation instructions. At the time of writing, installing pprof requires first installing Go of version 1.16+, Graphviz, and then running

go install github.com/google/pprof@latest

which installs pprof as $GOPATH/bin/pprof, where GOPATH defaults to ~/go.

Note

The version of pprof from google/pprof is not the same as the older tool of the same name distributed as part of the gperftools package. The gperftools version of pprof will not work with JAX.

Understanding how a JAX program is using GPU or TPU memory#

A common use of the device memory profiler is to figure out why a JAX program is using a large amount of GPU or TPU memory, for example if trying to debug an out-of-memory problem.

To capture a device memory profile to disk, use jax.profiler.save_device_memory_profile(). For example, consider the following Python program:

import jax
import jax.numpy as jnp
import jax.profiler

def func1(x):
  return jnp.tile(x, 10) * 0.5

def func2(x):
  y = func1(x)
  return y, jnp.tile(x, 10) + 1

x = jax.random.normal(jax.random.key(42), (1000, 1000))
y, z = func2(x)

z.block_until_ready()

jax.profiler.save_device_memory_profile("memory.prof")

If we first run the program above and then execute

pprof --web memory.prof

pprof opens a web browser containing the following visualization of the device memory profile in callgraph format:

Device memory profiling example

The callgraph is a visualization of the Python stack at the point the allocation of each live buffer was made. For example, in this specific case, the visualization shows that func2 and its callees were responsible for allocating 76.30MB, of which 38.15MB was allocated inside the call from func1 to func2. For more information about how to interpret callgraph visualizations, see the pprof documentation.

Functions compiled with jax.jit() are opaque to the device memory profiler. That is, any memory allocated inside a jit-compiled function will be attributed to the function as a whole.

In the example, the call to block_until_ready() is to ensure that func2 completes before the device memory profile is collected. See Asynchronous dispatch for more details.

Debugging memory leaks#

We can also use the JAX device memory profiler to track down memory leaks by using pprof to visualize the change in memory usage between two device memory profiles taken at different times. For example, consider the following program which accumulates JAX arrays into a constantly-growing Python list.

import jax
import jax.numpy as jnp
import jax.profiler

def afunction():
  return jax.random.normal(jax.random.key(77), (1000000,))

z = afunction()

def anotherfunc():
  arrays = []
  for i in range(1, 10):
    x = jax.random.normal(jax.random.key(42), (i, 10000))
    arrays.append(x)
    x.block_until_ready()
    jax.profiler.save_device_memory_profile(f"memory{i}.prof")

anotherfunc()

If we simply visualize the device memory profile at the end of execution (memory9.prof), it may not be obvious that each iteration of the loop in anotherfunc accumulates more device memory allocations:

pprof --web memory9.prof

Device memory profile at end of execution

The large but fixed allocation inside afunction dominates the profile but does not grow over time.

By using pprof’s --diff_base feature to visualize the change in memory usage across loop iterations, we can identify why the memory usage of the program increases over time:

pprof --web --diff_base memory1.prof memory9.prof

Device memory profile at end of execution

The visualization shows that the memory growth can be attributed to the call to normal inside anotherfunc.

Runtime value debugging in JAX#

Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has TL;DR summaries and you can click the “Read more” links at the bottom to learn more.

Table of contents:

Interactive inspection with jax.debug#

TL;DR Use jax.debug.print() to print values to stdout in jax.jit-,jax.pmap-, and pjit-decorated functions, and jax.debug.breakpoint() to pause execution of your compiled function to inspect values in the call stack:

import jax
import jax.numpy as jnp

@jax.jit
def f(x):
  jax.debug.print("🤯 {x} 🤯", x=x)
  y = jnp.sin(x)
  jax.debug.breakpoint()
  jax.debug.print("🤯 {y} 🤯", y=y)
  return y

f(2.)
# Prints:
# 🤯 2.0 🤯
# Enters breakpoint to inspect values!
# 🤯 0.9092974662780762 🤯

Click here to learn more!

Functional error checks with jax.experimental.checkify#

TL;DR Checkify lets you add jit-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the checkify.checkify transformation together with the assert-like checkify.check function to add runtime checks to JAX code:

from jax.experimental import checkify
import jax
import jax.numpy as jnp

def f(x, i):
  checkify.check(i >= 0, "index needs to be non-negative!")
  y = x[i]
  z = jnp.sin(y)
  return z

jittable_f = checkify.checkify(f)

err, z = jax.jit(jittable_f)(jnp.ones((5,)), -1)
print(err.get())
# >> index needs to be non-negative! (check failed at <...>:6 (f))

You can also use checkify to automatically add common checks:

errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)

err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)

err, z = checked_f(jnp.ones((5,)), -1)
err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))

err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)

Click here to learn more!

Throwing Python errors with JAX’s debug flags#

TL;DR Enable the jax_debug_nans flag to automatically detect when NaNs are produced in jax.jit-compiled code (but not in jax.pmap or jax.pjit-compiled code) and enable the jax_disable_jit flag to disable JIT-compilation, enabling use of traditional Python debugging tools like print and pdb.

import jax
jax.config.update("jax_debug_nans", True)

def f(x, y):
  return x / y
jax.jit(f)(0., 0.)  # ==> raises FloatingPointError exception!

Click here to learn more!

jax.debug.print and jax.debug.breakpoint#

The jax.debug package offers some useful tools for inspecting values inside of JIT-ted functions.

Debugging with jax.debug.print and other debugging callbacks#

TL;DR Use jax.debug.print() to print traced array values to stdout in jit- and pmap-decorated functions:

import jax
import jax.numpy as jnp

@jax.jit
def f(x):
  jax.debug.print("🤯 {x} 🤯", x=x)
  y = jnp.sin(x)
  jax.debug.print("🤯 {y} 🤯", y=y)
  return y

f(2.)
# Prints:
# 🤯 2.0 🤯
# 🤯 0.9092974662780762 🤯

With some transformations, like jax.grad and jax.vmap, you can use Python’s builtin print function to print out numerical values. But print won’t work with jax.jit or jax.pmap because those transformations delay numerical evaluation. So use jax.debug.print instead!

Semantically, jax.debug.print is roughly equivalent to the following Python function

def debug.print(fmt: str, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
  print(fmt.format(*args, **kwargs))

except that it can be staged out and transformed by JAX. See the API reference for more details.

Note that fmt cannot be an f-string because f-strings are formatted immediately, whereas for jax.debug.print, we’d like to delay formatting until later.

When to use “debug” print?#

You should use jax.debug.print for dynamic (i.e. traced) array values within JAX transformations like jit, vmap, and others. For printing of static values (like array shapes or dtypes), you can use a normal Python print statement.

Why “debug” print?#

In the name of debugging, jax.debug.print can reveal information about how computations are evaluated:

xs = jnp.arange(3.)

def f(x):
  jax.debug.print("x: {}", x)
  y = jnp.sin(x)
  jax.debug.print("y: {}", y)
  return y
jax.vmap(f)(xs)
# Prints: x: 0.0
#         x: 1.0
#         x: 2.0
#         y: 0.0
#         y: 0.841471
#         y: 0.9092974
jax.lax.map(f, xs)
# Prints: x: 0.0
#         y: 0.0
#         x: 1.0
#         y: 0.841471
#         x: 2.0
#         y: 0.9092974

Notice that the printed results are in different orders!

By revealing these inner-workings, the output of jax.debug.print doesn’t respect JAX’s usual semantics guarantees, like that jax.vmap(f)(xs) and jax.lax.map(f, xs) compute the same thing (in different ways). Yet these evaluation order details are exactly what we might want to see when debugging!

So use jax.debug.print for debugging, and not when semantics guarantees are important.

More examples of jax.debug.print#

In addition to the above examples using jit and vmap, here are a few more to have in mind.

Printing under jax.pmap#

When jax.pmap-ed, jax.debug.prints might be reordered!

xs = jnp.arange(2.)

def f(x):
  jax.debug.print("x: {}", x)
  return x
jax.pmap(f)(xs)
# Prints: x: 1.0
#         x: 0.0
# OR
# Prints: x: 1.0
#         x: 0.0
Printing under jax.grad#

Under a jax.grad, jax.debug.prints will only print on the forward pass:

def f(x):
  jax.debug.print("x: {}", x)
  return x * 2.

jax.grad(f)(1.)
# Prints: x: 1.0

This behavior is similar to how Python’s builtin print works under a jax.grad. But by using jax.debug.print here, the behavior is the same even if the caller applies a jax.jit.

To print on the backward pass, just use a jax.custom_vjp:

@jax.custom_vjp
def print_grad(x):
  return x

def print_grad_fwd(x):
  return x, None

def print_grad_bwd(_, x_grad):
  jax.debug.print("x_grad: {}", x_grad)
  return (x_grad,)

print_grad.defvjp(print_grad_fwd, print_grad_bwd)


def f(x):
  x = print_grad(x)
  return x * 2.
jax.grad(f)(1.)
# Prints: x_grad: 2.0
Printing in other transformations#

jax.debug.print also works in other transformations like xmap and pjit.

More control with jax.debug.callback#

In fact, jax.debug.print is a thin convenience wrapper around jax.debug.callback, which can be used directly for greater control over string formatting, or even the kind of output.

Semantically, jax.debug.callback is roughly equivalent to the following Python function

def callback(fun: Callable, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
  fun(*args, **kwargs)
  return None

As with jax.debug.print, these callbacks should only be used for debugging output, like printing or plotting. Printing and plotting are pretty harmless, but if you use it for anything else its behavior might surprise you under transformations. For example, it’s not safe to use jax.debug.callback for timing operations, since callbacks might be reordered and asynchronous (see below).

Sharp bits#

Like most JAX APIs, jax.debug.print can cut you if you’re not careful.

Ordering of printed results#

When distinct calls to jax.debug.print involve arguments which don’t depend on one another, they might be reordered when staged out, e.g. by jax.jit:

@jax.jit
def f(x, y):
  jax.debug.print("x: {}", x)
  jax.debug.print("y: {}", y)
  return x + y

f(2., 3.)
# Prints: x: 2.0
#         y: 3.0
# OR
# Prints: y: 3.0
#         x: 2.0

Why? Under the hood, the compiler gets a functional representation of the staged-out computation, where the imperative order of the Python function is lost and only data dependence remains. This change is invisible to users with functionally pure code, but in the presence of side-effects like printing, it’s noticeable.

To preserve the original order of jax.debug.prints as written in your Python function, you can use jax.debug.print(..., ordered=True), which will ensure the relative order of prints is preserved. But using ordered=True will raise an error under jax.pmap and other JAX transformations involving parallelism, since ordering can’t be guaranteed under parallel execution.

Asynchronous callbacks#

Depending on the backend, jax.debug.prints may happen asynchronously, i.e. not in your main program thread. This means that values could be printed to your screen even after your JAX function has returned a value.

@jax.jit
def f(x):
  jax.debug.print("x: {}", x)
  return x
f(2.).block_until_ready()
# <do something else>
# Prints: x: 2.

To block on the jax.debug.prints in a function, you can call jax.effects_barrier(), which will wait until any remaining side-effects in the function have completed as well:

@jax.jit
def f(x):
  jax.debug.print("x: {}", x)
  return x
f(2.).block_until_ready()
jax.effects_barrier()
# Prints: x: 2.
# <do something else>
Performance impacts#
Unnecessary materialization#

While jax.debug.print was designed to have a minimal performance footprint, it can interfere with compiler optimizations and potentially affect the memory profile of your JAX programs.

def f(w, b, x):
  logits = w.dot(x) + b
  jax.debug.print("logits: {}", logits)
  return jax.nn.relu(logits)

In this example, we are printing intermediate values in between a linear layer and the activation function. Compilers like XLA can perform fusion optimizations, which might avoid materializing logits in memory. But when we use jax.debug.print on logits, we are forcing those intermediates to be materialized, potentially slowing down the program and increasing memory usage.

Furthermore, when using jax.debug.print with jax.pjit, a global synchronization occurs that will materialize values on a single device.

Callback overhead#

jax.debug.print inherently incurs communication between an accelerator and its host. The underlying mechanism differs from backend to backend (e.g. GPU vs TPU) but in all cases, we’ll need to copy the printed values from device to host. In the CPU case, this overhead is smaller.

Furthermore, when using jax.debug.print with jax.pjit, a global synchronization occurs that adds some overhead.

Strengths and limitations of jax.debug.print#
Strengths#
  • Print debugging is simple and intuitive

  • jax.debug.callback can be used for other innocuous side-effects

Limitations#
  • Adding print statements is a manual process

  • Can have performance impacts

Interactive inspection with jax.debug.breakpoint()#

TL;DR Use jax.debug.breakpoint() to pause the execution of your JAX program to inspect values:

@jax.jit
def f(x):
  y, z = jnp.sin(x), jnp.cos(x)
  jax.debug.breakpoint()
  return y * z
f(2.) # ==> Pauses during execution!

JAX debugger

jax.debug.breakpoint() is actually just an application of jax.debug.callback(...) that captures information about the call stack. It has the same transformation behaviors as jax.debug.print as a result (e.g. vmap-ing jax.debug.breakpoint() unrolls it across the mapped axis).

Usage#

Calling jax.debug.breakpoint() in a compiled JAX function will pause your program when it hits the breakpoint. You’ll be presented with a pdb-like prompt that allows you to inspect the values in the call stack. Unlike pdb, you will not be able to step through the execution, but you are allowed to resume it.

Debugger commands:

  • help - prints out available commands

  • p - evaluates an expression and prints its result

  • pp - evaluates an expression and pretty-prints its result

  • u(p) - go up a stack frame

  • d(own) - go down a stack frame

  • w(here)/bt - print out a backtrace

  • l(ist) - print out code context

  • c(ont(inue)) - resumes the execution of the program

  • q(uit)/exit - exits the program (does not work on TPU)

Examples#
Usage with jax.lax.cond#

When combined with jax.lax.cond, the debugger can become a useful tool for detecting nans or infs.

def breakpoint_if_nonfinite(x):
  is_finite = jnp.isfinite(x).all()
  def true_fn(x):
    pass
  def false_fn(x):
    jax.debug.breakpoint()
  lax.cond(is_finite, true_fn, false_fn, x)

@jax.jit
def f(x, y):
  z = x / y
  breakpoint_if_nonfinite(z)
  return z
f(2., 0.) # ==> Pauses during execution!
Sharp bits#

Because jax.debug.breakpoint is a just an application of jax.debug.callback, it has the same sharp bits as jax.debug.print, with a few more caveats:

  • jax.debug.breakpoint materializes even more intermediates than jax.debug.print because it forces materialization of all values in the call stack

  • jax.debug.breakpoint has more runtime overhead than a jax.debug.print because it has to potentially copy all the intermediate values in a JAX program from device to host.

Strengths and limitations of jax.debug.breakpoint()#
Strengths#
  • Simple, intuitive and (somewhat) standard

  • Can inspect many values at the same time, up and down the call stack

Limitations#
  • Need to potentially use many breakpoints to pinpoint the source of an error

  • Materializes many intermediates

The checkify transformation#

TL;DR Checkify lets you add jit-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the checkify.checkify transformation together with the assert-like checkify.check function to add runtime checks to JAX code:

from jax.experimental import checkify
import jax
import jax.numpy as jnp

def f(x, i):
  checkify.check(i >= 0, "index needs to be non-negative, got {i}", i=i)
  y = x[i]
  z = jnp.sin(y)
  return z

jittable_f = checkify.checkify(f)

err, z = jax.jit(jittable_f)(jnp.ones((5,)), -2)
print(err.get())
# >> index needs to be non-negative, got -2! (check failed at <...>:6 (f))

You can also use checkify to automatically add common checks:

errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)

err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)

err, z = checked_f(jnp.ones((5,)), -1)
err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))

err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)

err, z = checked_f(jnp.array([5, 1]), 0)
err.throw()  # if no error occurred, throw does nothing!
Functionalizing checks#

The assert-like check API by itself is not functionally pure: it can raise a Python Exception as a side-effect, just like assert. So it can’t be staged out with jit, pmap, pjit, or scan:

jax.jit(f)(jnp.ones((5,)), -1)  # checkify transformation not used
# ValueError: Cannot abstractly evaluate a checkify.check which was not functionalized.

But the checkify transformation functionalizes (or discharges) these effects. A checkify-transformed function returns an error value as a new output and remains functionally pure. That functionalization means checkify-transformed functions can be composed with staging/transforms however we like:

err, z = jax.pmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
err.throw()
"""
ValueError:
..  at mapped index 0: index needs to be non-negative! (check failed at :6 (f))
..  at mapped index 2: out-of-bounds indexing at <..>:7 (f)
"""
Why does JAX need checkify?#

Under some JAX transformations you can express runtime error checks with ordinary Python assertions, for example when only using jax.grad and jax.numpy:

def f(x):
  assert x > 0., "must be positive!"
  return jnp.log(x)

jax.grad(f)(0.)
# ValueError: "must be positive!"

But ordinary assertions don’t work inside jit, pmap, pjit, or scan. In those cases, numeric computations are staged out rather than evaluated eagerly during Python execution, and as a result numeric values aren’t available:

jax.jit(f)(0.)
# ConcretizationTypeError: "Abstract tracer value encountered ..."

JAX transformation semantics rely on functional purity, especially when composing multiple transformations, so how can we provide an error mechanism without disrupting all that? Beyond needing a new API, the situation is trickier still: XLA HLO doesn’t support assertions or throwing errors, so even if we had a JAX API which was able to stage out assertions, how would we lower these assertions to XLA?

You could imagine manually adding run-time checks to your function and plumbing out values representing errors:

def f_checked(x):
  error = x <= 0.
  result = jnp.log(x)
  return error, result

err, y = jax.jit(f_checked)(0.)
if err:
  raise ValueError("must be positive!")
# ValueError: "must be positive!"

The error is a regular value computed by the function, and the error is raised outside of f_checked. f_checked is functionally pure, so we know by construction that it’ll already work with jit, pmap, pjit, scan, and all of JAX’s transformations. The only problem is that this plumbing can be a pain!

checkify does this rewrite for you: that includes plumbing the error value through the function, rewriting checks to boolean operations and merging the result with the tracked error value, and returning the final error value as an output to the checkified function:

def f(x):
  checkify.check(x > 0., "{} must be positive!", x)  # convenient but effectful API
  return jnp.log(x)

f_checked = checkify(f)

err, x = jax.jit(f_checked)(-1.)
err.throw()
# ValueError: -1. must be positive! (check failed at <...>:2 (f))

We call this functionalizing or discharging the effect introduced by calling check. (In the “manual” example above the error value is just a boolean. checkify’s error values are conceptually similar but also track error messages and expose throw and get methods; see jax.experimental.checkify). checkify.check also allows you to add run-time values to your error message by providing them as format arguments to the error message.

You could now manually instrument your code with run-time checks, but checkify can also automatically add checks for common errors! Consider these error cases:

jnp.arange(3)[5]                # out of bounds
jnp.sin(jnp.inf)                # NaN generated
jnp.ones((5,)) / jnp.arange(5)  # division by zero

By default checkify only discharges checkify.checks, and won’t do anything to catch errors like the above. But if you ask it to, checkify will also instrument your code with checks automatically.

def f(x, i):
  y = x[i]        # i could be out of bounds.
  z = jnp.sin(y)  # z could become NaN
  return z

errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)

err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)

err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)

The API for selecting which automatic checks to enable is based on Sets. See jax.experimental.checkify for more details.

checkify under JAX transformations.#

As demonstrated in the examples above, a checkified function can be happily jitted. Here’s a few more examples of checkify with other JAX transformations. Note that checkified functions are functionally pure, and should trivially compose with all JAX transformations!

jit#

You can safely add jax.jit to a checkified function, or checkify a jitted function, both will work.

def f(x, i):
  return x[i]

checkify_of_jit = checkify.checkify(jax.jit(f))
jit_of_checkify = jax.jit(checkify.checkify(f))
err, _ =  checkify_of_jit(jnp.ones((5,)), 100)
err.get()
# out-of-bounds indexing at <..>:2 (f)
err, _ = jit_of_checkify(jnp.ones((5,)), 100)
# out-of-bounds indexing at <..>:2 (f)
vmap/pmap#

You can vmap and pmap checkified functions (or checkify mapped functions). Mapping a checkified function will give you a mapped error, which can contain different errors for every element of the mapped dimension.

def f(x, i):
  checkify.check(i >= 0, "index needs to be non-negative!")
  return x[i]

checked_f = checkify.checkify(f, errors=checkify.all_checks)
errs, out = jax.vmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
errs.throw()
"""
ValueError:
  at mapped index 0: index needs to be non-negative! (check failed at <...>:2 (f))
  at mapped index 2: out-of-bounds indexing at <...>:3 (f)
"""

However, a checkify-of-vmap will produce a single (unmapped) error!

@jax.vmap
def f(x, i):
  checkify.check(i >= 0, "index needs to be non-negative!")
  return x[i]

checked_f = checkify.checkify(f, errors=checkify.all_checks)
err, out = checked_f(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
err.throw()
# ValueError: index needs to be non-negative! (check failed at <...>:2 (f))
pjit#

pjit of a checkified function just works, you only need to specify an additional out_axis_resources of None for the error value output.

def f(x):
  return x / x

f = checkify.checkify(f, errors=checkify.float_checks)
f = pjit(
  f,
  in_shardings=PartitionSpec('x', None),
  out_shardings=(None, PartitionSpec('x', None)))

with jax.sharding.Mesh(mesh.devices, mesh.axis_names):
 err, data = f(input_data)
err.throw()
# ValueError: divided by zero at <...>:4 (f)
grad#

Your gradient computation will also be instrumented if you checkify-of-grad:

def f(x):
 return x / (1 + jnp.sqrt(x))

grad_f = jax.grad(f)

err, _ = checkify.checkify(grad_f, errors=checkify.nan_checks)(0.)
print(err.get())
>> nan generated by primitive mul at <...>:3 (f)

Note that there’s no multiply in f, but there is a multiply in its gradient computation (and this is where the NaN is generated!). So use checkify-of-grad to add automatic checks to both forward and backward pass operations.

checkify.checks will only be applied to the primal value of your function. If you want to use a check on a gradient value, use a custom_vjp:

@jax.custom_vjp
def assert_gradient_negative(x):
 return x

def fwd(x):
 return assert_gradient_negative(x), None

def bwd(_, grad):
 checkify.check(grad < 0, "gradient needs to be negative!")
 return (grad,)

assert_gradient_negative.defvjp(fwd, bwd)

jax.grad(assert_gradient_negative)(-1.)
# ValueError: gradient needs to be negative!
Strengths and limitations of jax.experimental.checkify#
Strengths#
  • You can use it everywhere (errors are “just values” and behave intuitively under transformations like other values)

  • Automatic instrumentation: you don’t need to make local modifications to your code. Instead, checkify can instrument all of it!

Limitations#
  • Adding a lot of runtime checks can be expensive (eg. adding a NaN check to every primitive will add a lot of operations to your computation)

  • Requires threading error values out of functions and manually throwing the error. If the error is not explicitly thrown, you might miss out on errors!

  • Throwing an error value will materialize that error value on the host, meaning it’s a blocking operation which defeats JAX’s async run-ahead.

JAX debugging flags#

JAX offers flags and context managers that enable catching errors more easily.

jax_debug_nans configuration option and context manager#

TL;DR Enable the jax_debug_nans flag to automatically detect when NaNs are produced in jax.jit-compiled code (but not in jax.pmap or jax.pjit-compiled code).

jax_debug_nans is a JAX flag that when enabled, automatically raises an error when a NaN is detected. It has special handling for JIT-compiled – when a NaN output is detected from a JIT-ted function, the function is re-run eagerly (i.e. without compilation) and will throw an error at the specific primitive that produced the NaN.

Usage#

If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by:

  • setting the JAX_DEBUG_NANS=True environment variable;

  • adding jax.config.update("jax_debug_nans", True) near the top of your main file;

  • adding jax.config.parse_flags_with_absl() to your main file, then set the option using a command-line flag like --jax_debug_nans=True;

Example(s)#
import jax
jax.config.update("jax_debug_nans", True)

def f(x, y):
  return x / y
jax.jit(f)(0., 0.)  # ==> raises FloatingPointError exception!
Strengths and limitations of jax_debug_nans#
Strengths#
  • Easy to apply

  • Precisely detects where NaNs were produced

  • Throws a standard Python exception and is compatible with PDB postmortem

Limitations#
  • Not compatible with jax.pmap or jax.pjit

  • Re-running functions eagerly can be slow

  • Errors on false positives (e.g. intentionally created NaNs)

jax_disable_jit configuration option and context manager#

TL;DR Enable the jax_disable_jit flag to disable JIT-compilation, enabling use of traditional Python debugging tools like print and pdb

jax_disable_jit is a JAX flag that when enabled, disables JIT-compilation throughout JAX (including in control flow functions like jax.lax.cond and jax.lax.scan).

Usage#

You can disable JIT-compilation by:

  • setting the JAX_DISABLE_JIT=True environment variable;

  • adding jax.config.update("jax_disable_jit", True) near the top of your main file;

  • adding jax.config.parse_flags_with_absl() to your main file, then set the option using a command-line flag like --jax_disable_jit=True;

Examples#
import jax
jax.config.update("jax_disable_jit", True)

def f(x):
  y = jnp.log(x)
  if jnp.isnan(y):
    breakpoint()
  return y
jax.jit(f)(-2.)  # ==> Enters PDB breakpoint!
Strengths and limitations of jax_disable_jit#
Strengths#
  • Easy to apply

  • Enables use of Python’s built-in breakpoint and print

  • Throws standard Python exceptions and is compatible with PDB postmortem

Limitations#
  • Not compatible with jax.pmap or jax.pjit

  • Running functions without JIT-compilation can be slow

GPU performance tips#

This document focuses on performance tips for neural network workloads

Matmul precision#

On recent GPU generations, such as the Nvidia A100 generation or later, it can be a good idea to perform most computations in bfloat16 precision. For example, if using Flax, instantiate Dense layers using flax.linen.Dense(..., dtype=jax.numpy.bfloat16). Here are some code examples:

XLA performance flags#

The existence and exact behavior of XLA flags may be jaxlib-version dependent.

As of jaxlib==0.4.18 (released Oct 6 2023), setting these XLA flags can improve performance. Some are related to communication between GPUs, and so are only relevant when running computations on multiple devices, while others are related to code generation on each device.

Some of these may be set by default in future releases.

These flags can be set via the XLA_FLAGS shell environment variable. For example, we can add this to the top of a Python file:

import os
os.environ['XLA_FLAGS'] = (
    '--xla_gpu_enable_triton_softmax_fusion=true '
    '--xla_gpu_triton_gemm_any=True '
    '--xla_gpu_enable_async_collectives=true '
    '--xla_gpu_enable_latency_hiding_scheduler=true '
    '--xla_gpu_enable_highest_priority_async_stream=true '
)

For more examples, see also XLA Flags recommended for Pax training on Nvidia GPUs.

Code generation flags#
  • –xla_gpu_enable_triton_softmax_fusion This flag enables an automatic softmax fusion, based on pattern-matching backed by Triton code generation. The default value is False.

  • –xla_gpu_triton_gemm_any Use the Triton-based GEMM (matmul) emitter for any GEMM that it supports. The default value is False.

Communication flags#
  • –xla_gpu_enable_async_collectives This flag enables the collective ops such as AllReduce, AllGather, ReduceScatter and CollectivePermute to be asynchronous. Asynchronous communication can overlap cross-core communication with computation. The default value is False.

  • –xla_gpu_enable_latency_hiding_scheduler This flag enables latency hiding schedulers to overlap asynchronous communication with computation efficiently. The default value is False.

  • –xla_gpu_enable_pipelined_collectives When using pipeline parallelism, this flag enables overlapping the (i+1)-th layer weight AllGather with the i-th layer computation. It also enables overlapping (i+1)-th layer weight Reduce/ReduceScatter with i-th layer’s computation. The default value is False. There are some bugs when this flag is turned on.

  • –xla_gpu_collective_permute_decomposer_threshold This flag is useful when performing GSPMD pipelining. Setting a nonzero threshold decomposes CollectivePermutes into CollectivePermuteReceiveDone and CollectivePermuteSendDone pairs, so that computation can be performed between each corresponding ReceiveDone/SendDone pair and hence achieve more overlap. By default the threshold is 0 and there is no decomposition. Setting it to threshold > 0 such as --xla_gpu_collective_permute_decomposer_threshold=1024 can enable this feature.

  • –xla_gpu_all_gather_combine_threshold_bytes –xla_gpu_reduce_scatter_combine_threshold_bytes –xla_gpu_all_reduce_combine_threshold_bytes These flags tune when to combine multiple small AllGather/ReduceScatter/AllReduce into one big AllGather/ReduceScatter/AllReduce to reduce time spent on cross-device communication. For example, for the AllGather/ReduceScatter thresholds on a Transformer-based workload, consider tuning them high enough so as to combine at least a Transformer Layer’s weight AllGather/ReduceScatter. By default, the combine_threshold_bytes is set to 256.

NCCL flags#

These Nvidia NCCL flag values may be useful for single-host multi-device computations on Nvidia GPUs:

os.environ.update({
  "NCCL_LL128_BUFFSIZE": "-2",
  "NCCL_LL_BUFFSIZE": "-2",
   "NCCL_PROTO": "SIMPLE,LL,LL128",
 })

These NCCL flags could improve single-host communication speed. These flags don’t seem useful for multi-host communication yet.

Persistent Compilation Cache#

JAX has an optional disk cache for compiled programs. If enabled, JAX will store copies of compiled programs on disk, which can save recompilation time when running the same or similar tasks repeatedly.

Usage#

The compilation cache is enabled when the cache-location is set. This should be done prior to the first compilation. Set the location as follows:

import jax

# Make sure this is called before jax runs any operations!
jax.config.update("jax_compilation_cache_dir", "cache-location")

See the sections below for more detail on cache-location.

set_cache_dir() is an alternate way of setting cache-location.

Local filesystem#

cache-location can be a directory on the local filesystem. For example:

import jax

jax.config.update("jax_compilation_cache_dir", "/tmp/jax-cache")

Note: the cache does not have an eviction mechanism implemented. If the cache-location is a directory in the local filesystem, its size will continue to grow unless files are manually deleted.

Google Cloud#

When running on Google Cloud, the compilation cache can be placed on a Google Cloud Storage (GCS) bucket. We recommend the following configuration:

  • Create the bucket in the same region as where the workload will run.

  • Create the bucket in the same project as the workload’s VM(s). Ensure that permissions are set so that the VM(s) can write to the bucket.

  • There is no need for replication for smaller workloads. Larger workloads could benefit from replication.

  • Use “Standard” for the default storage class for the bucket.

  • Set the soft delete policy to its shortest: 7 days.

  • Set the object lifecycle to the expected duration of the workload run. For example, if the workload is expected to run for 10 days, set the object lifecycle to 10 days. That should cover restarts that occur during the entire run. Use age for the lifecycle condition and Delete for the action. See Object Lifecycle Management for details. If the object lifecycle is not set, the cache will continue to grow since there is no eviction mechanism implemented.

  • All encryption policies are supported.

Assuming that gs://jax-cache is the GCS bucket, set cache-location as follows:

import jax

jax.config.update("jax_compilation_cache_dir", "gs://jax-cache")

Understanding Jaxprs#

Updated: May 3, 2020 (for commit f1a46fe).

Conceptually, one can think of JAX transformations as first trace-specializing the Python function to be transformed into a small and well-behaved intermediate form that is then interpreted with transformation-specific interpretation rules. One of the reasons JAX can pack so much power into such a small software package is that it starts with a familiar and flexible programming interface (Python with NumPy) and it uses the actual Python interpreter to do most of the heavy lifting to distill the essence of the computation into a simple statically-typed expression language with limited higher-order features. That language is the jaxpr language.

Not all Python programs can be processed this way, but it turns out that many scientific computing and machine learning programs can.

Before we proceed, it is important to point out that not all JAX transformations literally materialize a jaxpr as described above; some, e.g., differentiation or batching, will apply transformations incrementally during tracing. Nevertheless, if one wants to understand how JAX works internally, or to make use of the result of JAX tracing, it is useful to understand jaxprs.

A jaxpr instance represents a function with one or more typed parameters (input variables) and one or more typed results. The results depend only on the input variables; there are no free variables captured from enclosing scopes. The inputs and outputs have types, which in JAX are represented as abstract values. There are two related representations in the code for jaxprs, jax.core.Jaxpr and jax.core.ClosedJaxpr. A jax.core.ClosedJaxpr represents a partially-applied jax.core.Jaxpr, and is what you obtain when you use jax.make_jaxpr() to inspect jaxprs. It has the following fields:

  • jaxpr is a jax.core.Jaxpr representing the actual computation content of the function (described below).

  • consts is a list of constants.

The most interesting part of the ClosedJaxpr is the actual execution content, represented as a jax.core.Jaxpr as printed using the following grammar:

Jaxpr ::= { lambda Var* ; Var+. let
              Eqn*
            in  [Expr+] }
where:
  • The parameters of the jaxpr are shown as two lists of variables separated by ;. The first set of variables are the ones that have been introduced to stand for constants that have been hoisted out. These are called the constvars, and in a jax.core.ClosedJaxpr the consts field holds corresponding values. The second list of variables, called invars, correspond to the inputs of the traced Python function.

  • Eqn* is a list of equations, defining intermediate variables referring to intermediate expressions. Each equation defines one or more variables as the result of applying a primitive on some atomic expressions. Each equation uses only input variables and intermediate variables defined by previous equations.

  • Expr+: is a list of output atomic expressions (literals or variables) for the jaxpr.

Equations are printed as follows:

Eqn  ::= Var+ = Primitive [ Param* ] Expr+
where:
  • Var+ are one or more intermediate variables to be defined as the output of a primitive invocation (some primitives can return multiple values).

  • Expr+ are one or more atomic expressions, each either a variable or a literal constant. A special variable unitvar or literal unit, printed as *, represents a value that is not needed in the rest of the computation and has been elided. That is, units are just placeholders.

  • Param* are zero or more named parameters to the primitive, printed in square brackets. Each parameter is shown as Name = Value.

Most jaxpr primitives are first-order (they take just one or more Expr as arguments):

Primitive := add | sub | sin | mul | ...

The jaxpr primitives are documented in the jax.lax module.

For example, here is the jaxpr produced for the function func1 below

>>> from jax import make_jaxpr
>>> import jax.numpy as jnp
>>> def func1(first, second):
...    temp = first + jnp.sin(second) * 3.
...    return jnp.sum(temp)
...
>>> print(make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8)))
{ lambda ; a:f32[8] b:f32[8]. let
    c:f32[8] = sin b
    d:f32[8] = mul c 3.0
    e:f32[8] = add a d
    f:f32[] = reduce_sum[axes=(0,)] e
  in (f,) }

Here there are no constvars, a and b are the input variables and they correspond respectively to first and second function parameters. The scalar literal 3.0 is kept inline. The reduce_sum primitive has named parameter axes, in addition to the operand e.

Note that even though execution of a program that calls into JAX builds a jaxpr, Python-level control-flow and Python-level functions execute normally. This means that just because a Python program contains functions and control-flow, the resulting jaxpr does not have to contain control-flow or higher-order features.

For example, when tracing the function func3 JAX will inline the call to inner and the conditional if second.shape[0] > 4, and will produce the same jaxpr as before

>>> def func2(inner, first, second):
...   temp = first + inner(second) * 3.
...   return jnp.sum(temp)
...
>>> def inner(second):
...   if second.shape[0] > 4:
...     return jnp.sin(second)
...   else:
...     assert False
...
>>> def func3(first, second):
...   return func2(inner, first, second)
...
>>> print(make_jaxpr(func3)(jnp.zeros(8), jnp.ones(8)))
{ lambda ; a:f32[8] b:f32[8]. let
    c:f32[8] = sin b
    d:f32[8] = mul c 3.0
    e:f32[8] = add a d
    f:f32[] = reduce_sum[axes=(0,)] e
  in (f,) }

Handling PyTrees#

In jaxpr there are no tuple types; instead primitives take multiple inputs and produce multiple outputs. When processing a function that has structured inputs or outputs, JAX will flatten those and in jaxpr they will appear as lists of inputs and outputs. For more details, please see the documentation for PyTrees (Pytrees).

For example, the following code produces an identical jaxpr to what we saw before (with two input vars, one for each element of the input tuple)

>>> def func4(arg):  # Arg is a pair
...   temp = arg[0] + jnp.sin(arg[1]) * 3.
...   return jnp.sum(temp)
...
>>> print(make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8))))
{ lambda ; a:f32[8] b:f32[8]. let
    c:f32[8] = sin b
    d:f32[8] = mul c 3.0
    e:f32[8] = add a d
    f:f32[] = reduce_sum[axes=(0,)] e
  in (f,) }

Constant Vars#

Some values in jaxprs are constants, in that their value does not depend on the jaxpr’s arguments. When these values are scalars they are represented directly in the jaxpr equations; non-scalar array constants are instead hoisted out to the top-level jaxpr, where they correspond to constant variables (“constvars”). These constvars differ from the other jaxpr parameters (“invars”) only as a bookkeeping convention.

Higher-order primitives#

jaxpr includes several higher-order primitives. They are more complicated because they include sub-jaxprs.

Conditionals#

JAX traces through normal Python conditionals. To capture a conditional expression for dynamic execution, one must use the jax.lax.switch() and jax.lax.cond() constructors, which have the signatures:

lax.switch(index: int, branches: Sequence[A -> B], operand: A) -> B

lax.cond(pred: bool, true_body: A -> B, false_body: A -> B, operand: A) -> B

Both of these will bind a primitive called cond internally. The cond primitive in jaxprs reflects the more general signature of lax.switch(): it takes an integer denoting the index of the branch to execute (clamped into valid indexing range).

For example:

>>> from jax import lax
>>>
>>> def one_of_three(index, arg):
...   return lax.switch(index, [lambda x: x + 1.,
...                             lambda x: x - 2.,
...                             lambda x: x + 3.],
...                     arg)
...
>>> print(make_jaxpr(one_of_three)(1, 5.))
{ lambda ; a:i32[] b:f32[]. let
    c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
    d:i32[] = clamp 0 c 2
    e:f32[] = cond[
      branches=(
        { lambda ; f:f32[]. let g:f32[] = add f 1.0 in (g,) }
        { lambda ; h:f32[]. let i:f32[] = sub h 2.0 in (i,) }
        { lambda ; j:f32[]. let k:f32[] = add j 3.0 in (k,) }
      )
      linear=(False,)
    ] d b
  in (e,) }

The cond primitive has a number of parameters:

  • branches are jaxprs that correspond to the branch functionals. In this example, those functionals each take one input variable, corresponding to x.

  • linear is a tuple of booleans that is used internally by the auto-differentiation machinery to encode which of the input parameters are used linearly in the conditional.

The above instance of the cond primitive takes two operands. The first one (d) is the branch index, then b is the operand (arg) to be passed to whichever jaxpr in branches is selected by the branch index.

Another example, using lax.cond():

>>> from jax import lax
>>>
>>> def func7(arg):
...   return lax.cond(arg >= 0.,
...                   lambda xtrue: xtrue + 3.,
...                   lambda xfalse: xfalse - 3.,
...                   arg)
...
>>> print(make_jaxpr(func7)(5.))
{ lambda ; a:f32[]. let
    b:bool[] = ge a 0.0
    c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    d:f32[] = cond[
      branches=(
        { lambda ; e:f32[]. let f:f32[] = sub e 3.0 in (f,) }
        { lambda ; g:f32[]. let h:f32[] = add g 3.0 in (h,) }
      )
      linear=(False,)
    ] c a
  in (d,) }

In this case, the boolean predicate is converted to an integer index (0 or 1), and branches are jaxprs that correspond to the false and true branch functionals, in that order. Again, each functional takes one input variable, corresponding to xfalse and xtrue respectively.

The following example shows a more complicated situation when the input to the branch functionals is a tuple, and the false branch functional contains a constant jnp.ones(1) that is hoisted as a constvar

>>> def func8(arg1, arg2):  # arg2 is a pair
...   return lax.cond(arg1 >= 0.,
...                   lambda xtrue: xtrue[0],
...                   lambda xfalse: jnp.array([1]) + xfalse[1],
...                   arg2)
...
>>> print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.)))
{ lambda a:i32[1]; b:f32[] c:f32[1] d:f32[]. let
    e:bool[] = ge b 0.0
    f:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
    g:f32[1] = cond[
      branches=(
        { lambda ; h:i32[1] i:f32[1] j:f32[]. let
            k:f32[1] = convert_element_type[new_dtype=float32 weak_type=True] h
            l:f32[1] = add k j
          in (l,) }
        { lambda ; m_:i32[1] n:f32[1] o:f32[]. let  in (n,) }
      )
      linear=(False, False, False)
    ] f a c d
  in (g,) }
While#

Just like for conditionals, Python loops are inlined during tracing. If you want to capture a loop for dynamic execution, you must use one of several special operations, jax.lax.while_loop() (a primitive) and jax.lax.fori_loop() (a helper that generates a while_loop primitive):

lax.while_loop(cond_fun: (C -> bool), body_fun: (C -> C), init: C) -> C
lax.fori_loop(start: int, end: int, body: (int -> C -> C), init: C) -> C

In the above signature, “C” stands for the type of the loop “carry” value. For example, here is an example fori loop

>>> import numpy as np
>>>
>>> def func10(arg, n):
...   ones = jnp.ones(arg.shape)  # A constant
...   return lax.fori_loop(0, n,
...                        lambda i, carry: carry + ones * 3. + arg,
...                        arg + ones)
...
>>> print(make_jaxpr(func10)(np.ones(16), 5))
{ lambda ; a:f32[16] b:i32[]. let
    c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
    d:f32[16] = add a c
    _:i32[] _:i32[] e:f32[16] = while[
      body_jaxpr={ lambda ; f:f32[16] g:f32[16] h:i32[] i:i32[] j:f32[16]. let
          k:i32[] = add h 1
          l:f32[16] = mul f 3.0
          m:f32[16] = add j l
          n:f32[16] = add m g
        in (k, i, n) }
      body_nconsts=2
      cond_jaxpr={ lambda ; o:i32[] p:i32[] q:f32[16]. let
          r:bool[] = lt o p
        in (r,) }
      cond_nconsts=0
    ] c a 0 b d
  in (e,) }

The while primitive takes 5 arguments: c a 0 b d, as follows:

  • 0 constants for cond_jaxpr (since cond_nconsts is 0)

  • 2 constants for body_jaxpr (c, and a)

  • 3 parameters for the initial value of carry

Scan#

JAX supports a special form of loop over the elements of an array (with statically known shape). The fact that there are a fixed number of iterations makes this form of looping easily reverse-differentiable. Such loops are constructed with the jax.lax.scan() function:

lax.scan(body_fun: (C -> A -> (C, B)), init_carry: C, in_arr: Array[A]) -> (C, Array[B])

This is written in terms of a Haskell Type Signature: C is the type of the scan carry, A is the element type of the input array(s), and B is the element type of the output array(s).

For the example consider the function func11 below

>>> def func11(arr, extra):
...   ones = jnp.ones(arr.shape)  #  A constant
...   def body(carry, aelems):
...     # carry: running dot-product of the two arrays
...     # aelems: a pair with corresponding elements from the two arrays
...     ae1, ae2 = aelems
...     return (carry + ae1 * ae2 + extra, carry)
...   return lax.scan(body, 0., (arr, ones))
...
>>> print(make_jaxpr(func11)(np.ones(16), 5.))
{ lambda ; a:f32[16] b:f32[]. let
    c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
    d:f32[] e:f32[16] = scan[
      _split_transpose=False
      jaxpr={ lambda ; f:f32[] g:f32[] h:f32[] i:f32[]. let
          j:f32[] = mul h i
          k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] g
          l:f32[] = add k j
          m:f32[] = convert_element_type[new_dtype=float32 weak_type=False] f
          n:f32[] = add l m
        in (n, g) }
      length=16
      linear=(False, False, False, False)
      num_carry=1
      num_consts=1
      reverse=False
      unroll=1
    ] b 0.0 a c
  in (d, e) }

The linear parameter describes for each of the input variables whether they are guaranteed to be used linearly in the body. Once the scan goes through linearization, more arguments will be linear.

The scan primitive takes 4 arguments: b 0.0 a c, of which:

  • one is the free variable for the body

  • one is the initial value of the carry

  • The next 2 are the arrays over which the scan operates.

XLA_call#

The call primitive arises from JIT compilation, and it encapsulates a sub-jaxpr along with parameters that specify the backend and the device on which the computation should run. For example

>>> from jax import jit
>>>
>>> def func12(arg):
...   @jit
...   def inner(x):
...     return x + arg * jnp.ones(1)  # Include a constant in the inner function
...   return arg + inner(arg - 2.)
...
>>> print(make_jaxpr(func12)(1.))  
{ lambda ; a:f32[]. let
    b:f32[] = sub a 2.0
    c:f32[1] = pjit[
      name=inner
      jaxpr={ lambda ; d:f32[] e:f32[]. let
          f:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0
          g:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
          h:f32[1] = mul g f
          i:f32[] = convert_element_type[new_dtype=float32 weak_type=False] e
          j:f32[1] = add i h
        in (j,) }
    ] a b
    k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a
    l:f32[1] = add k c
  in (l,) }
XLA_pmap#

If you use the jax.pmap() transformation, the function to be mapped is captured using the xla_pmap primitive. Consider this example

>>> from jax import pmap
>>>
>>> def func13(arr, extra):
...   def inner(x):
...     # use a free variable "extra" and a constant jnp.ones(1)
...     return (x + extra + jnp.ones(1)) / lax.psum(x, axis_name='rows')
...   return pmap(inner, axis_name='rows')(arr)
...
>>> print(make_jaxpr(func13)(jnp.ones((1, 3)), 5.))
{ lambda ; a:f32[1,3] b:f32[]. let
    c:f32[1,3] = xla_pmap[
      axis_name=rows
      axis_size=1
      backend=None
      call_jaxpr={ lambda ; d:f32[] e:f32[3]. let
          f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
          g:f32[3] = add e f
          h:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0
          i:f32[3] = add g h
          j:f32[3] = psum[axes=('rows',) axis_index_groups=None] e
          k:f32[3] = div i j
        in (k,) }
      devices=None
      donated_invars=(False, False)
      global_axis_size=1
      in_axes=(None, 0)
      is_explicit_global_axis_size=False
      name=inner
      out_axes=(0,)
    ] b a
  in (c,) }

The xla_pmap primitive specifies the name of the axis (parameter axis_name) and the body of the function to be mapped as the call_jaxpr parameter. The value of this parameter is a Jaxpr with 2 input variables.

The parameter in_axes specifies which of the input variables should be mapped and which should be broadcast. In our example, the value of extra is broadcast and the value of arr is mapped.

External Callbacks in JAX#

This guide outlines the uses of various callback functions, which allow JAX runtimes to execute Python code on the host, even while running under jit, vmap, grad, or another transformation.

Why callbacks?#

A callback routine is a way to perform host-side execution of code at runtime. As a simple example, suppose you’d like to print the value of some variable during the course of a computation. Using a simple Python print statement, it looks like this:

import jax

@jax.jit
def f(x):
  y = x + 1
  print("intermediate value: {}".format(y))
  return y * 2

result = f(2)
intermediate value: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>

What is printed is not the runtime value, but the trace-time abstract value (if you’re not famililar with tracing in JAX, a good primer can be found in How To Think In JAX).

To print the value at runtime we need a callback, for example jax.debug.print:

@jax.jit
def f(x):
  y = x + 1
  jax.debug.print("intermediate value: {}", y)
  return y * 2

result = f(2)
intermediate value: 3

This works by passing the runtime value represented by y back to the host process, where the host can print the value.

Flavors of Callback#

In earlier versions of JAX, there was only one kind of callback available, implemented in jax.experimental.host_callback. The host_callback routines had some deficiencies, and are now deprecated in favor of several callbacks designed for different situations:

(The jax.debug.print() function we used above is a wrapper around jax.debug.callback()).

From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow.

callback function

supports return value

jit

vmap

grad

scan/while_loop

guaranteed execution

jax.pure_callback

❌¹

jax.experimental.io_callback

✅/❌²

✅³

jax.debug.callback

¹ jax.pure_callback can be used with custom_jvp to make it compatible with autodiff

² jax.experimental.io_callback is compatible with vmap only if ordered=False.

³ Note that vmap of scan/while_loop of io_callback has complicated semantics, and its behavior may change in future releases.

Exploring jax.pure_callback#

jax.pure_callback is generally the callback function you should reach for when you want host-side execution of a pure function: i.e. a function that has no side-effects (such as printing values, reading data from disk, updating a global state, etc.).

The function you pass to jax.pure_callback need not actually be pure, but it will be assumed pure by JAX’s transformations and higher-order functions, which means that it may be silently elided or called multiple times.

import jax
import jax.numpy as jnp
import numpy as np

def f_host(x):
  # call a numpy (not jax.numpy) operation:
  return np.sin(x).astype(x.dtype)

def f(x):
  result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
  return jax.pure_callback(f_host, result_shape, x)

x = jnp.arange(5.0)
f(x)
Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)

Because pure_callback can be elided or duplicated, it is compatible out-of-the-box with transformations like jit and vmap, as well as higher-order primitives like scan and while_loop:”

jax.jit(f)(x)
Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)
jax.vmap(f)(x)
Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)
def body_fun(_, x):
  return _, f(x)
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]
Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)

However, because there is no way for JAX to introspect the content of the callback, pure_callback has undefined autodiff semantics:

%xmode minimal
Exception reporting mode: Minimal
jax.grad(f)(x)
ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.

For an example of using pure_callback with jax.custom_jvp, see Example: pure_callback with custom_jvp below.

By design functions passed to pure_callback are treated as if they have no side-effects: one consequence of this is that if the output of the function is not used, the compiler may eliminate the callback entirely:

def print_something():
  print('printing something')
  return np.int32(0)

@jax.jit
def f1():
  return jax.pure_callback(print_something, np.int32(0))
f1();
printing something
@jax.jit
def f2():
  jax.pure_callback(print_something, np.int32(0))
  return 1.0
f2();

In f1, the output of the callback is used in the return value of the function, so the callback is executed and we see the printed output. In f2 on the other hand, the output of the callback is unused, and so the compiler notices this and eliminates the function call. These are the correct semantics for a callback to a function with no side-effects.

Exploring jax.experimental.io_callback#

In contrast to jax.pure_callback(), jax.experimental.io_callback() is explicitly meant to be used with impure functions, i.e. functions that do have side-effects.

As an example, here is a callback to a global host-side numpy random generator. This is an impure operation because a side-effect of generating a random number in numpy is that the random state is updated (Please note that this is meant as a toy example of io_callback and not necessarily a recommended way of generating random numbers in JAX!).

from jax.experimental import io_callback
from functools import partial

global_rng = np.random.default_rng(0)

def host_side_random_like(x):
  """Generate a random array like x using the global_rng state"""
  # We have two side-effects here:
  # - printing the shape and dtype
  # - calling global_rng, thus updating its state
  print(f'generating {x.dtype}{list(x.shape)}')
  return global_rng.uniform(size=x.shape).astype(x.dtype)

@jax.jit
def numpy_random_like(x):
  return io_callback(host_side_random_like, x, x)

x = jnp.zeros(5)
numpy_random_like(x)
generating float32[5]
Array([0.6369617 , 0.26978672, 0.04097353, 0.01652764, 0.8132702 ],      dtype=float32)

The io_callback is compatible with vmap by default:

jax.vmap(numpy_random_like)(x)
generating float32[]
generating float32[]
generating float32[]
generating float32[]
generating float32[]
Array([0.91275555, 0.60663575, 0.72949654, 0.543625  , 0.9350724 ],      dtype=float32)

Note, however, that this may execute the mapped callbacks in any order. So, for example, if you ran this on a GPU, the order of the mapped outputs might differ from run to run.

If it is important that the order of callbacks be preserved, you can set ordered=True, in which case attempting to vmap will raise an error:

@jax.jit
def numpy_random_like_ordered(x):
  return io_callback(host_side_random_like, x, x, ordered=True)

jax.vmap(numpy_random_like_ordered)(x)
JaxStackTraceBeforeTransformation: ValueError: Cannot `vmap` ordered IO callback.

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------


The above exception was the direct cause of the following exception:

ValueError: Cannot `vmap` ordered IO callback.

On the other hand, scan and while_loop work with io_callback regardless of whether ordering is enforced:

def body_fun(_, x):
  return _, numpy_random_like_ordered(x)
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]
generating float32[]
generating float32[]
generating float32[]
generating float32[]
generating float32[]
Array([0.81585354, 0.0027385 , 0.8574043 , 0.03358557, 0.72965544],      dtype=float32)

Like pure_callback, io_callback fails under automatic differentiation if it is passed a differentiated variable:

jax.grad(numpy_random_like)(x)
JaxStackTraceBeforeTransformation: ValueError: IO callbacks do not support JVP.

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------


The above exception was the direct cause of the following exception:

ValueError: IO callbacks do not support JVP.

However, if the callback is not dependent on a differentiated variable, it will execute:

@jax.jit
def f(x):
  io_callback(lambda: print('hello'), None)
  return x

jax.grad(f)(1.0);
hello

Unlike pure_callback, the compiler will not remove the callback execution in this case, even though the output of the callback is unused in the subsequent computation.

Exploring debug.callback#

Both pure_callback and io_callback enforce some assumptions about the purity of the function they’re calling, and limit in various ways what JAX transforms and compilation machinery may do. debug.callback essentially assumes nothing about the callback function, such that the action of the callback reflects exactly what JAX is doing during the course of a program. Further, debug.callback cannot return any value to the program.

from jax import debug

def log_value(x):
  # This could be an actual logging call; we'll use
  # print() for demonstration
  print("log:", x)

@jax.jit
def f(x):
  debug.callback(log_value, x)
  return x

f(1.0);
log: 1.0

The debug callback is compatible with vmap:

x = jnp.arange(5.0)
jax.vmap(f)(x);
log: 0.0
log: 1.0
log: 2.0
log: 3.0
log: 4.0

And is also compatible with grad and other autodiff transformations

jax.grad(f)(1.0);
log: 1.0

This can make debug.callback more useful for general-purpose debugging than either pure_callback or io_callback.

Example: pure_callback with custom_jvp#

One powerful way to take advantage of jax.pure_callback() is to combine it with jax.custom_jvp (see Custom derivative rules for more details on custom_jvp). Suppose we want to create a JAX-compatible wrapper for a scipy or numpy function that is not yet available in the jax.scipy or jax.numpy wrappers.

Here, we’ll consider creating a wrapper for the Bessel function of the first kind, implemented in scipy.special.jv. We can start by defining a straightforward pure_callback:

import jax
import jax.numpy as jnp
import scipy.special

def jv(v, z):
  v, z = jnp.asarray(v), jnp.asarray(z)

  # Require the order v to be integer type: this simplifies
  # the JVP rule below.
  assert jnp.issubdtype(v.dtype, jnp.integer)

  # Promote the input to inexact (float/complex).
  # Note that jnp.result_type() accounts for the enable_x64 flag.
  z = z.astype(jnp.result_type(float, z.dtype))

  # Wrap scipy function to return the expected dtype.
  _scipy_jv = lambda v, z: scipy.special.jv(v, z).astype(z.dtype)

  # Define the expected shape & dtype of output.
  result_shape_dtype = jax.ShapeDtypeStruct(
      shape=jnp.broadcast_shapes(v.shape, z.shape),
      dtype=z.dtype)

  # We use vectorize=True because scipy.special.jv handles broadcasted inputs.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)

This lets us call into scipy.special.jv from transformed JAX code, including when transformed by jit and vmap:

from functools import partial
j1 = partial(jv, 1)
z = jnp.arange(5.0)
print(j1(z))
[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]

Here is the same result with jit:

print(jax.jit(j1)(z))
[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]

And here is the same result again with vmap:

print(jax.vmap(j1)(z))
[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]

However, if we call jax.grad, we see an error because there is no autodiff rule defined for this function:

jax.grad(j1)(z)
ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.

Let’s define a custom gradient rule for this. Looking at the definition of the Bessel Function of the First Kind, we find that there is a relatively straightforward recurrence relationship for the derivative with respect to the argument z:

\[\begin{split} d J_\nu(z) = \left\{ \begin{eqnarray} -J_1(z),\ &\nu=0\\ [J_{\nu - 1}(z) - J_{\nu + 1}(z)]/2,\ &\nu\ne 0 \end{eqnarray}\right. \end{split}\]

The gradient with respect to \(\nu\) is more complicated, but since we’ve restricted the v argument to integer types we don’t need to worry about its gradient for the sake of this example.

We can use jax.custom_jvp to define this automatic differentiation rule for our callback function:

jv = jax.custom_jvp(jv)

@jv.defjvp
def _jv_jvp(primals, tangents):
  v, z = primals
  _, z_dot = tangents  # Note: v_dot is always 0 because v is integer.
  jv_minus_1, jv_plus_1 = jv(v - 1, z), jv(v + 1, z)
  djv_dz = jnp.where(v == 0, -jv_plus_1, 0.5 * (jv_minus_1 - jv_plus_1))
  return jv(v, z), z_dot * djv_dz

Now computing the gradient of our function will work correctly:

j1 = partial(jv, 1)
print(jax.grad(j1)(2.0))
-0.06447162

Further, since we’ve defined our gradient in terms of jv itself, JAX’s architecture means that we get second-order and higher derivatives for free:

jax.hessian(j1)(2.0)
Array(-0.4003078, dtype=float32, weak_type=True)

Keep in mind that although this all works correctly with JAX, each call to our callback-based jv function will result in passing the input data from the device to the host, and passing the output of scipy.special.jv from the host back to the device. When running on accelerators like GPU or TPU, this data movement and host synchronization can lead to significant overhead each time jv is called. However, if you are running JAX on a single CPU (where the “host” and “device” are on the same hardware), JAX will generally do this data transfer in a fast, zero-copy fashion, making this pattern is a relatively straightforward way extend JAX’s capabilities.

Type promotion semantics#

This document describes JAX’s type promotion rules–i.e., the result of jax.numpy.promote_types() for each pair of types. For some background on the considerations that went into the design of what is described below, see Design of Type Promotion Semantics for JAX.

JAX’s type promotion behavior is determined via the following type promotion lattice:

_images/type_lattice.svg

where, for example:

  • b1 means np.bool_,

  • i2 means np.int16,

  • u4 means np.uint32,

  • bf means np.bfloat16,

  • f2 means np.float16,

  • c8 means np.complex64,

  • i* means Python int or weakly-typed int,

  • f* means Python float or weakly-typed float, and

  • c* means Python complex or weakly-typed complex.

(for more about weak types, see Weakly-typed values in JAX below).

Promotion between any two types is given by their join on this lattice, which generates the following binary promotion table:

b1u1u2u4u8i1i2i4i8bff2f4f8c8c16i*f*c*
b1b1u1u2u4u8i1i2i4i8bff2f4f8c8c16i*f*c*
u1u1u1u2u4u8i2i2i4i8bff2f4f8c8c16u1f*c*
u2u2u2u2u4u8i4i4i4i8bff2f4f8c8c16u2f*c*
u4u4u4u4u4u8i8i8i8i8bff2f4f8c8c16u4f*c*
u8u8u8u8u8u8f*f*f*f*bff2f4f8c8c16u8f*c*
i1i1i2i4i8f*i1i2i4i8bff2f4f8c8c16i1f*c*
i2i2i2i4i8f*i2i2i4i8bff2f4f8c8c16i2f*c*
i4i4i4i4i8f*i4i4i4i8bff2f4f8c8c16i4f*c*
i8i8i8i8i8f*i8i8i8i8bff2f4f8c8c16i8f*c*
bfbfbfbfbfbfbfbfbfbfbff4f4f8c8c16bfbfc8
f2f2f2f2f2f2f2f2f2f2f4f2f4f8c8c16f2f2c8
f4f4f4f4f4f4f4f4f4f4f4f4f4f8c8c16f4f4c8
f8f8f8f8f8f8f8f8f8f8f8f8f8f8c16c16f8f8c16
c8c8c8c8c8c8c8c8c8c8c8c8c8c16c8c16c8c8c8
c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16
i*i*u1u2u4u8i1i2i4i8bff2f4f8c8c16i*f*c*
f*f*f*f*f*f*f*f*f*f*bff2f4f8c8c16f*f*c*
c*c*c*c*c*c*c*c*c*c*c8c8c8c16c8c16c*c*c*

Jax’s type promotion rules differ from those of NumPy, as given by numpy.promote_types(), in those cells highlighted with a green background in the table above. There are three key classes of differences:

  • When promoting a weakly typed value against a typed JAX value of the same category, JAX always prefers the precision of the JAX value. For example, jnp.int16(1) + 1 will return int16 rather than promoting to int64 as in NumPy. Note that this applies only to Python scalar values; if the constant is a NumPy array then the above lattice is used for type promotion. For example, jnp.int16(1) + np.array(1) will return int64.

  • When promoting an integer or boolean type against a floating-point or complex type, JAX always prefers the type of the floating-point or complex type.

  • JAX supports the bfloat16 non-standard 16-bit floating point type (jax.numpy.bfloat16), which is useful for neural network training. The only notable promotion behavior is with respect to IEEE-754 float16, with which bfloat16 promotes to a float32.

The differences between NumPy and JAX are motivated by the fact that accelerator devices, such as GPUs and TPUs, either pay a significant performance penalty to use 64-bit floating point types (GPUs) or do not support 64-bit floating point types at all (TPUs). Classic NumPy’s promotion rules are too willing to overpromote to 64-bit types, which is problematic for a system designed to run on accelerators.

JAX uses floating point promotion rules that are more suited to modern accelerator devices and are less aggressive about promoting floating point types. The promotion rules used by JAX for floating-point types are similar to those used by PyTorch.

Effects of Python operator dispatch#

Keep in mind that Python operators like + will dispatch based on the Python type of the two values being added. This means that, for example, np.int16(1) + 1 will promote using NumPy rules, whereas jnp.int16(1) + 1 will promote using JAX rules. This can lead to potentially confusing non-associative promotion semantics when the two types of promotion are combined; for example with np.int16(1) + 1 + jnp.int16(1).

Weakly-typed values in JAX#

Weakly-typed values in JAX can in most cases be thought of as having promotion behavior equivalent to that of Python scalars, such as the integer scalar 2 in the following:

>>> x = jnp.arange(5, dtype='int8')
>>> 2 * x
Array([0, 2, 4, 6, 8], dtype=int8)

JAX’s weak type framework is designed to prevent unwanted type promotion within binary operations between JAX values and values with no explicitly user-specified type, such as Python scalar literals. For example, if 2 were not treated as weakly-typed, the expression above would lead to an implicit type promotion:

>>> jnp.int32(2) * x
Array([0, 2, 4, 6, 8], dtype=int32)

When used in JAX, Python scalars are sometimes promoted to DeviceArray objects, for example during JIT compilation. To maintain the desired promotion semantics in this case, DeviceArray objects carry a weak_type flag that can be seen in an array’s string representation:

>>> jnp.asarray(2)
Array(2, dtype=int32, weak_type=True)

If the dtype is specified explicitly, it will instead result in a standard strongly-typed array value:

>>> jnp.asarray(2, dtype='int32')
Array(2, dtype=int32)

Strict dtype promotion#

In some contexts it can be useful to disable implicit type promotion behavior, and instead require all promotions to be explicit. This can be done in JAX by setting the jax_numpy_dtype_promtion flag to 'strict'. Locally, it can be done with acontext manager:

>>> x = jnp.float32(1)
>>> y = jnp.int32(1)
>>> with jax.numpy_dtype_promotion('strict'):
...   z = x + y  
...
Traceback (most recent call last):
TypePromotionError: Input dtypes ('float32', 'int32') have no available implicit
dtype promotion path when jax_numpy_dtype_promotion=strict. Try explicitly casting
inputs to the desired output type, or set jax_numpy_dtype_promotion=standard.

For convenience, strict promotion mode will still allow safe weakly-typed promotions, so you can still write code code that mixes JAX arrays and Python scalars:

>>> with jax.numpy_dtype_promotion('strict'):
...   z = x + 1
>>> print(z)
2.0

If you would prefer to set the configuration globally, you can do so using the standard configuration update:

jax.config.update('jax_numpy_dtype_promotion', 'strict')

To restore the default standard type promotion, set this configuration to 'standard':

jax.config.update('jax_numpy_dtype_promotion', 'standard')

Pytrees#

What is a pytree?#

In JAX, we use the term pytree to refer to a tree-like structure built out of container-like Python objects. Classes are considered container-like if they are in the pytree registry, which by default includes lists, tuples, and dicts. That is:

  1. any object whose type is not in the pytree container registry is considered a leaf pytree;

  2. any object whose type is in the pytree container registry, and which contains pytrees, is considered a pytree.

For each entry in the pytree container registry, a container-like type is registered with a pair of functions that specify how to convert an instance of the container type to a (children, metadata) pair and how to convert such a pair back to an instance of the container type. Using these functions, JAX can canonicalize any tree of registered container objects into tuples.

Example pytrees:

[1, "a", object()]  # 3 leaves

(1, (2, 3), ())  # 3 leaves

[1, {"k1": 2, "k2": (3, 4)}, 5]  # 5 leaves

JAX can be extended to consider other container types as pytrees; see Extending pytrees below.

Pytrees and JAX functions#

Many JAX functions, like jax.lax.scan(), operate over pytrees of arrays. JAX function transformations can be applied to functions that accept as input and produce as output pytrees of arrays.

Applying optional parameters to pytrees#

Some JAX function transformations take optional parameters that specify how certain input or output values should be treated (e.g. the in_axes and out_axes arguments to vmap()). These parameters can also be pytrees, and their structure must correspond to the pytree structure of the corresponding arguments. In particular, to be able to “match up” leaves in these parameter pytrees with values in the argument pytrees, the parameter pytrees are often constrained to be tree prefixes of the argument pytrees.

For example, if we pass the following input to vmap() (note that the input arguments to a function are considered a tuple):

(a1, {"k1": a2, "k2": a3})

We can use the following in_axes pytree to specify that only the k2 argument is mapped (axis=0) and the rest aren’t mapped over (axis=None):

(None, {"k1": None, "k2": 0})

The optional parameter pytree structure must match that of the main input pytree. However, the optional parameters can optionally be specified as a “prefix” pytree, meaning that a single leaf value can be applied to an entire sub-pytree. For example, if we have the same vmap() input as above, but wish to only map over the dictionary argument, we can use:

(None, 0)  # equivalent to (None, {"k1": 0, "k2": 0})

Or, if we want every argument to be mapped, we can simply write a single leaf value that is applied over the entire argument tuple pytree:

0

This happens to be the default in_axes value for vmap()!

The same logic applies to other optional parameters that refer to specific input or output values of a transformed function, e.g. vmap’s out_axes.

Viewing the pytree definition of an object#

To view the pytree definition of an arbitrary object for debugging purposes, you can use:

from jax.tree_util import tree_structure
print(tree_structure(object))

Developer information#

This is primarily JAX internal documentation, end-users are not supposed to need to understand this to use JAX, except when registering new user-defined container types with JAX. Some of these details may change.

Internal pytree handling#

JAX flattens pytrees into lists of leaves at the api.py boundary (and also in control flow primitives). This keeps downstream JAX internals simpler: transformations like grad(), jit(), and vmap() can handle user functions that accept and return the myriad different Python containers, while all the other parts of the system can operate on functions that only take (multiple) array arguments and always return a flat list of arrays.

When JAX flattens a pytree it will produce a list of leaves and a treedef object that encodes the structure of the original value. The treedef can then be used to construct a matching structured value after transforming the leaves. Pytrees are tree-like, rather than DAG-like or graph-like, in that we handle them assuming referential transparency and that they can’t contain reference cycles.

Here is a simple example:

from jax.tree_util import tree_flatten, tree_unflatten
import jax.numpy as jnp

# The structured value to be transformed
value_structured = [1., (2., 3.)]

# The leaves in value_flat correspond to the `*` markers in value_tree
value_flat, value_tree = tree_flatten(value_structured)
print(f"{value_flat=}\n{value_tree=}")

# Transform the flat value list using an element-wise numeric transformer
transformed_flat = list(map(lambda v: v * 2., value_flat))
print(f"{transformed_flat=}")

# Reconstruct the structured output, using the original
transformed_structured = tree_unflatten(value_tree, transformed_flat)
print(f"{transformed_structured=}")
value_flat=[1.0, 2.0, 3.0]
value_tree=PyTreeDef([*, (*, *)])
transformed_flat=[2.0, 4.0, 6.0]
transformed_structured=[2.0, (4.0, 6.0)]

By default, pytree containers can be lists, tuples, dicts, namedtuple, None, OrderedDict. Other types of values, including numeric and ndarray values, are treated as leaves:

from collections import namedtuple
Point = namedtuple('Point', ['x', 'y'])

example_containers = [
    (1., [2., 3.]),
    (1., {'b': 2., 'a': 3.}),
    1.,
    None,
    jnp.zeros(2),
    Point(1., 2.)
]
def show_example(structured):
  flat, tree = tree_flatten(structured)
  unflattened = tree_unflatten(tree, flat)
  print(f"{structured=}\n  {flat=}\n  {tree=}\n  {unflattened=}")

for structured in example_containers:
  show_example(structured)
structured=(1.0, [2.0, 3.0])
  flat=[1.0, 2.0, 3.0]
  tree=PyTreeDef((*, [*, *]))
  unflattened=(1.0, [2.0, 3.0])
structured=(1.0, {'b': 2.0, 'a': 3.0})
  flat=[1.0, 3.0, 2.0]
  tree=PyTreeDef((*, {'a': *, 'b': *}))
  unflattened=(1.0, {'a': 3.0, 'b': 2.0})
structured=1.0
  flat=[1.0]
  tree=PyTreeDef(*)
  unflattened=1.0
structured=None
  flat=[]
  tree=PyTreeDef(None)
  unflattened=None
structured=Array([0., 0.], dtype=float32)
  flat=[Array([0., 0.], dtype=float32)]
  tree=PyTreeDef(*)
  unflattened=Array([0., 0.], dtype=float32)
structured=Point(x=1.0, y=2.0)
  flat=[1.0, 2.0]
  tree=PyTreeDef(CustomNode(namedtuple[Point], [*, *]))
  unflattened=Point(x=1.0, y=2.0)
Extending pytrees#

By default, any part of a structured value that is not recognized as an internal pytree node (i.e. container-like) is treated as a leaf:

class Special(object):
  def __init__(self, x, y):
    self.x = x
    self.y = y

  def __repr__(self):
    return "Special(x={}, y={})".format(self.x, self.y)


show_example(Special(1., 2.))
structured=Special(x=1.0, y=2.0)
  flat=[Special(x=1.0, y=2.0)]
  tree=PyTreeDef(*)
  unflattened=Special(x=1.0, y=2.0)

The set of Python types that are considered internal pytree nodes is extensible, through a global registry of types, and values of registered types are traversed recursively. To register a new type, you can use register_pytree_node():

from jax.tree_util import register_pytree_node

class RegisteredSpecial(Special):
  def __repr__(self):
    return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)

def special_flatten(v):
  """Specifies a flattening recipe.

  Params:
    v: the value of registered type to flatten.
  Returns:
    a pair of an iterable with the children to be flattened recursively,
    and some opaque auxiliary data to pass back to the unflattening recipe.
    The auxiliary data is stored in the treedef for use during unflattening.
    The auxiliary data could be used, e.g., for dictionary keys.
  """
  children = (v.x, v.y)
  aux_data = None
  return (children, aux_data)

def special_unflatten(aux_data, children):
  """Specifies an unflattening recipe.

  Params:
    aux_data: the opaque data that was specified during flattening of the
      current treedef.
    children: the unflattened children

  Returns:
    a re-constructed object of the registered type, using the specified
    children and auxiliary data.
  """
  return RegisteredSpecial(*children)

# Global registration
register_pytree_node(
    RegisteredSpecial,
    special_flatten,    # tell JAX what are the children nodes
    special_unflatten   # tell JAX how to pack back into a RegisteredSpecial
)

show_example(RegisteredSpecial(1., 2.))
structured=RegisteredSpecial(x=1.0, y=2.0)
  flat=[1.0, 2.0]
  tree=PyTreeDef(CustomNode(RegisteredSpecial[None], [*, *]))
  unflattened=RegisteredSpecial(x=1.0, y=2.0)

Alternatively, you can define appropriate tree_flatten and tree_unflatten methods on your class and decorate it with register_pytree_node_class():

from jax.tree_util import register_pytree_node_class

@register_pytree_node_class
class RegisteredSpecial2(Special):
  def __repr__(self):
    return "RegisteredSpecial2(x={}, y={})".format(self.x, self.y)

  def tree_flatten(self):
    children = (self.x, self.y)
    aux_data = None
    return (children, aux_data)

  @classmethod
  def tree_unflatten(cls, aux_data, children):
    return cls(*children)

show_example(RegisteredSpecial2(1., 2.))
structured=RegisteredSpecial2(x=1.0, y=2.0)
  flat=[1.0, 2.0]
  tree=PyTreeDef(CustomNode(RegisteredSpecial2[None], [*, *]))
  unflattened=RegisteredSpecial2(x=1.0, y=2.0)

When defining unflattening functions, in general children should contain all the dynamic elements of the data structure (arrays, dynamic scalars, and pytrees), while aux_data should contain all the static elements that will be rolled into the treedef structure. JAX sometimes needs to compare treedef for equality, or compute its hash for use in the JIT cache, and so care must be taken to ensure that the auxiliary data specified in the flattening recipe supports meaningful hashing and equality comparisons.

The whole set of functions for operating on pytrees are in jax.tree_util.

Custom PyTrees and Initialization#

One common gotcha with user-defined PyTree objects is that JAX transformations occasionally initialize them with unexpected values, so that any input validation done at initialization may fail. For example:

class MyTree:
  def __init__(self, a):
    self.a = jnp.asarray(a)

register_pytree_node(MyTree, lambda tree: ((tree.a,), None),
    lambda _, args: MyTree(*args))

tree = MyTree(jnp.arange(5.0))

jax.vmap(lambda x: x)(tree)      # Error because object() is passed to MyTree.
jax.jacobian(lambda x: x)(tree)  # Error because MyTree(...) is passed to MyTree

In the first case, JAX’s internals use arrays of object() values to infer the structure of the tree; in the second case, the jacobian of a function mapping a tree to a tree is defined as a tree of trees.

For this reason, the __init__ and __new__ methods of custom PyTree classes should generally avoid doing any array conversion or other input validation, or else anticipate and handle these special cases. For example:

class MyTree:
  def __init__(self, a):
    if not (type(a) is object or a is None or isinstance(a, MyTree)):
      a = jnp.asarray(a)
    self.a = a

Another possibility is to structure your tree_unflatten function so that it avoids calling __init__; for example:

def tree_unflatten(aux_data, children):
  del aux_data  # unused in this class
  obj = object.__new__(MyTree)
  obj.a = a
  return obj

If you go this route, make sure that your tree_unflatten function stays in-sync with __init__ if and when the code is updated.

Ahead-of-time lowering and compilation#

JAX offers several transformations, such as jax.jit and jax.pmap, returning a function that is compiled and runs on accelerators or the CPU. As the JIT acronym indicates, all compilation happens just-in-time for execution.

Some situations call for ahead-of-time (AOT) compilation instead. When you want to fully compile prior to execution time, or you want control over when different parts of the compilation process take place, JAX has some options for you.

First, let’s review the stages of compilation. Suppose that f is a function/callable output by jax.jit(), say f = jax.jit(F) for some input callable F. When it is invoked with arguments, say f(x, y) where x and y are arrays, JAX does the following in order:

  1. Stage out a specialized version of the original Python callable F to an internal representation. The specialization reflects a restriction of F to input types inferred from properties of the arguments x and y (usually their shape and element type).

  2. Lower this specialized, staged-out computation to the XLA compiler’s input language, StableHLO.

  3. Compile the lowered HLO program to produce an optimized executable for the target device (CPU, GPU, or TPU).

  4. Execute the compiled executable with the arrays x and y as arguments.

JAX’s AOT API gives you direct control over steps #2, #3, and #4 (but not #1), plus some other features along the way. An example:

>>> import jax
>>> import jax.numpy as jnp
>>> import numpy as np

>>> def f(x, y): return 2 * x + y
>>> x, y = 3, 4

>>> lowered = jax.jit(f).lower(x, y)

>>> # Print lowered HLO
>>> print(lowered.as_text())
module @jit_f.0 {
  func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
    %0 = stablehlo.constant dense<2> : tensor<i32>
    %1 = stablehlo.multiply %0, %arg0 : tensor<i32>
    %2 = stablehlo.add %1, %arg1 : tensor<i32>
    return %2 : tensor<i32>
  }
}

>>> compiled = lowered.compile()

>>> # Query for cost analysis, print FLOP estimate
>>> compiled.cost_analysis()[0]['flops']
2.0

>>> # Execute the compiled function!
>>> compiled(x, y)
DeviceArray(10, dtype=int32)

See the jax.stages documentation for more details on what functionality the lowering and compiled functions provide.

In place of jax.jit above, you can also lower(...) the result of jax.pmap(), as well as pjit and xmap (from jax.experimental.pjit and jax.experimental.maps respectively). In each case, you can compile() the result similarly.

All optional arguments to jit—such as static_argnums—are respected in the corresponding lowering, compilation, and execution. Again the same goes for pmap, pjit, and xmap.

In the example above, we can replace the arguments to lower with any objects that have shape and dtype attributes:

>>> i32_scalar = jax.ShapeDtypeStruct((), jnp.dtype('int32'))
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x, y)
DeviceArray(10, dtype=int32)

More generally, lower only needs its arguments to structurally supply what JAX must know for specialization and lowering. For typical array arguments like the ones above, this means shape and dtype fields. For static arguments, by contrast, JAX needs actual array values (more on this below).

Invoking an AOT-compiled function with arguments that are incompatible with its lowering raises an error:

>>> x_1d = y_1d = jnp.arange(3)
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_1d, y_1d)
...
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with int32[3]
Argument 'y' compiled with int32[] and called with int32[3]

>>> x_f = y_f = jnp.float32(72.)
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_f, y_f)
...
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with float32[]
Argument 'y' compiled with int32[] and called with float32[]

Relatedly, AOT-compiled functions cannot be transformed by JAX’s just-in-time transformations such as jax.jit, jax.grad(), and jax.vmap().

Lowering with static arguments#

Lowering with static arguments underscores the interaction between options passed to jax.jit, the arguments passed to lower, and the arguments needed to invoke the resulting compiled function. Continuing with our example above:

>>> lowered_with_x = jax.jit(f, static_argnums=0).lower(7, 8)

>>> # Lowered HLO, specialized to the *value* of the first argument (7)
>>> print(lowered_with_x.as_text())
module @jit_f.1 {
  func.func public @main(%arg0: tensor<i32>) -> tensor<i32> {
    %0 = stablehlo.constant dense<14> : tensor<i32>
    %1 = stablehlo.add %0, %arg0 : tensor<i32>
    return %1 : tensor<i32>
  }
}
>>> lowered_with_x.compile()(5)
DeviceArray(19, dtype=int32)

Note that lower here takes two arguments as usual, but the subsequent compiled function accepts only the remaining non-static second argument. The static first argument (value 7) is taken as a constant at lowering time and built into the lowered computation, where it is possibly folded in with other constants. In this case, its multiplication by 2 is simplified, resulting in the constant 14.

Although the second argument to lower above can be replaced by a hollow shape/dtype structure, it is necessary that the static first argument be a concrete value. Otherwise, lowering would err:

>>> jax.jit(f, static_argnums=0).lower(i32_scalar, i32_scalar)
TypeError: unsupported operand type(s) for *: 'int' and 'ShapeDtypeStruct'

>>> jax.jit(f, static_argnums=0).lower(10, i32_scalar).compile()(5)
DeviceArray(25, dtype=int32)

AOT-compiled functions cannot be transformed#

Compiled functions are specialized to a particular set of argument “types,” such as arrays with a specific shape and element type in our running example. From JAX’s internal point of view, transformations such as jax.vmap() alter the type signature of functions in a way that invalidates the compiled-for type signature. As a policy, JAX simply disallows compiled functions to be involved in transformations. Example:

>>> def g(x):
...   assert x.shape == (3, 2)
...   return x @ jnp.ones(2)

>>> def make_z(*shape):
...   return jnp.arange(np.prod(shape)).reshape(shape)

>>> z, zs = make_z(3, 2), make_z(4, 3, 2)

>>> g_jit = jax.jit(g)
>>> g_aot = jax.jit(g).lower(z).compile()

>>> jax.vmap(g_jit)(zs)
DeviceArray([[ 1.,  5.,  9.],
             [13., 17., 21.],
             [25., 29., 33.],
             [37., 41., 45.]], dtype=float32)

>>> jax.vmap(g_aot)(zs)
TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type <class 'jax.interpreters.batching.BatchTracer'>.

A similar error is raised when g_aot is involved in autodiff (e.g. jax.grad()). For consistency, transformation by jax.jit is disallowed as well, even though jit does not meaningfully modify its argument’s type signature.

Debug information and analyses, when available#

In addition to the primary AOT functionality (separate and explicit lowering, compilation, and execution), JAX’s various AOT stages also offer some additional features to help with debugging and gathering compiler feedback.

For instance, as the initial example above shows, lowered functions often offer a text representation. Compiled functions do the same, and also offer cost and memory analyses from the compiler. All of these are provided via methods on the jax.stages.Lowered and jax.stages.Compiled objects (e.g., lowered.as_text() and compiled.cost_analysis() above).

These methods are meant as an aid for manual inspection and debugging, not as a reliably programmable API. Their availability and output vary by compiler, platform, and runtime. This makes for two important caveats:

  1. If some functionality is unavailable on JAX’s current backend, then the method for it returns something trivial (and False-like). For example, if the compiler underlying JAX does not provide a cost analysis, then compiled.cost_analysis() will be None.

  2. If some functionality is available, there are still very limited guarantees on what the corresponding method provides. The return value is not required to be consistent—in type, structure, or value—across JAX configurations, backends/platforms, versions, or even invocations of the method. JAX cannot guarantee that the output of compiled.cost_analysis() on one day will remain the same on the following day.

When in doubt, see the package API documentation for jax.stages.

Inspecting staged-out computations#

Stage #1 in the list at the top of this note mentions specialization and staging, prior to lowering. JAX’s internal notion of a function specialized to the types of its arguments is not always a reified data structure in memory. To explicitly construct a view of JAX’s specialization of a function in the internal Jaxpr intermediate language, see jax.make_jaxpr().

JAX Errors#

This page lists a few of the errors you might encounter when using JAX, along with representative examples of how one might fix them.

class jax.errors.ConcretizationTypeError(tracer, context='')#

This error occurs when a JAX Tracer object is used in a context where a concrete value is required (see Different kinds of JAX values for more on what a Tracer is). In some situations, it can be easily fixed by marking problematic values as static; in others, it may indicate that your program is doing operations that are not directly supported by JAX’s JIT compilation model.

Examples:

Traced value where static value is expected

One common cause of this error is using a traced value where a static value is required. For example:

>>> from functools import partial
>>> from jax import jit
>>> import jax.numpy as jnp
>>> @jit
... def func(x, axis):
...   return x.min(axis)
>>> func(jnp.arange(4), 0)  
Traceback (most recent call last):
    ...
ConcretizationTypeError: Abstract tracer value encountered where concrete
value is expected: axis argument to jnp.min().

This can often be fixed by marking the problematic argument as static:

>>> @partial(jit, static_argnums=1)
... def func(x, axis):
...   return x.min(axis)

>>> func(jnp.arange(4), 0)
Array(0, dtype=int32)
Shape depends on Traced Value

Such an error may also arise when a shape in your JIT-compiled computation depends on the values within a traced quantity. For example:

>>> @jit
... def func(x):
...     return jnp.where(x < 0)

>>> func(jnp.arange(4))  
Traceback (most recent call last):
    ...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:
The error arose in jnp.nonzero.

This is an example of an operation that is incompatible with JAX’s JIT compilation model, which requires array sizes to be known at compile-time. Here the size of the returned array depends on the contents of x, and such code cannot be JIT compiled.

In many cases it is possible to work around this by modifying the logic used in the function; for example here is code with a similar issue:

>>> @jit
... def func(x):
...     indices = jnp.where(x > 1)
...     return x[indices].sum()

>>> func(jnp.arange(4))  
Traceback (most recent call last):
    ...
ConcretizationTypeError: Abstract tracer value encountered where concrete
value is expected: The error arose in jnp.nonzero.

And here is how you might express the same operation in a way that avoids creation of a dynamically-sized index array:

>>> @jit
... def func(x):
...   return jnp.where(x > 1, x, 0).sum()

>>> func(jnp.arange(4))
Array(5, dtype=int32)

To understand more subtleties having to do with tracers vs. regular values, and concrete vs. abstract values, you may want to read Different kinds of JAX values.

Parameters:
  • tracer (core.Tracer)

  • context (str)

class jax.errors.KeyReuseError(message)#

This error occurs when a PRNG key is reused in an unsafe manner. Key reuse is checked only when jax_debug_key_reuse is set to True.

Here is a simple example of code that would lead to such an error:

>>> with jax.debug_key_reuse(True):  
...   key = jax.random.key(0)
...   value = jax.random.uniform(key)
...   new_value = jax.random.uniform(key)
...
---------------------------------------------------------------------------
KeyReuseError                             Traceback (most recent call last)
...
KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0

This sort of key reuse is problematic because the JAX PRNG is stateless, and keys must be manually split; For more information on this see Sharp Bits: Random Numbers.

Parameters:

message (str)

class jax.errors.NonConcreteBooleanIndexError(tracer)#

This error occurs when a program attempts to use non-concrete boolean indices in a traced indexing operation. Under JIT compilation, JAX arrays must have static shapes (i.e. shapes that are known at compile-time) and so boolean masks must be used carefully. Some logic implemented via boolean masking is simply not possible in a jax.jit() function; in other cases, the logic can be re-expressed in a JIT-compatible way, often using the three-argument version of where().

Following are a few examples of when this error might arise.

Constructing arrays via boolean masking

This most commonly arises when attempting to create an array via a boolean mask within a JIT context. For example:

>>> import jax
>>> import jax.numpy as jnp

>>> @jax.jit
... def positive_values(x):
...   return x[x > 0]

>>> positive_values(jnp.arange(-5, 5))  
Traceback (most recent call last):
    ...
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])

This function is attempting to return only the positive values in the input array; the size of this returned array cannot be determined at compile-time unless x is marked as static, and so operations like this cannot be performed under JIT compilation.

Reexpressible Boolean Logic

Although creating dynamically sized arrays is not supported directly, in many cases it is possible to re-express the logic of the computation in terms of a JIT-compatible operation. For example, here is another function that fails under JIT for the same reason:

>>> @jax.jit
... def sum_of_positive(x):
...   return x[x > 0].sum()

>>> sum_of_positive(jnp.arange(-5, 5))  
Traceback (most recent call last):
    ...
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])

In this case, however, the problematic array is only an intermediate value, and we can instead express the same logic in terms of the JIT-compatible three-argument version of jax.numpy.where():

>>> @jax.jit
... def sum_of_positive(x):
...   return jnp.where(x > 0, x, 0).sum()

>>> sum_of_positive(jnp.arange(-5, 5))
Array(10, dtype=int32)

This pattern of replacing boolean masking with three-argument where() is a common solution to this sort of problem.

Boolean indexing into JAX arrays

The other situation where this error often arises is when using boolean indices, such as with .at[...].set(...). Here is a simple example:

>>> @jax.jit
... def manual_clip(x):
...   return x.at[x < 0].set(0)

>>> manual_clip(jnp.arange(-2, 2))  
Traceback (most recent call last):
    ...
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[4])

This function is attempting to set values smaller than zero to a scalar fill value. As above, this can be addressed by re-expressing the logic in terms of where():

>>> @jax.jit
... def manual_clip(x):
...   return jnp.where(x < 0, 0, x)

>>> manual_clip(jnp.arange(-2, 2))
Array([0, 0, 0, 1], dtype=int32)
Parameters:

tracer (core.Tracer)

class jax.errors.TracerArrayConversionError(tracer)#

This error occurs when a program attempts to convert a JAX Tracer object into a standard NumPy array (see Different kinds of JAX values for more on what a Tracer is). It typically occurs in one of a few situations.

Using non-JAX functions in JAX transformations

This error can occur if you attempt to use a non-JAX library like numpy or scipy inside a JAX transformation (jit(), grad(), jax.vmap(), etc.). For example:

>>> from jax import jit
>>> import numpy as np

>>> @jit
... def func(x):
...   return np.sin(x)

>>> func(np.arange(4))  
Traceback (most recent call last):
    ...
TracerArrayConversionError: The numpy.ndarray conversion method
__array__() was called on traced array with shape int32[4]

In this case, you can fix the issue by using jax.numpy.sin() in place of numpy.sin():

>>> import jax.numpy as jnp
>>> @jit
... def func(x):
...   return jnp.sin(x)

>>> func(jnp.arange(4))
Array([0.        , 0.84147096, 0.9092974 , 0.14112   ], dtype=float32)

See also External Callbacks for options for calling back to host-side computations from transformed JAX code.

Indexing a numpy array with a tracer

If this error arises on a line that involves array indexing, it may be that the array being indexed x is a standard numpy.ndarray while the indices idx are traced JAX arrays. For example:

>>> x = np.arange(10)

>>> @jit
... def func(i):
...   return x[i]

>>> func(0)  
Traceback (most recent call last):
    ...
TracerArrayConversionError: The numpy.ndarray conversion method
__array__() was called on traced array with shape int32[0]

Depending on the context, you may fix this by converting the numpy array into a JAX array:

>>> @jit
... def func(i):
...   return jnp.asarray(x)[i]

>>> func(0)
Array(0, dtype=int32)

or by declaring the index as a static argument:

>>> from functools import partial
>>> @partial(jit, static_argnums=(0,))
... def func(i):
...   return x[i]

>>> func(0)
Array(0, dtype=int32)

To understand more subtleties having to do with tracers vs. regular values, and concrete vs. abstract values, you may want to read Different kinds of JAX values.

Parameters:

tracer (core.Tracer)

class jax.errors.TracerBoolConversionError(tracer)#

This error occurs when a traced value in JAX is used in a context where a boolean value is expected (see Different kinds of JAX values for more on what a Tracer is).

The boolean cast may be an explicit (e.g. bool(x)) or implicit, through use of control flow (e.g. if x > 0 or while x), use of Python boolean operators (e.g. z = x and y, z = x or y, z = not x) or functions that use them (e.g. z = max(x, y), z = min(x, y) etc.).

In some situations, this problem can be easily fixed by marking traced values as static; in others, it may indicate that your program is doing operations that are not directly supported by JAX’s JIT compilation model.

Examples:

Traced value used in control flow

One case where this often arises is when a traced value is used in Python control flow. For example:

>>> from jax import jit
>>> import jax.numpy as jnp
>>> @jit
... def func(x, y):
...   return x if x.sum() < y.sum() else y

>>> func(jnp.ones(4), jnp.zeros(4))  
Traceback (most recent call last):
    ...
TracerBoolConversionError: Attempted boolean conversion of JAX Tracer [...]

We could mark both inputs x and y as static, but that would defeat the purpose of using jax.jit() here. Another option is to re-express the if statement in terms of the three-term jax.numpy.where():

>>> @jit
... def func(x, y):
...   return jnp.where(x.sum() < y.sum(), x, y)

>>> func(jnp.ones(4), jnp.zeros(4))
Array([0., 0., 0., 0.], dtype=float32)

For more complicated control flow including loops, see Control flow operators.

Control flow on traced values

Another common cause of this error is if you inadvertently trace over a boolean flag. For example:

>>> @jit
... def func(x, normalize=True):
...   if normalize:
...     return x / x.sum()
...   return x

>>> func(jnp.arange(5), True)  
Traceback (most recent call last):
    ...
TracerBoolConversionError: Attempted boolean conversion of JAX Tracer ...

Here because the flag normalize is traced, it cannot be used in Python control flow. In this situation, the best solution is probably to mark this value as static:

>>> from functools import partial
>>> @partial(jit, static_argnames=['normalize'])
... def func(x, normalize=True):
...   if normalize:
...     return x / x.sum()
...   return x

>>> func(jnp.arange(5), True)
Array([0. , 0.1, 0.2, 0.3, 0.4], dtype=float32)

For more on static_argnums, see the documentation of jax.jit().

Using non-JAX aware functions

Another common cause of this error is using non-JAX aware functions within JAX code. For example:

>>> @jit
... def func(x):
...   return min(x, 0)
>>> func(2)  
Traceback (most recent call last):
    ...
TracerBoolConversionError: Attempted boolean conversion of JAX Tracer ...

In this case, the error occurs because Python’s built-in min function is not compatible with JAX transforms. This can be fixed by replacing it with jnp.minumum:

>>> @jit
... def func(x):
...   return jnp.minimum(x, 0)
>>> print(func(2))
0

To understand more subtleties having to do with tracers vs. regular values, and concrete vs. abstract values, you may want to read Different kinds of JAX values.

Parameters:

tracer (core.Tracer)

class jax.errors.TracerIntegerConversionError(tracer)#

This error can occur when a JAX Tracer object is used in a context where a Python integer is expected (see Different kinds of JAX values for more on what a Tracer is). It typically occurs in a few situations.

Passing a tracer in place of an integer

This error can occur if you attempt to pass a traced value to a function that requires a static integer argument; for example:

>>> from jax import jit
>>> import numpy as np

>>> @jit
... def func(x, axis):
...   return np.split(x, 2, axis)

>>> func(np.arange(4), 0)  
Traceback (most recent call last):
    ...
TracerIntegerConversionError: The __index__() method was called on
traced array with shape int32[0]

When this happens, the solution is often to mark the problematic argument as static:

>>> from functools import partial
>>> @partial(jit, static_argnums=1)
... def func(x, axis):
...   return np.split(x, 2, axis)

>>> func(np.arange(10), 0)
[Array([0, 1, 2, 3, 4], dtype=int32),
 Array([5, 6, 7, 8, 9], dtype=int32)]

An alternative is to apply the transformation to a closure that encapsulates the arguments to be protected, either manually as below or by using functools.partial():

>>> jit(lambda arr: np.split(arr, 2, 0))(np.arange(4))
[Array([0, 1], dtype=int32), Array([2, 3], dtype=int32)]

Note a new closure is created at every invocation, which defeats the compilation caching mechanism, which is why static_argnums is preferred.

Indexing a list with a Tracer

This error can occur if you attempt to index a Python list with a traced quantity. For example:

>>> import jax.numpy as jnp
>>> from jax import jit

>>> L = [1, 2, 3]

>>> @jit
... def func(i):
...   return L[i]

>>> func(0)  
Traceback (most recent call last):
    ...
TracerIntegerConversionError: The __index__() method was called on
traced array with shape int32[0]

Depending on the context, you can generally fix this either by converting the list to a JAX array:

>>> @jit
... def func(i):
...   return jnp.array(L)[i]

>>> func(0)
Array(1, dtype=int32)

or by declaring the index as a static argument:

>>> from functools import partial
>>> @partial(jit, static_argnums=0)
... def func(i):
...   return L[i]

>>> func(0)
Array(1, dtype=int32, weak_type=True)

To understand more subtleties having to do with tracers vs. regular values, and concrete vs. abstract values, you may want to read Different kinds of JAX values.

Parameters:

tracer (core.Tracer)

class jax.errors.UnexpectedTracerError(msg)#

This error occurs when you use a JAX value that has leaked out of a function. What does it mean to leak a value? If you use a JAX transformation on a function f that stores, in some scope outside of f, a reference to an intermediate value, that value is considered to have been leaked. Leaking values is a side effect. (Read more about avoiding side effects in Pure Functions)

JAX detects leaks when you then use the leaked value in another operation later on, at which point it raises an UnexpectedTracerError. To fix this, avoid side effects: if a function computes a value needed in an outer scope, return that value from the transformed function explicitly.

Specifically, a Tracer is JAX’s internal representation of a function’s intermediate values during transformations, e.g. within jit(), pmap(), vmap(), etc. Encountering a Tracer outside of a transformation implies a leak.

Life-cycle of a leaked value

Consider the following example of a transformed function which leaks a value to an outer scope:

>>> from jax import jit
>>> import jax.numpy as jnp

>>> outs = []
>>> @jit                   # 1
... def side_effecting(x):
...   y = x + 1            # 3
...   outs.append(y)       # 4

>>> x = 1
>>> side_effecting(x)      # 2
>>> outs[0] + 1            # 5  
Traceback (most recent call last):
    ...
UnexpectedTracerError: Encountered an unexpected tracer.

In this example we leak a Traced value from an inner transformed scope to an outer scope. We get an UnexpectedTracerError when the leaked value is used, not when the value is leaked.

This example also demonstrates the life-cycle of a leaked value:

  1. A function is transformed (in this case, by jit())

  2. The transformed function is called (initiating an abstract trace of the function and turning x into a Tracer)

  3. The intermediate value y, which will later be leaked, is created (an intermediate value of a traced function is also a Tracer)

  4. The value is leaked (appended to a list in an outer scope, escaping the function through a side-channel)

  5. The leaked value is used, and an UnexpectedTracerError is raised.

The UnexpectedTracerError message tries to point to these locations in your code by including information about each stage. Respectively:

  1. The name of the transformed function (side_effecting) and which transform kicked off the trace jit()).

  2. A reconstructed stack trace of where the leaked Tracer was created, which includes where the transformed function was called. (When the Tracer was created, the final 5 stack frames were...).

  3. From the reconstructed stack trace, the line of code that created the leaked Tracer.

  4. The leak location is not included in the error message because it is difficult to pin down! JAX can only tell you what the leaked value looks like (what shape it has and where it was created) and what boundary it was leaked over (the name of the transformation and the name of the transformed function).

  5. The current error’s stack trace points to where the value is used.

The error can be fixed by the returning the value out of the transformed function:

>>> from jax import jit
>>> import jax.numpy as jnp

>>> outs = []
>>> @jit
... def not_side_effecting(x):
...   y = x+1
...   return y

>>> x = 1
>>> y = not_side_effecting(x)
>>> outs.append(y)
>>> outs[0] + 1  # all good! no longer a leaked value.
Array(3, dtype=int32, weak_type=True)
Leak checker

As discussed in point 2 and 3 above, JAX shows a reconstructed stack trace which points to where the leaked value was created. This is because JAX only raises an error when the leaked value is used, not when the value is leaked. This is not the most useful place to raise this error, because you need to know the location where the Tracer was leaked to fix the error.

To make this location easier to track down, you can use the leak checker. When the leak checker is enabled, an error is raised as soon as a Tracer is leaked. (To be more exact, it will raise an error when the transformed function from which the Tracer is leaked returns)

To enable the leak checker you can use the JAX_CHECK_TRACER_LEAKS environment variable or the with jax.checking_leaks() context manager.

Note

Note that this tool is experimental and may report false positives. It works by disabling some JAX caches, so it will have a negative effect on performance and should only be used when debugging.

Example usage:

>>> from jax import jit
>>> import jax.numpy as jnp

>>> outs = []
>>> @jit
... def side_effecting(x):
...   y = x+1
...   outs.append(y)

>>> x = 1
>>> with jax.checking_leaks():
...   y = side_effecting(x)  
Traceback (most recent call last):
    ...
Exception: Leaked Trace
Parameters:

msg (str)

Transfer guard#

JAX may transfer data between the host and devices and between devices during type conversion and input sharding. To log or disallow any unintended transfers, the user may configure a JAX transfer guard.

JAX transfer guards distinguish between two types of transfers:

  • Explicit transfers: jax.device_put*() and jax.device_get() calls.

  • Implicit transfers: Other transfers (e.g., printing a DeviceArray).

A transfer guard can take an action based on its guard level:

  • "allow": Silently allow all transfers (default).

  • "log": Log and allow implicit transfers. Silently allow explicit transfers.

  • "disallow": Disallow implicit transfers. Silently allow explicit transfers.

  • "log_explicit": Log and allow all transfers.

  • "disallow_explicit": Disallow all transfers.

JAX will raise a RuntimeError when disallowing a transfer.

The transfer guards use the standard JAX configuration system:

  • A --jax_transfer_guard=GUARD_LEVEL command-line flag and jax.config.update("jax_transfer_guard", GUARD_LEVEL) will set the global option.

  • A with jax.transfer_guard(GUARD_LEVEL): ... context manager will set the thread-local option within the scope of the context manager.

Note that similar to other JAX configuration options, a newly spawned thread will use the global option instead of any active thread-local option of the scope where the thread was spawned.

The transfer guards can also be applied more selectively, based on the direction of transfer. The flag and context manager name is suffixed with a corresponding transfer direction (e.g., --jax_transfer_guard_host_to_device and jax.config.transfer_guard_host_to_device):

  • "host_to_device": Converting a Python value or NumPy array into a JAX on-device buffer.

  • "device_to_device": Copying a JAX on-device buffer to a different device.

  • "device_to_host": Fetching a JAX on-device buffer.

Fetching a buffer on a CPU device is always allowed regardless of the transfer guard level.

The following shows an example of using the transfer guard.

>>> jax.config.update("jax_transfer_guard", "allow")  # This is default.
>>>
>>> x = jnp.array(1)
>>> y = jnp.array(2)
>>> z = jnp.array(3)
>>>
>>> print("x", x)  # All transfers are allowed.
x 1
>>> with jax.transfer_guard("disallow"):
...   print("x", x)  # x has already been fetched into the host.
...   print("y", jax.device_get(y))  # Explicit transfers are allowed.
...   try:
...     print("z", z)  # Implicit transfers are disallowed.
...     assert False, "This line is expected to be unreachable."
...   except:
...     print("z could not be fetched")  
x 1
y 2
z could not be fetched

Pallas: a JAX kernel language#

Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU. This section contains tutorials, guides and examples for using Pallas.

Pallas Design#

In this document, we explain the initial Pallas design. This is a snapshot of some of the earlier design decisions made and Pallas’s specific APIs might have changed since.

Introduction#

JAX is being used for a diverse set of workloads, from large scale machine learning to scientific computing. JAX’s success story is as much a success story for XLA, the primary compiler that JAX targets – XLA compiles JAX programs for accelerators and has enabled JAX to scale to the largest ML models. JAX describes logical computations in XLA’s representation, HLO. HLO describes how computations happen logically but not physically. Given a logical HLO computation, XLA decides how that computation is to be executed physically. For a wide variety of ML applications, XLA does a good job of compiling user programs but inevitably some users hit XLA’s limitations. In these cases, we need to provide an “escape hatch” to allow experts to write hand-tuned kernels that outperform XLA at that point in time. Furthermore, advances in ML systems research take some time to be incorporated into XLA and users often want to run ahead with them. Over time, the compiler can incorporate the optimizations that were proven out experimentally through hand-tuned kernels.

XLA does offer the CustomCall mechanism as an escape hatch, but it requires users to write C++ and on GPU it requires users to learn the CUDA programming model. The CUDA programming model is arguably too low-level for many machine learning GPU kernels, like matrix multiplication, and even expert users will have trouble using CUDA to implement efficient matrix multiplication or multi-headed attention. Not only this, JAX users are usually familiar with Python and NumPy-style array programming which doesn’t involve writing any C++ or thinking about GPU parallelism. All popular machine learning frameworks share this idea: manipulating (usually) arrays with high level operations like matmul or convolution. Unfortunately, this means implementing a custom operation via CustomCall is a big investment, involving potentially learning C++ and/or GPU programming.

Triton, a GPU compiler built and maintained by OpenAI, has taken the ML compiler world by storm. Triton offers the best of both worlds: an array-based programming model for GPU kernels. Triton is the primary code generation route for torch.compile in PyTorch 2.0, via the Torch Inductor library. Triton actively hides some aspects of GPU programming in the name of a more accessible programming model that can be used from Python and to generate optimized code from a higher-level representation. While GPUs are more flexible than what Triton offers, in the ML domain, Triton seems to be expressive enough for many applications.

In this document, we describe Pallas, an extension to JAX that enables kernel programming for both GPUs and TPUs using a Triton-like model. A JAX-based kernel language offers several advantages:

  • Although Triton exposes a TPU-like programming model to users, i.e. writing programs for tiles of arrays in L1-cache, it is specialized enough to GPU that we cannot directly compile Triton for TPU. For example, Triton offers atomic operations specifically meant to handle parallel writes that don’t necessarily make sense on TPU. A higher level front end can abstract away details of the platform while surfacing just that tile-based programming model. The kernels will thus be portable across different hardware platforms.

  • JAX as a tracing-based frontend for numerical computing is both mature and well-used. By embedding the kernel programming language in JAX itself, we can re-use JAX’s tracing infrastructure and provide a NumPy-like frontend that’s already familiar to users.

  • JAX transformations are key to its success, allowing users to express simple programs but transform them to achieve complex functionality. We can leverage the same transformations (vmap, jvp, etc.) to transform user-written kernels.

The open question is: is JAX a good fit for a kernel language at all? We think so. Triton demonstrates that an array programming language can be practical for writing GPU kernels and JAX is just that. JAX has also proven to be a flexible front-end for compilers and for program transformations.

We describe Pallas as follows: we first describe the ways in which we extend JAX to support writing custom kernels. We then show how we can lower Pallas to both Triton and Mosaic. We conclude by describing existing and potential ways to transform Pallas kernels via JAX transformations.

Pallas lowering path Visualization of Pallas lowering paths

Pallas: Extending JAX for kernels#

The key point we’d like to make is that Pallas is just JAX, with some extensions:

  1. Users now use reference types called Refs in their JAX code. This gives users more precise control over memory access and layout in JAX will more closely resemble physical layout.

  2. Users write their JAX programs using a subset of JAX primitives, along with a set of Pallas-specific primitives.

  3. Users embed their Pallas kernels in an outer JAX program via a special pallas_call higher-order function, that executes the kernel in a map. It is analogous to pmap or shard_map, except with references to shared memory.

We’ll go over these three extensions one at a time, by example.

Note that these APIs are still experimental and subject to change.

Reference types#

Let’s look at an example Pallas program for adding two vectors:

import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl

def add_kernel(x_ref, y_ref, o_ref):
  # In this code, `x_ref`, `y_ref` and `o_ref` are (8,)-shaped `Ref`s
  x = x_ref[:]
  y = y_ref[:]
  o_ref[:] = x + y
x, y = jnp.arange(8), jnp.arange(8, 16)
add = pl.pallas_call(add_kernel, out_shape=jax.ShapeDtypeStruct((8,), jnp.int32))
add(x, y)

Unlike a regular JAX program, add_kernel does not receive immutable array arguments. Instead, it’s provided with references that can be read from and updated in-place using NumPy-like syntax. Refs are not a Pallas-specific concept – they were introduced to JAX to represent stateful computations. However, we can leverage them when writing kernels that operate on mutable memory too.

Pallas kernels not only receive Refs corresponding to the inputs to the kernel, but also receive Refs for the outputs as well (specified in pallas_call via out_shape). Refs are special types that cannot be passed into the usual set of JAX primitives without being read from first. When you read from a Ref you get a JAX Array type out, and you must write an Array into a Ref.

Reading from/writing into Refs#

Reading from a Ref corresponds to loading an array into the lowest level of the memory hierarchy (L1-cache on GPU and vector registers on TPU). Writing into a Ref is analogous.

def f(x_ref, o_ref):
  # Using vanilla Python indexing
  x = x_ref[0, 2:5, :]
  # Or via Numpy advanced int indexing
  o_ref[jnp.arange(3), :] = x

# Note that in order to use NumPy advanced int indexing, you need to broadcast the indices against each other into the desired multidimensional shape:
def f(x_ref):
  # Assume x_ref is (8, 4) and we want to read out a (2, 3) slice
  x = x_ref[jnp.arange(2)[..., None], jnp.arange(3)[None, ...]]

Writing to Refs can be done via analogous __setitem__ style indexing.

Other forms of indexing (for example, dynamic slicing) can be done via pallas.load and pallas.store, new JAX primitives designed to make loading from/storing into memory easier. We’ll discuss these new primitives later.

Extending JAX with new Pallas primitives#

Because JAX was designed with HLO in mind, the set of JAX primitives closely mirrors the set of HLO operations. Targeting a new compiler (e.g. Triton or Mosaic) means we might need to supplement JAX’s primitives with new ones specific to the new compiler. At the same time, we may not be able to lower all JAX primitives, so we need to restrict it to a subset.

Because Pallas was initially designed with Triton in mind, we offer a set of new primitives targeting the Triton programming model. As we’ll show later, we can lower these primitives to Mosaic as well.

pallas.load and pallas.store#

pallas.load and pallas.store are primitives that allow loading from memory and storing into memory. Unlike __getitem__ and __setitem__ they are more flexible at the cost of being more verbose. Specifically, you can use the pallas.dynamic_slice (pallas.ds for short) construct (which should maybe be upstreamed into JAX to be used with Ref __getitem__ and __setitem__).

def f(x_ref, o_ref):
  # Reading from memory via pallas.load
  x = pl.load(x_ref, (0, slice(2, 5), slice(None)))
  # Using integer indexing automatically broadcasts
  x = pl.load(x_ref, (0, 2 + jnp.arange(3), slice(None)))
  # You can also use `pl.dynamic_slice` (`pl.ds` for short) objects as well
  pl.store(o_ref, (0, pl.ds(start=2, size=3), slice(None)), x)

pallas.load and pallas.store also support masking via the mask argument.

def f(x_ref, o_ref):
  # Reading from memory via pallas.load
  idx = jnp.arange(8)
  mask = idx < 5
  x = pl.load(x_ref, (idx,), mask=mask, other=float('-inf'))

Masking is important when doing out-of-bounds loads/stores. The operational semantics of masking can be compiler-determined (if we understand the documentation properly, Triton avoids the read from/write to memory if it’s masked).

pallas.program_id and pallas.num_programs#

As we’ll soon see, we’ll be executing the same Pallas kernels many times (either in parallel or in a pipeline depending on the backend). These new primitives tell us “where” we are in the execution of the kernel.

pallas.program_id takes in an axis argument, which tells us which index in an axis of a multidimensional grid this kernel is currently executing in (analogous to threadId from CUDA programming or lax.axis_index in jax.pmap). Note that we are currently borrowing the “program” terminology from Triton and in the future we might want to change it to something more familiar to JAX users.

def f(x_ref, o_ref):
  i = pl.program_id(axis=0)  # execution index in the first axis of the grid
  o_ref[i] = jnp.exp(x_ref[i])

pallas.num_programs also takes in an axis and returns the grid size for that axis.

Note that while program_id and num_programs are Triton-specific terminology they are easily generalized to make sense on TPU as well.

Using a subset of JAX primitives in Pallas#

Because we’re writing kernels, not high-level HLO programs, some JAX primitives may not be able to be represented in our underlying substrate efficiently. However, we know we can support most elementwise operations, simple dot products, and JAX control flow.

While we haven’t yet mapped out exactly all the JAX primitives that we can support in Pallas kernels, we can certainly identify some that are not easy to lower or are unlikely to be useful:

  • conv_general - convolution usually isn’t offered as primitive in the underlying hardware.

  • gather/scatter - the underlying compiler may not support noncontiguous memory reads and writes

Executing Pallas kernels with pallas_call#

Now that we’ve written our Pallas kernels (a.k.a. JAX with Refs and the extra Pallas primitives), how do we execute them on a GPU or TPU? We use pallas_call, a higher order function (akin to jax.jit and jax.pmap) that executes the kernel.

The signature of pallas_call is as follows:

def pallas_call(
    kernel: Callable,
    in_specs: Sequence[Spec],
    out_specs: Sequence[Spec],
    out_shapes: Sequence[jax.ShapeDtypeStruct],
    grid: Optional[Tuple[int, ...]] = None) -> Callable:
  ...

When we provide a kernel to pallas_call we provide additional information. The first is out_shape which tells the kernel what the outputs look like (pallas_call will pass a Ref corresponding to these into the kernel to be written to). The rest of the information (in_specs, out_specs, and grid) are information about how the kernel will be scheduled on the accelerator.

The (rough) semantics for pallas_call are as follows:

def pallas_call(kernel, in_specs, out_specs, out_shapes, grid):
  def execute(*args):
    outputs = map(empty_ref, out_shapes)
    grid_indices = map(range, grid)
    for indices in itertools.product(*grid_indices): # Could run in parallel!
      local_inputs = [in_spec.transform(arg, indices) for arg, in_spec in
                      zip(args, in_specs)]
      local_outputs = [out_spec.transform(arg, indices) for arg, out_spec  in
                       zip(outputs, out_specs)]
      kernel(*local_inputs, *local_outputs) # writes to outputs
  return execute

Specifically, pallas_call will “loop” over grid iteration space, applying a transformation to the inputs and outputs specified via the in_specs and out_specs. In each iteration, the kernel will be called on the transformed inputs and outputs. Note that the “loop” over the iteration space could be executed in parallel (e.g. on GPU). pallas_call also provides no guarantees on the order of loop iterations over the iteration space, just that every member of the iteration space will be looped over. Compilers like Triton and Mosaic will have more specific operational semantics associated with the grid.

Transformation functions#

The in_specs and out_specs arguments to pallas_call allow inputs and outputs to be transformed in some way. The two options that Pallas offers right now are an identity transformation (where inputs and outputs are left unchanged), and BlockSpecs, take fixed-size slices of Refs determined by the loop index.

A BlockSpec takes an index_map function and a block_shape. Logically, it takes an array and slices it along each axis into block_shape sizes blocks. The index_map function takes loop indices (from the grid index set) and maps them to block indices. The transform function converts Refs into logical views of the Ref at the corresponding block. When we specify None in an entry in block_shape, that corresponds to “mapping” over that dimension, removing it from the block within the kernel.

class BlockSpec:
  index_map: Callable[[Tuple[Int, ...]], Tuple[Int, ...]]
  block_shape: Tuple[Optional[int], ...]

  def transform(self, ref, *loop_indices):
    block_indices = self.transform_function(loop_indices)
    # Returns a view of `ref` starting at `block_indices` of shape self.block_shape
    ...

We could also imagine other Specs that are used with pallas_call, for example a Spec that corresponds to overlapping windows to, say, implement convolutions.

Immediate benefits of Pallas as a front-end#

By offering a JAX front-end for kernel writing, we can immediately reap some benefits.

More flexible front end#

The first is that JAX users are already accustomed to the benefits (and limitations) of programming with JAX and its tracing-based transformations. This means users can use closures and other familiar Python constructs when writing Pallas kernels. This is unlike the existing AST-parsing-based Triton front end or the MLIR builders for Mosaic. For example, this makes Pallas far more amenable to templating than Triton.

See this example of how we can use higher-order functions in Python to template a kernel.

def make_kernel(eltwise_kernel):
  def add(x_ref, y_ref, o_ref):
    x = pl.load(x_ref, ())
    y = pl.load(y_ref, ())
    pl.store(o_ref, (), eltwise_kernel(x + y))
  return add

kernel1 = make_kernel(lambda x: x * 2)
kernel2 = make_kernel(jnp.exp)

pl.pallas_call(kernel1, out_shape=x, grid=1)(1., 1.)
pl.pallas_call(kernel2, out_shape=x, grid=1)(1., 1.)
Emulation mode#

By representing kernels as programs with JAX primitives and some new Pallas primitives, we can also lower Pallas programs to StableHLO directly and compile/execute them with XLA. Specifically, a pallas_call can be implemented as a lax.scan over the grid. This enables us to develop GPU or TPU kernels on any XLA-supported platform (even CPU!) and debug them using JAX/XLA debugging tools (like jax.debug.print). We can also use the more reliable and better tested XLA numerics to verify the correctness of the Triton and Mosaic compilers. One could also imagine perturbing the scan ordering to simulate the parallel reads and writes that happen on GPU.

Examples#
add#

We modify our add_kernel example to operate over (2,)-sized blocks using BlockSpecs.

def add_kernel(x_ref, y_ref, o_ref):
  # In this code, `x_ref`, `y_ref` and `o_ref` are (2,)-shaped `Ref`s
  x = x_ref[:]
  y = y_ref[:]
  o_ref[:] = x + y
x, y = jnp.arange(8), jnp.arange(8, 16)
add = pl.pallas_call(
    add_kernel,
    out_shape=jax.ShapeDtypeStruct((8,), jnp.int32),
    in_specs=[
      pl.BlockSpec(lambda i: i, (2,)),
      pl.BlockSpec(lambda i: i, (2,))
    ],
    out_specs=pl.BlockSpec(lambda i: i, (2,)),
    grid=(4,))
add(x, y)
Templated matmul#

In this example, we compute tiles of the output by doing an unrolled accumulation over blocks of rows and columns from our input arrays. We inline an activation function into the body of the kernel using a higher order function so we can emit a fused kernel.

def matmul_kernel(x_ref, y_ref, o_ref, *, activation, block_k):
  acc = jnp.zeros((x_ref.shape[0], y_ref.shape[1]), jnp.float32)
  for k in range(x_ref.shape[1] // block_k):
    x = x_ref[:, k*block_k:(k+1)*block_k]
    y = y_ref[k*block_k:(k+1)*block_k, :]
    acc += x @ y
  o_ref[:, :] = activation(acc).astype(o_ref.dtype)

x, y = jnp.ones((512, 256)), jnp.ones((256, 1024))
block_shape = 128, 256, 128

@partial(jax.jit, static_argnames=["block_shape", "activation"])
def matmul(x, y, *, block_shape, activation):
  block_m, block_n, block_k = block_shape
  fused_matmul = pl.pallas_call(
      partial(matmul_kernel, block_k=block_k, activation=activation),
      out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1],), jnp.float32),
      in_specs=[
        pl.BlockSpec(lambda i, j: (i, 0), (block_m, x.shape[1])),
        pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], block_n))
      ],
      out_specs=pl.BlockSpec(lambda i, j: (i, j), (block_m, block_n)),
      grid=(4, 4),
  )
  return fused_matmul(x, y)

z = matmul(x, y, block_shape=block_shape, activation=jax.nn.gelu)
Lowering Pallas#

After users express their Pallas kernels, we lower them to different representations depending on the target backend. On GPUs, we lower Pallas to Triton IR, and on TPU we lower Pallas to Mosaic.

Lowering Pallas to Triton for GPU#

Lowering Pallas to Triton is easy because Pallas was designed with Triton as a target language in mind. The main differences between Pallas and Triton is that Triton doesn’t have a notion of BlockSpecs and also uses pointers when doing memory loads and stores as opposed to indices.

Triton supports pointers as an array element type in its language and in Triton you can load from and store to arrays of pointers. In Pallas, when given a (4, 5)-shaped Ref, x_ref, and then do like x_ref[3, 2], we need to lower this to computing a Triton pointer to the appropriate row-major position in x_ref (that is, doing 5 * 3 + 2 * 1). Similarly, when we lower slices to Triton, e.g. x_ref[4, :] we need to produce an array of pointers 5 * 4 + jnp.arange(3).

Other than that, lowering to Triton is fairly straightforward. JAX dot products can be lowered to Triton dot products and JAX unary primitives are lowered to their Triton equivalents. Triton’s atomic operations are lowered via new Pallas atomic primitives.

Lowering Pallas to Mosaic for TPU#

Mosaic consumes (mostly) standard dialect MLIR and emits LLO to be compiled for TPU. Pallas can be lowered to Mosaic via translating JAX primitives to MLIR (mostly the vector and arith dialects). The BlockSpecs can be converted into pipeline schedules (i.e. the transform_funcs in Mosaic).

Transforming Pallas#

A natural question is how do JAX transformations interact with Pallas kernels? There are two main ways: transformations inside Pallas kernels and transformations outside Pallas kernels.

Transformation inside Pallas kernels should actually “just work”, so long as we are able to lower the transformed code. For example, we could use jax.grad(jnp.sin)(...) inside of a JAX kernel because we can lower a cos to both Triton and Mosaic. However, we might not be able to lower a jax.vmap(lax.dynamic_slice) because it could turn into a gather that we cannot lower.

Transformations of Pallas kernels from the outer JAX programs is perhaps the more interesting case. How do we handle things like vmap(pallas_call) and grad(pallas_call)?

vmap-of-pallas_call#

vmap automatically vectorizes JAX programs. While kernel writers might want precise control over how a batched kernel will behave differently from its unbatched variant, we can offer a reasonable default vmap rule for pallas_call while offering the jax.custom_vmap customization mechanism. When pallas_call is vmap-ed, we augment the pallas_call to have an extra grid dimension corresponding to the new batch dimension and transform the BlockSpecs to handle indexing along that dimension.

grad-of-pallas_call#

grad of pallas_call enables automatic differentiation of kernels. jax.grad breaks down into applications of three distinct transforms: jvp, partial_eval and transpose. In principle, we can re-use most of JAX’s infrastructure when implementing these rules for pallas_call (since it behaves much like existing JAX higher order primitives).

However, automatic differentiation of kernels can result in a performance hit due to how memory access is transposed. If we write a GPU kernel with overlapping-and-parallel reads and disjoint-but-parallel writes, we automatically transpose it into a kernel that has overlapping-but-parallel writes (which are slow when done atomically) and disjoint-and-parallel reads. To emit a kernel that better uses parallelism with shared memory, we would need to reorder loops and change how the kernel is vectorized. Unfortunately, we do not have a program representation amenable to that in Pallas. A potential direction to automatically differentiating kernels efficiently is to explore a different representation, perhaps one like that in Dex. We could also look at how Enzyme approaches this problem. However, AD of Pallas kernels may still be useful for a class of kernels that does transpose efficiently (for example elementwise kernels).

In general, though, jax.custom_vjp is a viable escape hatch to express Pallas kernels that work with jax.grad.

Other transformations#

We could imagine other JAX transformations applying to Pallas kernels that we haven’t explicitly explored yet. For example, checkify is a JAX transformation that does functional error handling. We could imagine using checkify with pallas_call to allow plumbing out error codes from GPU kernels that indicate if OOB access or NaNs were produced.

Another potential transformation to integrate with is custom_partitioning to enable automatically partitionable kernels to be used with pjit.

Pallas Quickstart#

Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU. Pallas allows you to use the same JAX functions and APIs but operates at a lower level of abstraction.

Specifically, Pallas requires users to think about memory access and how to divide up computations across multiple compute units in a hardware accelerator. On GPUs, Pallas lowers to Triton and on TPUs, Pallas lowers to Mosaic.

Let’s dive into some examples.

Note: Pallas is still an experimental API and you may be broken by changes!

Hello world in Pallas#
from functools import partial

import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np

We’ll first write the “hello world” in Pallas, a kernel that adds two vectors.

def add_vectors_kernel(x_ref, y_ref, o_ref):
  x, y = x_ref[...], y_ref[...]
  o_ref[...] = x + y

Ref types

Let’s dissect this function a bit. Unlike most JAX functions you’ve probably written, it does not take in jax.Arrays as inputs and doesn’t return any values. Instead it takes in Ref objects as inputs. Note that we also don’t have any outputs but we are given an o_ref, which corresponds to the desired output.

Reading from Refs

In the body, we are first reading from x_ref and y_ref, indicated by the [...] (the ellipsis means we are reading the whole Ref; alternatively we also could have used x_ref[:]). Reading from a Ref like this returns a jax.Array.

Writing to Refs

We then write x + y to o_ref. Mutation has not historically been supported in JAX – jax.Arrays are immutable! Refs are new (experimental) types that allow mutation under certain circumstances. We can interpret writing to a Ref as mutating its underlying buffer.

So we’ve written what we call a “kernel”, which we define as a program that will run as an atomic unit of execution on an accelerator, without any interaction with the host. How do we invoke it from a JAX computation? We use the pallas_call higher-order function.

@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
  return pl.pallas_call(add_vectors_kernel,
                        out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
                        )(x, y)
add_vectors(jnp.arange(8), jnp.arange(8))
Array([ 0,  2,  4,  6,  8, 10, 12, 14], dtype=int32)

pallas_call lifts the Pallas kernel function into an operation that can be called as part of a larger JAX program. But, to do so, it needs a few more details. Here we specify out_shape, an object that has a .shape and .dtype (or a list thereof). out_shape determines the shape/dtype of o_ref in our add_vector_kernel.

pallas_call returns a function that takes in and returns jax.Arrays.

What’s actually happening here?

Thus far we’ve described how to think about Pallas kernels but what we’ve actually accomplished is we’re writing a function that’s executed very close to the compute units.

On GPU, x_ref corresponds to a value in high-bandwidth memory (HBM) and when we do x_ref[...] we are copying the value from HBM into static RAM (SRAM) (this is a costly operation generally speaking!). We then use GPU vector compute to execute the addition, then copy the resulting value in SRAM back to HBM.

On TPU, we do something slightly different. Before the kernel is ever executed, we fetch the value from HBM into SRAM. x_ref therefore corresponds to a value in SRAM and when we do x_ref[...] we are copying the value from SRAM into a register. We then use TPU vector compute to execute the addition, then copy the resulting value back into SRAM. After the kernel is executed, the SRAM value is copied back into HBM.

We are in the process of writing backend-specific Pallas guides. Coming soon!

Pallas programming model#

In our “hello world” example, we wrote a very simple kernel. It takes advantage of the fact that our 8-sized arrays can comfortably fit inside the SRAM of hardware accelerators. In most real-world applications, this will not be the case!

Part of writing Pallas kernels is thinking about how to take big arrays that live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations that operate on “blocks” of those arrays that can fit in SRAM.

Grids#

To automatically “carve” up the inputs and outputs, you provide a grid and BlockSpecs to pallas_call.

A grid is a tuple of integers (e.g. (), (2, 3, 4), or (8,)) that specifies an iteration space. For example, a grid (4, 5) would have 20 elements: (0, 0), (0, 1), ..., (0, 4), (1, 0), ..., (3, 4). We run the kernel function once for each element, a style of single-program multiple-data (SPMD) programming.

A visualization of a 2D grid

A 2D grid

When we provide a grid to pallas_call, the kernel is executed as many times as prod(grid). Each of these invocations is referred to as a “program”, To access which program (i.e. which element of the grid) the kernel is currently executing, we use program_id(axis=...). For example, for invocation (1, 2), program_id(axis=0) returns 1 and program_id(axis=1) returns 2.

Here’s an example kernel that uses a grid and program_id.

def iota_kernel(o_ref):
  i = pl.program_id(0)
  o_ref[i] = i

We now execute it using pallas_call with an additional grid argument.

def iota(len: int):
  return pl.pallas_call(iota_kernel,
                        out_shape=jax.ShapeDtypeStruct((len,), jnp.int32),
                        grid=(len,))()
iota(8)
Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)

On GPUs, each program is executed in parallel on separate threads. Thus, we need to think about race conditions on writes to HBM. A reasonable approach is to write our kernels in such a way that different programs write to disjoint places in HBM to avoid these parallel writes. On the other hand, parallelizing the computation is how we can execute operations like matrix multiplications really quickly.

On TPUs, programs are executed in a combination of parallel and sequential (depending on the architecture) so there are slightly different considerations.

Block specs#

With grid and program_id in mind, Pallas provides an abstraction that takes care of some common indexing patterns seen in a lot of kernels. To build intuition, let’s try to implement a matrix multiplication.

A simple strategy for implementing a matrix multiplication in Pallas is to implement it recursively. We know our underlying hardware has support for small matrix multiplications (using GPU and TPU tensorcores), so we just express a big matrix multiplication in terms of smaller ones.

Suppose we have input matrices \(X\) and \(Y\) and are computing \(Z = XY\). We first express \(X\) and \(Y\) as block matrices. \(X\) will have “row” blocks and \(Y\) will have “column” blocks.

\[\begin{split} \begin{align*} X = \begin{bmatrix} X_0 \\ X_1 \end{bmatrix} \end{align*} \end{split}\]
\[ \begin{align*} Y = \begin{bmatrix} Y_0 & Y_1 \end{bmatrix} \end{align*} \]
\[\begin{split} \begin{align*} Z &= \begin{bmatrix} X_0 \\ X_1 \end{bmatrix} \begin{matrix} \begin{bmatrix} Y_0 & Y_1 \end{bmatrix} \\ ~ \end{matrix} \\ &= \begin{bmatrix} X_0 Y_0 & X_0 Y_1 \\ X_1 Y_0 & X_1 Y_1 \end{bmatrix} \end{align*} \end{split}\]

Our strategy is that because \(Z\) is also a block matrix, we can assign each of the programs in our Pallas kernel one of the output blocks. Computing each output block corresponds to doing a smaller matrix multiply between a “row” block of \(X\) and a “column” block of \(Y\).

To express this pattern, we use BlockSpecs. A BlockSpec specifies a block shape for each input and output, and an “index map” function, that maps a set of program indices to a block index.

A visualization of a BlockSpec`

A visualization of a BlockSpec

For a concrete example, let’s say we’d like to multiply two (1024, 1024) matrices x and y together to produce z, and would like to parallelize the computation 4 ways. We split up z into 4 (512, 512) blocks where each block is computed with a (512, 1024) x (1024, 512) matrix multiplication. To express this, we’d first use a (2, 2) grid (one block for each program).

For x, we use BlockSpec(lambda i, j: (i, 0), (512, 1024)) – this carves x up into “row” blocks. To see this see how both program instances (1, 0) and (1, 1) pick the (1, 0) block in x. For y, we use a transposed version BlockSpec(lambda i, j: (0, j), (1024, 512)). Finally, for z we use BlockSpec(lambda i, j: (i, j), (512, 512)).

These BlockSpecs are passed into pallas_call via in_specs and out_specs.

Underneath the hood, pallas_call will automatically carve up your inputs and outputs into Refs for each block that will be passed into the kernel.

def matmul_kernel(x_ref, y_ref, z_ref):
  z_ref[...] = x_ref[...] @ y_ref[...]

def matmul(x: jax.Array, y: jax.Array):
  return pl.pallas_call(
    matmul_kernel,
    out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),
    grid=(2, 2),
    in_specs=[
      pl.BlockSpec(lambda i, j: (i, 0), (x.shape[0] // 2, x.shape[1])),
      pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], y.shape[1] // 2))
    ],
    out_specs=pl.BlockSpec(
      lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)
    )
  )(x, y)
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (1024, 1024))
y = jax.random.normal(k2, (1024, 1024))
z = matmul(x, y)
np.testing.assert_allclose(z, x @ y)

Note that this is a very naive implementation of a matrix multiplication but consider it a starting point for various types of optimizations. Let’s add an additional feature to our matrix multiply: fused activation. It’s actually really easy! Just pass a higher-order activation function into the kernel.

def matmul_kernel(x_ref, y_ref, z_ref, *, activation):
  z_ref[...] = activation(x_ref[...] @ y_ref[...])

def matmul(x: jax.Array, y: jax.Array, *, activation):
  return pl.pallas_call(
    partial(matmul_kernel, activation=activation),
    out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),
    grid=(2, 2),
    in_specs=[
      pl.BlockSpec(lambda i, j: (i, 0), (x.shape[0] // 2, x.shape[1])),
      pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], y.shape[1] // 2))
    ],
    out_specs=pl.BlockSpec(
      lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)
    ),
  )(x, y)
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (1024, 1024))
y = jax.random.normal(k2, (1024, 1024))
z = matmul(x, y, activation=jax.nn.relu)
np.testing.assert_allclose(z, jax.nn.relu(x @ y))

To conclude, let’s highlight a cool feature of Pallas: it composes with jax.vmap! To turn this matrix multiplication into a batched version, we just need to vmap it.

k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (4, 1024, 1024))
y = jax.random.normal(k2, (4, 1024, 1024))
z = jax.vmap(partial(matmul, activation=jax.nn.relu))(x, y)
np.testing.assert_allclose(z, jax.nn.relu(jax.vmap(jnp.matmul)(x, y)))

Pallas TPU#

TPU specific documentation.

Writing TPU kernels with Pallas#

This page focuses on the details that are important when attempting to run Pallas kernels on Google TPUs. For one, the TPU backend is still in an experimental phase, and only a subset of JAX NumPy will be accepted. Furthermore, writing performant code for TPUs might require thinking carefully about the native capabilities of the hardware. While many patterns that are unnatural to the hardware will be accepted, they might end up requiring software emulation, and can slow down the computation.

Warning

This feature should still be considered experimental as work is still in progress (in particular on improving the error messages).

Note

While all the features described here are experimental, we remain very serious about maintaining their correctness. As such, it might not be uncommon to see a “not implemented” error while attempting to write TPU kernels. But, if a kernel is accepted by the compiler, it must return the expected results.

If you see unexpected outputs, please compare them against a kernel run with interpret=True passed in to pallas_call. If the results diverge, please file a bug report.

What is a TPU?#
A TPUv4 board

TPU is a hardware accelerator developed at Google. You can think of TPUs as GPUs, but specialized for machine learning workloads specifically. As such, their architecture differs quite significantly. However, we believe that Pallas can make it easy to start writing TPU kernels, even without having a full understanding of the underlying hardware. Having said that, understanding the hardware well will certainly make it easier to write performant kernels.

In a nutshell, the main difference between TPUs and GPUs is that TPUs are sequential machines with a very wide vector register (kind of like a CPU!). At the same time, they allow the software to schedule certain operations in the background, making them execute asynchronously with respect to the main instruction stream. This includes things like HBM memory accesses (which cannot be issued directly, but instead have to be prefetched to lower levels of the memory hierarchy by the DMA subunits), matrix multiplies (supported by the MXU unit) or matrix transpositions and permutes (supported by the XLU unit).

If you’re interested in learning more about the TPU architecture in detail, we recommend reading a collection of papers published over the years. While many of them talk about specific TPU generations, many of the ideas described transfer to later generations as well.

Noteworthy properties and restrictions#
BlockSpecs and grid iteration#

BlockSpecs generally behave as expected in Pallas — every invocation of the kernel body gets access to slices of the inputs and is meant to initialize a slice of the output.

Warning

Not all window shapes are supported. If the last two dimensions of your input are larger than 8 and 128 respectively, the window shape in those dimensions must be a multiple of the respective factor. If the input dimension is smaller, the window should span the full dimension.

One interesting aspect of Pallas TPU kernels is the way they handle memory spaces: While the inputs to pallas_call will often reside in HBM (the main TPU memory), the references passed in to the kernel body will point to buffers in lower levels of memory hierarchy (VMEM or SMEM). This enables the kernel body to write and read them at very high speeds, while all the communication with HBM (which has very high latency) is handled by the compiler and overlapped with compute.

What’s more, compared to GPUs, TPUs are actually highly sequential machines. Ergo, the grid is generally not processed in parallel, but sequentially, in lexicographic order (though see the Multicore TPU configurations section for exceptions). This unlocks some interesting capabilities:

  • When two (lexicographically) consecutive grid indices use the same slice of an input, the HBM transfer for the second iteration is skipped, as the data is already available.

  • Multiple invocations of the kernel body can write to the same slice of the output, without any risk of race conditions. However, we do require that all invocations that write to a particular slice are consecutive.

The “consecutive” restriction on the output usually means that the some prefix of the grid dimensions always vary the slice of the output an invocation needs to access, while the output window remains constant for the remaining suffix.

For example, when implementing a Pallas TPU kernel for matrix multiplication, one would generally use a 3 dimensional grid: the first two dimensions would correspond to slicing along the first axis of the left operand and the second axis of the second operand. The third and last grid axis would tile the reduction dimension. The grid axis corresponding to the reduction dimension has to be the last one, since the output window does not vary along this axis. The output reference can be then used as an accumulator for partial results.

Note

VMEM is fairly large for such a low-level memory hierarchy (16MB+), making it possible to use large window sizes. And, oftentimes, the larger the window size, the better the eventual hardware utilization will be. However, it is possible to specify a window size that (together with space necessary to hold spilled vector registers) exceeds the size of VMEM. In this case, you will likely see a low-level compiler error message complaining about an out-of-memory error.

Dimension ordering is meaningful#

In JAX programs, the ordering of intermediate arrays inside jax.jit usually has no impact on performance, as the compiler is free to rearrange them. However, as Pallas is meant to expose lower-level capabilities, the dimension order can have great impact on the quality of generated code.

Recall that the TPUs perform bulk of the computation on 2D vector registers. Pallas TPU will only ever consider mapping the last two dimensions of intermediate arrays to those vector register dimensions (sublanes and lanes respectively). An array of shape (n, 1, 1) is guaranteed to require at least n vector registers to represent. If n becomes too large, this can lead to spills, and potential VMEM OOM errors due to an overly large memory footprint. But it also might not — the low-level compiler is free to rearrange the instructions to lower the register pressure, and is in fact very good at it. Still, it is a good rule of thumb to keep the last two dimensions large (especially the last dimension), while keeping the leading dimensions small.

Multicore TPU configurations#

In newer TPU generations, the two cores on a chip are often abstracted as a single device. To take advantage of multiple cores, Pallas has to break the sequential grid execution guarantees, and will need to parallelize one of the grid axes over cores. This is an opt-in procedure. To allow that, pallas_call requires an extra parameter named dimension_semantics:

That parameter is a list, with as many entries as many axes there are in the grid. Only parallel dimensions can be partitioned over cores. As a rule of thumb, the dimensions are parallel, unless the output window does not vary. As such, dimension_semantics is always a number of parallel axes followed by a number of arbitrary axes.

While partitioning a kernel over a 2-core TPU device often leads to a 2x speedup, it can be in fact significantly smaller. This is especially true if different instances of the body have highly varying cost. If all of the expensive steps get mapped to one core, but all cheap steps are assigned to the other, the second core will be sitting idle until the first one completes its tasks.

Pallas TPU generally favors partitioning axes of a size that is a multiple of the number of TPU cores, and prefers to partition leading grid axes.

Placing operands in SMEM#

Most of the compute on the TPU will happen on the vector unit. Still, there are many cases where it is useful to perform a number of scalar operations, e.g., to carry out control-flow. For that reason, TPUs come with a separate scalar unit, and a separate scalar memory (SMEM) attached to it. As a rule of thumb, any data used to perform control-flow decisions should be placed in SMEM.

SMEM is a low-latency memory that supports random access, but lets you only read and write 32-bit values with a single instruction (very small compared to the 4KBi granularity of VMEM transactions, but much more flexible due to lack of alignment requirements!).

The scalar memory is also very useful when implementing kernels that do not access the tiles of inputs in a regular pattern, such as when writing block-sparse kernels. In Pallas, this can be achieved by replacing the grid argument to pallas_call with a grid_spec of PrefetchScalarGridSpec with a non-zero num_scalar_prefetch argument. If num_scalar_prefetch is n, then the first n arguments to pallas_call will be placed in SMEM. No BlockSpecs should be specified for those arguments. But, the BlockSpecs for all subsequent arguments will receive not only the grid indices, but also the SMEM references to the leading operands.

Note

We are working on implementing examples for this feature. Stay tuned!

Supported data types#

At the moment Pallas TPU only supports the following data types:

  • jnp.float32

  • jnp.bfloat16

  • jnp.int* (all precisions, except for jnp.int4)

  • jnp.uint* (all precisions)

Computation placement#

All scalar (i.e. 0D) arrays will be stored in scalar registers, and operations on then will be executed on the scalar core. All other operations (even on single-element, but 1D+ arrays) will be executed on the vector core.

Supported operations#
Matrix multiplication#

Matrix multiplication always produces results in the float32 format. If your inputs are not float32, we recommend using lax.dot with preferred_element_type set to jnp.float32.

When using lax.dot_general, it is possible to fuse transpositions of the last two dimensions of matrix multiplication operands into the operation, which can improve overall kernel performance.

Precision control#

Pallas TPU lowering is aware of jax.default_matmul_precision. For best performance (and lowest precision), use bfloat16. If you care about numerical accuracy, you might want to set the precision to float32.

Warning

Even if you pass in 32-bit operands to a matrix multiplication, they will be rounded to bfloat16 unless float32 precision is requested.

Transposition#

If the value has at least 4 dimensions, arbitrary transpositions of all but the last two axes are free. Otherwise, only the transposition of the last two axes is implemented. Note that some transpositions of the last two dimensions can be fused into matrix multiplication.

Accessing memory#

Arbitrary slices of references can be read or updated, subject to implementation constraints. Currently, no restrictions are placed on inputs that are 32-bit wide, but only some slicing patterns are supported for narrower types. Reads and writes that are aligned to multiples of, and have a length that is a multiple of 8 and 128 respectively in the last two dimensions are always supported.

Reads and writes to vector memory generally happen on tiles of shape (8, 128). As such, when reading or writing to references that have at least two dimensions, the best performance is achieved when the base offset of the memory access has indices divisible by the tiling, and the size of the read region is a multiple of the tile size.

Elementwise operations#

Many elementwise operations are supported. It is worth noting that the hardware generally only supports elementwise computation using 32-bit types. When loading operands that use lower-precision types, they should generally be upcast to a 32-bit type before applying elementwise ops.

It is worth noting that they can vary significantly in their cost. As such, we outline three categories of supported operations: cheap (🟢), medium (🌕) and expensive (🔴).

Operation

Cost

jnp.add, +

🟢

jnp.sub, -

🟢

jnp.mul, *

🟢

/, //, %

🌕

jnp.max, jnp.min

🟢

jnp.where (select)

🟢

jnp.abs

🟢

|, ^, &, ~

🟢

<<, >>

🟢

Comparisons (==, …)

🟢

Type casts (.astype)

🟢

jnp.exp

🌕

jnp.tanh

🌕

jnp.pow

🌕

jnp.sin

🔴

jnp.cos

🔴

Many JAX functions are implemented in terms of other JAX primitives, so this list might not be comprehensive. For example, jax.nn.relu is implemented in terms of comparisons and jnp.where will work in Pallas kernels too.

Array constructors#

All constant array constructors are supported (jnp.ones, jnp.zeros, jnp.full). Notably, the jax.random module is not compatible with Pallas as of today.

Reductions#

Sum, maximum and minimum reductions are supported, but only on a single array axis at a time.

Reductions over the last array dimension are generally the slowest. Reductions over the second last dimension are faster, but still slower than over the leading dimensions.

Broadcasting#

The performance characteristics of broadcasting are very similar to those of reductions. Broadcasting along all but the two trailing dimensions is always supported and free. Broadcasting along the second to last dimension is slower, while broadcasting along the last dimension is the slowest.

Reshapes#

As usual, reshapes in all dimensions but the last two dimensions are supported and free.

The only two supported cases when a reshape can modify the last two dimensions of an array is when (1) some leading dimensions are flattened onto the second to last dimension, or (2) it adds a dimension that was just removed by a reduction.

Control flow#

The TPU backend features limited support for control flow at the moment. The currently supported functions are cond, fori_loop and for_loop. However, loop primitives get fully unrolled during the compilation at the moment, so try to keep the loop trip count reasonably small.

Overusing control flow can lead to significant regressions in low-level code generation, and it is recommended to try to squeeze as many computationally expensive operations into a single basic block as possible.

Pipelining and BlockSpecs#

In this guide we’ll cover how memory spaces in TPU work and how to write pipelines in Pallas that overlap memory I/O with compute.

#@title Imports

import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np
TPU and its memory spaces#

A TPU and its TensorCore consist of memory spaces (where arrays can reside), registers (which temporarily store scalar and array values) and compute units (that do computation with values in registers). Below is a diagram of a TPU in which x and y are arrays that live in high-bandwidth memory (HBM):

TPU Memory Space Cartoon.png

Let’s talk about the components of this diagram in more detail:

  • Memory spaces: A TPU has high-bandwidth memory (HBM) which is what we often think of as “device memory”. There is also vector memory (VMEM), a cache meant for storing vector and array values, and scalar memory (SMEM), a cache designed to store scalar values.

  • Registers: A TensorCore has two main types of registers: vector registers (VREGs) store array values, and scalar registers (SREGs) store scalar values. Values can be loaded into memory from their respective caches (VMEM for VREGs and SMEM for SREGs).

  • Compute units: A TensorCore has a scalar unit, vector unit (VPU) and matrix unit (MXU) that can do numerical computation. Compute units operate on values that live in SREGs and VREGs and output values into those registers as well.

In order to do a vectorized computation on our values x and y that live in HBM, we need to:

  1. Copy the values x and y into VMEM.

  2. Load the values from VMEM into VREGs.

  3. Execute the computation using the VPU or MXU, storing the output in VREGs.

  4. Store the values in the output VREGs into VMEM.

  5. Copy the output values in VMEM back to HBM.

Let’s implement a Pallas function that does just that!

def add_matrices_kernel(x_vmem_ref, y_vmem_ref, z_vmem_ref):
  # Load x and y from VMEM into VREGs
  x_vregs = x_vmem_ref[:, :]
  y_vregs = y_vmem_ref[:, :]
  # Execute a vectorized add
  z_vregs = x_vregs + y_vregs
  # Store the output values in VREGs back into VMEM
  z_vmem_ref[:, :] = z_vregs


def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array:
  # pallas_call will first allocate scratch buffers for `x` and `y` in VMEM.
  # It will then copy `x` and `y` from HBM into VMEM.
  z = pl.pallas_call(
      add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
  )(x, y)
  # pallas_call will also copy the output from VMEM back into HBM.
  return z


x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       ...,
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)

We’ve written two functions: add_matrices_kernel and add_matrices.

add_matrices_kernel operates using Refs that live in VMEM. Loading from a VMEM Ref produces a value that lives in VREGs. Values in VREGs behave like jax.Arrays in that we can use jnp and jax.lax operations on them to produce new values that live in VREGs. When we produce the values we’d like to return, we store them in the output VMEM Ref.

The add_matrices function acts on jax.Arrays and returns a jax.Array. Inside it, we pass x and y into pallas_call. pallas_call is responsible for copying x and y into VMEM and for allocating the VMEM buffers that the kernel operates on (including allocating z_vmem_ref, the output VMEM buffer). After the kernel function is finished running, pallas_call will also copy the value in z_vmem_ref to HBM, resulting in an output jax.Array.

Constraints of using VMEM/SMEM#

Pallas exposes access to lower level memory spaces like VMEM and SMEM but writing kernels utilizing them adds some considerations.

  1. Memory capacity. VMEM and SMEM are small! VMEM on v4 TPUs is only 16MiB and SMEM ranges in the tens to hundreds of KiB. If our arrays are too big, we won’t even be able to fit them into VMEM at all. For reference, a f32[2048, 2048] array is 16MiB, so our above kernel won’t scale beyond moderately sized arrays.

  2. Memory bandwidth. Copying to/from HBM and VMEM takes a long time, at least compared to most compute instructions. The add_matrices function above will likely spend more time copying between HBM and VMEM than actually performing the addition itself.

With these two constraints in mind, we’ll have to rethink our strategy for getting performance out of our TPUs.

Primer: Pipelining#

Pipelining our computation offers a way of dealing with both the memory capacity and bandwidth constraints in one fell swoop. What do we mean by pipelining?

The goal is: in parallel copy to/from HBM and VMEM while utilizing our compute units. Naively this is difficult because in our program above we copy all of x and y before we start doing any compute with them, creating a dependence between the copy and the compute.

However, if we can chunk up our computation into several subcomputations (e.g. when we add two matrices, we can express that as addition of “blocks” of the original matrices together), we can now overlap the copies of one of those subcomputations with the compute of the other. Let’s walk through a simple example:

Let’s say we split our arrays x and y into x1, x2 and y1, y2 (for example, split along the leading axis, resulting in two (256, 512) arrays for each input. We can now execute the following pipelined computation.

  1. Copy x1 and y1 into VMEM.

  2. Start copying x2 and y2 into VMEM

  3. Load x1, y1 from VMEM into VREGs.

  4. Execute the z1 = x1 + y1 using the compute units.

  5. Store z1 into VMEM.

  6. Start copying z1 from VMEM back into HBM.

  7. Wait until x2, y2 have been copied into VMEM.

  8. Load x2, y2 from VMEM into VREGs.

  9. Execute the z2 = x2 + y2 using the compute units.

  10. Store z2 into VMEM.

  11. Wait until z1 is copied into HBM.

  12. Start copying z2 from VMEM back into HBM.

  13. Wait until z2 is copied into HBM.

Any time we are doing compute here, we are asynchronously copying something. This means that some of the time spent copying is not wasted.

The two most important numbers for determining how efficient a pipelined computation are a) how many floating point operations (FLOPs) we need to execute and b) how many bytes we need to copy to execute that computation. The ratio of these two (FLOPs/memory usage) is called the arithmetic intensity of an operation and determines if our pipeline will be compute bound or memory bound.

Pipelining in Pallas#

How do we implement a pipeline like the one above in Pallas? It seems like a complex sequence of asynchronous data operations and executing kernels that would be a pain to implement manually. Fear not! Pallas offers an API for expressing pipelines without too much boilerplate, namely through grids and BlockSpecs.

grid, a.k.a. kernels in a loop#

See how in the above pipelined example, we are executing the same logic multiple times: steps 3-5 and 8-10 both execute the same operations, only on different inputs. The generalized version of this is a loop in which the same kernel is executed multiple times. pallas_call provides an option to do exactly that.

The number of iterations in the loop is specified via the grid argument to pallas_call. Conceptually:

pl.pallas_call(some_kernel, grid=n)(...)

maps to

for i in range(n):
  # do HBM -> VMEM copies
  some_kernel(...)
  # do VMEM -> HBM copies

Grids can be generalized to be multi-dimensional, corresponding to nested loops. For example,

pl.pallas_call(some_kernel, grid=(n, m))(...)

is equivalent to

for i in range(n):
  for j in range(m):
    # do HBM -> VMEM copies
    some_kernel(...)
    # do VMEM -> HBM copies

This generalizes to any tuple of integers (a length d grid will correspond to d nested loops).

BlockSpec, a.k.a. how to chunk up inputs#

The next piece of information we need to provide Pallas in order to automatically pipeline our computation is information on how to chunk it up. Specifically, we need to provide a mapping between the iteration of the loop to which block of our inputs and outputs to be operated on. A BlockSpec is exactly these two pieces of information.

First we pick a block_shape for our inputs. In the pipelining example above, we had (512, 512)-shaped arrays and split them along the leading dimension into two (256, 512)-shaped arrays. In this pipeline, our block_shape would be (256, 512).

We then provide an index_map function that maps the iteration space to the blocks. Specifically, in the aforementioned pipeline, on the 1st iteration we’d like to select x1 and on the second iteration we’d like to use x2. This can be expressed with the following index_map:

def x_index_map(i):
  return (i, 0)

We’d then construct the BlockSpec:

block_spec = pl.BlockSpec(x_index_map, (256, 512))

The BlockSpecs for y and z will be the same as the one for x.

Putting it together#

We provide these arguments to pallas_call via grid, in_specs and out_specs (in_specs corresponds to the tuple of positional arguments, and out_specs corresponds to the output).

def add_matrices_pipelined(x: jax.Array, y: jax.Array) -> jax.Array:
  block_spec = pl.BlockSpec(lambda i: (i, 0), (256, 512))
  return pl.pallas_call(
      add_matrices_kernel,
      out_shape=x,
      in_specs=[block_spec, block_spec],
      out_specs=block_spec,
      grid=(2,))(x, y)

add_matrices_pipelined(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       ...,
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)

We’ve only added a little bit of code to our original function to add automatic pipelining but the BlockSpecs and grid do a lot of heavy lifting!

How does it work? Well, the BlockSpecs provide enough information to start prefetching blocks of our input from HBM into VMEM. For example, if we are starting iteration i of our grid, we can pass i + 1 into the index_map functions to obtain the blocks needed for the next iteration. We can then start an asynchronous copy for those blocks. Similarly for outputs, we can wait for the outputs of the previous iteration to be copied before starting the copy for the current iteration’s outputs.

Parameterizing a pipeline#

It’s common to parameterize the block shapes in our kernel. Block sizes are perhaps the most important parameter to tune when optimizing the performance of Pallas kernels! They give us control over the pipeline (for example, picking smaller blocks adds more iterations to our pipelined loop where each iteration has less work to do).

Furthermore, we could also carve up the inputs and outputs along the 2nd dimension (we are only splitting along the first right now). Let’s write a more general kernel that handles both of these features.

def add_matrices_pipelined_2d(
    x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256
) -> jax.Array:
  m, n = x.shape
  block_spec = pl.BlockSpec(lambda i, j: (i, j), (bm, bn))

  return pl.pallas_call(
      add_matrices_kernel,
      out_shape=x,
      in_specs=[block_spec, block_spec],
      out_specs=block_spec,
      grid=(m // bm, n // bn),
  )(x, y)


np.testing.assert_array_equal(
    add_matrices_pipelined_2d(x, y, bm=256, bn=256), x + y
)
np.testing.assert_array_equal(
    add_matrices_pipelined_2d(x, y, bm=128, bn=128), x + y
)
np.testing.assert_array_equal(
    add_matrices_pipelined_2d(x, y, bm=512, bn=512), x + y
)
Handling reductions#

How would you implement something like jnp.sum using pallas_call? Specifically, we’d like to pipeline across the reduction dimension.

Take the example of reducing a (8, 512, 512)-shaped array to a (512, 512)-shaped one.

x = jnp.ones((8, 512, 512))
jnp.sum(x, axis=0)
Array([[8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       ...,
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.]], dtype=float32)

To do this using pallas_call, we could use a grid of size (8,) and in each iteration i load x[i] into VMEM. Then we could add x[i] to an output VMEM buffer. Let’s implement this naively first.

# Warning: this implementation is incorrect!

def naive_sum_kernel(x_ref, o_ref):
  o_ref[...] += x_ref[...]

def naive_sum(x: jax.Array) -> jax.Array:
  grid, *out_shape = x.shape
  return pl.pallas_call(
      naive_sum_kernel,
      grid=grid,
      # None in `block_shape` means we pick a size of 1 and squeeze it away
      in_specs=[pl.BlockSpec(lambda i: (i, 0, 0), (None, *out_shape))],
      out_specs=pl.BlockSpec(lambda i: (0, 0), out_shape),
      out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)
      )(x)
naive_sum(x)
Array([[9., 9., 9., ..., 9., 9., 9.],
       [9., 9., 9., ..., 9., 9., 9.],
       [9., 9., 9., ..., 9., 9., 9.],
       ...,
       [9., 9., 9., ..., 9., 9., 9.],
       [9., 9., 9., ..., 9., 9., 9.],
       [9., 9., 9., ..., 9., 9., 9.]], dtype=float32)

Notice how we’ve set up the BlockSpecs: we’re loading the entirety of the (512, 512) dimension into VMEM (no pipelining there) but selecting the i-th dimension of x each iteration in the index_map. We are using a None for that dimension in the block shape, which indicates that we are selecting a singleton dimension from x that we would like to squeeze away in the kernel. Therefore, x_ref is (512, 512)-shaped in VMEM as well.

out_spec uses lambda i: (0, 0) as its index_map, indicating that o_ref is unchanged over the course of the pipeline. This means that we can update its value each iteration by reading from and writing to it. Or can it? Actually there is one catch: o_ref is initially garbage, meaning we’ll be accumulating into garbage. This will result in the overall function outputting the incorrect value!

Therefore, whenever we do a reduction in a kernel, we need to make sure to initialize the Ref that is storing the reduced value. We can accomplish this by conditionally writing a value to out_ref when we’re on iteration 0. We can do this with the helper function pl.when, a convenience wrapper around jax.lax.cond, and pl.program_id, which queries which iteration in a grid axis we are in.

def sum_kernel(x_ref, o_ref):
  @pl.when(pl.program_id(axis=0) == 0)
  def _():
    o_ref[...] = jnp.zeros_like(o_ref)

  o_ref[...] += x_ref[...]

def sum(x: jax.Array) -> jax.Array:
  grid, *out_shape = x.shape
  return pl.pallas_call(
      sum_kernel,
      grid=grid,
      # None in `block_shape` means we pick a size of 1 and squeeze it away
      in_specs=[pl.BlockSpec(lambda i: (i, 0, 0), (None, *out_shape))],
      out_specs=pl.BlockSpec(lambda i: (0, 0), out_shape),
      out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)
      )(x)
sum(x)
Array([[8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       ...,
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.]], dtype=float32)

This sum function now outputs the correct values!

One last thing to note about reductions in Pallas are that they must be done in the minormost (rightmost) dimensions of our grid (our grid is 1-dimensional in the above example so we are reducing over its minormost dimension). This is because the pipeline that Pallas generates using the BlockSpecs, grid and kernel function does not read outputs back from HBM. Once you’ve written an output value back to HBM you cannot revisit it. Therefore, you cannot do a reduction across a grid dimension that has any revisiting and therefore all reductions need to happen in the rightmost dimensions.

TPUs in Megacore configuration#

Some TPU chips have two TensorCores but appear as one device to JAX users. This is called “megacore”. The separate TensorCores have their own separate VMEM, VREGs, SMEM, SREGs and compute units but share HBM.

TPU Memory Space Cartoon (Megacore).png

Conceptually, TPUs in Megacore behave like very simple GPUs, i.e. they have only two threads. How do we modify our kernels to utilize both TensorCores simultaneously?

The basic idea is that if we have embarrassingly parallel dimensions in our computation, we can split up those dimensions across the TensorCores. We can indicate which dimensions are parallelizable by providing an annotation to pallas_call called dimension_semantics.

def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:
  block_spec = pl.BlockSpec(lambda i: (i, 0), (256, 512))
  return pl.pallas_call(
      add_matrices_kernel,
      out_shape=x,
      in_specs=[block_spec, block_spec],
      out_specs=block_spec,
      grid=(2,),
      compiler_params=dict(mosaic=dict(dimension_semantics=("parallel",))))(
        x, y)

x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices_pipelined_megacore(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       ...,
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)

dimension_semantics should be a tuple of same length as grid where each entry is either "parallel" or "arbitrary". "parallel" indicates to Pallas that the iterations of the for loop corresponding to that dimension can be executed independently without affecting the correctness of the program. "arbitrary" indicates to Pallas that there can be no assumptions made about this grid dimension and it therefore cannot be parallelized.

By specifying dimension_semantics, we now execute the kernel simultaneously on each TensorCore. Pallas will handle splitting up the grid automatically.

Note that Megacore is only currently available on TPU v4 and TPU v5p. Supplying dimension_semantics annotations is a no-op on other platforms, but not specifying it will result in only one TensorCore being used (even if there are more than one available).

Conclusion#

In this guide we covered how to express TPU pipelines using pallas_call, grid and BlockSpecs. We covered how to express nested loops via a multi-dimensional grid and how to handle reductions by initialize our accumulators at the beginning of the reduction. We also learned how to handle Megacore by adding annotations to the kernel.

Exercises left to the reader:

  • Try implementing a sum kernel that pipelines the other dimensions as well

  • Add megacore support to the add kernel and the sum kernel as well.

Advanced Tutorials#

This section contains examples and tutorials on more advanced topics, such as Multi Core computation, Custom operations, and more in depth applications

Copyright 2018 The JAX Authors.

Licensed under the Apache License, Version 2.0 (the “License”);

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Training a Simple Neural Network, with tensorflow/datasets Data Loading#

Open in Colab Open in Kaggle

Forked from neural_network_and_data_loading.ipynb

JAX

Let’s combine everything we showed in the quickstart to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use tensorflow/datasets data loading API to load images and labels (because it’s pretty great, and the world doesn’t need yet another data loading library :P).

Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won’t use any neural network libraries or special APIs for building our model.

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

Hyperparameters#

Let’s get a few bookkeeping items out of the way.

# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.key(0))

Auto-batching predictions#

Let us first define our prediction function. Note that we’re defining this for a single image example. We’re going to use JAX’s vmap function to automatically handle mini-batches, with no performance penalty.

from jax.scipy.special import logsumexp

def relu(x):
  return jnp.maximum(0, x)

def predict(params, image):
  # per-example predictions
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)
  
  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits - logsumexp(logits)

Let’s check that our prediction function only works on single images.

# This works on single examples
random_flattened_image = random.normal(random.key(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)
(10,)
# Doesn't work with a batch
random_flattened_images = random.normal(random.key(1), (10, 28 * 28))
try:
  preds = predict(params, random_flattened_images)
except TypeError:
  print('Invalid shapes!')
Invalid shapes!
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)
(10, 10)

At this point, we have all the ingredients we need to define our neural network and train it. We’ve built an auto-batched version of predict, which we should be able to use in a loss function. We should be able to use grad to take the derivative of the loss with respect to the neural network parameters. Last, we should be able to use jit to speed up everything.

Utility and loss functions#

def one_hot(x, k, dtype=jnp.float32):
  """Create a one-hot encoding of x of size k."""
  return jnp.array(x[:, None] == jnp.arange(k), dtype)
  
def accuracy(params, images, targets):
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
  preds = batched_predict(params, images)
  return -jnp.mean(preds * targets)

@jit
def update(params, x, y):
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

Data Loading with tensorflow/datasets#

JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don’t include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let’s just use them instead of reinventing anything. We’ll use the tensorflow/datasets data loader.

import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type='GPU')

import tensorflow_datasets as tfds

data_dir = '/tmp/tfds'

# Fetch full datasets for evaluation
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c

# Full train set
train_images, train_labels = train_data['image'], train_data['label']
train_images = jnp.reshape(train_images, (len(train_images), num_pixels))
train_labels = one_hot(train_labels, num_labels)

# Full test set
test_images, test_labels = test_data['image'], test_data['label']
test_images = jnp.reshape(test_images, (len(test_images), num_pixels))
test_labels = one_hot(test_labels, num_labels)
print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape)
Train: (60000, 784) (60000, 10)
Test: (10000, 784) (10000, 10)

Training Loop#

import time

def get_train_batches():
  # as_supervised=True gives us the (image, label) as a tuple instead of a dict
  ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)
  # You can build up an arbitrary tf.data input pipeline
  ds = ds.batch(batch_size).prefetch(1)
  # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
  return tfds.as_numpy(ds)

for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in get_train_batches():
    x = jnp.reshape(x, (len(x), num_pixels))
    y = one_hot(y, num_labels)
    params = update(params, x, y)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))
Epoch 0 in 28.30 sec
Training set accuracy 0.8400499820709229
Test set accuracy 0.8469000458717346
Epoch 1 in 14.74 sec
Training set accuracy 0.8743667006492615
Test set accuracy 0.8803000450134277
Epoch 2 in 14.57 sec
Training set accuracy 0.8901500105857849
Test set accuracy 0.8957000374794006
Epoch 3 in 14.36 sec
Training set accuracy 0.8991333246231079
Test set accuracy 0.903700053691864
Epoch 4 in 14.20 sec
Training set accuracy 0.9061833620071411
Test set accuracy 0.9087000489234924
Epoch 5 in 14.89 sec
Training set accuracy 0.9113333225250244
Test set accuracy 0.912600040435791
Epoch 6 in 13.95 sec
Training set accuracy 0.9156833291053772
Test set accuracy 0.9176000356674194
Epoch 7 in 13.32 sec
Training set accuracy 0.9192000031471252
Test set accuracy 0.9214000701904297
Epoch 8 in 13.55 sec
Training set accuracy 0.9222500324249268
Test set accuracy 0.9241000413894653
Epoch 9 in 13.40 sec
Training set accuracy 0.9253666996955872
Test set accuracy 0.9269000291824341

We’ve now used most of the JAX API: grad for derivatives, jit for speedups and vmap for auto-vectorization. We used NumPy to specify all of our computation, and borrowed the great data loaders from tensorflow/datasets, and ran the whole thing on the GPU.

Training a Simple Neural Network, with PyTorch Data Loading#

Open in Colab Open in Kaggle

Copyright 2018 The JAX Authors.

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

JAX

Let’s combine everything we showed in the quickstart to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch’s data loading API to load images and labels (because it’s pretty great, and the world doesn’t need yet another data loading library).

Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won’t use any neural network libraries or special APIs for building our model.

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

Hyperparameters#

Let’s get a few bookkeeping items out of the way.

# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 8
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.key(0))

Auto-batching predictions#

Let us first define our prediction function. Note that we’re defining this for a single image example. We’re going to use JAX’s vmap function to automatically handle mini-batches, with no performance penalty.

from jax.scipy.special import logsumexp

def relu(x):
  return jnp.maximum(0, x)

def predict(params, image):
  # per-example predictions
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)
  
  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits - logsumexp(logits)

Let’s check that our prediction function only works on single images.

# This works on single examples
random_flattened_image = random.normal(random.key(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)
(10,)
# Doesn't work with a batch
random_flattened_images = random.normal(random.key(1), (10, 28 * 28))
try:
  preds = predict(params, random_flattened_images)
except TypeError:
  print('Invalid shapes!')
Invalid shapes!
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)
(10, 10)

At this point, we have all the ingredients we need to define our neural network and train it. We’ve built an auto-batched version of predict, which we should be able to use in a loss function. We should be able to use grad to take the derivative of the loss with respect to the neural network parameters. Last, we should be able to use jit to speed up everything.

Utility and loss functions#

def one_hot(x, k, dtype=jnp.float32):
  """Create a one-hot encoding of x of size k."""
  return jnp.array(x[:, None] == jnp.arange(k), dtype)
  
def accuracy(params, images, targets):
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
  preds = batched_predict(params, images)
  return -jnp.mean(preds * targets)

@jit
def update(params, x, y):
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

Data Loading with PyTorch#

JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don’t include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let’s just use them instead of reinventing anything. We’ll grab PyTorch’s data loader, and make a tiny shim to make it work with NumPy arrays.

!pip install torch torchvision
Requirement already satisfied: torch in /opt/anaconda3/lib/python3.7/site-packages (1.4.0)
Requirement already satisfied: torchvision in /opt/anaconda3/lib/python3.7/site-packages (0.5.0)
Requirement already satisfied: numpy in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (1.17.2)
Requirement already satisfied: six in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (1.12.0)
Requirement already satisfied: pillow>=4.1.1 in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (6.2.0)
import numpy as np
from jax.tree_util import tree_map
from torch.utils import data
from torchvision.datasets import MNIST

def numpy_collate(batch):
  return tree_map(np.asarray, data.default_collate(batch))

class NumpyLoader(data.DataLoader):
  def __init__(self, dataset, batch_size=1,
                shuffle=False, sampler=None,
                batch_sampler=None, num_workers=0,
                pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
    super(self.__class__, self).__init__(dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        collate_fn=numpy_collate,
        pin_memory=pin_memory,
        drop_last=drop_last,
        timeout=timeout,
        worker_init_fn=worker_init_fn)

class FlattenAndCast(object):
  def __call__(self, pic):
    return np.ravel(np.array(pic, dtype=jnp.float32))
# Define our dataset, using torch datasets
mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw
Processing...
Done!
# Get the full train dataset (for checking accuracy while training)
train_images = np.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)
train_labels = one_hot(np.array(mnist_dataset.train_labels), n_targets)

# Get full test dataset
mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)
test_images = jnp.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=jnp.float32)
test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets)

/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:55: UserWarning: train_data has been renamed data
  warnings.warn("train_data has been renamed data")
/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:45: UserWarning: train_labels has been renamed targets
  warnings.warn("train_labels has been renamed targets")
/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:60: UserWarning: test_data has been renamed data
  warnings.warn("test_data has been renamed data")
/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:50: UserWarning: test_labels has been renamed targets
  warnings.warn("test_labels has been renamed targets")

Training Loop#

import time

for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in training_generator:
    y = one_hot(y, n_targets)
    params = update(params, x, y)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))
Epoch 0 in 55.15 sec
Training set accuracy 0.9157500267028809
Test set accuracy 0.9195000529289246
Epoch 1 in 42.26 sec
Training set accuracy 0.9372166991233826
Test set accuracy 0.9384000301361084
Epoch 2 in 44.37 sec
Training set accuracy 0.9491666555404663
Test set accuracy 0.9469000697135925
Epoch 3 in 41.75 sec
Training set accuracy 0.9568166732788086
Test set accuracy 0.9534000158309937
Epoch 4 in 41.16 sec
Training set accuracy 0.9631333351135254
Test set accuracy 0.9577000737190247
Epoch 5 in 38.89 sec
Training set accuracy 0.9675000309944153
Test set accuracy 0.9616000652313232
Epoch 6 in 40.68 sec
Training set accuracy 0.9708333611488342
Test set accuracy 0.9650000333786011
Epoch 7 in 41.50 sec
Training set accuracy 0.973716676235199
Test set accuracy 0.9672000408172607

We’ve now used the whole of the JAX API: grad for derivatives, jit for speedups and vmap for auto-vectorization. We used NumPy to specify all of our computation, and borrowed the great data loaders from PyTorch, and ran the whole thing on the GPU.

Autobatching for Bayesian Inference#

Open in Colab Open in Kaggle

This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.

Inspired by a notebook by @davmre.

import functools
import itertools
import re
import sys
import time

from matplotlib.pyplot import *

import jax

from jax import lax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import random

import numpy as np
import scipy as sp

Generate a fake binary classification dataset#

np.random.seed(10009)

num_features = 10
num_points = 100

true_beta = np.random.randn(num_features).astype(jnp.float32)
all_x = np.random.randn(num_points, num_features).astype(jnp.float32)
y = (np.random.rand(num_points) < sp.special.expit(all_x.dot(true_beta))).astype(jnp.int32)
y
array([0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0,
       1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0,
       1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
       0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1,
       1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0], dtype=int32)

Write the log-joint function for the model#

We’ll write a non-batched version, a manually batched version, and an autobatched version.

Non-batched#
def log_joint(beta):
    result = 0.
    # Note that no `axis` parameter is provided to `jnp.sum`.
    result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=1.))
    result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta))))
    return result
log_joint(np.random.randn(num_features))
Array(-213.2356, dtype=float32)
# This doesn't work, because we didn't write `log_prob()` to handle batching.
try:
  batch_size = 10
  batched_test_beta = np.random.randn(batch_size, num_features)

  log_joint(np.random.randn(batch_size, num_features))
except ValueError as e:
  print("Caught expected exception " + str(e))
Caught expected exception Incompatible shapes for broadcasting: shapes=[(100,), (100, 10)]
Manually batched#
def batched_log_joint(beta):
    result = 0.
    # Here (and below) `sum` needs an `axis` parameter. At best, forgetting to set axis
    # or setting it incorrectly yields an error; at worst, it silently changes the
    # semantics of the model.
    result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=1.),
                           axis=-1)
    # Note the multiple transposes. Getting this right is not rocket science,
    # but it's also not totally mindless. (I didn't get it right on the first
    # try.)
    result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta.T).T)),
                           axis=-1)
    return result
batch_size = 10
batched_test_beta = np.random.randn(batch_size, num_features)

batched_log_joint(batched_test_beta)
Array([-147.84033 , -207.02205 , -109.26075 , -243.80833 , -163.0291  ,
       -143.84848 , -160.28773 , -113.771706, -126.60544 , -190.81992 ],      dtype=float32)
Autobatched with vmap#

It just works.

vmap_batched_log_joint = jax.vmap(log_joint)
vmap_batched_log_joint(batched_test_beta)
Array([-147.84033 , -207.02205 , -109.26075 , -243.80833 , -163.0291  ,
       -143.84848 , -160.28773 , -113.771706, -126.60544 , -190.81992 ],      dtype=float32)

Self-contained variational inference example#

A little code is copied from above.

Set up the (batched) log-joint function#
@jax.jit
def log_joint(beta):
    result = 0.
    # Note that no `axis` parameter is provided to `jnp.sum`.
    result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=10.))
    result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta))))
    return result

batched_log_joint = jax.jit(jax.vmap(log_joint))
Define the ELBO and its gradient#
def elbo(beta_loc, beta_log_scale, epsilon):
    beta_sample = beta_loc + jnp.exp(beta_log_scale) * epsilon
    return jnp.mean(batched_log_joint(beta_sample), 0) + jnp.sum(beta_log_scale - 0.5 * np.log(2*np.pi))
 
elbo = jax.jit(elbo)
elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))
Optimize the ELBO using SGD#
def normal_sample(key, shape):
    """Convenience function for quasi-stateful RNG."""
    new_key, sub_key = random.split(key)
    return new_key, random.normal(sub_key, shape)

normal_sample = jax.jit(normal_sample, static_argnums=(1,))

key = random.key(10003)

beta_loc = jnp.zeros(num_features, jnp.float32)
beta_log_scale = jnp.zeros(num_features, jnp.float32)

step_size = 0.01
batch_size = 128
epsilon_shape = (batch_size, num_features)
for i in range(1000):
    key, epsilon = normal_sample(key, epsilon_shape)
    elbo_val, (beta_loc_grad, beta_log_scale_grad) = elbo_val_and_grad(
        beta_loc, beta_log_scale, epsilon)
    beta_loc += step_size * beta_loc_grad
    beta_log_scale += step_size * beta_log_scale_grad
    if i % 10 == 0:
        print('{}\t{}'.format(i, elbo_val))
0	-180.8538818359375
10	-113.06045532226562
20	-102.73727416992188
30	-99.787353515625
40	-98.90898132324219
50	-98.29745483398438
60	-98.18632507324219
70	-97.57972717285156
80	-97.28599548339844
90	-97.46996307373047
100	-97.4771728515625
110	-97.5806655883789
120	-97.4943618774414
130	-97.50271606445312
140	-96.86396026611328
150	-97.44197845458984
160	-97.06941223144531
170	-96.84028625488281
180	-97.21336364746094
190	-97.56503295898438
200	-97.26397705078125
210	-97.11979675292969
220	-97.39595031738281
230	-97.16831970214844
240	-97.118408203125
250	-97.24345397949219
260	-97.29788970947266
270	-96.69286346435547
280	-96.96438598632812
290	-97.30055236816406
300	-96.63591766357422
310	-97.0351791381836
320	-97.52909088134766
330	-97.28811645507812
340	-97.07321166992188
350	-97.15619659423828
360	-97.25881958007812
370	-97.19515228271484
380	-97.13092041015625
390	-97.11726379394531
400	-96.938720703125
410	-97.26676940917969
420	-97.35322570800781
430	-97.21007537841797
440	-97.28434753417969
450	-97.1630859375
460	-97.2612533569336
470	-97.21343994140625
480	-97.23997497558594
490	-97.14913940429688
500	-97.23527526855469
510	-96.93419647216797
520	-97.21209716796875
530	-96.82575988769531
540	-97.01284790039062
550	-96.94175720214844
560	-97.16520690917969
570	-97.29165649414062
580	-97.42941284179688
590	-97.24370574951172
600	-97.15222930908203
610	-97.49844360351562
620	-96.9906997680664
630	-96.88956451416016
640	-96.89968872070312
650	-97.13793182373047
660	-97.43705749511719
670	-96.99235534667969
680	-97.15623474121094
690	-97.1869125366211
700	-97.11160278320312
710	-97.78105163574219
720	-97.23226165771484
730	-97.16206359863281
740	-96.99581909179688
750	-96.6672134399414
760	-97.16795349121094
770	-97.51435089111328
780	-97.28900146484375
790	-96.91226196289062
800	-97.17100524902344
810	-97.29047393798828
820	-97.16242980957031
830	-97.19107055664062
840	-97.56382751464844
850	-97.00194549560547
860	-96.86555480957031
870	-96.76338195800781
880	-96.83660888671875
890	-97.12178039550781
900	-97.09554290771484
910	-97.0682373046875
920	-97.11947631835938
930	-96.87930297851562
940	-97.45624542236328
950	-96.69279479980469
960	-97.29376220703125
970	-97.3353042602539
980	-97.34962463378906
990	-97.09675598144531
Display the results#

Coverage isn’t quite as good as we might like, but it’s not bad, and nobody said variational inference was exact.

figure(figsize=(7, 7))
plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')
plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\sigma$ Error Bars')
plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')
plot_scale = 3
plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')
xlabel('True beta')
ylabel('Estimated beta')
legend(loc='best')
<matplotlib.legend.Legend at 0x7fd32443e410>
_images/35c68e861c717b58be9bc16347aa3dcb2f272ac1cfa159a23214a7c2302a4aae.png

Using JAX in multi-host and multi-process environments#

Introduction#

This guide explains how to use JAX in environments such as GPU clusters and Cloud TPU pods where accelerators are spread across multiple CPU hosts or JAX processes. We’ll refer to these as “multi-process” environments.

This guide specifically focuses on how to use collective communication operations (e.g. jax.lax.psum() ) in multi-process settings, although other communication methods may be useful too depending on your use case (e.g. RPC, mpi4jax). If you’re not already familiar with JAX’s collective operations, we recommend starting with the Introduction to sharded computation section. An important requirement of multi-process environments in JAX is direct communication links between accelerators, e.g. the high-speed interconnects for Cloud TPUs or NCCL for GPUs. These links allow collective operations to run across multiple processes’ worth of accelerators with high performance.

Multi-process programming model#

Key concepts:

  • You must run at least one JAX process per host.

  • You should initialize the cluster with jax.distributed.initialize().

  • Each process has a distinct set of local devices it can address. The global devices are the set of all devices across all processes.

  • Use standard JAX parallelism APIs like pmap() and xmap(). Each process “sees” local input and output to parallelized functions, but communication inside the computations is global.

  • Make sure all processes run the same parallel computations in the same order.

Launching JAX processes#

Unlike other distributed systems where a single controller node manages many worker nodes, JAX uses a “multi-controller” programming model where each JAX Python process runs independently, sometimes referred to as a Single Program, Multiple Data (SPMD) model. Generally, the same JAX Python program is run in each process, with only slight differences between each process’s execution (e.g. different processes will load different input data). Furthermore, you must manually run your JAX program on each host! JAX doesn’t automatically start multiple processes from a single program invocation.

(The requirement for multiple processes is why this guide isn’t offered as a notebook – we don’t currently have a good way to manage multiple Python processes from a single notebook.)

Initializing the cluster#

To initialize the cluster, you should call jax.distributed.initialize() at the start of each process. jax.distributed.initialize() must be called early in the program, before any JAX computations are executed.

The API jax.distributed.initialize() takes several arguments, namely:

  • coordinator_address: the IP address of process 0 in your cluster, together with a port available on that process. Process 0 will start a JAX service exposed via that IP address and port, to which the other processes in the cluster will connect.

  • coordinator_bind_address: the IP address and port to which the JAX service on process 0 in your cluster will bind. By default, it will bind to all available interfaces using the same port as coordinator_address.

  • num_processes: the number of processes in the cluster

  • process_id: the ID number of this process, in the range [0 .. num_processes).

  • local_device_ids: Restricts the visible devices of the current process to local_device_ids.

For example on GPU, a typical usage is:

import jax

jax.distributed.initialize(coordinator_address="192.168.0.1:1234",
                           num_processes=2,
                           process_id=0)

On Cloud TPU, Slurm and Open MPI environments, you can simply call jax.distributed.initialize() with no arguments. Default values for the arguments will be chosen automatically. When running on GPUs with Slurm and Open MPI, it is assumed that one process is started per GPU, i.e. each process will be assigned only one visible local device. Otherwise it is assumed that one process is started per host, i.e. each process will be assigned all local devices. The Open MPI auto-initialization is only used when the JAX processes are launched via mpirun/mpiexec.

import jax

jax.distributed.initialize()

On TPU at present calling jax.distributed.initialize() is optional, but recommended since it enables additional checkpointing and health checking features.

Local vs. global devices#

Before we get to running multi-process computations from your program, it’s important to understand the distinction between local and global devices.

A process’s local devices are those that it can directly address and launch computations on. For example, on a GPU cluster, each host can only launch computations on the directly attached GPUs. On a Cloud TPU pod, each host can only launch computations on the 8 TPU cores attached directly to that host (see the Cloud TPU System Architecture documentation for more details). You can see a process’s local devices via jax.local_devices().

The global devices are the devices across all processes. A computation can span devices across processes and perform collective operations via the direct communication links between devices, as long as each process launches the computation on its local devices. You can see all available global devices via jax.devices(). A process’s local devices are always a subset of the global devices.

Running multi-process computations#

So how do you actually run a computation involving cross-process communication? Use the same parallel evaluation APIs that you would in a single process!

For example, shard_map() can be used to run a parallel computation across multiple processes. (If you’re not already familiar with how to use shard_map to run across multiple devices within a single process, check out the Introduction to sharded computation tutorial.) Each process should call the same pmapped function and pass in arguments to be mapped across its local devices (i.e., the pmapped axis size is equal to the number of local devices). Similarly, the function will return outputs sharded across local devices only. Inside the function, however, collective communication operations are run across all global devices, across all processes. Conceptually, this can be thought of as running a pmap over a single array sharded across hosts, where each host “sees” only its local shard of the input and output.

Here’s an example of multi-process pmap in action:

# The following is run in parallel on each host on a GPU cluster or TPU pod slice.
>>> import jax
>>> jax.distributed.initialize()  # On GPU, see above for the necessary arguments.
>>> jax.device_count()  # total number of accelerator devices in the cluster
32
>>> jax.local_device_count()  # number of accelerator devices attached to this host
8
# The psum is performed over all mapped devices across the pod slice
>>> xs = jax.numpy.ones(jax.local_device_count())
>>> jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
ShardedDeviceArray([32., 32., 32., 32., 32., 32., 32., 32.], dtype=float32)

xmap() works similarly when using a physical hardware mesh (see the xmap tutorial if you’re not familiar with the single-process version). Like pmap() , the inputs and outputs are local and any parallel communication inside the xmapped function is global. The mesh is also global.

It’s very important that all processes run the same cross-process computations in the same order. Running the same JAX Python program in each process is usually sufficient. Some common pitfalls to look out for that may cause differently-ordered computations despite running the same program:

  • Processes passing differently-shaped inputs to the same parallel function can cause hangs or incorrect return values. Differently-shaped inputs are safe so long as they result in identically-shaped per-device data shards across processes; e.g. passing in different leading batch sizes in order to run on different numbers of local devices per process is ok, but having each process pad its batch to a different max example length is not.

  • “Last batch” issues where a parallel function is called in a (training) loop, and one or more processes exit the loop earlier than the rest. This will cause the rest to hang waiting for the already-finished processes to start the computation.

  • Conditions based on non-deterministic ordering of collections can cause code processes to hang. For example, iterating over set on current Python versions or dict before Python 3.7 may result in a different ordering on different processes, even with the same insertion order.

Distributed arrays and automatic parallelization#

Open in Colab Open in Kaggle

This tutorial discusses parallelism via jax.Array, the unified array object model available in JAX v0.4.1 and newer.

import os

import functools
from typing import Optional

import numpy as np

import jax
import jax.numpy as jnp

⚠️ WARNING: The notebook requires 8 devices to run.

if len(jax.local_devices()) < 8:
  raise Exception("Notebook requires 8 devices to run")

Intro and a quick example#

By reading this tutorial notebook, you’ll learn about jax.Array, a unified datatype for representing arrays, even with physical storage spanning multiple devices. You’ll also learn about how using jax.Arrays together with jax.jit can provide automatic compiler-based parallelization.

Before we think step by step, here’s a quick example. First, we’ll create a jax.Array sharded across multiple devices:

from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
# Create a Sharding object to distribute a value across devices:
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
# Create an array of random values:
x = jax.random.normal(jax.random.key(0), (8192, 8192))
# and use jax.device_put to distribute it across devices:
y = jax.device_put(x, sharding.reshape(4, 2))
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘

Next, we’ll apply a computation to it and visualize how the result values are stored across multiple devices too:

z = jnp.sin(y)
jax.debug.visualize_array_sharding(z)
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘

The evaluation of the jnp.sin application was automatically parallelized across the devices on which the input values (and output values) are stored:

# `x` is present on a single device
%timeit -n 5 -r 5 jnp.sin(x).block_until_ready()
The slowest run took 13.32 times longer than the fastest. This could mean that an intermediate result is being cached 
5 loops, best of 5: 9.69 ms per loop
# `y` is sharded across 8 devices.
%timeit -n 5 -r 5 jnp.sin(y).block_until_ready()
5 loops, best of 5: 1.86 ms per loop

Now let’s look at each of these pieces in more detail!

Sharding describes how array values are laid out in memory across devices#

Sharding basics, and the PositionalSharding subclass#

To parallelize computation across multiple devices, we first must lay out input data across multiple devices.

In JAX, Sharding objects describe distributed memory layouts. They can be used with jax.device_put to produce a value with distributed layout.

For example, here’s a value with a single-device Sharding:

import jax
x = jax.random.normal(jax.random.key(0), (8192, 8192))
jax.debug.visualize_array_sharding(x)
┌───────────────────────┐
│                       │
│                       │
│                       │
│                       │
│         TPU 0         │
│                       │
│                       │
│                       │
│                       │
└───────────────────────┘

Here, we’re using the jax.debug.visualize_array_sharding function to show where the value x is stored in memory. All of x is stored on a single device, so the visualization is pretty boring!

But we can shard x across multiple devices by using jax.device_put and a Sharding object. First, we make a numpy.ndarray of Devices using mesh_utils.create_device_mesh, which takes hardware topology into account for the Device order:

from jax.experimental import mesh_utils
devices = mesh_utils.create_device_mesh((8,))

Then, we create a PositionalSharding and use it with device_put:

from jax.sharding import PositionalSharding

sharding = PositionalSharding(devices)

x = jax.device_put(x, sharding.reshape(8, 1))
jax.debug.visualize_array_sharding(x)
┌───────────────────────┐
│         TPU 0         │
├───────────────────────┤
│         TPU 1         │
├───────────────────────┤
│         TPU 2         │
├───────────────────────┤
│         TPU 3         │
├───────────────────────┤
│         TPU 6         │
├───────────────────────┤
│         TPU 7         │
├───────────────────────┤
│         TPU 4         │
├───────────────────────┤
│         TPU 5         │
└───────────────────────┘

Here sharding is a PositionalSharding which acts like an array with sets of devices as elements:

sharding
PositionalSharding([{TPU 0} {TPU 1} {TPU 2} {TPU 3} {TPU 6} {TPU 7} {TPU 4} {TPU 5}])

The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.

By writing PositionalSharding(ndarray_of_devices), we fix the device order and the initial shape. Then we can reshape it:

sharding.reshape(8, 1)
PositionalSharding([[{TPU 0}]
                    [{TPU 1}]
                    [{TPU 2}]
                    [{TPU 3}]
                    [{TPU 6}]
                    [{TPU 7}]
                    [{TPU 4}]
                    [{TPU 5}]])
sharding.reshape(4, 2)
PositionalSharding([[{TPU 0} {TPU 1}]
                    [{TPU 2} {TPU 3}]
                    [{TPU 6} {TPU 7}]
                    [{TPU 4} {TPU 5}]])

To use device_put with a data array x, we can reshape the sharding into a shape that is congruent with x.shape, meaning a shape with the same length as x.shape and where each element evenly divides the corresponding element of x.shape:

def is_congruent(x_shape: Sequence[int], sharding_shape: Sequence[int]) -> bool:
  return (len(x_shape) == len(sharding_shape) and
          all(d1 % d2 == 0 for d1, d2 in zip(x_shape, sharding_shape)))

For example, we can reshape sharding to have shape (4, 2), then use it in a device_put:

sharding = sharding.reshape(4, 2)
print(sharding)
PositionalSharding([[{TPU 0} {TPU 1}]
                    [{TPU 2} {TPU 3}]
                    [{TPU 6} {TPU 7}]
                    [{TPU 4} {TPU 5}]])
y = jax.device_put(x, sharding)
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘

Here y represents the same value as x, but its shards (i.e. slices) are stored in different devices’ memories.

Different PositionalSharding shapes result in different distributed layouts (i.e. shardings) of the result:

sharding = sharding.reshape(1, 8)
print(sharding)
PositionalSharding([[{TPU 0} {TPU 1} {TPU 2} {TPU 3} {TPU 6} {TPU 7} {TPU 4} {TPU 5}]])
y = jax.device_put(x, sharding)
jax.debug.visualize_array_sharding(y)
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 6 │ TPU 7 │ TPU 4 │ TPU 5 │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘

In some cases, we don’t just want to store each slice of x in a single device’s memory; we might want to replicate some slices, meaning storing copies of a slice’s values in multiple devices’ memories.

With PositionalSharding, we can express replication by calling the reducer method replicate:

sharding = sharding.reshape(4, 2)
print(sharding.replicate(axis=0, keepdims=True))
PositionalSharding([[{TPU 0, 2, 4, 6} {TPU 1, 3, 5, 7}]])
y = jax.device_put(x, sharding.replicate(axis=0, keepdims=True))
jax.debug.visualize_array_sharding(y)
┌───────────┬───────────┐
│           │           │
│           │           │
│           │           │
│           │           │
│TPU 0,2,4,6│TPU 1,3,5,7│
│           │           │
│           │           │
│           │           │
│           │           │
└───────────┴───────────┘

Here the visualization shows that x is sharded two ways along its second dimension (and not sharded along the first dimension), and each of those shards is replicated four ways (i.e. stored in four device memories).

The replicate method is analogous to the familiar NumPy array reduction methods like .sum() and .prod(). It operates along an axis performing a set union. So if sharding has shape (4, 2), then sharding.replicate(0, keepdims=True) has shape (1, 2), and sharding.replicate(1, keepdims=True) has shape (4, 1). Unlike analogous NumPy methods, keepdims=True is actually the default, so reduced-over axes aren’t squeezed:

print(sharding.replicate(0).shape)
print(sharding.replicate(1).shape)
(1, 2)
(4, 1)
y = jax.device_put(x, sharding.replicate(1))
jax.debug.visualize_array_sharding(y)
┌───────────────────────┐
│        TPU 0,1        │
├───────────────────────┤
│        TPU 2,3        │
├───────────────────────┤
│        TPU 6,7        │
├───────────────────────┤
│        TPU 4,5        │
└───────────────────────┘
NamedSharding gives a way to express shardings with names#

So far we’ve worked with PositionalSharding, but there are alternative ways to express shardings. In fact, Sharding is an interface, and any class that implements that interface can be used with functions like device_put.

Another convenient way to express sharding is with the NamedSharding:

from jax.sharding import Mesh
from jax.sharding import PartitionSpec
from jax.sharding import NamedSharding
from jax.experimental import mesh_utils

P = PartitionSpec

devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('a', 'b'))
y = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘

We can define a helper function to make things simpler:

devices = mesh_utils.create_device_mesh((4, 2))
default_mesh = Mesh(devices, axis_names=('a', 'b'))

def mesh_sharding(
    pspec: PartitionSpec, mesh: Optional[Mesh] = None,
  ) -> NamedSharding:
  if mesh is None:
    mesh = default_mesh
  return NamedSharding(mesh, pspec)
y = jax.device_put(x, mesh_sharding(P('a', 'b')))
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘

Here, we use P('a', 'b') to express that the first and second axes of x should be sharded over the device mesh axes 'a' and 'b', respectively. We can easily switch to P('b', 'a') to shard the axes of x over different devices:

y = jax.device_put(x, mesh_sharding(P('b', 'a')))
jax.debug.visualize_array_sharding(y)
┌───────┬───────┬───────┬───────┐
│       │       │       │       │
│ TPU 0 │ TPU 2 │ TPU 6 │ TPU 4 │
│       │       │       │       │
│       │       │       │       │
├───────┼───────┼───────┼───────┤
│       │       │       │       │
│ TPU 1 │ TPU 3 │ TPU 7 │ TPU 5 │
│       │       │       │       │
│       │       │       │       │
└───────┴───────┴───────┴───────┘
# This `None` means that `x` is not sharded on its second dimension,
# and since the Mesh axis name 'b' is not mentioned, shards are
# replicated across it.
y = jax.device_put(x, mesh_sharding(P('a', None)))
jax.debug.visualize_array_sharding(y)
┌───────────────────────┐
│        TPU 0,1        │
├───────────────────────┤
│        TPU 2,3        │
├───────────────────────┤
│        TPU 6,7        │
├───────────────────────┤
│        TPU 4,5        │
└───────────────────────┘

Here, because P('a', None) doesn’t mention the Mesh axis name 'b', we get replication over the axis 'b'. The None here is just acting as a placeholder to line up against the second axis of the value x, without expressing sharding over any mesh axis. (As a shorthand, trailing Nones can be omitted, so that P('a', None) means the same thing as P('a'). But it doesn’t hurt to be explicit!)

To shard only over the second axis of x, we can use a None placeholder in the PartitionSpec:

y = jax.device_put(x, mesh_sharding(P(None, 'b')))
jax.debug.visualize_array_sharding(y)
┌───────────┬───────────┐
│           │           │
│           │           │
│           │           │
│           │           │
│TPU 0,2,4,6│TPU 1,3,5,7│
│           │           │
│           │           │
│           │           │
│           │           │
└───────────┴───────────┘
y = jax.device_put(x, mesh_sharding(P(None, 'a')))
jax.debug.visualize_array_sharding(y)
┌───────┬───────┬───────┬───────┐
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│TPU 0,1│TPU 2,3│TPU 6,7│TPU 4,5│
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
└───────┴───────┴───────┴───────┘

For a fixed mesh, we can even partition one logical axis of x over multiple device mesh axes:

y = jax.device_put(x, mesh_sharding(P(('a', 'b'), None)))
jax.debug.visualize_array_sharding(y)
┌───────────────────────┐
│         TPU 0         │
├───────────────────────┤
│         TPU 1         │
├───────────────────────┤
│         TPU 2         │
├───────────────────────┤
│         TPU 3         │
├───────────────────────┤
│         TPU 6         │
├───────────────────────┤
│         TPU 7         │
├───────────────────────┤
│         TPU 4         │
├───────────────────────┤
│         TPU 5         │
└───────────────────────┘

Using NamedSharding makes it easy to define a device mesh once and give its axes names, then just refer to those names in PartitionSpecs for each device_put as needed.

Computation follows data sharding and is automatically parallelized#

With sharded input data, the compiler can give us parallel computation. In particular, functions decorated with jax.jit can operate over sharded arrays without copying data onto a single device. Instead, computation follows sharding: based on the sharding of the input data, the compiler decides shardings for intermediates and output values, and parallelizes their evaluation, even inserting communication operations as necessary.

For example, the simplest computation is an elementwise one:

from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
x = jax.device_put(x, sharding.reshape(4, 2))
print('input sharding:')
jax.debug.visualize_array_sharding(x)

y = jnp.sin(x)
print('output sharding:')
jax.debug.visualize_array_sharding(y)
input sharding:
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘
output sharding:
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘

Here for the elementwise operation jnp.sin the compiler chose the output sharding to be the same as the input. Moreover, the compiler automatically parallelized the computation, so that each device computed its output shard from its input shard in parallel.

In other words, even though we wrote the jnp.sin computation as if a single machine were to execute it, the compiler splits up the computation for us and executes it on multiple devices.

We can do the same for more than just elementwise operations too. Consider a matrix multiplication with sharded inputs:

y = jax.device_put(x, sharding.reshape(4, 2).replicate(1))
z = jax.device_put(x, sharding.reshape(4, 2).replicate(0))
print('lhs sharding:')
jax.debug.visualize_array_sharding(y)
print('rhs sharding:')
jax.debug.visualize_array_sharding(z)

w = jnp.dot(y, z)
print('out sharding:')
jax.debug.visualize_array_sharding(w)
lhs sharding:
┌───────────────────────┐
│        TPU 0,1        │
├───────────────────────┤
│        TPU 2,3        │
├───────────────────────┤
│        TPU 6,7        │
├───────────────────────┤
│        TPU 4,5        │
└───────────────────────┘
rhs sharding:
┌───────────┬───────────┐
│           │           │
│           │           │
│           │           │
│           │           │
│TPU 0,2,4,6│TPU 1,3,5,7│
│           │           │
│           │           │
│           │           │
│           │           │
└───────────┴───────────┘
out sharding:
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘

Here the compiler chose the output sharding so that it could maximally parallelize the computation: without needing communication, each device already has the input shards it needs to compute its output shard.

How can we be sure it’s actually running in parallel? We can do a simple timing experiment:

x_single = jax.device_put(x, jax.devices()[0])
jax.debug.visualize_array_sharding(x_single)
┌───────────────────────┐
│                       │
│                       │
│                       │
│                       │
│         TPU 0         │
│                       │
│                       │
│                       │
│                       │
└───────────────────────┘
np.allclose(jnp.dot(x_single, x_single),
            jnp.dot(y, z))
True
%timeit -n 5 -r 5 jnp.dot(x_single, x_single).block_until_ready()
5 loops, best of 5: 19.3 ms per loop
%timeit -n 5 -r 5 jnp.dot(y, z).block_until_ready()
5 loops, best of 5: 3.25 ms per loop

Even copying a sharded Array produces a result with the sharding of the input:

w_copy = jnp.copy(w)
jax.debug.visualize_array_sharding(w_copy)
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘

So computation follows data placement: when we explicitly shard data with jax.device_put, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of JAX’s policy of following explicit device placement.

When explicit shardings disagree, JAX errors#

But what if two arguments to a computation are explicitly placed on different sets of devices, or with incompatible device orders? In these ambiguous cases, an error is raised:

import textwrap
from termcolor import colored

def print_exception(e):
  name = colored(f'{type(e).__name__}', 'red')
  print(textwrap.fill(f'{name}: {str(e)}'))
sharding1 = PositionalSharding(jax.devices()[:4])
sharding2 = PositionalSharding(jax.devices()[4:])

y = jax.device_put(x, sharding1.reshape(2, 2))
z = jax.device_put(x, sharding2.reshape(2, 2))
try: y + z
except ValueError as e: print_exception(e)
ValueError: Devices of all `Array` inputs and outputs should
be the same. Got array device ids [0, 1, 2, 3] on platform TPU and
another array's device ids [4, 5, 6, 7] on platform TPU
devices = jax.devices()
permuted_devices = [devices[i] for i in [0, 1, 2, 3, 6, 7, 4, 5]]

sharding1 = PositionalSharding(devices)
sharding2 = PositionalSharding(permuted_devices)

y = jax.device_put(x, sharding1.reshape(4, 2))
z = jax.device_put(x, sharding2.reshape(4, 2))
try: y + z
except ValueError as e: print_exception(e)
ValueError: Devices of all `Array` inputs and outputs should
be the same. Got array device ids [0, 1, 2, 3, 4, 5, 6, 7] on platform
TPU and another array's device ids [0, 1, 2, 3, 6, 7, 4, 5] on
platform TPU

We say arrays that have been explicitly placed or sharded with jax.device_put are committed to their device(s), and so won’t be automatically moved. See the device placement FAQ for more information.

When arrays are not explicitly placed or sharded with jax.device_put, they are placed uncommitted on the default device. Unlike committed arrays, uncommitted arrays can be moved and resharded automatically: that is, uncommitted arrays can be arguments to a computation even if other arguments are explicitly placed on different devices.

For example, the output of jnp.zeros, jnp.arange, and jnp.array are uncommitted:

y = jax.device_put(x, sharding1.reshape(4, 2))
y + jnp.ones_like(y)
y + jnp.arange(y.size).reshape(y.shape)
print('no error!')
no error!

Constraining shardings of intermediates in jitted code#

While the compiler will attempt to decide how a function’s intermediate values and outputs should be sharded, we can also give it hints using jax.lax.with_sharding_constraint. Using jax.lax.with_sharding_constraint is much like jax.device_put, except we use it inside staged-out (i.e. jit-decorated) functions:

sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
x = jax.random.normal(jax.random.key(0), (8192, 8192))
x = jax.device_put(x, sharding.reshape(4, 2))
@jax.jit
def f(x):
  x = x + 1
  y = jax.lax.with_sharding_constraint(x, sharding.reshape(2, 4))
  return y
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘
┌───────┬───────┬───────┬───────┐
│       │       │       │       │
│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │
│       │       │       │       │
│       │       │       │       │
├───────┼───────┼───────┼───────┤
│       │       │       │       │
│ TPU 6 │ TPU 7 │ TPU 4 │ TPU 5 │
│       │       │       │       │
│       │       │       │       │
└───────┴───────┴───────┴───────┘
@jax.jit
def f(x):
  x = x + 1
  y = jax.lax.with_sharding_constraint(x, sharding.replicate())
  return y
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘
┌───────────────────────┐
│                       │
│                       │
│                       │
│                       │
│  TPU 0,1,2,3,4,5,6,7  │
│                       │
│                       │
│                       │
│                       │
└───────────────────────┘

By adding with_sharding_constraint, we’ve constrained the sharding of the output. In addition to respecting the annotation on a particular intermediate, the compiler will use annotations to decide shardings for other values.

It’s often a good practice to annotate the outputs of computations, for example based on how the values are ultimately consumed.

Examples: neural networks#

⚠️ WARNING: The following is meant to be a simple demonstration of automatic sharding propagation with jax.Array, but it may not reflect best practices for real examples. For instance, real examples may require more use of with_sharding_constraint.

We can use jax.device_put and jax.jit’s computation-follows-sharding features to parallelize computation in neural networks. Here are some simple examples, based on this basic neural network:

import jax
import jax.numpy as jnp
def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jnp.maximum(outputs, 0)
  return outputs

def loss(params, batch):
  inputs, targets = batch
  predictions = predict(params, inputs)
  return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
loss_jit = jax.jit(loss)
gradfun = jax.jit(jax.grad(loss))
def init_layer(key, n_in, n_out):
    k1, k2 = jax.random.split(key)
    W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
    b = jax.random.normal(k2, (n_out,))
    return W, b

def init_model(key, layer_sizes, batch_size):
    key, *keys = jax.random.split(key, len(layer_sizes))
    params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))

    key, *keys = jax.random.split(key, 3)
    inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
    targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))

    return params, (inputs, targets)

layer_sizes = [784, 8192, 8192, 8192, 10]
batch_size = 8192

params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)
8-way batch data parallelism#
sharding = PositionalSharding(jax.devices()).reshape(8, 1)
batch = jax.device_put(batch, sharding)
params = jax.device_put(params, sharding.replicate())
loss_jit(params, batch)
Array(23.469475, dtype=float32)
step_size = 1e-5

for _ in range(30):
  grads = gradfun(params, batch)
  params = [(W - step_size * dW, b - step_size * db)
            for (W, b), (dW, db) in zip(params, grads)]

print(loss_jit(params, batch))
10.760101
%timeit -n 5 -r 5 gradfun(params, batch)[0][0].block_until_ready()
5 loops, best of 5: 26.3 ms per loop
batch_single = jax.device_put(batch, jax.devices()[0])
params_single = jax.device_put(params, jax.devices()[0])
%timeit -n 5 -r 5 gradfun(params_single, batch_single)[0][0].block_until_ready()
5 loops, best of 5: 122 ms per loop
4-way batch data parallelism and 2-way model tensor parallelism#
sharding = sharding.reshape(4, 2)
batch = jax.device_put(batch, sharding.replicate(1))
jax.debug.visualize_array_sharding(batch[0])
jax.debug.visualize_array_sharding(batch[1])
┌───────┐
│TPU 0,1│
├───────┤
│TPU 2,3│
├───────┤
│TPU 4,5│
├───────┤
│TPU 6,7│
└───────┘
┌───────┐
│TPU 0,1│
├───────┤
│TPU 2,3│
├───────┤
│TPU 4,5│
├───────┤
│TPU 6,7│
└───────┘
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params

W1 = jax.device_put(W1, sharding.replicate())
b1 = jax.device_put(b1, sharding.replicate())

W2 = jax.device_put(W2, sharding.replicate(0))
b2 = jax.device_put(b2, sharding.replicate(0))

W3 = jax.device_put(W3, sharding.replicate(0).T)
b3 = jax.device_put(b3, sharding.replicate())

W4 = jax.device_put(W4, sharding.replicate())
b4 = jax.device_put(b4, sharding.replicate())

params = (W1, b1), (W2, b2), (W3, b3), (W4, b4)
jax.debug.visualize_array_sharding(W2)
┌───────────┬───────────┐
│           │           │
│           │           │
│           │           │
│           │           │
│TPU 0,2,4,6│TPU 1,3,5,7│
│           │           │
│           │           │
│           │           │
│           │           │
└───────────┴───────────┘
jax.debug.visualize_array_sharding(W3)
┌───────────────────────┐
│                       │
│      TPU 0,2,4,6      │
│                       │
│                       │
├───────────────────────┤
│                       │
│      TPU 1,3,5,7      │
│                       │
│                       │
└───────────────────────┘
print(loss_jit(params, batch))
10.760103
step_size = 1e-5

for _ in range(30):
    grads = gradfun(params, batch)
    params = [(W - step_size * dW, b - step_size * db)
              for (W, b), (dW, db) in zip(params, grads)]
print(loss_jit(params, batch))
10.752466
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
jax.debug.visualize_array_sharding(W2)
jax.debug.visualize_array_sharding(W3)
┌───────────┬───────────┐
│           │           │
│           │           │
│           │           │
│           │           │
│TPU 0,2,4,6│TPU 1,3,5,7│
│           │           │
│           │           │
│           │           │
│           │           │
└───────────┴───────────┘
┌───────────────────────┐
│                       │
│      TPU 0,2,4,6      │
│                       │
│                       │
├───────────────────────┤
│                       │
│      TPU 1,3,5,7      │
│                       │
│                       │
└───────────────────────┘
%timeit -n 10 -r 10 gradfun(params, batch)[0][0].block_until_ready()
10 loops, best of 10: 30.5 ms per loop

Sharp bits#

Generating random numbers#

JAX comes with a functional, deterministic random number generator. It underlies the various sampling functions in the jax.random module, such as jax.random.uniform.

JAX’s random numbers are produced by a counter-based PRNG, so in principle, random number generation should be a pure map over counter values. A pure map is a trivially partitionable operation in principle. It should require no cross-device communication, nor any redundant computation across devices.

However, the existing stable RNG implementation is not automatically partitionable, for historical reasons.

Consider the following example, where a function draws random uniform numbers and adds them to the input, elementwise:

@jax.jit
def f(key, x):
  numbers = jax.random.uniform(key, x.shape)
  return x + numbers

key = jax.random.key(42)
x_sharding = jax.sharding.PositionalSharding(jax.devices())
x = jax.device_put(jnp.arange(24), x_sharding)

On a partitioned input, the function f produces output that is also partitioned:

jax.debug.visualize_array_sharding(f(key, x))
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘

But if we inspect the compiled computation for f on this partitioned input, we see that it does involve some communication:

f_exe = f.lower(key, x).compile()
print('Communicating?', 'collective-permute' in f_exe.as_text())
Communicating? True

One way to work around this is to configure JAX with the experimental upgrade flag jax_threefry_partitionable. With the flag on, the “collective permute” operation is now gone from the compiled computation:

jax.config.update('jax_threefry_partitionable', True)
f_exe = f.lower(key, x).compile()
print('Communicating?', 'collective-permute' in f_exe.as_text())
Communicating? False

The output is still partitioned:

jax.debug.visualize_array_sharding(f(key, x))
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘

One caveat to the jax_threefry_partitionable option, however, is that the random values produced may be different than without the flag set, even though they were generated by the same random key:

jax.config.update('jax_threefry_partitionable', False)
print('Stable:')
print(f(key, x))
print()

jax.config.update('jax_threefry_partitionable', True)
print('Partitionable:')
print(f(key, x))
Stable:
[ 0.72503686  1.8532515   2.983416    3.083253    4.0332246   5.4782867
  6.1720605   7.6900277   8.602836    9.810046   10.861367   11.907651
 12.330483   13.456195   14.808557   15.960099   16.067581   17.739723
 18.335474   19.46401    20.390276   21.116539   22.858128   23.223194  ]

Partitionable:
[ 0.48870957  1.6797972   2.6162715   3.561016    4.4506445   5.585866
  6.0748096   7.775133    8.698959    9.818634   10.350306   11.87282
 12.925881   13.86013    14.477554   15.818481   16.711355   17.586697
 18.073738   19.777622   20.404566   21.119123   22.026257   23.63918   ]

In jax_threefry_partitionable mode, the JAX PRNG remains deterministic, but its implementation is new (and under development). The random values generated for a given key will be the same at a given JAX version (or a given commit on the main branch), but may vary across releases.

SPMD multi-device parallelism with shard_map#

shard_map is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or instances, communicate with each other via explicit collective communication operations.

shard_map is complementary to, and composable with, the automatic compiler-based parallelization built into jit. With jit you write code as if for a single device, and the compiler can automatically partition computation over multiple devices, generating per-device code and communication collectives behind the scenes. With shard_map you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.

If you’re familiar with pmap, think of shard_map as an evolution. It’s more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see a detailed comparison to pmap.)

By reading this tutorial, you’ll learn how to use shard_map to get full control over your multi-device code. You’ll see in detail how it composes with jax.jit’s automatic parallelization and jax.grad’s automatic differentiation. We’ll also give some basic examples of neural network parallelization strategies.

We’ll assume this tutorial is being run in an environment with eight devices:

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices

So, let’s see a shard_map!#

Without further ado, here’s a toy example:

from functools import partial

import jax
import jax.numpy as jnp

from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('x', 'y'))

a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 *  4.).reshape(16, 4)

@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),
         out_specs=P('x', None))
def matmul_basic(a_block, b_block):
  # a_block: f32[2, 8]
  # b_block: f32[8, 4]
  c_partialsum = jnp.dot(a_block, b_block)
  c_block = jax.lax.psum(c_partialsum, 'y')
  # c_block: f32[2, 4]
  return c_block

c = matmul_basic(a, b)   # c: f32[8, 4]

This function computes a matrix multiply in parallel by performing local block matrix multiplies followed by a collective sum operation. We can check the result is correct:

from jax.tree_util import tree_map, tree_all

def allclose(a, b):
  return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))

allclose(c, jnp.dot(a, b))
True

The result is sharded along its rows:

jax.debug.visualize_array_sharding(c)
            
  CPU 0,1   
            
            
  CPU 2,3   
            
            
  CPU 4,5   
            
            
  CPU 6,7   
            

At a high level, shard_map is kind of like vmap or pmap, in that we’re mapping a function over pieces of array data, but notice that

  • shard_map slices up inputs into blocks (and the output is formed by concatenating result blocks), keeping the rank the same, whereas vmap would reduce the rank by mapping away an axis;

  • the mesh argument lets us control precise device placement of computation and results;

  • we’re mapping over multiple data axes at once, and setting up multiple axis names for collectives (both 'x' and 'y' here);

  • since we’re not using jax.jit yet, everything is eagerly evaluated, and we can even print intermediate values for debugging.

The above code is performing the same computation as this jax.jit automatic parallelization code:

from jax.sharding import NamedSharding

a = jax.device_put(a, NamedSharding(mesh, P('x', 'y')))
b = jax.device_put(b, NamedSharding(mesh, P('y', None)))

@jax.jit
def matmul_reference(a, b):
  c = jnp.dot(a, b)
  return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None)))

c_ref = matmul_reference(a, b)
allclose(c_ref, jnp.dot(a, b))
True

We can think of shard_map as performing a device_put or with_sharding_constraint on its inputs according to its mesh and in_specs arguments, so the blocks over which matmul_basic operates are the same as in matmul_reference:

print('a blocks:'); jax.debug.visualize_array_sharding(a)
print('b blocks:'); jax.debug.visualize_array_sharding(b)
print('c blocks:'); jax.debug.visualize_array_sharding(c)
a blocks:
b blocks:
c blocks:
                                                  
          CPU 0                    CPU 1          
                                                  
                                                  
          CPU 2                    CPU 3          
                                                  
                                                  
          CPU 4                    CPU 5          
                                                  
                                                  
          CPU 6                    CPU 7          
                                                  
           
           
CPU 0,2,4,6
           
           
           
           
           
CPU 1,3,5,7
           
           
           
            
  CPU 0,1   
            
            
  CPU 2,3   
            
            
  CPU 4,5   
            
            
  CPU 6,7   
            

Slow down, start with the basics!#

Rank-reducing vs rank-preserving maps#

We can think of vmap and pmap as unstacking each array input along an axis (e.g. unpacking a 2D matrix into its 1D rows), applying its body function to each piece, and stacking the results back together, at least when collectives aren’t involved:

def check_vmap(f, xs):
  ans = jax.vmap(f, in_axes=(0,), out_axes=0)(xs)
  expected = jnp.stack([f(x) for x in xs])  # vmap reference semantics
  print(allclose(ans, expected))

check_vmap(lambda x: x @ x, jnp.arange(12).reshape(4, 3))
True

For example, if xs had shape f32[8,5] then each x would have shape f32[5], and if each f(x) had shape f32[3,7] then the final stacked result vmap(f)(xs) would have shape f32[8,3,7]. That is, each application of the body function f takes as argument inputs with one fewer axis than the corresponding argument to vmap(f). We can say these are rank-reducing maps with unstacking/stacking of inputs/outputs.

The number of logical applications of f, or instances of f, is determined by the size of the input axis being mapped over: for example, if we map over an input axis of size 8, semantically we get 8 logical applications of the function.

In contrast, shard_map does not have this rank-reducing behavior. Instead, we can think of it as slicing (or “unconcatenating”) along input axes into blocks, applying the body function, and concatenating the results back together (again when collectives aren’t involved):

import numpy as np
devices = np.array(jax.devices()[:4])
mesh = Mesh(devices, ('i',))  # mesh.shape['i'] = 4

def check_shmap(f, y):
  ans = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))(y)
  expected = jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, mesh.shape['i'])])
  print(allclose(ans, expected))

check_shmap(lambda x: x.T @ x, jnp.arange(32).reshape(8, 4))
True

Recall that jnp.split slices its input into equally-sized blocks with the same rank, so that if in the above example y had shape f32[8,5] then each y_blk would have shape f32[2,5], and if each f(y_blk) had shape f32[3,7] then the final concatenated result shard_map(f, ...)(y) would have shape f32[12,7]. So shard_map maps over shards, or blocks, of its inputs. We can say it’s a rank-preserving map with unconcatenating/concatenating of its inputs/outputs.

The number of logical applications of f is determined by the mesh size, not by any input axis size: for example, if we have a mesh of total size 4 (i.e. over 4 devices) then semantically we get 4 logical applications of the function, corresponding to the 4 devices physically computing them.

Controlling how each input is split (unconcatenated) and tiled with in_specs#

Each of the in_specs identifies some of the corresponding input array’s axes with mesh axes by name using PartitionSpecs, representing how to split (or unconcatenate) that input into the blocks to which the body function is applied. That identification determines the shard sizes; when an input axis is identified with a mesh axis, the input is split (unconcatenated) along that logical axis into a number of pieces equal to the corresponding mesh axis size. (It’s an error if the corresponding mesh axis size does not evenly divide the input array axis size.) If an input’s pspec does not mention a mesh axis name, then there’s no splitting over that mesh axis. For example:

devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, ('i', 'j'))

@partial(shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))
def f1(x_block):
  print(x_block.shape)  # prints (3, 12)
  return x_block

x1 = jnp.arange(12 * 12).reshape(12, 12)
y = f1(x1)
(3, 12)

Here, because the input pspec did not mention the mesh axis name 'j', no input array axis is split over that mesh axis; similarly, because the second axis of the input array is not identified with (and hence split over) any mesh axis, application of f1 gets a full view of the input along that axis.

When a mesh axis is not mentioned in an input pspec, we can always rewrite to a less efficient program where all mesh axes are mentioned but the caller performs a jnp.tile, for example:

@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', 'j'))
def f2(x_block):
  print(x_block.shape)
  return x_block

x = jnp.arange(12 * 12).reshape(12, 12)
x_ = jnp.tile(x, (1, mesh.shape['j']))  # x_ has shape (12, 24)
y = f2(x_)  # prints (3,12), and f1(x) == f2(x_)
(3, 12)

In other words, because each input pspec can mention each mesh axis name zero or one times, rather than having to mention each name exactly once, we can say that in addition to the jnp.split built into its input, shard_map also has a jnp.tile built into its input, at least logically (though the tiling may not need to be carried out physically, depending on the arguments’ physical sharding layout). The tiling to use is not unique; we could also have tiled along the first axis, and used the pspec P(('j', 'i'), None).

Physical data movement is possible on inputs, as each device needs to have a copy of the appropriate data.

Controlling how each output assembled by concatenation, block transposition, and untiling using out_specs#

Analogously to the input side, each of the out_specs identifies some of the corresponding output array’s axes with mesh axes by name, representing how the output blocks (one for each application of the body function, or equivalently one for each physical device) should be assembled back together to form the final output value. For example, in both the f1 and f2 examples above the out_specs indicate we should form the final output by concatenating together the block results along both axes, resulting in both cases an array y of shape (12, 24). (It’s an error if an output shape of the body function, i.e. an output block shape, has a rank too small for the concatenation described by the corresponding output pspec.)

When a mesh axis name is not mentioned in an output pspec, it represents an un-tiling: when the user writes an output pspec which does not mention one of the mesh axis names, they promise that the output blocks are equal along that mesh axis, and so only one block along that axis is used in the output (rather than concatenating all the blocks together along that mesh axis). For example, using the same mesh as above:

x = jnp.array([[3.]])

z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))()
print(z)  # prints the same as jnp.tile(x, (4, 2))

z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', None))()
print(z)  # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))

z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P(None, None))()
print(z)  # prints the same as jnp.tile(x, (1, 1)), or just x
[[3. 3.]
 [3. 3.]
 [3. 3.]
 [3. 3.]]
[[3.]
 [3.]
 [3.]
 [3.]]
[[3.]]

The body function closing over an array value is equivalent to passing it as an augment with a corresponding input pspec of P(None, None). As another example, following more closely to the other examples above:

@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None))
def f3(x_block):
  return jax.lax.psum(x_block, 'j')

x = jnp.arange(12 * 12).reshape(12, 12)
y3 = f3(x)
print(y3.shape)
(12, 6)

The result has a second axis size of 6, half the size of the input’s second axis. In this case, the un-tile expressed by not mentioning the mesh axis name 'j' in the output pspec was safe because of the collective psum, which ensures each output block is equal along the corresponding mesh axis. Here are two more examples where we vary which mesh axes are mentioned in the output pspec:

@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f4(x_block):
  return jax.lax.psum(x_block, 'i')

x = jnp.arange(12 * 12).reshape(12, 12)
y4 = f4(x)
print(y4.shape)  # (3,12)


@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, None))
def f5(x_block):
  return jax.lax.psum(x_block, ('i', 'j'))

y5 = f5(x)
print(y5.shape)  # (3,6)
(3, 12)
(3, 6)

On the physical side, not mentioning a mesh axis name in an output pspec assembles an Array from the output device buffers with replicated layout along that mesh axis.

There is no runtime check that the output blocks are actually equal along a mesh axis to be un-tiled along, or equivalently that the corresponding physical buffers have equal values and thus can be interpreted as a replicated layout for a single logical array. But we can provide a static check mechanism which raises an error on all potentially-incorrect programs.

Because the out_specs can mention mesh axis names zero or one times, and because they can be mentioned in any order, we can say that in addition to the jnp.concatenate built into its output, shard_map also has both an untile and a block transpose built into its output.

Physical data movement is not possible on outputs, no matter the output pspec. Instead, out_specs just encodes how to assemble the block outputs into Arrays, or physically how to interpret the buffers across devices as the physical layout of a single logical Array.

API Specification#

from jax.sharding import Mesh
Specs = PyTree[PartitionSpec]

def shard_map(
    f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs,
    auto: collections.abc.Set[AxisName] = frozenset([]),
    check_rep: bool = True,
) -> Callable:
  ...

where:

  • communication collectives like psum in the body of f can mention the axis names of mesh;

  • mesh encodes devices arranged in an array and with associated axis names, just like it does for sharding.NamedSharding;

  • in_specs and out_specs are PartitionSpecs which can affinely mention axis names from mesh to express slicing/unconcatenation and concatenation of inputs and outputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively;

  • auto is an optional set of axis names corresponding to the subset of names of mesh to treat automatically in the body, as in the caller, rather than manually;

  • check_rep is an optional boolean indicating whether to check statically for any replication errors in out_specs, and also whether to enable a related automatic differentiation optimization (see JEP).

The shapes of the arguments passed to f have the same ranks as the arguments passed to shard_map-of-f, and the shape of an argument to f is computed from the shape shape of the corresponding argument to shard_map-of-f and the corresponding PartitionSpec spec as roughly tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec)).

Collectives tutorial#

A shard_map need not be a pure map: function applications can communicate with each other via collectives, using axis names defined in the mesh argument.

Recall that shard_map maps a function over shards, or blocks, of input data, so that this:

mesh = Mesh(jax.devices(), ('i',))
x = jnp.arange(16.)
f_shmapped = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))
y = f_shmapped(x)

Computes the same values, evaluating applications of f to the same argument values, as this reference function:

def f_shmapped_ref(x):
  x_blocks = jnp.array_split(x, mesh.shape[0])
  y_blocks = [f(x_blk) for x_blk in x_blocks]
  return jnp.concatenate(y_blocks)

We call these applications of f to different argument shards function instances. Each function instance is executed on a different device (or subset of devices).

These reference semantics work when f has no communication collectives in it. But what if we want the function instances to communicate, corresponding to having cross-device communication? That is, what are the reference semantics when f contains a collective? Say f has just one collective, and is of the form

def f(x_blk):
  z_blk = f_part1(x_blk)
  u_blk = collective(z_blk, axis_name)
  v_blk = f_part2(x_blk, z_blk, u_blk)
  return v_blk

where we’re assuming there’s only one mesh axis we’re mapping over, and axis_name is the corresponding name for it. Then the reference semantics would look more like:

def f_shmapped_ref(x):
  x_blocks = jnp.array_split(x, mesh.shape[0])
  z_blocks = [f_part1(x_blk) for x_blk in x_blocks]
  u_blocks = [collective_ref(i, z_blocks) for i in range(len(z_blocks))]
  v_blocks = [f_part2(x_blk, z_blk, u_blk) for x_blk, z_blk, u_blk
              in zip(x_blocks, z_blocks, u_blocks)]
  return jnp.concatenate(v_blocks)

Notice that collective_ref might depend on all the z_blocks. That is, while f_part1 and f_part2 are mapped over blocks independently, a collective introduces some amount of cross-block dependence. Physically, that means communication across devices. Exactly what communication happens, and what values are computed, depend on the collective.

psum#

The simplest collective may be jax.lax.psum, which computes an all-reduce-sum along a device mesh axis (or multiple axes). Here’s a toy example:

Illustration of a psum computation.
import jax
import jax.numpy as jnp
from jax import lax

from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental.shard_map import shard_map
mesh1d = Mesh(jax.devices()[:4], ('i',))

@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P(None))
def f1(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.psum(x_block, 'i')
  print('AFTER:\n', y_block)
  return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f1(x)
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22 20 12 17]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[22 20 12 17]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[22 20 12 17]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[22 20 12 17]

FINAL RESULT:
 [22 20 12 17]

The prints show that each function application starts with its own chunk of the argument value x_block. After the psum, each function application has the same value of y_block, computed by summing the applications’ x_block values together.

In the case where there’s a single axis name in the computation, we could say that the collective_ref reference implementation for psum is

def psum_ref(_, x_blocks):
  tot = sum(x_blocks)
  return [tot] * len(x_blocks)

Notice also that because f1 returns y_block, the result of a psum over 'i', we can use out_specs=P() so the caller gets a single logical copy of the result value, rather than a tiled result.

When there is more than one mesh axis, we can perform a psum over each one separately, or over multiple axes at once:

mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j'))

@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f2(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.psum(x_block, 'i')
  print('AFTER:\n', y_block)
  return y_block

y = f2(jnp.arange(16).reshape(4, 4))
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[0 1]
 [4 5]]

On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[2 3]
 [6 7]]

On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8  9]
 [12 13]]

On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[10 11]
 [14 15]]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[ 8 10]
 [16 18]]

On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[12 14]
 [20 22]]

On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8 10]
 [16 18]]

On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[12 14]
 [20 22]]

FINAL RESULT:
 [[ 8 10 12 14]
 [16 18 20 22]]

By applying a psum over mesh axis 'i', we get values of y_block which are equal along axis ‘i', but not axis 'j'. (So we can use out_specs=P(None, 'j') to get a single logical result along that axis.)

If we apply the psum over both axes, the y_block value is equal along both axes:

@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None))
def f3(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.psum(x_block, ('i', 'j'))
  print('AFTER:\n', y_block)
  return y_block

y = f3(jnp.arange(16).reshape(4, 4))
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[0 1]
 [4 5]]

On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[2 3]
 [6 7]]

On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8  9]
 [12 13]]

On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[10 11]
 [14 15]]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[20 24]
 [36 40]]

On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[20 24]
 [36 40]]

On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[20 24]
 [36 40]]

On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[20 24]
 [36 40]]

FINAL RESULT:
 [[20 24]
 [36 40]]

In machine learning, we often use psum to compute total losses or, when we have a grad inside the shard_mapped function body, total gradients.

In the sequel, we’ll see how psum can be implemented in terms of other primitives, which gives some intuition about its communication cost.

all_gather#

Another fundamental operation is gathering array shards along an axis, so that each function application has a full copy of the data along that axis:

Illustration of an all_gather computation.
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f4(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.all_gather(x_block, 'i', tiled=True)
  print('AFTER:\n', y_block)
  return y_block

x = jnp.array([3, 9, 5, 2])
y = f4(x)
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[9]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[2]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 9 5 2]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[3 9 5 2]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[3 9 5 2]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[3 9 5 2]

FINAL RESULT:
 [3 9 5 2 3 9 5 2 3 9 5 2 3 9 5 2]

The prints show that each function application again starts with its own chunk of the argument value x_block. After the all_gather, they have a common value, computed by concatenating the values of x_block.

(Notice that we actually can’t set out_specs=P() here. For technical reasons related to automatic differentiation, we consider the output of all_gather not to be guaranteed invariant across devices. If we wanted it to be guaranteed invariant, we could use jax.lax.all_gather_invariant, or in this case we could just avoid doing the all_gather in the function body and instead just use out_specs=P('i') to perform the concatenation.)

When tiled=False (the default), results are stacked along a new axis instead of concatenated:

@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f5(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.all_gather(x_block, 'i', tiled=False)
  print('AFTER:\n', y_block)
  return y_block

y = f5(x)
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[9]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[2]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[[3]
 [9]
 [5]
 [2]]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[[3]
 [9]
 [5]
 [2]]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[[3]
 [9]
 [5]
 [2]]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[[3]
 [9]
 [5]
 [2]]

FINAL RESULT:
 [[3]
 [9]
 [5]
 [2]
 [3]
 [9]
 [5]
 [2]
 [3]
 [9]
 [5]
 [2]
 [3]
 [9]
 [5]
 [2]]

We could write the collective_ref reference semantics function for all_gather as

def all_gather_ref(_, x_blocks, *, tiled=False):
  combine = jnp.concatenate if tiled else jnp.stack
  return [combine(x_blocks)] * len(x_blocks)

In deep learning, we might use all_gathers on parameters in fully sharded data parallelism (FSDP).

psum_scatter#

The jax.lax.psum_scatter collective is a bit less intuitive. It’s like psum except each function instance gets only one shard of the result:

Illustration of a psum_scatter computation.
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f6(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True)
  print('AFTER:\n', y_block)
  return y_block

x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f6(x)
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[20]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[12]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[17]

FINAL RESULT:
 [22 20 12 17]

As shown by the prints, each resulting y_block has a smaller size than the argument x_block, unlike with psum. Moreover, compared to psum, here each y_block only represents a slice of the sum of the x_blocks across function instances. (Even though each function instance gets only one shard of the sum, the final output y is the same as in the psum example because here we use out_specs=P('i') to concatenate each function instance’s output.)

In terms of what values are computed, a collective_ref reference implementation might look like:

def psum_scatter_ref(i, x_blocks, *, tiled=False):
  axis_size = len(x_blocks)
  tot = sum(x_blocks)
  if tiled:
    tot = tot.reshape(axis_size, -1, *tot.shape[1:])  # split leading axis
  return [tot[i] for i in range(tot.shape[0])]

It’s not captured in the semantics reference implementation, but psum_scatter is useful because these results can be computed more efficiently, with less communication, than a full psum. In fact, one way to think of psum_scatter is as “the first half of a psum, before an all_gather”. That is, one way to implement psum is:

def psum(x, axis_name):
  summed_chunk = jax.lax.psum_scatter(x, axis_name)
  return jax.lax.all_gather(summed_chunk, axis_name)

Indeed, this implementation is often used on both TPU and GPU!

The reason psum_scatter can require about half the communication as a full psum is illustrated the ppermute section.

Another intuition is that we can use psum_scatter to implement a distributed matrix multiplication with inputs and outputs sharded over the same axis. In machine learning, psum_scatter can be used in tensor-parallel matrix multiplies or fully-sharded data parallel gradient accumulation, as shown in the examples to follow.

ppermute#

The jax.lax.ppermute collective provides the most direct way for function instances to send data to one another. Given a mesh axis and a list of (source_index, destination_index) pairs representing indices along that mesh axis, ppermute sends its argument value from each source function instance to each destination:

@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f7(x_block):
  sz = jax.lax.psum(1, 'i')
  print('BEFORE:\n', x_block)
  y_block = jax.lax.ppermute(x_block, 'i', [(i, (i + 1) % sz) for i in range(sz)])
  print('AFTER:\n', y_block)
  return y_block

y = f7(jnp.arange(8))
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[0 1]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[2 3]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[4 5]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[6 7]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[6 7]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[0 1]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[2 3]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[4 5]

FINAL RESULT:
 [6 7 0 1 2 3 4 5]

In this case, with just two function instances, each instance’s value of y_block is the other’s value of x_block.

Source indices and destination indices can’t be repeated. If an index does not appear as a destination, then the value of the corresponding function instance’s result is an array of zeros.

A collective_ref reference implementation could look like

def ppermute_ref(i, x_blocks, perm):
  results = [jnp.zeros_like(x_blocks[0])] * len(x_blocks)
  for src, dst in perm:
    results[dst] = x_blocks[src]
  return results

Other collectives can be implemented efficiently, in terms of total communication, using ppermutes where each function passes data only to its neighbors. For example, we could implement psum_scatter using a sequence of ppermutes and local additions this way:

Illustration of a psum_scatter implementation.

Or, with a numerical example:

Illustration of a psum_scatter implementation.

Intuitively, on each iteration each function instance sends ‘up’ the value it received on the previous iteration, and reduces (adds) the value it receives this iteration. In code, it might look like this:

def psum_scatter(x, axis_name, *, tiled=False):
  size = jax.lax.psum(1, axis_name)
  idx = jax.lax.axis_index(axis_name)  # function instance index along axis_name
  if tiled:
    x = x.reshape(size, -1, *x.shape[1:])  # split leading axis
  shift = partial(jax.lax.ppermute, axis_name=axis_name,
                  perm=[(i, (i - 1) % size) for i in range(size)])
  for i in range(1, size):
    update = shift(x[(idx + i) % size])
    x = x.at[(idx + i + 1) % size].add(update)
  return x[idx]
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f8(x_block):
  print('BEFORE:\n', x_block)
  y_block = psum_scatter(x_block, 'i', tiled=True)
  print('AFTER:\n', y_block)
  return y_block

x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f8(x)
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[20]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[12]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[17]

FINAL RESULT:
 [22 20 12 17]

On TPU, there are higher-dimensional variants of this algorithm to exploit multiple bidirectional physical mesh axes.

Notice that psum_scatter is the transpose of all_gather. Indeed, a way to implement all_gather in terms of ppermute looks like the reverse of the above process:

Illustration of an all_gather implementation.

In deep learning, we might use ppermute when implementing SPMD pipeline parallelism, where we divide our network along its depth into stages and evaluate the applications of stages in parallel. Or we might use ppermute in parallelizing the evaluation of convolutional layers, where we shard over spatial axes and thus devices must communicate “halos” to each other. Or it may be used under-the-hood in tensor-parallel matrix multiplies.

all_to_all#

A final collective is all_to_all, which is essentially a block matrix transpose operating along one positional axis and one cross-device axis:

Illustration of an all_to_all computation.
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f9(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.all_to_all(x_block, 'i', split_axis=0, concat_axis=0,
                               tiled=True)
  print('AFTER:\n', y_block)
  return y_block

x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f9(x)
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 5 5 9]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[1 9 3 7]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[4 2 5 1]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[1 6 8 2]

FINAL RESULT:
 [3 5 5 9 1 9 3 7 4 2 5 1 1 6 8 2]

The split_axis argument indicates which positional axis should be sharded and partitioned across the mesh axis. The concat_axis argument indicates the axis along which the communicated results should be concatenated or stacked.

When tiled=False (the default), the split_axis axis size must equal the size of the mesh axis named axis_name, and a new axis of that size is created at position concat_axis for the stacked results. When tiled=True, the split_axis axis size need only be evenly divisible by the size of the mesh axis, and results are concatenated along the existing axis concat_axis.

The collective_ref reference semantics when split_axis=0 and concat_axis=0 might look like:

def all_to_all_ref(_, x_blocks, *, tiled=False):
  axis_size = len(x_blocks)
  if tiled:
    splits = [jnp.array_split(x, axis_size) for x in x_blocks]
    return [jnp.concatenate(s) for s in zip(*splits)]
  else:
    splits = [list(x) for x in x_blocks]
    return [jnp.stack(s) for s in zip(*splits)]

In deep learning, we might use all_to_all in mixture-of-expert routing, where we first sort our local batch of examples according to which expert they should go to, then apply an all_to_all to redistribute examples to experts.

Toy examples#

How might we use shard_map and collective communication in practice? These examples, while simple, give some idea.

Matrix multiplies#

Parallelizing matrix multiplication is central in scaling up deep learning models, both for training and for inference. When jax.jit automatically parallelizes matrix multiplication, it can use one of several different strategies, depending on matrix sizes, hardware details, and other factors. How might we write some of those parallelized routines more explicitly using shard_map? And how can we optimize them to get better compute/communication overlap and thus improve FLOP utilization?

import jax
import jax.numpy as jnp

from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental.shard_map import shard_map
mesh = Mesh(jax.devices()[:4], ('i',))

def device_put(x, pspec):
  return jax.device_put(x, NamedSharding(mesh, pspec))
Example 1: all-gather on one side#

Consider performing a matrix multiplication where we shard the left-hand side argument (can think: parameters) on its leading (non-contracting) dimension:

lhs_spec = P('i', None)
lhs = device_put(jax.random.normal(jax.random.key(0), (8, 8)), lhs_spec)

And wee shard the right-hand side argument (can think: activations) on its contracting dimension, with a similar sharding for the output:

rhs_spec = P('i', None)
rhs = device_put(jax.random.normal(jax.random.key(1), (8, 4)), rhs_spec)

To perform this matrix multiplication, we can first all-gather the right-hand side and then perform local matrix multiplies against the sharded left-hand side:

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_allgather(lhs_block, rhs_block):
  rhs = jax.lax.all_gather(rhs_block, 'i', tiled=True)
  return lhs_block @ rhs
out = matmul_allgather(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True

That’s great, but we’re not getting any compute/communication overlap here: before we can start the matmul, we need the all_gather to complete. Here’s a profile using the same code, but on larger example shapes ((8192, 8192) for lhs and (8192, 1024) for rhs):

Profile of an all-gather matmul without overlap.

We can get compute/communication overlap if instead of calling all_gather we basically inline our above implementation of all_gather in terms of ppermute, then interleave steps of the gather permutation with local matrix multiplies:

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_allgather_overlapped(lhs_block, rhs_block):
  size = jax.lax.psum(1, 'i')
  idx = jax.lax.axis_index('i')
  shift = partial(jax.lax.ppermute, axis_name='i',
                  perm=[(i, (i + 1) % size) for i in range(size)])

  B = lhs_block.shape[1] // size
  lhs_blocks = lambda i: lax.dynamic_slice_in_dim(lhs_block, i * B, B, 1)

  out_block = lhs_blocks(idx) @ rhs_block
  for i in range(1, size):
    rhs_block = shift(rhs_block)
    out_block += lhs_blocks((idx - i) % size) @ rhs_block
  return out_block
out = matmul_allgather_overlapped(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True

This implementation allows overlap between communication and computation, and also avoids gathering a large intermediate onto each device. But on TPU it uses only half the interconnect bandwidth by permuting in only one direction along the ring. To permute bidirectionally, we just split the blocks in half and send each half in each direction:

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_allgather_overlapped_bidi(lhs_block, rhs_block):
  size = jax.lax.psum(1, 'i')
  idx = jax.lax.axis_index('i')
  shift_up = partial(jax.lax.ppermute, axis_name='i',
                     perm=[(i, (i + 1) % size) for i in range(size)])
  shift_dn = partial(jax.lax.ppermute, axis_name='i',
                     perm=[(i, (i - 1) % size) for i in range(size)])

  B = lhs_block.shape[1] // size // 2  # half-size blocks
  lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 1)

  rhs_block_lo, rhs_block_hi = jnp.split(rhs_block, 2, axis=0)
  out_block  = lhs_blocks(idx, 0) @ rhs_block_lo
  out_block += lhs_blocks(idx, 1) @ rhs_block_hi
  for i in range(1, size):
    rhs_block_lo = shift_up(rhs_block_lo)
    rhs_block_hi = shift_dn(rhs_block_hi)
    out_block += lhs_blocks((idx - i) % size, 0) @ rhs_block_lo
    out_block += lhs_blocks((idx + i) % size, 1) @ rhs_block_hi
  return out_block
out = matmul_allgather_overlapped_bidi(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
Profile of an all-gather matmul with overlap.

In practice, to reduce compile times we would probably roll this into a jax.lax.fori_loop. We might also have additional axes of parallelism involved.

Example 2: psum_scatter the result#

Another sharding we might start with has both lhs and rhs sharded along their contracting dimensions, with the output sharded like rhs again:

lhs_spec = P(None, 'i')
lhs = device_put(lhs, lhs_spec)

rhs_spec = P('i', None)
rhs = device_put(rhs, rhs_spec)

Here we can use a reduce_scatter to perform the contraction sum over shards:

@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_psumscatter(lhs_block, rhs_block):
  out_summand = lhs_block @ rhs_block
  return jax.lax.psum_scatter(out_summand, 'i', tiled=True)

out = matmul_psumscatter(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True

But the scattering communication must wait for the entire local matrix multiply to finish before it can start. To get communication/computation overlap, we can inline an implementation of psum_scatter in terms of ppermute, then interleave the communication steps with local matrix multiplies:

@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_psumscatter_overlapped(lhs_block, rhs_block):
  size = jax.lax.psum(1, 'i')
  idx = jax.lax.axis_index('i')
  shift = partial(jax.lax.ppermute, axis_name='i',
                  perm=[(i, (i - 1) % size) for i in range(size)])
  lhs_block = lhs_block.reshape(size, -1, lhs_block.shape[1])  # split 1st axis

  out_summand = lhs_block[(idx + 1) % size] @ rhs_block
  for i in range(1, size):
    out_summand = shift(out_summand)
    out_summand += lhs_block[(idx + i + 1) % size] @ rhs_block
  return out_summand
out = matmul_psumscatter_overlapped(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True

As in the previous example, to fully utilize interconnects on TPU, we’d run a bidirectional version:

@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_psumscatter_overlapped_bidi(lhs_block, rhs_block):
  size = jax.lax.psum(1, 'i')
  idx = jax.lax.axis_index('i')
  shift_up = partial(jax.lax.ppermute, axis_name='i',
                     perm=[(i, (i + 1) % size) for i in range(size)])
  shift_dn = partial(jax.lax.ppermute, axis_name='i',
                     perm=[(i, (i - 1) % size) for i in range(size)])

  B = lhs_block.shape[0] // size // 2  # half-size blocks
  lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 0)

  out_summand_lo = lhs_blocks((idx - 1) % size, 0) @ rhs_block
  out_summand_hi = lhs_blocks((idx + 1) % size, 1) @ rhs_block
  for i in range(1, size):
    out_summand_lo = shift_up(out_summand_lo)
    out_summand_hi = shift_dn(out_summand_hi)
    out_summand_lo += lhs_blocks((idx - i - 1) % size, 0) @ rhs_block
    out_summand_hi += lhs_blocks((idx + i + 1) % size, 1) @ rhs_block
  return jnp.concatenate([out_summand_lo, out_summand_hi])
out = matmul_psumscatter_overlapped_bidi(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True

Neural networks#

We can use shard_map to parallelize computation in neural networks, either by itself or in combination with the automatic partitioning in jax.jit. This section has a few examples based on this toy neural network and random data:

import jax
import jax.numpy as jnp

def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jax.nn.relu(outputs)
  return outputs

def loss(params, batch):
  inputs, targets = batch
  predictions = predict(params, inputs)
  return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
def init_layer(key, n_in, n_out):
    k1, k2 = jax.random.split(key)
    W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
    b = jax.random.normal(k2, (n_out,))
    return W, b

def init(key, layer_sizes, batch_size):
    key, *keys = jax.random.split(key, len(layer_sizes))
    params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))

    key, *keys = jax.random.split(key, 3)
    inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
    targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))

    return params, (inputs, targets)
layer_sizes = [784, 128, 128, 128, 128, 128, 8]
batch_size = 32

params, batch = init(jax.random.PRNGKey(0), layer_sizes, batch_size)

Compare these examples with the purely automatic partitioning examples in the “Distributed arrays and automatic partitioning” doc. While in those automatic partitioning examples we don’t need to edit the model functions to use different parallelization strategies, with shard_map we often do.

8-way batch data parallelism#

The simplest multi-device parallelism strategy is to shard the batch of inputs and targets over multiple devices, replicate the parameters over those devices, and apply the model in parallel to those shards of data. To evaluate the total loss, the devices need only communicate with a scalar-sized all-reduce-sum at the end. (To evaluate the gradient of the loss, the devices must perform all-reduce-sums of parameter gradients in the backward pass.)

from functools import partial

from jax.sharding import NamedSharding, Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map
from jax.experimental import mesh_utils

devices = mesh_utils.create_device_mesh((8,))

# replicate initial params on all devices, shard data batch over devices
mesh = Mesh(devices, ('batch',))
batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
params = jax.device_put(params, NamedSharding(mesh, P()))

# adapt the loss function to sum the losses across devices
def loss_dp(params, batch):
  @partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P())
  def loss_spmd(local_batch):
    inputs, targets = local_batch
    predictions = predict(params, inputs)  # use reference 'predict`
    local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
    return jax.lax.pmean(local_loss, 'batch')
  return loss_spmd(batch)

We can check that the loss and its gradients match the reference (base) model:

print(jax.jit(loss)(params, batch))
print(jax.jit(loss_dp)(params, batch))
22.779888
22.779888
def allclose(a, b):
  return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))

print(allclose(jax.jit(jax.grad(loss))(params, batch),
               jax.jit(jax.grad(loss_dp))(params, batch)))
True

We can print the compiler IR to inspect the gradient computation and verify that the collective all-reduce-sum operations happen where we’d expect: at the end of the forward pass to compute the loss value, and in the backward pass to compute the total parameter gradients.

8-way fully sharded data parallelism (FSDP)#

Another strategy is to additionally shard the parameters over the devices, all-gathering each one when the full value is needed for the jnp.dot or bias addition. Since we only have one full parameter in local device memory at a time, rather than keeping all parameters in all device memories as in the preceding DP example, we free up significant memory that we can use for larger models or larger batch sizes. And because XLA will overlap computation and inter-device communication, the wall-clock time doesn’t suffer.

So now we need collectives in two places: the model prediction function predict needs to all-gather the parameters before they’re used, and as in the DP case the loss function needs to sum the local losses to compute the total loss.

There’s one other ingredient we need: we don’t want to store the fully gathered parameters from the forward pass for use on the backward pass. Instead, we want to gather them again on the backward pass. We can express that by using jax.remat with a custom policy (or a custom_vjp), though XLA typically does that rematerialization automatically.

This general FSDP approach is similar to weight update sharding (WUS) and ZeRO-3.

# shard data batch *and params* over devices
mesh = Mesh(devices, ('batch',))
batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
params = jax.device_put(params, NamedSharding(mesh, P('batch')))

# adapt the prediction function to gather weights just before their use,
# and to re-gather them on the backward pass (rather than saving them)
@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')
def predict_fsdp(params_frag, inputs):
  for W_frag, b_frag in params_frag:
    W = jax.lax.all_gather(W_frag, 'batch', tiled=True)
    b = jax.lax.all_gather(b_frag, 'batch', tiled=True)
    outputs = jnp.dot(inputs, W) + b
    inputs = jax.nn.relu(outputs)
  return outputs

def loss_fsdp(params, batch):
  @partial(shard_map, mesh=mesh, in_specs=P('batch'), out_specs=P())
  def loss_spmd(local_params, local_batch):
    inputs, targets = local_batch
    predictions = predict_fsdp(local_params, inputs)
    local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
    return jax.lax.pmean(local_loss, 'batch')
  return loss_spmd(params, batch)

Again we can check that the loss and its gradients match the reference model:

print(jax.jit(loss)(params, batch))
print(jax.jit(loss_fsdp)(params, batch))

print(allclose(jax.jit(jax.grad(loss))(params, batch),
               jax.jit(jax.grad(loss_fsdp))(params, batch)))
22.779888
22.779888
True
8-way tensor parallelism (TP)#

Usually we don’t use tensor model parallelism by itself, but seeing it in isolation is a good warmup on parallel matrix multiplication. It’s also a good example of using shard_map in a library function, called in a larger jit-based computation.

The parallelization idea is that we’ll keep the data/activations sharded over its feature axis (rather than its batch axis), and we’ll similarly shard weight matrices over their input-feature axis (and biases over their feature axis). Then to perform the parallel matrix multiplication, we’ll perform local matrix multiplications followed by a psum_scatter to sum the local results and efficiently scatter the result’s shards.

devices = mesh_utils.create_device_mesh((8,))
mesh = Mesh(devices, ('feats',))

batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats')))
params = jax.device_put(params, NamedSharding(mesh, P('feats')))

def predict_tp(params, inputs):
  for W, b in params:
    outputs = gemm_tp(inputs, W, b)
    inputs = jax.nn.relu(outputs)
  return outputs

@partial(shard_map, mesh=mesh,
         in_specs=(P(None, 'feats'), P('feats', None), P('feats')),
         out_specs=P(None, 'feats'))
def gemm_tp(inputs, W, b):
  block_result = jnp.dot(inputs, W)
  return jax.lax.psum_scatter(block_result, 'feats',
                              scatter_dimension=1, tiled=True) + b

def loss_tp(params, batch):
  inputs, targets = batch
  predictions = predict_tp(params, inputs)
  return jnp.mean(jnp.sum((predictions - targets) ** 2, axis=-1))  # NOTE psum!
FSDP + TP, with shard_map at the top level#

We can compose these strategies together, using multiple axes of parallelism.

devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, ('batch', 'feats'))

batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats')))
params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats'))))

# mostly same as previous predict_fsdp definition, except we call gemm_tp
@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')
def predict_fsdp_tp(params_frag, inputs):
  for W_frag, b_frag in params_frag:
    W = jax.lax.all_gather(W_frag, 'batch', tiled=True)
    b = jax.lax.all_gather(b_frag, 'batch', tiled=True)
    block_result = jnp.dot(inputs, W)
    outputs = jax.lax.psum_scatter(block_result, 'feats',
                                   scatter_dimension=1, tiled=True) + b
    inputs = jax.nn.relu(outputs)
  return outputs

@partial(shard_map, mesh=mesh,
         in_specs=(P(('feats', 'batch')), P('batch', 'feats')),
         out_specs=P())
def loss_fsdp_tp(local_params, local_batch):
  inputs, targets = local_batch
  predictions = predict_fsdp_tp(local_params, inputs)
  sq_err = jax.lax.psum(jnp.sum((predictions - targets)**2, axis=-1), 'feats')
  return jax.lax.pmean(jnp.mean(sq_err), 'batch')

Notice how we have to do two collective reductions: one over 'feats' and one over 'batch'. In the pure TP example, we didn’t write the 'feats' reduction explicitly because we only used shard_map within gemm_tp; in the caller loss_tp, the compiler automatically translated our use of jnp.sum to perform a psum as needed given the sharded result returned by predict_tp.

print(jax.jit(loss)(params, batch))
print(jax.jit(loss_fsdp_tp)(params_, batch_))

print(allclose(jax.jit(jax.grad(loss))(params, batch),
               jax.jit(jax.grad(loss_fsdp_tp))(params, batch)))
22.779886
22.779886
True
SPMD pipeline parallelism (PP)#

With pipeline parallelism we aim to parallelize the evaluation of layers at different depths in our network. For example, one device might compute the application of the first layer while another device computes the application of the second; when they finish, the first device passes its results to the second while the second passes its results to the device responsible for the third layer, and the process repeats. In general the number of pipeline stages may be different from the number of layers, as each stage may be responsible for multiple layers.

With SPMD pipelining, we exploit the fact that most layers in the network apply the computation, just with different parameter values. In particular, we can stack together all the parameters except for those for the first and last layers, then use a shard_map to map over blocks of those layer parameters, where each block of parameters corresponds to a pipeline stage. We then use the jax.lax.ppermute collective to shift data down the parallel pipeline.

This particular pipelining strategy is essentially the GPipe strategy. There are several variants, as well as quite different strategies, and which is appropriate can depend on the speed of the networking between stages and batch sizes. But for this tutorial we’ll focus on just one strategy.

First, we choose some pipeline parameters:

L = len(params) - 2        # num layers, excluding first and last
N = batch_size             # batch size
F = params[0][0].shape[1]  # num features

# choose some pipeline parameters
S = 2      # number of stages
B = 8      # size of each microbatch
assert L % S == 0, "S (number of stages) must divide L (number of inner layers)"

# compute some useful quantities
M, ragged = divmod(N, B)  # M is number of microbatches
assert not ragged, "B (size of each microbatch) must divide total batch size"
K, ragged = divmod(M, S)  # K is microbatches per stage
assert not ragged, "S (number of stages) must divide number of microbatches"
print(f'{S} stages, {L // S} layer(s) per stage, {L} pipelined layers total')
print(f'{B} examples per microbatch, {M} microbatches total')
2 stages, 2 layer(s) per stage, 4 pipelined layers total
8 examples per microbatch, 4 microbatches total
mesh = Mesh(jax.devices()[:S], ('stages',))

def predict_pp(params, inputs):
  (W_first, b_first), inner_params, (W_last, b_last) = params
  inputs = jax.nn.relu(jnp.dot(inputs, W_first) + b_first)
  inputs = spmd_pipeline(lambda Wb, x: jax.nn.relu(x @ Wb[0] + Wb[1]),
                        inner_params, inputs)
  outputs = jnp.dot(inputs, W_last) + b_last
  return outputs

@partial(shard_map, mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')),
         out_specs=P())
def loss_pp(params, batch):
  inputs, targets = batch
  predictions = predict_pp(params, inputs.reshape(K, B, -1)).reshape(K * B, -1)
  local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
  return jax.lax.pmean(local_loss, 'stages')
def spmd_pipeline(fn, stage_params, inputs):
  stage = jax.lax.axis_index('stages')
  outputs = jnp.zeros_like(inputs) * jnp.nan
  state = jnp.zeros((L // S, B, F)) * jnp.nan
  for i in range(M+L-1):
    state = state.at[0].set(jnp.where(stage == 0, inputs[i % K], state[0]))
    state = jax.vmap(fn)(stage_params, state)
    outputs = outputs.at[(i-L+1) % K].set(jnp.where(stage == S-1, state[-1], outputs[(i-L+1) % K]))
    state, inputs, outputs = shift(i, state, inputs, outputs)
  outputs = jax.lax.ppermute(outputs, 'stages', [(i, (i+1) % S) for i in range(S)])
  return outputs

def shift(i, state, inputs, outputs):
  sh = lambda x, d: jax.lax.ppermute(x, 'stages', [(i, (i+d) % S) for i in range(S)])
  state = jnp.roll(state, +1, axis=0).at[0].set(sh(state[-1], +1))
  if (i % K) == (-1 % K):
    inputs = sh(inputs, +1)
  if ((i-L+1) % K) == (-1 % K):
    outputs = sh(outputs, +1)
  return state, inputs, outputs
first_params, *inner_params, last_params = params
Ws, bs = zip(*inner_params)
params_stacked = jnp.stack(Ws), jnp.stack(bs)
first_params = jax.device_put(first_params, NamedSharding(mesh, P()))
params_stacked = jax.device_put(params_stacked, NamedSharding(mesh, P('stages')))
last_params = jax.device_put(last_params, NamedSharding(mesh, P()))
params_ = first_params, params_stacked, last_params

batch_ = jax.device_put(batch, NamedSharding(mesh, P('stages')))
print(jax.jit(loss)(params, batch))
print(jax.jit(loss_pp)(params_, batch_))
22.779886
22.779884
_ = jax.jit(jax.grad(loss_pp))(params_, batch_)   # don't crash

Named axes and easy-to-revise parallelism with xmap#

UPDATE: xmap is deprecated and will be removed in a future release. The recommended ways to do multi-device programming in JAX are using: 1) jit (automatic partitioning of computation and jax.Array sharding); and/or 2) shard_map (manual data sharding). Learn more in Why don’t pmap or xmap already solve this? in the shard_map JEP document.

This tutorial introduces JAX xmap (jax.experimental.maps.xmap) and the named-axis programming model that comes with it. By reading this, you’ll learn how to write error-avoiding, self-documenting functions using named axes, then control how they’re executed on hardware at any scale, from your laptop CPU to the largest TPU supercomputer.

We start with a toy neural network example.

From positions to names in a toy neural network#

Presentations on JAX often start with a simple neural network prediction function and loss, written in pure NumPy. Here’s a simple network with one hidden layer:

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices
import jax.numpy as jnp
from jax import lax
from jax.nn import one_hot, relu
from jax.scipy.special import logsumexp

def predict(w1, w2, images):
  hiddens = relu(jnp.dot(images, w1))
  logits = jnp.dot(hiddens, w2)
  return logits - logsumexp(logits, axis=1, keepdims=True)

def loss(w1, w2, images, labels):
  predictions = predict(w1, w2, images)
  targets = one_hot(labels, predictions.shape[-1])
  losses = jnp.sum(targets * predictions, axis=1)
  return -jnp.mean(losses, axis=0)

We can then initialize inputs with the right shapes and compute the loss value:

w1 = jnp.zeros((784, 512))
w2 = jnp.zeros((512, 10))
images = jnp.zeros((128, 784))
labels = jnp.zeros(128, dtype=jnp.int32)

print(loss(w1, w2, images, labels))

Here’s how we might write the same function using named axes. Don’t worry if you can’t follow the API details. They are not important now and we will explain everything step-by-step afterwards. This is just to show you what you can do with xmap before you learn them!

def named_predict(w1, w2, image):
  hidden = relu(lax.pdot(image, w1, 'inputs'))
  logits = lax.pdot(hidden, w2, 'hidden')
  return logits - logsumexp(logits, 'classes')

def named_loss(w1, w2, images, labels):
  predictions = named_predict(w1, w2, images)
  num_classes = lax.psum(1, 'classes')
  targets = one_hot(labels, num_classes, axis='classes')
  losses = lax.psum(targets * predictions, 'classes')
  return -lax.pmean(losses, 'batch')

This code is simpler: we don’t need to worry about axis order when calling functions like jnp.dot, or remember which axis position to reduce over with logsumexp, jnp.sum, or jnp.mean.

But the real win is that names let us use xmap to control our function’s execution. At its simplest, xmap will just vectorize over all named axes, so that the function is executed just like its positional-axis counterpart:

from jax.experimental.maps import xmap

in_axes = [['inputs', 'hidden', ...],
           ['hidden', 'classes', ...],
           ['batch', 'inputs', ...],
           ['batch', ...]]

loss = xmap(named_loss, in_axes=in_axes, out_axes=[...])
print(loss(w1, w2, images, labels))

But on a whim we can decide to parallelize over the batch axis:

import jax
import numpy as np
from jax.sharding import Mesh

loss = xmap(named_loss, in_axes=in_axes, out_axes=[...],
            axis_resources={'batch': 'x'})

devices = np.array(jax.local_devices())
with Mesh(devices, ('x',)):
  print(loss(w1, w2, images, labels))

Or we might want to perform model parallelism over the hidden axis:

loss = xmap(named_loss, in_axes=in_axes, out_axes=[...],
            axis_resources={'hidden': 'x'})

devices = np.array(jax.local_devices())
with Mesh(devices, ('x',)):
  print(loss(w1, w2, images, labels))

Or we might want to do both model and batch data parallelism at once:

loss = xmap(named_loss, in_axes=in_axes, out_axes=[...],
            axis_resources={'batch': 'x', 'hidden': 'y'})

devices = np.array(jax.local_devices()).reshape((4, 2))
with Mesh(devices, ('x', 'y')):
  print(loss(w1, w2, images, labels))

With xmap, we can revise our parallelism strategy on a dime, without needing to rewrite our neural network function.

Preliminaries#

import jax.numpy as jnp
from jax import lax
from functools import partial
import jax
import numpy as np

To better illustrate the new programming model, we make extensive use of custom type annotations in this notebook. The annotations have no effect on how the code evaluates and will be unchecked for now.

from typing import Any, Callable

class ArrayType:
  def __getitem__(self, idx):
    return Any
f32 = ArrayType()
i32 = ArrayType()

Tensors with named axes#

The NumPy programming model is based around nd-arrays. Each nd-array can be associated with a two-component type:

  • the element type (accessible via the .dtype attribute)

  • shape (a tuple of integers given by .shape).

Using our little type annotation language, we will write these types as dtype[shape_tuple].

For example, a 5x7x4 array of 32-bit floating point numbers will be denoted as f32[(5, 7, 4)].

Here is a small example that shows how the annotations can demonstrate the way shapes propagate through a simple NumPy program:

x: f32[(2, 3)] = np.ones((2, 3), dtype=np.float32)
y: f32[(3, 5)] = np.ones((3, 5), dtype=np.float32)
z: f32[(2, 5)] = x.dot(y)  # matrix multiplication
w: f32[(7, 1, 5)] = np.ones((7, 1, 5), dtype=np.float32)
q: f32[(7, 2, 5)] = z + w  # broadcasting

The extension we propose is to add another component of array type: a named_shape, mapping axis names (arbitrary hashable objects, with strings being a common choice) to integer sizes. Most importantly, because each axis has a name, their order has no meaning. That is, a named shape of {'a': 2, 'b': 5} is indistinguishable from a named shape of {'b': 5, 'a': 2}.

This is not an entirely new idea. Some good examples of where using named axes has been proposed in the past are: Mesh TensorFlow, Tensor Considered Harmful manifesto as well as the xarray and einops packages. Keep in mind that many of those are slightly different in that they do assign an order to the named axes, but they are unordered in JAX.

From now on we will allow the type annotations to have two components, the first one still being the value’s .shape, while the second one will be the .named_shape.

e: f32[(5, 7), {'batch': 20, 'sequence': 30}]
# e.shape == (5, 7)
# e.named_shape == {'batch': 20, 'sequence': 30} == {'sequence': 30, 'batch': 20}

While we don’t modify the meaning of .ndim (which is always equal to len(shape)) and .size (equal to the product of shape), we do so solely for backward-compatibility reasons. The true rank of an array that has non-empty named axes is len(shape) + len(named_shape). The true number of elements stored in such an array is equal to the product of sizes of all dimensions, both positional and named.

Introducing and eliminating named axes#

But how does one create such arrays, if all top-level JAX operations work in the NumPy model with purely positional axes? While this constraint could be lifted at some point, for the time being the only way to introduce named axes is to use xmap.

xmap can be thought of as an adapter that takes in arrays with positional axes, makes some of them named (as specified by in_axes), and calls the function that it wraps. Once the wrapped function returns arrays, all named axes appearing in those are converted back to positional axes (as specified by out_axes).

in_axes should have a structure that matches the signature of the xmapped function arguments, except with all places where array arguments would be replaced by an axis mapping. There are two ways in which axis mappings can be specified:

  • as dictionaries mapping positional axes to axis names (e.g. {0: 'x', 2: 'y'}); and

  • as lists of axis names terminated by the ellipsis object (e.g. ['a', 'b', ...]), indicating that a prefix of positional dimensions are to be mapped to given names.

out_axes are similar, except that their structure has to match the return signature of the xmapped function (but again, with all arrays replaced by axes mappings).

For each array argument, all positional axes mentioned in its respective in_axes axis mapping are converted to named axes. For each array result, all named axes are inserted in the positions indicated by its respective out_axes.

from jax.experimental.maps import xmap

def my_func(x: f32[(5,), {'batch': 20}]) -> f32[(5,), {'batch': 20}]:
  assert x.shape == (5,)
  # assert x.named_shape == {'batch': 20}  # TODO: Implement named_shape
  return x

x: f32[(20, 5)] = jnp.zeros((20, 5), dtype=np.float32)
f = xmap(my_func,
         in_axes={0: 'batch'},   # Name the first axis of the only argument 'batch'
         out_axes={1: 'batch'})  # Place the 'batch' named axis of the output as the second positional axis
y: f32[(5, 20)] = f(x)
assert (y == x.T).all()  # The first dimension was removed from x and then re-inserted as the last dim

While this might seem like a handful at first, if you’ve seen code that uses jnp.einsum you are already familiar with this approach. The einsum function interprets an expression such as nk,km->nm assigning names (each letter is considered a separate name) to positional axes, performing necessary broadcasts and reductions, and finally putting back the results in positional axes, according to the order given by the right-hand side of the -> separator. While einsum never lets you interact with named axes directly, they do appear naturally in its implementation. xmap is a generalized einsum because named axes are now first-class and you get to implement the function that can manipulate them.

Continuing this analogy, xmap(my_func, ...) from the above example is equivalent to jnp.einsum('bx->xb'). But of course not every xmapped function will have an equivalent einsum.

One more similarity with einsum is that whenever a name is reused for multiple axes, they do have to have the same size:

x = jnp.arange(5)
y = jnp.arange(7)
try:
  jnp.einsum('i,i->i', x, y)
except Exception as e:
  print('einsum:', e)
try:
  xmap(lambda x, y: x * y,
       in_axes=(['i', ...], ['i', ...]),
       out_axes=['i', ...])(x, y)
except Exception as e:
  print('xmap:', e)

Named axis propagation#

We now know how named axes are introduced and eliminated, but what are they good for? How do they propagate throughout the program? Let’s explore a few examples.

Interactions with positional axes#

First rule: named axes never implicitly interact with positional axes. Any function that’s written without named axes in mind can always be invoked with inputs that have named dimensions. The result is the same as if vmap was applied on a per-named-axis basis.

from jax.scipy.linalg import expm_frechet

# Any other function that does not assume existence of any named axes would do too,
# at least as long as it matches this type signature:
expm_frechet: Callable[[f32[(3, 3)], f32[(3, 3)]], f32[(3, 3)]]
f = partial(expm_frechet, compute_expm=False)

# Each A with each E
batch_A = jnp.ones((5, 3, 3), dtype=np.float32)
batch_E = jnp.ones((5, 3, 3), dtype=np.float32)
batch_AE = xmap(f,
                in_axes=(['b', ...], ['b', ...]),      # Map first axes of both inputs to 'b'
                out_axes=['b', ...])(batch_A, batch_E) # Place 'b' as the first positional axis in the result
for i in range(5):
  np.testing.assert_allclose(batch_AE[i], f(batch_A[i], batch_E[i]))

# All-pairs of As and Es
batch_A = jnp.ones((7, 3, 3), dtype=np.float32)
batch_E = jnp.ones((5, 3, 3), dtype=np.float32)
batch_AE = xmap(f,
                in_axes=(['ba', ...], ['be', ...]),           # Map first axes of inputs to 'ba' and 'be' respectively
                out_axes=['ba', 'be', ...])(batch_A, batch_E) # Prefix all positional dimensions of output with 'ba' and 'be'
for i in range(7):
  for j in range(5):
    np.testing.assert_allclose(batch_AE[i,j], f(batch_A[i], batch_E[j]))
Broadcasting#

Secondly, named axes are broadcast by name, and every existing NumPy (and almost every JAX) operator implicitly broadcasts the named dimensions. Whenever a standard NumPy function is called with arrays with named axes, the NumPy function determines the positional shape of the result array, while the named shape becomes a union of all named shapes of its inputs. Analyze the following example to understand how the axes propagate:

def named_broadcasting(
    x: f32[(2, 1, 1), {'a': 2}],
    y: f32[(1, 3, 1), {'b': 3}],
    z: f32[(1, 1, 5), {'c': 5}]) \
      -> f32[(2, 3, 5), {'a': 2, 'b': 3, 'c': 5}]:
  i: f32[(2, 3, 1), {'a': 2, 'b': 3}] = x + y
  j: f32[(1, 3, 5), {'b': 3, 'c': 5}] = y + z
  k: f32[(2, 3, 5), {'a': 2, 'b': 3, 'c': 5}] = i + j
  return k

x = jnp.ones((2, 2, 1, 1), dtype=np.float32)
y = jnp.ones((3, 1, 3, 1), dtype=np.float32)
z = jnp.ones((5, 1, 1, 5), dtype=np.float32)
k = xmap(named_broadcasting,
         in_axes=(['a', ...], ['b', ...], ['c', ...]),
         out_axes=['a', 'b', 'c', ...])(x, y, z)
assert k.shape == (2, 3, 5, 2, 3, 5)

To recap, the named shape of the result of an expression such as i + j with i having a named shape of {'a': 2, 'b': 3} and j of {'b': 3, 'c': 5} is {'a': 2, 'b': 3, 'c': 5}. The 'b' axis is present in both inputs, so no broadcasting is necessary, while 'a' and 'c' occur in only one of the two inputs, causing the other one to get broadcast along the axis missing in its named shape.

No shape errors can occur when operating over named axes, because xmap enforces that a single name is associated with a single size inside its body.

While the rule for broadcasting named axes might seem like an arbitrary extension of the NumPy model, it is actually consistent with it.

Broadcasting first looks for pairs of dimensions it considers as equivalent in both operands. For all matched pairs, it asserts that both sizes are equal or one of them is 1. All unpaired dimensions are carried over to the result.

Now, in the positional world the way NumPy broadcasting chooses to form the pairs is by right-aligning the shapes. But our axes are named, so there is a straightforward way of finding equivalent axes: just check their names for equality!

Reductions#

But named axes are not only good for batching! In fact, our goal is that named axes should be equivalent to positional axes. In particular, every NumPy function that takes in positional axes as arguments should also accept named axes.

The paragraph above is aspirational and the set of NumPy functions that do accept named axes is relatively limited. At the moment named axes are only supported in:

  • jnp.sum, jnp.max, jnp.min

Reductions are a good example:

def named_broadcast_and_reduce(
    x: f32[(), {'x': 2}],
    y: f32[(5,), {'y': 4}]) \
      -> f32[()]:
  z: f32[(5,), {'x': 2, 'y': 4}] = x + y
  w: f32[()] = jnp.sum(z, axis=(0, 'x', 'y'))
  # We could also reduce in steps:
  # w0 : f32[(), {'x': 2, 'y': 4}] = jnp.sum(z, 0)      # eliminate the positional axis
  # w0x: f32[(), {'y': 4}]         = jnp.sum(w0, 'x')   # eliminate the `x` axis
  # w  : f32[()]                   = jnp.sum(w0x, 'y')  # eliminate the `y` axis
  return w

positional_broadcast_and_reduce: Callable[[f32[(2,)], f32[(5, 4)]], f32[()]]
positional_broadcast_and_reduce = \
  xmap(named_broadcast_and_reduce,
       in_axes=({0: 'x'}, {1: 'y'}),
       out_axes={})
positional_broadcast_and_reduce(jnp.arange(2, dtype=np.float32),
                                jnp.arange(20, dtype=np.float32).reshape((5, 4)))
einsum#

Similarly to how we have extended reductions with support for named axes, we’ve also made it possible to contract over named axes using jnp.einsum.

Operands and results still use a convention of one letter per positional axis, but now it is also possible to mention named axes in curly braces. For example, n{b,k} implies that a value will have a single positional dimension n and named dimensions b and k (their order doesn’t matter). Following the usual einsum semantics, any named axes that appear in inputs, but do not appear in an output will be contracted (summed after all multiplications are performed).

It is acceptable to omit a named dimension from all arguments and the result in which case it will be treated according to the usual broadcasting semantics. However, it is not acceptable to mention a named axis in one argument that has it in its named shape and skip it in another argument that also has it in its named shape. Of course, skipping it in the arguments that don’t have it is required.

NOTE: This invariant is unchecked at the moment (it is still work-in-progress). Such axis skipping will result in undefined behavior.

At the moment jnp.einsum with named axes only supports two inputs and a single result.

def named_batch_matrix_single_matrix(
    x: f32[(5,), {'b': 20, 'k': 7}],
    y: f32[(), {'k': 7, 'm': 11}]) \
      -> f32[(5,), {'b': 20, 'm': 11}]:
  return jnp.einsum('n{b,k},{k,m}->n{b,m}', x, y)

x = jnp.ones((20, 5, 7))
y = jnp.ones((7, 11))
z = jnp.einsum('bnk,km->bnm', x, y)
zx = xmap(named_batch_matrix_single_matrix,
          in_axes=[{0: 'b', 2: 'k'}, ['k', 'm', ...]],
          out_axes={0: 'b', 2: 'm'})(x, y)
np.testing.assert_allclose(z, zx)

The example above is admittedly no clearer than using jnp.einsum directly. But contractions over named axes are a crucial component of larger applications such as Transformer models and this is only meant to be an exercise to show you how the names propagate.

Collectives#

Finally, all collectives that could have been used with pmapped functions also work with named axes. As we’ll show later, xmap can be used as a drop-in replacement for pmap that makes programming for multi-dimensional hardware meshes much easier.

x = jnp.arange(8)
xmap(lambda x: lax.pshuffle(x, 'i', list(reversed(range(8)))),
     in_axes=['i', ...], out_axes=['i', ...])(x)

Parallelism support#

While the new programming paradigm can be nice at times, the killer feature of xmap is its ability to parallelize code over supercomputer-scale hardware meshes!

Named axes are the secret sauce that makes all this possible, thanks to the carefully tuned rules that describe their propagation. Good support for partitioning in a purely positional programming model is notoriously difficult. Positional axes are usually disposable and it is hard to keep track of the way axis partitioning propagates through the program. As you’ll see below, named axes enable us to define a straightforward correspondence between their names and hardware resources, making it easy to reason about the way different values end up partitioned.

In all the previous examples, we haven’t said a word about parallelism and for a good reason. By default xmap doesn’t perform any parallelization and vectorizes the computation in the same way vmap does (i.e. it still executes on a single device). To partition the computation over multiple accelerators we have to introduce one more concept: resource axes.

The basic idea is that logical axes (the ones that appear in named shapes) assume that we have abundant hardware and memory, but before the program is to be executed, they have to be placed somewhere. The default (vmap-like) evaluation style pays a high memory cost on the default JAX device. By mapping logical axes to (one or more) resource axes through the axis_resources argument, we can control how xmap evaluates the computation.

x = jnp.ones((2048, 2048))

local_matmul = xmap(jnp.vdot,
                    in_axes=({0: 'left'}, {1: 'right'}),
                    out_axes=['left', 'right', ...])
distr_matmul = xmap(jnp.vdot,
                    in_axes=({0: 'left'}, {1: 'right'}),
                    out_axes=['left', 'right', ...],
                    axis_resources={'left': 'x', 'right': 'y'})

Both local_matmul and distr_matmul implement matrix multiplication, but distr_matmul will additionally partition the left and right logical axes over the x and y resource axes.

But… where do those resource names come from?#

Well, it depends, but one good choice is… a hardware mesh!

For our purposes a mesh is an nd-array of devices with named axes. But, because NumPy doesn’t support named axes (that’s our extension!), the meshes are represented by a pair of an nd-array of JAX device objects (as obtained from jax.devices() or jax.local_devices()) and a tuple of resource axis names of length matching the rank of the array.

How real hardware is represented as an abstract mesh
axis_names = ('x', 'y')
mesh_devices = np.array(jax.devices()).reshape((2, 4))
assert len(axis_names) == mesh_devices.ndim
mesh_def = (mesh_devices, axis_names)
mesh_def

The mesh axis names are exactly the names of resources that named axes can be mapped to. But just creating a mesh definition won’t make the resource names visible to distr_matmul:

try:
  distr_matmul(x, x)
except Exception as e:
  print(e)

To introduce the resources in a scope, use the with Mesh context manager:

from jax.sharding import Mesh

local = local_matmul(x, x)  # The local function doesn't require the mesh definition
with Mesh(*mesh_def):  # Makes the mesh axis names available as resources
  distr = distr_matmul(x, x)
np.testing.assert_allclose(local, distr)

Anyway, the best part of it is that specifying axis_resources never changes program semantics. You are free to experiment with different ways of partitioning your computation (just change the assignment of resources to named axes!) and even how the physical devices are organized in the mesh (by changing the construction of the NumPy array of devices). None of those things should have any significant influence on the results you get back (up to, for example, floating point inaccuracy), though of course some of them will achieve significantly better performance than the others.

xmap doesn’t provide any automatic scheduling options at the moment, because the best schedule often has to be somewhat carefully matched to your program. We’re considering adding support for that in the future, but it will take time.

Once you map a logical axis to a mesh dimension, the size of that logical axis has to be divisible by the mesh dimension size.

Is my data replicated? Or partitioned? Where is it?#

Named axes also give us a neat way of reasoning about partitioning and replication. A value is partitioned over a mesh axis if and only if it has a named axis that has been mapped to that mesh axis in its shape. Otherwise, it will be replicated over all slices along that axis.

For example, assume that we’re in an xmap that had axis_resources={'a': 'x', 'b': 'y'} specified (i.e. we are running the computation over a 2D mesh with x and y axes with sizes 2 and 3 respectively). Then:

  • An array of type f32[(5, 5), {}] is completely replicated over the whole mesh. All devices store a local copy of the value.

  • An array of type f32[(6,), {'a': 8}] is partitioned over mesh axis x, because it has 'a' in its named shape, and 'a' is mapped to x. It is replicated over mesh axis y. To put it differently, all devices in a slice of the mesh with the same x coordinate will store a local copy of a chunk of this array. But, mesh slices with different x coordinates will store different chunks of the data.

  • An array of type f32[(), {'a': 8, 'c': 7}] is partitioned just like in the previous case: split over the x mesh axis and replicated over the y axis. Named dimensions with no resources specified are no different than positional dimensions when considering partitioning, so 'c' has no influence on it.

  • An array of type f32[(), {'a': 8, 'b': 12}] is completely partitioned over the whole mesh. Every device holds a distinct chunk of the data.

An illustration for the above examples

This also highlights one restriction: xmap won’t complain if you specify axis_resources={'a': 'x', 'b': 'x'}, but consider how would an array with type f32[(2, 8), {'a': 4, 'b': 12}] be partitioned. If the size of the x mesh axis is 2, then we only have 2 devices, but we have 4 chunks to place (2 along 'a' and 2 along 'b')! Now we can state it in full: named axes mapped to the same resources can never both appear in the named shape of a single array. But they can appear in named shapes of two distinct arrays, such as in this program:

def sum_two_args(x: f32[(), {'a': 4}], y: f32[(), {'b': 12}]) -> f32[()]:
  return jnp.sum(x, axis='a') + jnp.sum(y, axis='b')

q = jnp.ones((4,), dtype=np.float32)
u = jnp.ones((12,), dtype=np.float32)
with Mesh(np.array(jax.devices()[:4]), ('x',)):
  v = xmap(sum_two_args,
           in_axes=(['a', ...], ['b', ...]),
           out_axes=[...],
           axis_resources={'a': 'x', 'b': 'x'})(q, u)
  print(v)

This program is valid, because jnp.sum eliminates the axes that cannot co-occur before the values are added.

While the final release of xmap will ensure that you don’t accidentally end up doing so, the current implementation doesn’t verify it. Violating this restriction will result in undefined behavior.

Why axis_resources and not a more direct mapping to hardware?#

At this point you might wonder why go through the detour of introducing yet another concept of resource axes in the mix. For as long as you’re interested in partitioning your computations over hardware, there is no good reason, but this mental framework is more flexible than that!

For example, there is one additional resource we all deal with: time! Just like a computation can be partitioned over multiple hardware devices, e.g. to lower its memory usage, the same thing can be achieved with a single accelerator that evaluates a chunk of the computation in multiple steps.

So, while hardware meshes are the only source of resource axes in JAX programs at the moment, we are planning to extend the whole system with other sources.

Porting positional code to named code#

In this section we will go over a few more real examples to show how xmap can help you implement and distribute various models.

This section is a work in progress

The Autodiff Cookbook#

Open in Colab Open in Kaggle

alexbw@, mattjj@

JAX has a pretty general automatic differentiation system. In this notebook, we’ll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics.

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

key = random.key(0)

Gradients#

Starting with grad#

You can differentiate a function with grad:

grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))
0.070650816

grad takes a function and returns a function. If you have a Python function f that evaluates the mathematical function \(f\), then grad(f) is a Python function that evaluates the mathematical function \(\nabla f\). That means grad(f)(x) represents the value \(\nabla f(x)\).

Since grad operates on functions, you can apply it to its own output to differentiate as many times as you like:

print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0))
-0.13621868
0.25265405

Let’s look at computing gradients with grad in a linear logistic regression model. First, the setup:

def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

# Outputs probability of a label being true.
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)

# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])

# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())

Use the grad function with its argnums argument to differentiate a function with respect to positional arguments.

# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print('W_grad', W_grad)

# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print('W_grad', W_grad)

# But we can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print('b_grad', b_grad)

# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print('W_grad', W_grad)
print('b_grad', b_grad)
W_grad [-0.16965583 -0.8774644  -1.4901346 ]
W_grad [-0.16965583 -0.8774644  -1.4901346 ]
b_grad -0.29227245
W_grad [-0.16965583 -0.8774644  -1.4901346 ]
b_grad -0.29227245

This grad API has a direct correspondence to the excellent notation in Spivak’s classic Calculus on Manifolds (1965), also used in Sussman and Wisdom’s Structure and Interpretation of Classical Mechanics (2015) and their Functional Differential Geometry (2013). Both books are open-access. See in particular the “Prologue” section of Functional Differential Geometry for a defense of this notation.

Essentially, when using the argnums argument, if f is a Python function for evaluating the mathematical function \(f\), then the Python expression grad(f, i) evaluates to a Python function for evaluating \(\partial_i f\).

Differentiating with respect to nested lists, tuples, and dicts#

Differentiating with respect to standard Python containers just works, so use tuples, lists, and dicts (and arbitrary nesting) however you like.

def loss2(params_dict):
    preds = predict(params_dict['W'], params_dict['b'], inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

print(grad(loss2)({'W': W, 'b': b}))
{'W': Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32), 'b': Array(-0.29227245, dtype=float32)}

You can register your own container types to work with not just grad but all the JAX transformations (jit, vmap, etc.).

Evaluate a function and its gradient using value_and_grad#

Another convenient function is value_and_grad for efficiently computing both a function’s value as well as its gradient’s value:

from jax import value_and_grad
loss_value, Wb_grad = value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b))
loss value 3.0519385
loss value 3.0519385
Checking against numerical differences#

A great thing about derivatives is that they’re straightforward to check with finite differences:

# Set a step size for finite differences calculations
eps = 1e-4

# Check b_grad with scalar finite differences
b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', grad(loss, 1)(W, b))

# Check W_grad with finite differences in a random direction
key, subkey = random.split(key)
vec = random.normal(subkey, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(W, b), unitvec))
b_grad_numerical -0.29325485
b_grad_autodiff -0.29227245
W_dirderiv_numerical -0.2002716
W_dirderiv_autodiff -0.19909117

JAX provides a simple convenience function that does essentially the same thing, but checks up to any order of differentiation that you like:

from jax.test_util import check_grads
check_grads(loss, (W, b), order=2)  # check up to 2nd order derivatives
Hessian-vector products with grad-of-grad#

One thing we can do with higher-order grad is build a Hessian-vector product function. (Later on we’ll write an even more efficient implementation that mixes both forward- and reverse-mode, but this one will use pure reverse-mode.)

A Hessian-vector product function can be useful in a truncated Newton Conjugate-Gradient algorithm for minimizing smooth convex functions, or for studying the curvature of neural network training objectives (e.g. 1, 2, 3, 4).

For a scalar-valued function \(f : \mathbb{R}^n \to \mathbb{R}\) with continuous second derivatives (so that the Hessian matrix is symmetric), the Hessian at a point \(x \in \mathbb{R}^n\) is written as \(\partial^2 f(x)\). A Hessian-vector product function is then able to evaluate

\(\qquad v \mapsto \partial^2 f(x) \cdot v\)

for any \(v \in \mathbb{R}^n\).

The trick is not to instantiate the full Hessian matrix: if \(n\) is large, perhaps in the millions or billions in the context of neural networks, then that might be impossible to store.

Luckily, grad already gives us a way to write an efficient Hessian-vector product function. We just have to use the identity

\(\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)\),

where \(g(x) = \partial f(x) \cdot v\) is a new scalar-valued function that dots the gradient of \(f\) at \(x\) with the vector \(v\). Notice that we’re only ever differentiating scalar-valued functions of vector-valued arguments, which is exactly where we know grad is efficient.

In JAX code, we can just write this:

def hvp(f, x, v):
    return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)

This example shows that you can freely use lexical closure, and JAX will never get perturbed or confused.

We’ll check this implementation a few cells down, once we see how to compute dense Hessian matrices. We’ll also write an even better version that uses both forward-mode and reverse-mode.

Jacobians and Hessians using jacfwd and jacrev#

You can compute full Jacobian matrices using the jacfwd and jacrev functions:

from jax import jacfwd, jacrev

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

J = jacfwd(f)(W)
print("jacfwd result, with shape", J.shape)
print(J)

J = jacrev(f)(W)
print("jacrev result, with shape", J.shape)
print(J)
jacfwd result, with shape (4, 3)
[[ 0.05981758  0.12883787  0.08857603]
 [ 0.04015916 -0.04928625  0.00684531]
 [ 0.12188288  0.01406341 -0.3047072 ]
 [ 0.00140431 -0.00472531  0.00263782]]
jacrev result, with shape (4, 3)
[[ 0.05981757  0.12883787  0.08857603]
 [ 0.04015916 -0.04928625  0.00684531]
 [ 0.12188289  0.01406341 -0.3047072 ]
 [ 0.00140431 -0.00472531  0.00263782]]

These two functions compute the same values (up to machine numerics), but differ in their implementation: jacfwd uses forward-mode automatic differentiation, which is more efficient for “tall” Jacobian matrices, while jacrev uses reverse-mode, which is more efficient for “wide” Jacobian matrices. For matrices that are near-square, jacfwd probably has an edge over jacrev.

You can also use jacfwd and jacrev with container types:

def predict_dict(params, inputs):
    return predict(params['W'], params['b'], inputs)

J_dict = jacrev(predict_dict)({'W': W, 'b': b}, inputs)
for k, v in J_dict.items():
    print("Jacobian from {} to logits is".format(k))
    print(v)
Jacobian from W to logits is
[[ 0.05981757  0.12883787  0.08857603]
 [ 0.04015916 -0.04928625  0.00684531]
 [ 0.12188289  0.01406341 -0.3047072 ]
 [ 0.00140431 -0.00472531  0.00263782]]
Jacobian from b to logits is
[0.11503381 0.04563541 0.23439017 0.00189771]

For more details on forward- and reverse-mode, as well as how to implement jacfwd and jacrev as efficiently as possible, read on!

Using a composition of two of these functions gives us a way to compute dense Hessian matrices:

def hessian(f):
    return jacfwd(jacrev(f))

H = hessian(f)(W)
print("hessian, with shape", H.shape)
print(H)
hessian, with shape (4, 3, 3)
[[[ 0.02285465  0.04922541  0.03384247]
  [ 0.04922541  0.10602397  0.07289147]
  [ 0.03384247  0.07289147  0.05011288]]

 [[-0.03195215  0.03921401 -0.00544639]
  [ 0.03921401 -0.04812629  0.00668421]
  [-0.00544639  0.00668421 -0.00092836]]

 [[-0.01583708 -0.00182736  0.03959271]
  [-0.00182736 -0.00021085  0.00456839]
  [ 0.03959271  0.00456839 -0.09898177]]

 [[-0.00103524  0.00348343 -0.00194457]
  [ 0.00348343 -0.01172127  0.0065432 ]
  [-0.00194457  0.0065432  -0.00365263]]]

This shape makes sense: if we start with a function \(f : \mathbb{R}^n \to \mathbb{R}^m\), then at a point \(x \in \mathbb{R}^n\) we expect to get the shapes

  • \(f(x) \in \mathbb{R}^m\), the value of \(f\) at \(x\),

  • \(\partial f(x) \in \mathbb{R}^{m \times n}\), the Jacobian matrix at \(x\),

  • \(\partial^2 f(x) \in \mathbb{R}^{m \times n \times n}\), the Hessian at \(x\),

and so on.

To implement hessian, we could have used jacfwd(jacrev(f)) or jacrev(jacfwd(f)) or any other composition of the two. But forward-over-reverse is typically the most efficient. That’s because in the inner Jacobian computation we’re often differentiating a function wide Jacobian (maybe like a loss function \(f : \mathbb{R}^n \to \mathbb{R}\)), while in the outer Jacobian computation we’re differentiating a function with a square Jacobian (since \(\nabla f : \mathbb{R}^n \to \mathbb{R}^n\)), which is where forward-mode wins out.

How it’s made: two foundational autodiff functions#

Jacobian-Vector products (JVPs, aka forward-mode autodiff)#

JAX includes efficient and general implementations of both forward- and reverse-mode automatic differentiation. The familiar grad function is built on reverse-mode, but to explain the difference in the two modes, and when each can be useful, we need a bit of math background.

JVPs in math#

Mathematically, given a function \(f : \mathbb{R}^n \to \mathbb{R}^m\), the Jacobian of \(f\) evaluated at an input point \(x \in \mathbb{R}^n\), denoted \(\partial f(x)\), is often thought of as a matrix in \(\mathbb{R}^m \times \mathbb{R}^n\):

\(\qquad \partial f(x) \in \mathbb{R}^{m \times n}\).

But we can also think of \(\partial f(x)\) as a linear map, which maps the tangent space of the domain of \(f\) at the point \(x\) (which is just another copy of \(\mathbb{R}^n\)) to the tangent space of the codomain of \(f\) at the point \(f(x)\) (a copy of \(\mathbb{R}^m\)):

\(\qquad \partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\).

This map is called the pushforward map of \(f\) at \(x\). The Jacobian matrix is just the matrix for this linear map in a standard basis.

If we don’t commit to one specific input point \(x\), then we can think of the function \(\partial f\) as first taking an input point and returning the Jacobian linear map at that input point:

\(\qquad \partial f : \mathbb{R}^n \to \mathbb{R}^n \to \mathbb{R}^m\).

In particular, we can uncurry things so that given input point \(x \in \mathbb{R}^n\) and a tangent vector \(v \in \mathbb{R}^n\), we get back an output tangent vector in \(\mathbb{R}^m\). We call that mapping, from \((x, v)\) pairs to output tangent vectors, the Jacobian-vector product, and write it as

\(\qquad (x, v) \mapsto \partial f(x) v\)

JVPs in JAX code#

Back in Python code, JAX’s jvp function models this transformation. Given a Python function that evaluates \(f\), JAX’s jvp is a way to get a Python function for evaluating \((x, v) \mapsto (f(x), \partial f(x) v)\).

from jax import jvp

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

key, subkey = random.split(key)
v = random.normal(subkey, W.shape)

# Push forward the vector `v` along `f` evaluated at `W`
y, u = jvp(f, (W,), (v,))

In terms of Haskell-like type signatures, we could write

jvp :: (a -> b) -> a -> T a -> (b, T b)

where we use T a to denote the type of the tangent space for a. In words, jvp takes as arguments a function of type a -> b, a value of type a, and a tangent vector value of type T a. It gives back a pair consisting of a value of type b and an output tangent vector of type T b.

The jvp-transformed function is evaluated much like the original function, but paired up with each primal value of type a it pushes along tangent values of type T a. For each primitive numerical operation that the original function would have applied, the jvp-transformed function executes a “JVP rule” for that primitive that both evaluates the primitive on the primals and applies the primitive’s JVP at those primal values.

That evaluation strategy has some immediate implications about computational complexity: since we evaluate JVPs as we go, we don’t need to store anything for later, and so the memory cost is independent of the depth of the computation. In addition, the FLOP cost of the jvp-transformed function is about 3x the cost of just evaluating the function (one unit of work for evaluating the original function, for example sin(x); one unit for linearizing, like cos(x); and one unit for applying the linearized function to a vector, like cos_x * v). Put another way, for a fixed primal point \(x\), we can evaluate \(v \mapsto \partial f(x) \cdot v\) for about the same marginal cost as evaluating \(f\).

That memory complexity sounds pretty compelling! So why don’t we see forward-mode very often in machine learning?

To answer that, first think about how you could use a JVP to build a full Jacobian matrix. If we apply a JVP to a one-hot tangent vector, it reveals one column of the Jacobian matrix, corresponding to the nonzero entry we fed in. So we can build a full Jacobian one column at a time, and to get each column costs about the same as one function evaluation. That will be efficient for functions with “tall” Jacobians, but inefficient for “wide” Jacobians.

If you’re doing gradient-based optimization in machine learning, you probably want to minimize a loss function from parameters in \(\mathbb{R}^n\) to a scalar loss value in \(\mathbb{R}\). That means the Jacobian of this function is a very wide matrix: \(\partial f(x) \in \mathbb{R}^{1 \times n}\), which we often identify with the Gradient vector \(\nabla f(x) \in \mathbb{R}^n\). Building that matrix one column at a time, with each call taking a similar number of FLOPs to evaluate the original function, sure seems inefficient! In particular, for training neural networks, where \(f\) is a training loss function and \(n\) can be in the millions or billions, this approach just won’t scale.

To do better for functions like this, we just need to use reverse-mode.

Vector-Jacobian products (VJPs, aka reverse-mode autodiff)#

Where forward-mode gives us back a function for evaluating Jacobian-vector products, which we can then use to build Jacobian matrices one column at a time, reverse-mode is a way to get back a function for evaluating vector-Jacobian products (equivalently Jacobian-transpose-vector products), which we can use to build Jacobian matrices one row at a time.

VJPs in math#

Let’s again consider a function \(f : \mathbb{R}^n \to \mathbb{R}^m\). Starting from our notation for JVPs, the notation for VJPs is pretty simple:

\(\qquad (x, v) \mapsto v \partial f(x)\),

where \(v\) is an element of the cotangent space of \(f\) at \(x\) (isomorphic to another copy of \(\mathbb{R}^m\)). When being rigorous, we should think of \(v\) as a linear map \(v : \mathbb{R}^m \to \mathbb{R}\), and when we write \(v \partial f(x)\) we mean function composition \(v \circ \partial f(x)\), where the types work out because \(\partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\). But in the common case we can identify \(v\) with a vector in \(\mathbb{R}^m\) and use the two almost interchangeably, just like we might sometimes flip between “column vectors” and “row vectors” without much comment.

With that identification, we can alternatively think of the linear part of a VJP as the transpose (or adjoint conjugate) of the linear part of a JVP:

\(\qquad (x, v) \mapsto \partial f(x)^\mathsf{T} v\).

For a given point \(x\), we can write the signature as

\(\qquad \partial f(x)^\mathsf{T} : \mathbb{R}^m \to \mathbb{R}^n\).

The corresponding map on cotangent spaces is often called the pullback of \(f\) at \(x\). The key for our purposes is that it goes from something that looks like the output of \(f\) to something that looks like the input of \(f\), just like we might expect from a transposed linear function.

VJPs in JAX code#

Switching from math back to Python, the JAX function vjp can take a Python function for evaluating \(f\) and give us back a Python function for evaluating the VJP \((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\).

from jax import vjp

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

y, vjp_fun = vjp(f, W)

key, subkey = random.split(key)
u = random.normal(subkey, y.shape)

# Pull back the covector `u` along `f` evaluated at `W`
v = vjp_fun(u)

In terms of Haskell-like type signatures, we could write

vjp :: (a -> b) -> a -> (b, CT b -> CT a)

where we use CT a to denote the type for the cotangent space for a. In words, vjp takes as arguments a function of type a -> b and a point of type a, and gives back a pair consisting of a value of type b and a linear map of type CT b -> CT a.

This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating \((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\) is only about three times the cost of evaluating \(f\). In particular, if we want the gradient of a function \(f : \mathbb{R}^n \to \mathbb{R}\), we can do it in just one call. That’s how grad is efficient for gradient-based optimization, even for objectives like neural network training loss functions on millions or billions of parameters.

There’s a cost, though: though the FLOPs are friendly, memory scales with the depth of the computation. Also, the implementation is traditionally more complex than that of forward-mode, though JAX has some tricks up its sleeve (that’s a story for a future notebook!).

For more on how reverse-mode works, see this tutorial video from the Deep Learning Summer School in 2017.

Vector-valued gradients with VJPs#

If you’re interested in taking vector-valued gradients (like tf.gradients):

from jax import vjp

def vgrad(f, x):
  y, vjp_fn = vjp(f, x)
  return vjp_fn(jnp.ones(y.shape))[0]

print(vgrad(lambda x: 3*x**2, jnp.ones((2, 2))))
[[6. 6.]
 [6. 6.]]
Hessian-vector products using both forward- and reverse-mode#

In a previous section, we implemented a Hessian-vector product function just using reverse-mode (assuming continuous second derivatives):

def hvp(f, x, v):
    return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)

That’s efficient, but we can do even better and save some memory by using forward-mode together with reverse-mode.

Mathematically, given a function \(f : \mathbb{R}^n \to \mathbb{R}\) to differentiate, a point \(x \in \mathbb{R}^n\) at which to linearize the function, and a vector \(v \in \mathbb{R}^n\), the Hessian-vector product function we want is

\((x, v) \mapsto \partial^2 f(x) v\)

Consider the helper function \(g : \mathbb{R}^n \to \mathbb{R}^n\) defined to be the derivative (or gradient) of \(f\), namely \(g(x) = \partial f(x)\). All we need is its JVP, since that will give us

\((x, v) \mapsto \partial g(x) v = \partial^2 f(x) v\).

We can translate that almost directly into code:

from jax import jvp, grad

# forward-over-reverse
def hvp(f, primals, tangents):
  return jvp(grad(f), primals, tangents)[1]

Even better, since we didn’t have to call jnp.dot directly, this hvp function works with arrays of any shape and with arbitrary container types (like vectors stored as nested lists/dicts/tuples), and doesn’t even have a dependence on jax.numpy.

Here’s an example of how to use it:

def f(X):
  return jnp.sum(jnp.tanh(X)**2)

key, subkey1, subkey2 = random.split(key, 3)
X = random.normal(subkey1, (30, 40))
V = random.normal(subkey2, (30, 40))

ans1 = hvp(f, (X,), (V,))
ans2 = jnp.tensordot(hessian(f)(X), V, 2)

print(jnp.allclose(ans1, ans2, 1e-4, 1e-4))
True

Another way you might consider writing this is using reverse-over-forward:

# reverse-over-forward
def hvp_revfwd(f, primals, tangents):
  g = lambda primals: jvp(f, primals, tangents)[1]
  return grad(g)(primals)

That’s not quite as good, though, because forward-mode has less overhead than reverse-mode, and since the outer differentiation operator here has to differentiate a larger computation than the inner one, keeping forward-mode on the outside works best:

# reverse-over-reverse, only works for single arguments
def hvp_revrev(f, primals, tangents):
  x, = primals
  v, = tangents
  return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)


print("Forward over reverse")
%timeit -n10 -r3 hvp(f, (X,), (V,))
print("Reverse over forward")
%timeit -n10 -r3 hvp_revfwd(f, (X,), (V,))
print("Reverse over reverse")
%timeit -n10 -r3 hvp_revrev(f, (X,), (V,))

print("Naive full Hessian materialization")
%timeit -n10 -r3 jnp.tensordot(hessian(f)(X), V, 2)
Forward over reverse
4.78 ms ± 202 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over forward
9.14 ms ± 5.06 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over reverse
13.7 ms ± 8.1 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Naive full Hessian materialization
55.8 ms ± 977 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)

Composing VJPs, JVPs, and vmap#

Jacobian-Matrix and Matrix-Jacobian products#

Now that we have jvp and vjp transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX’s vmap transformation to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products.

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

# Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`.
# First, use a list comprehension to loop over rows in the matrix M.
def loop_mjp(f, x, M):
    y, vjp_fun = vjp(f, x)
    return jnp.vstack([vjp_fun(mi) for mi in M])

# Now, use vmap to build a computation that does a single fast matrix-matrix
# multiply, rather than an outer loop over vector-matrix multiplies.
def vmap_mjp(f, x, M):
    y, vjp_fun = vjp(f, x)
    outs, = vmap(vjp_fun)(M)
    return outs

key = random.key(0)
num_covecs = 128
U = random.normal(key, (num_covecs,) + y.shape)

loop_vs = loop_mjp(f, W, M=U)
print('Non-vmapped Matrix-Jacobian product')
%timeit -n10 -r3 loop_mjp(f, W, M=U)

print('\nVmapped Matrix-Jacobian product')
vmap_vs = vmap_mjp(f, W, M=U)
%timeit -n10 -r3 vmap_mjp(f, W, M=U)

assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical'
Non-vmapped Matrix-Jacobian product
146 ms ± 1.14 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)

Vmapped Matrix-Jacobian product
5.92 ms ± 44 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)
/tmp/ipykernel_4161/3769736790.py:8: DeprecationWarning: vstack requires ndarray or scalar arguments, got <class 'tuple'> at position 0. In a future JAX release this will be an error.
  return jnp.vstack([vjp_fun(mi) for mi in M])
def loop_jmp(f, W, M):
    # jvp immediately returns the primal and tangent values as a tuple,
    # so we'll compute and select the tangents in a list comprehension
    return jnp.vstack([jvp(f, (W,), (mi,))[1] for mi in M])

def vmap_jmp(f, W, M):
    _jvp = lambda s: jvp(f, (W,), (s,))[1]
    return vmap(_jvp)(M)

num_vecs = 128
S = random.normal(key, (num_vecs,) + W.shape)

loop_vs = loop_jmp(f, W, M=S)
print('Non-vmapped Jacobian-Matrix product')
%timeit -n10 -r3 loop_jmp(f, W, M=S)
vmap_vs = vmap_jmp(f, W, M=S)
print('\nVmapped Jacobian-Matrix product')
%timeit -n10 -r3 vmap_jmp(f, W, M=S)

assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical'
Non-vmapped Jacobian-Matrix product
289 ms ± 403 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)

Vmapped Jacobian-Matrix product
3.38 ms ± 123 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)
The implementation of jacfwd and jacrev#

Now that we’ve seen fast Jacobian-matrix and matrix-Jacobian products, it’s not hard to guess how to write jacfwd and jacrev. We just use the same technique to push-forward or pull-back an entire standard basis (isomorphic to an identity matrix) at once.

from jax import jacrev as builtin_jacrev

def our_jacrev(f):
    def jacfun(x):
        y, vjp_fun = vjp(f, x)
        # Use vmap to do a matrix-Jacobian product.
        # Here, the matrix is the Euclidean basis, so we get all
        # entries in the Jacobian at once. 
        J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))
        return J
    return jacfun

assert jnp.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!'
from jax import jacfwd as builtin_jacfwd

def our_jacfwd(f):
    def jacfun(x):
        _jvp = lambda s: jvp(f, (x,), (s,))[1]
        Jt =vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
        return jnp.transpose(Jt)
    return jacfun

assert jnp.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!'

Interestingly, Autograd couldn’t do this. Our implementation of reverse-mode jacobian in Autograd had to pull back one vector at a time with an outer-loop map. Pushing one vector at a time through the computation is much less efficient than batching it all together with vmap.

Another thing that Autograd couldn’t do is jit. Interestingly, no matter how much Python dynamism you use in your function to be differentiated, we could always use jit on the linear part of the computation. For example:

def f(x):
    try:
        if x < 3:
            return 2 * x ** 3
        else:
            raise ValueError
    except ValueError:
        return jnp.pi * x

y, f_vjp = vjp(f, 4.)
print(jit(f_vjp)(1.))
(Array(3.1415927, dtype=float32, weak_type=True),)

Complex numbers and differentiation#

JAX is great at complex numbers and differentiation. To support both holomorphic and non-holomorphic differentiation, it helps to think in terms of JVPs and VJPs.

Consider a complex-to-complex function \(f: \mathbb{C} \to \mathbb{C}\) and identify it with a corresponding function \(g: \mathbb{R}^2 \to \mathbb{R}^2\),

def f(z):
  x, y = jnp.real(z), jnp.imag(z)
  return u(x, y) + v(x, y) * 1j

def g(x, y):
  return (u(x, y), v(x, y))

That is, we’ve decomposed \(f(z) = u(x, y) + v(x, y) i\) where \(z = x + y i\), and identified \(\mathbb{C}\) with \(\mathbb{R}^2\) to get \(g\).

Since \(g\) only involves real inputs and outputs, we already know how to write a Jacobian-vector product for it, say given a tangent vector \((c, d) \in \mathbb{R}^2\), namely

\(\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}\).

To get a JVP for the original function \(f\) applied to a tangent vector \(c + di \in \mathbb{C}\), we just use the same definition and identify the result as another complex number,

\(\partial f(x + y i)(c + d i) = \begin{matrix} \begin{bmatrix} 1 & i \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}\).

That’s our definition of the JVP of a \(\mathbb{C} \to \mathbb{C}\) function! Notice it doesn’t matter whether or not \(f\) is holomorphic: the JVP is unambiguous.

Here’s a check:

def check(seed):
  key = random.key(seed)

  # random coeffs for u and v
  key, subkey = random.split(key)
  a, b, c, d = random.uniform(subkey, (4,))

  def fun(z):
    x, y = jnp.real(z), jnp.imag(z)
    return u(x, y) + v(x, y) * 1j

  def u(x, y):
    return a * x + b * y

  def v(x, y):
    return c * x + d * y

  # primal point
  key, subkey = random.split(key)
  x, y = random.uniform(subkey, (2,))
  z = x + y * 1j

  # tangent vector
  key, subkey = random.split(key)
  c, d = random.uniform(subkey, (2,))
  z_dot = c + d * 1j

  # check jvp
  _, ans = jvp(fun, (z,), (z_dot,))
  expected = (grad(u, 0)(x, y) * c +
              grad(u, 1)(x, y) * d +
              grad(v, 0)(x, y) * c * 1j+
              grad(v, 1)(x, y) * d * 1j)
  print(jnp.allclose(ans, expected))
check(0)
check(1)
check(2)
True
True
True

What about VJPs? We do something pretty similar: for a cotangent vector \(c + di \in \mathbb{C}\) we define the VJP of \(f\) as

\((c + di)^* \; \partial f(x + y i) = \begin{matrix} \begin{bmatrix} c & -d \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} 1 \\ -i \end{bmatrix}\).

What’s with the negatives? They’re just to take care of complex conjugation, and the fact that we’re working with covectors.

Here’s a check of the VJP rules:

def check(seed):
  key = random.key(seed)

  # random coeffs for u and v
  key, subkey = random.split(key)
  a, b, c, d = random.uniform(subkey, (4,))

  def fun(z):
    x, y = jnp.real(z), jnp.imag(z)
    return u(x, y) + v(x, y) * 1j

  def u(x, y):
    return a * x + b * y

  def v(x, y):
    return c * x + d * y

  # primal point
  key, subkey = random.split(key)
  x, y = random.uniform(subkey, (2,))
  z = x + y * 1j

  # cotangent vector
  key, subkey = random.split(key)
  c, d = random.uniform(subkey, (2,))
  z_bar = jnp.array(c + d * 1j)  # for dtype control

  # check vjp
  _, fun_vjp = vjp(fun, z)
  ans, = fun_vjp(z_bar)
  expected = (grad(u, 0)(x, y) * c +
              grad(v, 0)(x, y) * (-d) +
              grad(u, 1)(x, y) * c * (-1j) +
              grad(v, 1)(x, y) * (-d) * (-1j))
  assert jnp.allclose(ans, expected, atol=1e-5, rtol=1e-5)
check(0)
check(1)
check(2)

What about convenience wrappers like grad, jacfwd, and jacrev?

For \(\mathbb{R} \to \mathbb{R}\) functions, recall we defined grad(f)(x) as being vjp(f, x)[1](1.0), which works because applying a VJP to a 1.0 value reveals the gradient (i.e. Jacobian, or derivative). We can do the same thing for \(\mathbb{C} \to \mathbb{R}\) functions: we can still use 1.0 as the cotangent vector, and we just get out a complex number result summarizing the full Jacobian:

def f(z):
  x, y = jnp.real(z), jnp.imag(z)
  return x**2 + y**2

z = 3. + 4j
grad(f)(z)
Array(6.-8.j, dtype=complex64)

For general \(\mathbb{C} \to \mathbb{C}\) functions, the Jacobian has 4 real-valued degrees of freedom (as in the 2x2 Jacobian matrices above), so we can’t hope to represent all of them within a complex number. But we can for holomorphic functions! A holomorphic function is precisely a \(\mathbb{C} \to \mathbb{C}\) function with the special property that its derivative can be represented as a single complex number. (The Cauchy-Riemann equations ensure that the above 2x2 Jacobians have the special form of a scale-and-rotate matrix in the complex plane, i.e. the action of a single complex number under multiplication.) And we can reveal that one complex number using a single call to vjp with a covector of 1.0.

Because this only works for holomorphic functions, to use this trick we need to promise JAX that our function is holomorphic; otherwise, JAX will raise an error when grad is used for a complex-output function:

def f(z):
  return jnp.sin(z)

z = 3. + 4j
grad(f, holomorphic=True)(z)
Array(-27.034945-3.8511531j, dtype=complex64, weak_type=True)

All the holomorphic=True promise does is disable the error when the output is complex-valued. We can still write holomorphic=True when the function isn’t holomorphic, but the answer we get out won’t represent the full Jacobian. Instead, it’ll be the Jacobian of the function where we just discard the imaginary part of the output:

def f(z):
  return jnp.conjugate(z)

z = 3. + 4j
grad(f, holomorphic=True)(z)  # f is not actually holomorphic!
Array(1.-0.j, dtype=complex64, weak_type=True)

There are some useful upshots for how grad works here:

  1. We can use grad on holomorphic \(\mathbb{C} \to \mathbb{C}\) functions.

  2. We can use grad to optimize \(f : \mathbb{C} \to \mathbb{R}\) functions, like real-valued loss functions of complex parameters x, by taking steps in the direction of the conjugate of grad(f)(x).

  3. If we have an \(\mathbb{R} \to \mathbb{R}\) function that just happens to use some complex-valued operations internally (some of which must be non-holomorphic, e.g. FFTs used in convolutions) then grad still works and we get the same result that an implementation using only real values would have given.

In any case, JVPs and VJPs are always unambiguous. And if we wanted to compute the full Jacobian matrix of a non-holomorphic \(\mathbb{C} \to \mathbb{C}\) function, we can do it with JVPs or VJPs!

You should expect complex numbers to work everywhere in JAX. Here’s differentiating through a Cholesky decomposition of a complex matrix:

A = jnp.array([[5.,    2.+3j,    5j],
              [2.-3j,   7.,  1.+7j],
              [-5j,  1.-7j,    12.]])

def f(X):
    L = jnp.linalg.cholesky(X)
    return jnp.sum((L - jnp.sin(L))**2)

grad(f, holomorphic=True)(A)
Array([[-0.7534186  +0.j       , -3.0509028 -10.940544j ,
         5.9896846  +3.5423026j],
       [-3.0509028 +10.940544j , -8.904491   +0.j       ,
        -5.1351523  -6.559373j ],
       [ 5.9896846  -3.5423026j, -5.1351523  +6.559373j ,
         0.01320427 +0.j       ]], dtype=complex64)

More advanced autodiff#

In this notebook, we worked through some easy, and then progressively more complicated, applications of automatic differentiation in JAX. We hope you now feel that taking derivatives in JAX is easy and powerful.

There’s a whole world of other autodiff tricks and functionality out there. Topics we didn’t cover, but hope to in an “Advanced Autodiff Cookbook” include:

  • Gauss-Newton Vector Products, linearizing once

  • Custom VJPs and JVPs

  • Efficient derivatives at fixed-points

  • Estimating the trace of a Hessian using random Hessian-vector products.

  • Forward-mode autodiff using only reverse-mode autodiff.

  • Taking derivatives with respect to custom data types.

  • Checkpointing (binomial checkpointing for efficient reverse-mode, not model snapshotting).

  • Optimizing VJPs with Jacobian pre-accumulation.

Custom derivative rules for JAX-transformable Python functions#

Open in Colab Open in Kaggle

mattjj@ Mar 19 2020, last updated Oct 14 2020

There are two ways to define differentiation rules in JAX:

  1. using jax.custom_jvp and jax.custom_vjp to define custom differentiation rules for Python functions that are already JAX-transformable; and

  2. defining new core.Primitive instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems.

This notebook is about #1. To read instead about #2, see the notebook on adding primitives.

For an introduction to JAX’s automatic differentiation API, see The Autodiff Cookbook. This notebook assumes some familiarity with jax.jvp and jax.grad, and the mathematical meaning of JVPs and VJPs.

TL;DR#

Custom JVPs with jax.custom_jvp#
import jax.numpy as jnp
from jax import custom_jvp

@custom_jvp
def f(x, y):
  return jnp.sin(x) * y

@f.defjvp
def f_jvp(primals, tangents):
  x, y = primals
  x_dot, y_dot = tangents
  primal_out = f(x, y)
  tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
  return primal_out, tangent_out
from jax import jvp, grad

print(f(2., 3.))
y, y_dot = jvp(f, (2., 3.), (1., 0.))
print(y)
print(y_dot)
print(grad(f)(2., 3.))
2.7278922
2.7278922
-1.2484405
-1.2484405
# Equivalent alternative using the defjvps convenience wrapper

@custom_jvp
def f(x, y):
  return jnp.sin(x) * y

f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y,
          lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot)
print(f(2., 3.))
y, y_dot = jvp(f, (2., 3.), (1., 0.))
print(y)
print(y_dot)
print(grad(f)(2., 3.))
2.7278922
2.7278922
-1.2484405
-1.2484405
Custom VJPs with jax.custom_vjp#
from jax import custom_vjp

@custom_vjp
def f(x, y):
  return jnp.sin(x) * y

def f_fwd(x, y):
# Returns primal output and residuals to be used in backward pass by f_bwd.
  return f(x, y), (jnp.cos(x), jnp.sin(x), y)

def f_bwd(res, g):
  cos_x, sin_x, y = res # Gets residuals computed in f_fwd
  return (cos_x * g * y, sin_x * g)

f.defvjp(f_fwd, f_bwd)
print(grad(f)(2., 3.))
-1.2484405

Example problems#

To get an idea of what problems jax.custom_jvp and jax.custom_vjp are meant to solve, let’s go over a few examples. A more thorough introduction to the jax.custom_jvp and jax.custom_vjp APIs is in the next section.

Numerical stability#

One application of jax.custom_jvp is to improve the numerical stability of differentiation.

Say we want to write a function called log1pexp, which computes \(x \mapsto \log ( 1 + e^x )\). We can write that using jax.numpy:

import jax.numpy as jnp

def log1pexp(x):
  return jnp.log(1. + jnp.exp(x))

log1pexp(3.)
Array(3.0485873, dtype=float32, weak_type=True)

Since it’s written in terms of jax.numpy, it’s JAX-transformable:

from jax import jit, grad, vmap

print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
3.0485873
0.95257413
[0.5       0.7310586 0.8807971]

But there’s a numerical stability problem lurking here:

print(grad(log1pexp)(100.))
nan

That doesn’t seem right! After all, the derivative of \(x \mapsto \log (1 + e^x)\) is \(x \mapsto \frac{e^x}{1 + e^x}\), and so for large values of \(x\) we’d expect the value to be about 1.

We can get a bit more insight into what’s going on by looking at the jaxpr for the gradient computation:

from jax import make_jaxpr

make_jaxpr(grad(log1pexp))(100.)
{ lambda ; a:f32[]. let
    b:f32[] = exp a
    c:f32[] = add 1.0 b
    _:f32[] = log c
    d:f32[] = div 1.0 c
    e:f32[] = mul d b
  in (e,) }

Stepping through how the jaxpr would be evaluated, we can see that the last line would involve multiplying values that floating point math will round to 0 and \(\infty\), respectively, which is never a good idea. That is, we’re effectively evaluating lambda x: (1 / (1 + jnp.exp(x))) * jnp.exp(x) for large x, which effectively turns into 0. * jnp.inf.

Instead of generating such large and small values, hoping for a cancellation that floats can’t always provide, we’d rather just express the derivative function as a more numerically stable program. In particular, we can write a program that more closely evaluates the equal mathematical expression \(1 - \frac{1}{1 + e^x}\), with no cancellation in sight.

This problem is interesting because even though our definition of log1pexp could already be JAX-differentiated (and transformed with jit, vmap, …), we’re not happy with the result of applying standard autodiff rules to the primitives comprising log1pexp and composing the result. Instead, we’d like to specify how the whole function log1pexp should be differentiated, as a unit, and thus arrange those exponentials better.

This is one application of custom derivative rules for Python functions that are already JAX transformable: specifying how a composite function should be differentiated, while still using its original Python definition for other transformations (like jit, vmap, …).

Here’s a solution using jax.custom_jvp:

from jax import custom_jvp

@custom_jvp
def log1pexp(x):
  return jnp.log(1. + jnp.exp(x))

@log1pexp.defjvp
def log1pexp_jvp(primals, tangents):
  x, = primals
  x_dot, = tangents
  ans = log1pexp(x)
  ans_dot = (1 - 1/(1 + jnp.exp(x))) * x_dot
  return ans, ans_dot
print(grad(log1pexp)(100.))
1.0
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
3.0485873
0.95257413
[0.5       0.7310586 0.8807971]

Here’s a defjvps convenience wrapper to express the same thing:

@custom_jvp
def log1pexp(x):
  return jnp.log(1. + jnp.exp(x))

log1pexp.defjvps(lambda t, ans, x: (1 - 1/(1 + jnp.exp(x))) * t)
print(grad(log1pexp)(100.))
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
1.0
3.0485873
0.95257413
[0.5       0.7310586 0.8807971]
Enforcing a differentiation convention#

A related application is to enforce a differentiation convention, perhaps at a boundary.

Consider the function \(f : \mathbb{R}_+ \to \mathbb{R}_+\) with \(f(x) = \frac{x}{1 + \sqrt{x}}\), where we take \(\mathbb{R}_+ = [0, \infty)\). We might implement \(f\) as a program like this:

def f(x):
  return x / (1 + jnp.sqrt(x))

As a mathematical function on \(\mathbb{R}\) (the full real line), \(f\) is not differentiable at zero (because the limit defining the derivative doesn’t exist from the left). Correspondingly, autodiff produces a nan value:

print(grad(f)(0.))
nan

But mathematically if we think of \(f\) as a function on \(\mathbb{R}_+\) then it is differentiable at 0 [Rudin’s Principles of Mathematical Analysis Definition 5.1, or Tao’s Analysis I 3rd ed. Definition 10.1.1 and Example 10.1.6]. Alternatively, we might say as a convention we want to consider the directional derivative from the right. So there is a sensible value for the Python function grad(f) to return at 0.0, namely 1.0. By default, JAX’s machinery for differentiation assumes all functions are defined over \(\mathbb{R}\) and thus doesn’t produce 1.0 here.

We can use a custom JVP rule! In particular, we can define the JVP rule in terms of the derivative function \(x \mapsto \frac{\sqrt{x} + 2}{2(\sqrt{x} + 1)^2}\) on \(\mathbb{R}_+\),

@custom_jvp
def f(x):
  return x / (1 + jnp.sqrt(x))

@f.defjvp
def f_jvp(primals, tangents):
  x, = primals
  x_dot, = tangents
  ans = f(x)
  ans_dot = ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * x_dot
  return ans, ans_dot
print(grad(f)(0.))
1.0

Here’s the convenience wrapper version:

@custom_jvp
def f(x):
  return x / (1 + jnp.sqrt(x))

f.defjvps(lambda t, ans, x: ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * t)
print(grad(f)(0.))
1.0
Gradient clipping#

While in some cases we want to express a mathematical differentiation computation, in other cases we may even want to take a step away from mathematics to adjust the computation autodiff performs. One canonical example is reverse-mode gradient clipping.

For gradient clipping, we can use jnp.clip together with a jax.custom_vjp reverse-mode-only rule:

from functools import partial
from jax import custom_vjp

@custom_vjp
def clip_gradient(lo, hi, x):
  return x  # identity function

def clip_gradient_fwd(lo, hi, x):
  return x, (lo, hi)  # save bounds as residuals

def clip_gradient_bwd(res, g):
  lo, hi = res
  return (None, None, jnp.clip(g, lo, hi))  # use None to indicate zero cotangents for lo and hi

clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
import matplotlib.pyplot as plt
from jax import vmap

t = jnp.linspace(0, 10, 1000)

plt.plot(jnp.sin(t))
plt.plot(vmap(grad(jnp.sin))(t))
[<matplotlib.lines.Line2D at 0x7fc8cf18fc10>]
_images/69f04c238c2d26c96d6d485ae211d6dd5b50fefe99767266b0b5cc076df8e92a.png
def clip_sin(x):
  x = clip_gradient(-0.75, 0.75, x)
  return jnp.sin(x)

plt.plot(clip_sin(t))
plt.plot(vmap(grad(clip_sin))(t))
[<matplotlib.lines.Line2D at 0x7fc8cd05bca0>]
_images/b918a873ab00c934c1cd6519d94be1e07b18145357503e84805733ef1f338992.png
Python debugging#

Another application that is motivated by development workflow rather than numerics is to set a pdb debugger trace in the backward pass of reverse-mode autodiff.

When trying to track down the source of a nan runtime error, or just examine carefully the cotangent (gradient) values being propagated, it can be useful to insert a debugger at a point in the backward pass that corresponds to a specific point in the primal computation. You can do that with jax.custom_vjp.

We’ll defer an example until the next section.

Implicit function differentiation of iterative implementations#

This example gets pretty deep in the mathematical weeds!

Another application for jax.custom_vjp is reverse-mode differentiation of functions that are JAX-transformable (by jit, vmap, …) but not efficiently JAX-differentiable for some reason, perhaps because they involve lax.while_loop. (It’s not possible to produce an XLA HLO program that efficiently computes the reverse-mode derivative of an XLA HLO While loop because that would require a program with unbounded memory use, which isn’t possible to express in XLA HLO, at least without side-effecting interactions through infeed/outfeed.)

For example, consider this fixed_point routine which computes a fixed point by iteratively applying a function in a while_loop:

from jax.lax import while_loop

def fixed_point(f, a, x_guess):
  def cond_fun(carry):
    x_prev, x = carry
    return jnp.abs(x_prev - x) > 1e-6

  def body_fun(carry):
    _, x = carry
    return x, f(a, x)

  _, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess)))
  return x_star

This is an iterative procedure for numerically solving the equation \(x = f(a, x)\) for \(x\), by iterating \(x_{t+1} = f(a, x_t)\) until \(x_{t+1}\) is sufficiently close to \(x_t\). The result \(x^*\) depends on the parameters \(a\), and so we can think of there being a function \(a \mapsto x^*(a)\) that is implicitly defined by equation \(x = f(a, x)\).

We can use fixed_point to run iterative procedures to convergence, for example running Newton’s method to calculate square roots while only executing adds, multiplies, and divides:

def newton_sqrt(a):
  update = lambda a, x: 0.5 * (x + a / x)
  return fixed_point(update, a, a)
print(newton_sqrt(2.))
1.4142135

We can vmap or jit the function as well:

print(jit(vmap(newton_sqrt))(jnp.array([1., 2., 3., 4.])))
[1.        1.4142135 1.7320509 2.       ]

We can’t apply reverse-mode automatic differentiation because of the while_loop, but it turns out we wouldn’t want to anyway: instead of differentiating through the implementation of fixed_point and all its iterations, we can exploit the mathematical structure to do something that is much more memory-efficient (and FLOP-efficient in this case, too!). We can instead use the implicit function theorem [Prop A.25 of Bertsekas’s Nonlinear Programming, 2nd ed.], which guarantees (under some conditions) the existence of the mathematical objects we’re about to use. In essence, we linearize at the solution and solve those linear equations iteratively to compute the derivatives we want.

Consider again the equation \(x = f(a, x)\) and the function \(x^*\). We want to evaluate vector-Jacobian products like \(v^\mathsf{T} \mapsto v^\mathsf{T} \partial x^*(a_0)\).

At least in an open neighborhood around the point \(a_0\) at which we want to differentiate, let’s assume that the equation \(x^*(a) = f(a, x^*(a))\) holds for all \(a\). Since the two sides are equal as functions of \(a\), their derivatives must be equal as well, so let’s differentiate both sides:

\(\qquad \partial x^*(a) = \partial_0 f(a, x^*(a)) + \partial_1 f(a, x^*(a)) \partial x^*(a)\).

Setting \(A = \partial_1 f(a_0, x^*(a_0))\) and \(B = \partial_0 f(a_0, x^*(a_0))\), we can write the quantity we’re after more simply as

\(\qquad \partial x^*(a_0) = B + A \partial x^*(a_0)\),

or, by rearranging,

\(\qquad \partial x^*(a_0) = (I - A)^{-1} B\).

That means we can evaluate vector-Jacobian products like

\(\qquad v^\mathsf{T} \partial x^*(a_0) = v^\mathsf{T} (I - A)^{-1} B = w^\mathsf{T} B\),

where \(w^\mathsf{T} = v^\mathsf{T} (I - A)^{-1}\), or equivalently \(w^\mathsf{T} = v^\mathsf{T} + w^\mathsf{T} A\), or equivalently \(w^\mathsf{T}\) is the fixed point of the map \(u^\mathsf{T} \mapsto v^\mathsf{T} + u^\mathsf{T} A\). That last characterization gives us a way to write the VJP for fixed_point in terms of a call to fixed_point! Moreover, after expanding \(A\) and \(B\) back out, we can see we need only to evaluate VJPs of \(f\) at \((a_0, x^*(a_0))\).

Here’s the upshot:

from jax import vjp

@partial(custom_vjp, nondiff_argnums=(0,))
def fixed_point(f, a, x_guess):
  def cond_fun(carry):
    x_prev, x = carry
    return jnp.abs(x_prev - x) > 1e-6

  def body_fun(carry):
    _, x = carry
    return x, f(a, x)

  _, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess)))
  return x_star

def fixed_point_fwd(f, a, x_init):
  x_star = fixed_point(f, a, x_init)
  return x_star, (a, x_star)

def fixed_point_rev(f, res, x_star_bar):
  a, x_star = res
  _, vjp_a = vjp(lambda a: f(a, x_star), a)
  a_bar, = vjp_a(fixed_point(partial(rev_iter, f),
                             (a, x_star, x_star_bar),
                             x_star_bar))
  return a_bar, jnp.zeros_like(x_star)
  
def rev_iter(f, packed, u):
  a, x_star, x_star_bar = packed
  _, vjp_x = vjp(lambda x: f(a, x), x_star)
  return x_star_bar + vjp_x(u)[0]

fixed_point.defvjp(fixed_point_fwd, fixed_point_rev)
print(newton_sqrt(2.))
1.4142135
print(grad(newton_sqrt)(2.))
print(grad(grad(newton_sqrt))(2.))
0.35355338
-0.088388346

We can check our answers by differentiating jnp.sqrt, which uses a totally different implementation:

print(grad(jnp.sqrt)(2.))
print(grad(grad(jnp.sqrt))(2.))
0.35355338
-0.08838835

A limitation to this approach is that the argument f can’t close over any values involved in differentiation. That is, you might notice that we kept the parameter a explicit in the argument list of fixed_point. For this use case, consider using the low-level primitive lax.custom_root, which allows for deriviatives in closed-over variables with custom root-finding functions.

Basic usage of jax.custom_jvp and jax.custom_vjp APIs#

Use jax.custom_jvp to define forward-mode (and, indirectly, reverse-mode) rules#

Here’s a canonical basic example of using jax.custom_jvp, where the comments use Haskell-like type signatures:

from jax import custom_jvp
import jax.numpy as jnp

# f :: a -> b
@custom_jvp
def f(x):
  return jnp.sin(x)

# f_jvp :: (a, T a) -> (b, T b)
def f_jvp(primals, tangents):
  x, = primals
  t, = tangents
  return f(x), jnp.cos(x) * t

f.defjvp(f_jvp)
<function __main__.f_jvp(primals, tangents)>
from jax import jvp

print(f(3.))

y, y_dot = jvp(f, (3.,), (1.,))
print(y)
print(y_dot)
0.14112
0.14112
-0.9899925

In words, we start with a primal function f that takes inputs of type a and produces outputs of type b. We associate with it a JVP rule function f_jvp that takes a pair of inputs representing the primal inputs of type a and the corresponding tangent inputs of type T a, and produces a pair of outputs representing the primal outputs of type b and tangent outputs of type T b. The tangent outputs should be a linear function of the tangent inputs.

You can also use f.defjvp as a decorator, as in

@custom_jvp
def f(x):
  ...

@f.defjvp
def f_jvp(primals, tangents):
  ...

Even though we defined only a JVP rule and no VJP rule, we can use both forward- and reverse-mode differentiation on f. JAX will automatically transpose the linear computation on tangent values from our custom JVP rule, computing the VJP as efficiently as if we had written the rule by hand:

from jax import grad

print(grad(f)(3.))
print(grad(grad(f))(3.))
-0.9899925
-0.14112

For automatic transposition to work, the JVP rule’s output tangents must be linear as a function of the input tangents. Otherwise a transposition error is raised.

Multiple arguments work like this:

@custom_jvp
def f(x, y):
  return x ** 2 * y

@f.defjvp
def f_jvp(primals, tangents):
  x, y = primals
  x_dot, y_dot = tangents
  primal_out = f(x, y)
  tangent_out = 2 * x * y * x_dot + x ** 2 * y_dot
  return primal_out, tangent_out
print(grad(f)(2., 3.))
12.0

The defjvps convenience wrapper lets us define a JVP for each argument separately, and the results are computed separately then summed:

@custom_jvp
def f(x):
  return jnp.sin(x)

f.defjvps(lambda t, ans, x: jnp.cos(x) * t)
print(grad(f)(3.))
-0.9899925

Here’s a defjvps example with multiple arguments:

@custom_jvp
def f(x, y):
  return x ** 2 * y

f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot,
          lambda y_dot, primal_out, x, y: x ** 2 * y_dot)
print(grad(f)(2., 3.))
print(grad(f, 0)(2., 3.))  # same as above
print(grad(f, 1)(2., 3.))
12.0
12.0
4.0

As a shorthand, with defjvps you can pass a None value to indicate that the JVP for a particular argument is zero:

@custom_jvp
def f(x, y):
  return x ** 2 * y

f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot,
          None)
print(grad(f)(2., 3.))
print(grad(f, 0)(2., 3.))  # same as above
print(grad(f, 1)(2., 3.))
12.0
12.0
0.0

Calling a jax.custom_jvp function with keyword arguments, or writing a jax.custom_jvp function definition with default arguments, are both allowed so long as they can be unambiguously mapped to positional arguments based on the function signature retrieved by the standard library inspect.signature mechanism.

When you’re not performing differentiation, the function f is called just as if it weren’t decorated by jax.custom_jvp:

@custom_jvp
def f(x):
  print('called f!')  # a harmless side-effect
  return jnp.sin(x)

@f.defjvp
def f_jvp(primals, tangents):
  print('called f_jvp!')  # a harmless side-effect
  x, = primals
  t, = tangents
  return f(x), jnp.cos(x) * t
from jax import vmap, jit

print(f(3.))
called f!
0.14112
print(vmap(f)(jnp.arange(3.)))
print(jit(f)(3.))
called f!
[0.         0.84147096 0.9092974 ]
called f!
0.14112

The custom JVP rule is invoked during differentiation, whether forward or reverse:

y, y_dot = jvp(f, (3.,), (1.,))
print(y_dot)
called f_jvp!
called f!
-0.9899925
print(grad(f)(3.))
called f_jvp!
called f!
-0.9899925

Notice that f_jvp calls f to compute the primal outputs. In the context of higher-order differentiation, each application of a differentiation transform will use the custom JVP rule if and only if the rule calls the original f to compute the primal outputs. (This represents a kind of fundamental tradeoff, where we can’t make use of intermediate values from the evaluation of f in our rule and also have the rule apply in all orders of higher-order differentiation.)

grad(grad(f))(3.)
called f_jvp!
called f_jvp!
called f!
Array(-0.14112, dtype=float32, weak_type=True)

You can use Python control flow with jax.custom_jvp:

@custom_jvp
def f(x):
  if x > 0:
    return jnp.sin(x)
  else:
    return jnp.cos(x)

@f.defjvp
def f_jvp(primals, tangents):
  x, = primals
  x_dot, = tangents
  ans = f(x)
  if x > 0:
    return ans, 2 * x_dot
  else:
    return ans, 3 * x_dot
print(grad(f)(1.))
print(grad(f)(-1.))
2.0
3.0
Use jax.custom_vjp to define custom reverse-mode-only rules#

While jax.custom_jvp suffices for controlling both forward- and, via JAX’s automatic transposition, reverse-mode differentiation behavior, in some cases we may want to directly control a VJP rule, for example in the latter two example problems presented above. We can do that with jax.custom_vjp:

from jax import custom_vjp
import jax.numpy as jnp

# f :: a -> b
@custom_vjp
def f(x):
  return jnp.sin(x)

# f_fwd :: a -> (b, c)
def f_fwd(x):
  return f(x), jnp.cos(x)

# f_bwd :: (c, CT b) -> CT a
def f_bwd(cos_x, y_bar):
  return (cos_x * y_bar,)

f.defvjp(f_fwd, f_bwd)
from jax import grad

print(f(3.))
print(grad(f)(3.))
0.14112
-0.9899925

In words, we again start with a primal function f that takes inputs of type a and produces outputs of type b. We associate with it two functions, f_fwd and f_bwd, which describe how to perform the forward- and backward-passes of reverse-mode autodiff, respectively.

The function f_fwd describes the forward pass, not only the primal computation but also what values to save for use on the backward pass. Its input signature is just like that of the primal function f, in that it takes a primal input of type a. But as output it produces a pair, where the first element is the primal output b and the second element is any “residual” data of type c to be stored for use by the backward pass. (This second output is analogous to PyTorch’s save_for_backward mechanism.)

The function f_bwd describes the backward pass. It takes two inputs, where the first is the residual data of type c produced by f_fwd and the second is the output cotangents of type CT b corresponding to the output of the primal function. It produces an output of type CT a representing the cotangents corresponding to the input of the primal function. In particular, the output of f_bwd must be a sequence (e.g. a tuple) of length equal to the number of arguments to the primal function.

So multiple arguments work like this:

from jax import custom_vjp

@custom_vjp
def f(x, y):
  return jnp.sin(x) * y

def f_fwd(x, y):
  return f(x, y), (jnp.cos(x), jnp.sin(x), y)

def f_bwd(res, g):
  cos_x, sin_x, y = res
  return (cos_x * g * y, sin_x * g)

f.defvjp(f_fwd, f_bwd)
print(grad(f)(2., 3.))
-1.2484405

Calling a jax.custom_vjp function with keyword arguments, or writing a jax.custom_vjp function definition with default arguments, are both allowed so long as they can be unambiguously mapped to positional arguments based on the function signature retrieved by the standard library inspect.signature mechanism.

As with jax.custom_jvp, the custom VJP rule comprised by f_fwd and f_bwd is not invoked if differentiation is not applied. If function is evaluated, or transformed with jit, vmap, or other non-differentiation transformations, then only f is called.

@custom_vjp
def f(x):
  print("called f!")
  return jnp.sin(x)

def f_fwd(x):
  print("called f_fwd!")
  return f(x), jnp.cos(x)

def f_bwd(cos_x, y_bar):
  print("called f_bwd!")
  return (cos_x * y_bar,)

f.defvjp(f_fwd, f_bwd)
print(f(3.))
called f!
0.14112
print(grad(f)(3.))
called f_fwd!
called f!
called f_bwd!
-0.9899925
from jax import vjp

y, f_vjp = vjp(f, 3.)
print(y)
called f_fwd!
called f!
0.14112
print(f_vjp(1.))
called f_bwd!
(Array(-0.9899925, dtype=float32, weak_type=True),)

Forward-mode autodiff cannot be used on the jax.custom_vjp function and will raise an error:

from jax import jvp

try:
  jvp(f, (3.,), (1.,))
except TypeError as e:
  print('ERROR! {}'.format(e))
called f_fwd!
called f!
ERROR! can't apply forward-mode autodiff (jvp) to a custom_vjp function.

If you want to use both forward- and reverse-mode, use jax.custom_jvp instead.

We can use jax.custom_vjp together with pdb to insert a debugger trace in the backward pass:

import pdb

@custom_vjp
def debug(x):
  return x  # acts like identity

def debug_fwd(x):
  return x, x

def debug_bwd(x, g):
  import pdb; pdb.set_trace()
  return g

debug.defvjp(debug_fwd, debug_bwd)
def foo(x):
  y = x ** 2
  y = debug(y)  # insert pdb in corresponding backward pass step
  return jnp.sin(y)
jax.grad(foo)(3.)

> <ipython-input-113-b19a2dc1abf7>(12)debug_bwd()
-> return g
(Pdb) p x
Array(9., dtype=float32)
(Pdb) p g
Array(-0.91113025, dtype=float32)
(Pdb) q

More features and details#

Working with list / tuple / dict containers (and other pytrees)#

You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any pytrees are permissible, so long as their structures are consistent according to the type constraints.

Here’s a contrived example with jax.custom_jvp:

from collections import namedtuple
Point = namedtuple("Point", ["x", "y"])

@custom_jvp
def f(pt):
  x, y = pt.x, pt.y
  return {'a': x ** 2,
          'b': (jnp.sin(x), jnp.cos(y))}

@f.defjvp
def f_jvp(primals, tangents):
  pt, = primals
  pt_dot, =  tangents
  ans = f(pt)
  ans_dot = {'a': 2 * pt.x * pt_dot.x,
             'b': (jnp.cos(pt.x) * pt_dot.x, -jnp.sin(pt.y) * pt_dot.y)}
  return ans, ans_dot

def fun(pt):
  dct = f(pt)
  return dct['a'] + dct['b'][0]
pt = Point(1., 2.)

print(f(pt))
{'a': 1.0, 'b': (Array(0.84147096, dtype=float32, weak_type=True), Array(-0.41614684, dtype=float32, weak_type=True))}
print(grad(fun)(pt))
Point(x=Array(2.5403023, dtype=float32, weak_type=True), y=Array(0., dtype=float32, weak_type=True))

And an analogous contrived example with jax.custom_vjp:

@custom_vjp
def f(pt):
  x, y = pt.x, pt.y
  return {'a': x ** 2,
          'b': (jnp.sin(x), jnp.cos(y))}

def f_fwd(pt):
  return f(pt), pt

def f_bwd(pt, g):
  a_bar, (b0_bar, b1_bar) = g['a'], g['b']
  x_bar = 2 * pt.x * a_bar + jnp.cos(pt.x) * b0_bar
  y_bar = -jnp.sin(pt.y) * b1_bar
  return (Point(x_bar, y_bar),)

f.defvjp(f_fwd, f_bwd)

def fun(pt):
  dct = f(pt)
  return dct['a'] + dct['b'][0]
pt = Point(1., 2.)

print(f(pt))
{'a': 1.0, 'b': (Array(0.84147096, dtype=float32, weak_type=True), Array(-0.41614684, dtype=float32, weak_type=True))}
print(grad(fun)(pt))
Point(x=Array(2.5403023, dtype=float32, weak_type=True), y=Array(-0., dtype=float32, weak_type=True))
Handling non-differentiable arguments#

Some use cases, like the final example problem, call for non-differentiable arguments like function-valued arguments to be passed to functions with custom differentiation rules, and for those arguments to also be passed to the rules themselves. In the case of fixed_point, the function argument f was such a non-differentiable argument. A similar situation arises with jax.experimental.odeint.

jax.custom_jvp with nondiff_argnums#

Use the optional nondiff_argnums parameter to jax.custom_jvp to indicate arguments like these. Here’s an example with jax.custom_jvp:

from functools import partial

@partial(custom_jvp, nondiff_argnums=(0,))
def app(f, x):
  return f(x)

@app.defjvp
def app_jvp(f, primals, tangents):
  x, = primals
  x_dot, = tangents
  return f(x), 2. * x_dot
print(app(lambda x: x ** 3, 3.))
27.0
print(grad(app, 1)(lambda x: x ** 3, 3.))
2.0

Notice the gotcha here: no matter where in the argument list these parameters appear, they’re placed at the start of the signature of the corresponding JVP rule. Here’s another example:

@partial(custom_jvp, nondiff_argnums=(0, 2))
def app2(f, x, g):
  return f(g((x)))

@app2.defjvp
def app2_jvp(f, g, primals, tangents):
  x, = primals
  x_dot, = tangents
  return f(g(x)), 3. * x_dot
print(app2(lambda x: x ** 3, 3., lambda y: 5 * y))
3375.0
print(grad(app2, 1)(lambda x: x ** 3, 3., lambda y: 5 * y))
3.0
jax.custom_vjp with nondiff_argnums#

A similar option exists for jax.custom_vjp, and, similarly, the convention is that the non-differentiable arguments are passed as the first arguments to the _bwd rule, no matter where they appear in the signature of the original function. The signature of the _fwd rule remains unchanged - it is the same as the signature of the primal function. Here’s an example:

@partial(custom_vjp, nondiff_argnums=(0,))
def app(f, x):
  return f(x)

def app_fwd(f, x):
  return f(x), x

def app_bwd(f, x, g):
  return (5 * g,)

app.defvjp(app_fwd, app_bwd)
print(app(lambda x: x ** 2, 4.))
16.0
print(grad(app, 1)(lambda x: x ** 2, 4.))
5.0

See fixed_point above for another usage example.

You don’t need to use nondiff_argnums with array-valued arguments, for example ones with integer dtype. Instead, nondiff_argnums should only be used for argument values that don’t correspond to JAX types (essentially don’t correspond to array types), like Python callables or strings. If JAX detects that an argument indicated by nondiff_argnums contains a JAX Tracer, then an error is raised. The clip_gradient function above is a good example of not using nondiff_argnums for integer-dtype array arguments.

Control autodiff’s saved values with jax.checkpoint (aka jax.remat)#

import jax
import jax.numpy as jnp

TL;DR#

Use the jax.checkpoint decorator (aliased as jax.remat) with jax.grad to control which intermediates are saved on the forward pass versus recomputed on the backward pass, trading off memory and FLOPs.

Don’t miss the practical notes for a discussion about how jax.checkpoint interacts with jax.jit.

Without using jax.checkpoint, the forward pass of jax.grad(f)(x) saves, for use on the backward pass, the values of Jacobian coefficients and other intermediates. We call these saved values residuals:

def g(W, x):
  y = jnp.dot(W, x)
  return jnp.sin(y)

def f(W1, W2, W3, x):
  x = g(W1, x)
  x = g(W2, x)
  x = g(W3, x)
  return x

W1 = jnp.ones((5, 4))
W2 = jnp.ones((6, 5))
W3 = jnp.ones((7, 6))
x = jnp.ones(4)

# Inspect the 'residual' values to be saved on the forward pass
# if we were to evaluate `jax.grad(f)(W1, W2, W3, x)`
from jax.ad_checkpoint import print_saved_residuals
jax.ad_checkpoint.print_saved_residuals(f, W1, W2, W3, x)
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
f32[5] output of cos from <ipython-input-4-f510dde58e22>:3 (g)
f32[6] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
f32[6] output of cos from <ipython-input-4-f510dde58e22>:3 (g)
f32[7] output of cos from <ipython-input-4-f510dde58e22>:3 (g)

By applying jax.checkpoint to sub-functions, as a decorator or at specific application sites, we force JAX not to save any of that sub-function’s residuals. Instead, only the inputs of a jax.checkpoint-decorated function might be saved, and any residuals consumed on the backward pass are re-computed from those inputs as needed:

def f2(W1, W2, W3, x):
  x = jax.checkpoint(g)(W1, x)
  x = jax.checkpoint(g)(W2, x)
  x = jax.checkpoint(g)(W3, x)
  return x

jax.ad_checkpoint.print_saved_residuals(f2, W1, W2, W3, x)
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
f32[6] output of sin from <ipython-input-4-f510dde58e22>:3 (g)

Here the values of two sin applications are saved because they are arguments in subsequent applications of the jax.checkpoint-decorated g function, and inputs to a jax.checkpoint-decorated function may be saved. But no values of cos applications are saved.

To control which values are saveable without having to edit the definition of the function to be differentiated, you can use a rematerialization policy. Here is an example that saves only the results of dot operations with no batch dimensions (since they are often FLOP-bound, and hence worth saving rather than recomputing):

f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
jax.ad_checkpoint.print_saved_residuals(f3, W1, W2, W3, x)
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)
f32[6] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)
f32[7] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)

You can also use policies to refer to intermediate values you name using jax.ad_checkpoint.checkpoint_name:

from jax.ad_checkpoint import checkpoint_name

def f4(W1, W2, W3, x):
  x = checkpoint_name(g(W1, x), name='a')
  x = checkpoint_name(g(W2, x), name='b')
  x = checkpoint_name(g(W3, x), name='c')
  return x

f4 = jax.checkpoint(f4, policy=jax.checkpoint_policies.save_only_these_names('a'))
jax.ad_checkpoint.print_saved_residuals(f4, W1, W2, W3, x)
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] named 'a' from <ipython-input-7-fc0ed1c14b8d>:4 (f4)

When playing around with these toy examples, we can get a closer look at what’s going on using the print_fwd_bwd utility defined in this notebook:

from jax.tree_util import tree_flatten, tree_unflatten

from rich.console import Console
from rich.table import Table
import rich.text

def print_fwd_bwd(f, *args, **kwargs) -> None:
  args, in_tree = tree_flatten((args, kwargs))

  def f_(*args):
    args, kwargs = tree_unflatten(in_tree, args)
    return f(*args, **kwargs)

  fwd = jax.make_jaxpr(lambda *args: jax.vjp(f_, *args))(*args).jaxpr

  y, f_vjp = jax.vjp(f_, *args)
  res, in_tree = tree_flatten(f_vjp)

  def g_(*args):
    *res, y = args
    f_vjp = tree_unflatten(in_tree, res)
    return f_vjp(y)

  bwd = jax.make_jaxpr(g_)(*res, y).jaxpr

  table = Table(show_header=False, show_lines=True, padding=(1, 2, 0, 2), box=None)
  table.add_row("[bold green]forward computation:",
                "[bold green]backward computation:")
  table.add_row(rich.text.Text.from_ansi(str(fwd)),
                rich.text.Text.from_ansi(str(bwd)))
  console = Console(width=240, force_jupyter=True)
  console.print(table)

def _renderable_repr(self):
  return self.html
rich.jupyter.JupyterRenderable._repr_html_ = _renderable_repr
# no use of jax.checkpoint:
print_fwd_bwd(f, W1, W2, W3, x)
                                                                                                                                                                      
  forward computation:                                                        backward computation:                                                                   
                                                                                                                                                                      
  { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let                   { lambda ; a:f32[7] b:f32[6] c:f32[7,6] d:f32[6] e:f32[5] f:f32[6,5] g:f32[5] h:f32[4]  
      e:f32[5] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a d        i:f32[5,4] j:f32[7]. let                                                            
      f:f32[5] = sin e                                                            k:f32[7] = mul j a                                                                  
      g:f32[5] = cos e                                                            l:f32[6] = dot_general[dimension_numbers=(([0], [0]), ([], []))] k c                
      h:f32[6] = dot_general[dimension_numbers=(([1], [0]), ([], []))] b f        m:f32[7,6] = dot_general[dimension_numbers=(([], []), ([], []))] k b                
      i:f32[6] = sin h                                                            n:f32[6] = mul l d                                                                  
      j:f32[6] = cos h                                                            o:f32[5] = dot_general[dimension_numbers=(([0], [0]), ([], []))] n f                
      k:f32[7] = dot_general[dimension_numbers=(([1], [0]), ([], []))] c i        p:f32[6,5] = dot_general[dimension_numbers=(([], []), ([], []))] n e                
      l:f32[7] = sin k                                                            q:f32[5] = mul o g                                                                  
      m:f32[7] = cos k                                                            r:f32[4] = dot_general[dimension_numbers=(([0], [0]), ([], []))] q i                
    in (l, m, i, c, j, f, b, g, d, a) }                                           s:f32[5,4] = dot_general[dimension_numbers=(([], []), ([], []))] q h                
                                                                                in (s, p, m, r) }                                                                     
# using jax.checkpoint with policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable:
print_fwd_bwd(f3, W1, W2, W3, x)
                                                                                                                                                                             
  forward computation:                                                        backward computation:                                                                          
                                                                                                                                                                             
  { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let                   { lambda ; a:f32[5] b:f32[6] c:f32[7] d:f32[5,4] e:f32[6,5] f:f32[7,6] g:f32[4] h:f32[7]. let  
      e:f32[5] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a d        i:f32[5,4] j:f32[6,5] k:f32[7,6] l:f32[4] = remat2[                                        
      f:f32[5] = sin e                                                              differentiated=True                                                                      
      g:f32[6] = dot_general[dimension_numbers=(([1], [0]), ([], []))] b f          jaxpr={ lambda ; m:f32[5] n:f32[6] o:f32[7] p:f32[5,4] q:f32[6,5] r:f32[7,6]             
      h:f32[6] = sin g                                                                  s:f32[4] t:f32[7]. let                                                               
      i:f32[7] = dot_general[dimension_numbers=(([1], [0]), ([], []))] c h              u:f32[5] = sin m                                                                     
      j:f32[7] = sin i                                                                  v:f32[5] = cos m                                                                     
    in (j, e, g, i, a, b, c, d) }                                                       w:f32[6] = sin n                                                                     
                                                                                        x:f32[6] = cos n                                                                     
                                                                                        y:f32[7] = cos o                                                                     
                                                                                        z:f32[7] = mul t y                                                                   
                                                                                        ba:f32[6] = dot_general[dimension_numbers=(([0], [0]), ([], []))] z r                
                                                                                        bb:f32[6] = mul ba x                                                                 
                                                                                        bc:f32[5] = dot_general[dimension_numbers=(([0], [0]), ([], []))] bb q               
                                                                                        bd:f32[5] = mul bc v                                                                 
                                                                                        be:f32[4] = dot_general[dimension_numbers=(([0], [0]), ([], []))] bd p               
                                                                                        bf:f32[5,4] = dot_general[dimension_numbers=(([], []), ([], []))] bd s               
                                                                                        bg:f32[6,5] = dot_general[dimension_numbers=(([], []), ([], []))] bb u               
                                                                                        bh:f32[7,6] = dot_general[dimension_numbers=(([], []), ([], []))] z w                
                                                                                      in (bf, bg, bh, be) }                                                                  
                                                                                    policy=<function dot_with_no_batch_dims at 0x7f5e469b1700>                               
                                                                                    prevent_cse=True                                                                         
                                                                                  ] a b c d e f g h                                                                          
                                                                                in (i, j, k, l) }                                                                            

Let’s think step by step#

You might want to first (re)read the Autodiff Cookbook Part 1.

Fundamentals of jax.checkpoint#

In both jax.linearize and jax.vjp there is flexibility in how and when some values are computed. Different choices can trade off memory use against FLOPs. JAX provides control over these choices with jax.checkpoint.

One such choice is whether to perform Jacobian coefficient computations on the forward pass, as soon as the inputs are available, or on the backward pass, just before the coefficients are needed. Consider the example of sin_vjp:

def sin_vjp(x):
  y = jnp.sin(x)
  cos_x = jnp.cos(x)
  return y, lambda y_bar: cos_x * y_bar

Another valid implementation would compute the value of jnp.cos(x) on the backward pass rather than on the forward pass:

def sin_vjp2(x):
  y = jnp.sin(x)
  return y, lambda y_bar: jnp.cos(x) * y_bar

For this particular function, the amount of memory used by the two versions is the same, though we’ve reduced the FLOPs for the primal computation (i.e. the forward pass) and increased the FLOPs for the cotangent computation (i.e. the backward pass).

There’s another choice when it comes to function composition. Recall our VJP rule for a composition of two functions:

def f(x):
  y = g(x)
  z = h(y)
  return z

def f_vjp(x):
  y, g_vjp = jax.vjp(g, x)
  z, h_vjp = jax.vjp(h, y)
  def f_bwd(z_bar):
    y_bar, = h_vjp(z_bar)
    x_bar, = g_vjp(y_bar)
    return x_bar
  return z, f_bwd

An alternative is:

def f_vjp_checkpoint(x):
  y = g(x)
  z, h_vjp = jax.vjp(h, y)
  def f_bwd2(z_bar):
    y_bar, = h_vjp(z_bar)
    _, g_vjp = jax.vjp(g, x)
    x_bar, = g_vjp(y_bar)
    return x_bar
  return z, f_bwd2

In words, this alternative implementation doesn’t compute g_vjp, or the residual values in its closure, on the forward pass. Instead it only computes them in the backward pass f_bwd2. That means f_vjp_checkpoint requires less memory: if g and h each required similar amounts of memory for their residuals, each much larger than x, then the function produced by f_vjp_checkpoint(x) requires half the memory as that of f_vjp(x)!

The cost we pay is redundant work: in f_bwd2 we must re-evaluate g(x) as part of jax.vjp(g, x) just to discard its value (in the underscore variable on the line _, g_vjp = jax.vjp(g, x)).

We can get this VJP behavior in autodiff � without having to write VJP functions directly � by instead using jax.checkpoint in an alternative definition of the original function f:

def f_checkpoint(x):
  y = jax.checkpoint(g)(x)
  z = h(y)
  return z

In other words, we apply jax.checkpoint to g, the first stage of f, rather than to f itself. This way, when we evaluate jax.grad(f_checkpoint)(x), we’d get a computation like:

  1. run the forward pass of g, discarding residual values;

  2. run the forward pass of h, saving residuals;

  3. run the backward pass of h, consuming residuals from step 2;

  4. re-run the forward pass of g, saving residuals;

  5. run the backward pass of g, consuming residuals from step 4.

That is, by evaluating jax.grad(f_checkpoint)(x) we’d get the same computation as:

def f_checkpoint_grad(x):
  y = g(x)                  # step 1
  _, h_vjp = jax.vjp(h)(y)  # step 2
  y_bar, = h_vjp(1.0)       # step 3
  _, g_vjp = jax.vjp(g, x)  # step 4
  x_bar, = g_vjp(y_bar)     # step 5
  return x_bar

In general, jax.checkpoint(foo) is a new function which has the same input-output behavior as foo, but behaves differently under autodiff, particularly under jax.linearize and jax.vjp (and their wrappers, like jax.grad) but not jax.jvp. When differentiated, only the input to a jax.checkpoint-differentiated function is stored on the forward pass; on the backward pass, residuals (i.e. intermediates from foo and its Jacobian coefficient values needed for the backward pass) are recomputed.

Notice that if f = lambda x: h(g(x)) is the function we want to differentiate, i.e. if we want to apply jax.grad(f), we don’t get any memory savings by applying jax.checkpoint to f itself. That’s because evaluating jax.grad(jax.checkpoint(f))(x) would lead to a computation like:

  1. run the forward pass, discarding all residuals;

  2. immediately re-run the forward pass, saving residuals;

  3. run the backward pass, consuming residuals from step 2.

That is, in code we’d have something like:

def f_grad_bad(x):
  _ = f(x)                  # step 1
  _, f_vjp = jax.vjp(f, x)  # step 2
  x_bar, = f_vjp(1.0)       # step 3
  return x_bar

We also wouldn’t get any memory savings by applying jax.checkpoint to h, the second stage of f. That’s because evaluating jax.grad(lambda x: jax.checkpoint(h)(g(x))) would lead to a computation like:

  1. run the forward pass of g, saving residuals;

  2. run the forward pass of h, discarding residuals;

  3. immediately re-run the forward pass of h, saving residuals;

  4. run the backward pass of h, consuming residuals from step 3;

  5. run the backward pass of g, consuming residuals from step 1.

That is, in code we’d have something like:

def f_grad_bad2(x):
  y, g_vjp = jax.vjp(g, x)  # step 1
  z = h(y)                  # step 2
  _, h_vjp = jax.vjp(h, y)  # step 3
  y_bar, = h_vjp(1.0)       # step 3
  x_bar, = g_vjp(y_bar)     # step 5
  return x_bar

Slightly more generally, if we had a chain composition of functions, like f = lambda x: f3(f2(f1(x))), and we were interested in evaluating jax.grad(f), we could say that:

  • we shouldn’t apply jax.checkpoint to the whole function f, since that wouldn’t save any memory (and will perform wasteful recomputation);

  • we shouldn’t apply jax.checkpoint to the last sub-function f3, since that wouldn’t save any memory (and will perform wasteful recomputation);

  • we could apply jax.checkpoint to f1, f2, or their composition lambda x: f2(f1(x)), since any of those might save memory and would express different memory/recompute tradeoffs.

Custom policies for what’s saveable#

As shown so far, using jax.checkpoint switches from one extreme to another:

  • without jax.checkpoint, JAX’s autodiff tends to compute everything possible on the forward pass and store it for the backward pass;

  • with a jax.checkpoint decorator, we instead compute as little as possible on the forward pass and recompute values as needed on the backward pass.

To operate between these two extremes, saving some things and not others, we can carefully place jax.checkpoint decorators on sub-functions. But that requires editing the function to be differentiated, e.g. model code, which may be inconvenient. It can also be hard to experiment with variations.

So an alternative is to use the policy argument to jax.checkpoint. A policy is a callable (i.e. a function) which takes as input a type-level specification of a first order primitive application and returns a boolean indicating whether the corresponding output value(s) are allowed to be saved as residuals (or instead must be recomputed in the (co)tangent computation as needed). To write robust code, a policy should be selected from the attributes on jax.checkpoint_policies, like jax.checkpoint_policies.dots_with_no_batch_dims_saveable, since the API for writing custom policy callables is considered internal.

For example, consider this function to be differentiated:

def loss(params, x, y):
  return jnp.sum((predict(params, x) - y)**2)

def predict(params, x):
  *Ws, Wlast = params
  for W in Ws:
    x = layer(W, x)
  x = jnp.dot(Wlast, x)
  return x

def layer(W, x):
  return jnp.sin(jnp.dot(W, x))
W1 = W2 = W3 = jnp.ones((4, 4))
params = [W1, W2, W3]
x = jnp.ones(4)
y = jnp.ones(4)
print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] output of sin from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of sin from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of mul from <ipython-input-18-3808b5023c3d>:2 (loss)

Instead of saving so many values on the forward pass, perhaps we only want to save the results of matrix multiplications with no batch dimension (since they may be FLOP- rather than memory-bound). We can do that using the policy jax.checkpoint_policies.dots_with_no_batch_dims_saveable:

loss_checkpoint = jax.checkpoint(loss, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
print_saved_residuals(loss_checkpoint, params, x, y)
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] from the argument 'y'
f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:8 (predict)

Notice also that by providing a policy, we didn’t need to edit the code defining loss, predict, or layer. That is particularly convenient if we want to experiment with policies in calling code (e.g. a training script) without changing library code (e.g. the neural network library).

Some policies can refer to values named with jax.ad_checkpoint.checkpoint_name:

from jax.ad_checkpoint import checkpoint_name

def predict(params, x):
  *Ws, Wlast = params
  for i, W in enumerate(Ws):
    x = layer(W, x)
    x = checkpoint_name(x, name=f'layer{i}_output')
  x = jnp.dot(Wlast, x)
  return x

By itself, checkpoint_name is just an identity function. But because some policy functions know to look for them, we can use the names to control whether certain values output by checkpoint_name are considered saveable:

print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] named 'layer0_output' from <ipython-input-22-e48aedf368ad>:7 (predict)
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] named 'layer1_output' from <ipython-input-22-e48aedf368ad>:7 (predict)
f32[4] output of mul from <ipython-input-18-3808b5023c3d>:2 (loss)
loss_checkpoint2 = jax.checkpoint(loss, policy=jax.checkpoint_policies.save_any_names_but_these('layer1_output'))
print_saved_residuals(loss_checkpoint2, params, x, y)
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] from the argument 'y'

Another policy which refers to names is jax.checkpoint_policies.save_only_these_names.

Some of the policies are:

  • everything_saveable (the default strategy, as if jax.checkpoint were not being used at all)

  • nothing_saveable (i.e. rematerialize everything, as if a custom policy were not being used at all)

  • dots_saveable or its alias checkpoint_dots

  • dots_with_no_batch_dims_saveable or its alias checkpoint_dots_with_no_batch_dims

  • save_anything_but_these_names (save any values except for the output of checkpoint_name with any of the names given)

  • save_any_names_but_these (save only named values, i.e. any outputs of checkpoint_name, except for those with the names given)

  • save_only_these_names (save only named values, and only among the names given)

Policies only indicate what is saveable; a value is only saved if it’s actually needed by the backward pass.

Advanced: recursive jax.checkpoint#

By applying jax.checkpoint in the right way, there are many tradeoffs between memory usage and (re)computation that can be expressed. One surprising example is recursive checkpointing, where we apply jax.checkpoint to a function which itself calls jax.checkpoint-decorated functions in a way so that memory usage from the chain composition of \(D\) functions scales like \(\mathcal{O}(\log_2 D)\) rather than \(\mathcal{O}(D)\).

As a toy example, consider the chain composition of multiple jnp.sin functions:

def chain_compose(funs):
  def f(x):
    for fun in funs:
      x = fun(x)
    return x
  return f

f = chain_compose([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)

In general, the number of stored residuals scales linearly with the length of the chain:

f = chain_compose([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)

But we can apply jax.checkpoint recursively to improve the scaling:

def recursive_checkpoint(funs):
  if len(funs) == 1:
    return funs[0]
  elif len(funs) == 2:
    f1, f2 = funs
    return lambda x: f1(f2(x))
  else:
    f1 = recursive_checkpoint(funs[:len(funs)//2])
    f2 = recursive_checkpoint(funs[len(funs)//2:])
    return lambda x: f1(jax.checkpoint(f2)(x))
f = recursive_checkpoint([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] from the argument 'x'
f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f = recursive_checkpoint([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] from the argument 'x'
f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)

The cost here, as usual, is recomputation: in particular, we end up performing \(\mathcal{O}(\log_2 D)\) times as many FLOPs:

f = chain_compose([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
                                                                                                                                 
  forward computation:                  backward computation:                                                                    
                                                                                                                                 
  { lambda ; a:f32[]. let               { lambda ; a:f32[] b:f32[] c:f32[] d:f32[] e:f32[] f:f32[] g:f32[] h:f32[] i:f32[]. let  
      b:f32[] = sin a                       j:f32[] = mul i a                                                                    
      c:f32[] = cos a                       k:f32[] = mul j b                                                                    
      d:f32[] = sin b                       l:f32[] = mul k c                                                                    
      e:f32[] = cos b                       m:f32[] = mul l d                                                                    
      f:f32[] = sin d                       n:f32[] = mul m e                                                                    
      g:f32[] = cos d                       o:f32[] = mul n f                                                                    
      h:f32[] = sin f                       p:f32[] = mul o g                                                                    
      i:f32[] = cos f                       q:f32[] = mul p h                                                                    
      j:f32[] = sin h                     in (q,) }                                                                              
      k:f32[] = cos h                                                                                                            
      l:f32[] = sin j                                                                                                            
      m:f32[] = cos j                                                                                                            
      n:f32[] = sin l                                                                                                            
      o:f32[] = cos l                                                                                                            
      p:f32[] = sin n                                                                                                            
      q:f32[] = cos n                                                                                                            
    in (p, q, o, m, k, i, g, e, c) }                                                                                             
f = recursive_checkpoint([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
                                                                                                                                        
  forward computation:                                                              backward computation:                               
                                                                                                                                        
  { lambda ; a:f32[]. let                                                           { lambda ; a:f32[] b:f32[] c:f32[] d:f32[]. let     
      b:f32[] = remat2[                                                                 e:f32[] = mul d a                               
        differentiated=False                                                            f:f32[] = mul e b                               
        jaxpr={ lambda ; c:f32[]. let d:f32[] = sin c; e:f32[] = sin d in (e,) }        g:f32[] = remat2[                               
        policy=None                                                                       differentiated=True                           
        prevent_cse=True                                                                  jaxpr={ lambda ; h:f32[] i:f32[]. let         
      ] a                                                                                     j:f32[] = sin h                           
      f:f32[] = sin b                                                                         k:f32[] = cos h                           
      g:f32[] = sin f                                                                         l:f32[] = cos j                           
      h:f32[] = sin g                                                                         m:f32[] = mul i l                         
      i:f32[] = sin h                                                                         n:f32[] = mul m k                         
      j:f32[] = sin i                                                                       in (n,) }                                   
      k:f32[] = cos i                                                                     policy=None                                   
      l:f32[] = sin j                                                                     prevent_cse=True                              
      m:f32[] = cos j                                                                   ] c f                                           
    in (l, m, k, g, a) }                                                                o:f32[] = remat2[                               
                                                                                          differentiated=True                           
                                                                                          jaxpr={ lambda ; p:f32[] q:f32[]. let         
                                                                                              r:f32[] = sin p                           
                                                                                              s:f32[] = sin r                           
                                                                                              t:f32[] = sin s                           
                                                                                              u:f32[] = cos s                           
                                                                                              v:f32[] = cos t                           
                                                                                              w:f32[] = mul q v                         
                                                                                              x:f32[] = mul w u                         
                                                                                              y:f32[] = remat2[                         
                                                                                                differentiated=True                     
                                                                                                jaxpr={ lambda ; z:f32[] ba:f32[]. let  
                                                                                                    bb:f32[] = sin z                    
                                                                                                    bc:f32[] = cos z                    
                                                                                                    bd:f32[] = cos bb                   
                                                                                                    be:f32[] = mul ba bd                
                                                                                                    bf:f32[] = mul be bc                
                                                                                                  in (bf,) }                            
                                                                                                policy=None                             
                                                                                                prevent_cse=True                        
                                                                                              ] p x                                     
                                                                                            in (y,) }                                   
                                                                                          policy=None                                   
                                                                                          prevent_cse=True                              
                                                                                        ] 3.0 g                                         
                                                                                      in (o,) }                                         

Practical notes#

When differentiated functions are staged out to XLA for compilation, for example by applying jax.jit to a function which contains a jax.grad call, XLA will automatically optimize the computation, including decisions about when to compute or rematerialize values. As a result, jax.checkpoint often isn’t needed for differentiated functions under a jax.jit. XLA will optimize things for you.

One exception is when using staged-out control flow, like jax.lax.scan. Automatic compiler optimizations across multiple control flow primitives, e.g. across a forward-pass scan and the corresponding backward-pass scan, typically aren’t aren’t as thorough. As a result, it’s often a good idea to use jax.checkpoint on the body function passed to jax.lax.scan.

For example, one common pattern in large Transformer models is to express the architecture as a jax.lax.scan over layers so as to reduce compilation times. That is, using a simple fully-connected network as an analogy, instead of writing something like this:

LayerParam = tuple[jnp.ndarray, jnp.ndarray]  # weights, bias pair for a layer
ParamsList = list[LayerParam]

def net(params: ParamsList, x: jnp.ndarray):
  for W, b in params:
    x = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return x

We would instead iterate over the layer application with jax.lax.scan:

StackedWeights = jnp.ndarray  # all weight matrices stacked together
StackedBiases = jnp.ndarray   # all bias vectors stacked together

all_weights = jnp.stack([W for W, _ in params])
all_biases = jnp.stack([b for _, b in params])

def layer(x, W_b_pair):
  W, b = W_b_pair
  out = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return out, None

def net(all_weights, all_biases, x):
  x, _ = jax.lax.scan(layer, x, (all_weights, all_biases))
  return x

This scan-over-layers version reduces compile times, but by foiling some compiler optimizations it can lead to inefficient computation of gradients. To mitigate the issue, we would use jax.checkpoint on the scanned function:

from functools import partial

@partial(jax.checkpoint,
         policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def layer(x, W_b_pair):
  W, b = W_b_pair
  out = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return out, None

By using jax.checkpoint this way, we’re manually controlling which values JAX’s autodiff saves between the forward and backward passes, and hence not relying on XLA optimizations to choose for us.

How JAX primitives work#

Open in Colab Open in Kaggle

necula@google.com, October 2019.

JAX implements certain transformations of Python functions, e.g., jit, grad, vmap, or pmap. The Python functions to be transformed must be JAX-traceable, which means that as the Python function executes the only operations it applies to the data are either inspections of data attributes such as shape or type, or special operations called JAX primitives. In particular, a JAX-traceable function is sometimes invoked by JAX with abstract arguments. An example of a JAX abstract value is ShapedArray(float32[2,2]), which captures the type and the shape of values, but not the concrete data values. JAX primitives know how to operate on both concrete data values and on the JAX abstract values.

The JAX-transformed functions must themselves be JAX-traceable functions, to ensure that these transformations can be composed, e.g., jit(jacfwd(grad(f))).

There are pre-defined JAX primitives corresponding to most XLA operations, e.g., add, matmul, sin, cos, indexing. JAX comes with an implementation of numpy functions in terms of JAX primitives, which means that Python programs using JAX’s implementation of numpy are JAX-traceable and therefore transformable. Other libraries can be made JAX-traceable by implementing them in terms of JAX primitives.

The set of JAX primitives is extensible. Instead of reimplementing a function in terms of pre-defined JAX primitives, one can define a new primitive that encapsulates the behavior of the function.

The goal of this document is to explain the interface that a JAX primitive must support in order to allow JAX to perform all its transformations.

Consider that we want to add to JAX support for a multiply-add function with three arguments, defined mathematically as “multiply_add(x, y, z) = x * y + z”. This function operates on 3 identically-shaped tensors of floating point values and performs the operations pointwise.

Using existing primitives#

The easiest way to define new functions is to write them in terms of JAX primitives, or in terms of other functions that are themselves written using JAX primitives, e.g., those defined in the jax.lax module:

from jax import lax
from jax._src import api

def multiply_add_lax(x, y, z):
  """Implementation of multiply-add using the jax.lax primitives."""
  return lax.add(lax.mul(x, y), z)


def square_add_lax(a, b):
  """A square-add function using the newly defined multiply-add."""
  return multiply_add_lax(a, a, b)

print("square_add_lax = ", square_add_lax(2., 10.))
# Differentiate w.r.t. the first argument
print("grad(square_add_lax) = ", api.grad(square_add_lax, argnums=0)(2.0, 10.))
square_add_lax =  14.0
grad(square_add_lax) =  4.0

In order to understand how JAX is internally using the primitives, we add some helpers for tracing function calls.

#@title Helper functions (execute this cell)
import functools
import traceback

_indentation = 0
def _trace(msg=None):
    """Print a message at current indentation."""
    if msg is not None:
        print("  " * _indentation + msg)

def _trace_indent(msg=None):
    """Print a message and then indent the rest."""
    global _indentation
    _trace(msg)
    _indentation = 1 + _indentation

def _trace_unindent(msg=None):
    """Unindent then print a message."""
    global _indentation
    _indentation = _indentation - 1
    _trace(msg)

def trace(name):
  """A decorator for functions to trace arguments and results."""

  def trace_func(func):  # pylint: disable=missing-docstring
    def pp(v):
        """Print certain values more succinctly"""
        vtype = str(type(v))
        if "jax._src.xla_bridge._JaxComputationBuilder" in vtype:
            return "<JaxComputationBuilder>"
        elif "jaxlib.xla_extension.XlaOp" in vtype:
            return "<XlaOp at 0x{:x}>".format(id(v))
        elif ("partial_eval.JaxprTracer" in vtype or
              "batching.BatchTracer" in vtype or
              "ad.JVPTracer" in vtype):
            return "Traced<{}>".format(v.aval)
        elif isinstance(v, tuple):
            return "({})".format(pp_values(v))
        else:
            return str(v)
    def pp_values(args):
        return ", ".join([pp(arg) for arg in args])
    
    @functools.wraps(func)
    def func_wrapper(*args):
      _trace_indent("call {}({})".format(name, pp_values(args)))
      res = func(*args)
      _trace_unindent("|<- {} = {}".format(name, pp(res)))
      return res

    return func_wrapper

  return trace_func

class expectNotImplementedError(object):
  """Context manager to check for NotImplementedError."""
  def __enter__(self): pass
  def __exit__(self, type, value, tb):
    global _indentation
    _indentation = 0
    if type is NotImplementedError:
      print("\nFound expected exception:")
      traceback.print_exc(limit=3)
      return True
    elif type is None:  # No exception
      assert False, "Expected NotImplementedError"
    else:
      return False

Instead of using jax.lax primitives directly, we can use other functions that are already written in terms of those primitives, such as those in jax.numpy:

import jax.numpy as jnp
import numpy as np

@trace("multiply_add_numpy")
def multiply_add_numpy(x, y, z):
    return jnp.add(jnp.multiply(x, y), z)

@trace("square_add_numpy")
def square_add_numpy(a, b):
    return multiply_add_numpy(a, a, b)

print("\nNormal evaluation:")  
print("square_add_numpy = ", square_add_numpy(2., 10.))
print("\nGradient evaluation:")
print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.))
Normal evaluation:
call square_add_numpy(2.0, 10.0)
  call multiply_add_numpy(2.0, 2.0, 10.0)
  |<- multiply_add_numpy = 14.0
|<- square_add_numpy = 14.0
square_add_numpy =  14.0

Gradient evaluation:
call square_add_numpy(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
  call multiply_add_numpy(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
  |<- multiply_add_numpy = Traced<ConcreteArray(14.0, dtype=float32, weak_type=True)>
|<- square_add_numpy = Traced<ConcreteArray(14.0, dtype=float32, weak_type=True)>
grad(square_add_numpy) =  4.0

Notice that in the process of computing grad, JAX invokes square_add_numpy and multiply_add_numpy with special arguments ConcreteArray(...) (described further below in this colab). It is important to remember that a JAX-traceable function must be able to operate not only on concrete arguments but also on special abstract arguments that JAX may use to abstract the function execution.

The JAX traceability property is satisfied as long as the function is written in terms of JAX primitives.

Defining new JAX primitives#

The right way to add support for multiply-add is in terms of existing JAX primitives, as shown above. However, in order to demonstrate how JAX primitives work let us pretend that we want to add a new primitive to JAX for the multiply-add functionality.

from jax import core
multiply_add_p = core.Primitive("multiply_add")  # Create the primitive

@trace("multiply_add_prim")
def multiply_add_prim(x, y, z):
  """The JAX-traceable way to use the JAX primitive.
  
  Note that the traced arguments must be passed as positional arguments
  to `bind`. 
  """
  return multiply_add_p.bind(x, y, z)

@trace("square_add_prim")
def square_add_prim(a, b):
  """A square-add function implemented using the new JAX-primitive."""
  return multiply_add_prim(a, a, b)

If we try to call the newly defined functions we get an error, because we have not yet told JAX anything about the semantics of the new primitive.

with expectNotImplementedError():
  square_add_prim(2., 10.)
call square_add_prim(2.0, 10.0)
  call multiply_add_prim(2.0, 2.0, 10.0)

Found expected exception:
Traceback (most recent call last):
  File "/tmp/ipykernel_4031/2844449444.py", line 2, in <module>
    square_add_prim(2., 10.)
  File "/tmp/ipykernel_4031/1393342955.py", line 48, in func_wrapper
    res = func(*args)
  File "/tmp/ipykernel_4031/1308506715.py", line 16, in square_add_prim
    return multiply_add_prim(a, a, b)
NotImplementedError: Evaluation rule for 'multiply_add' not implemented
Primal evaluation rules#
@trace("multiply_add_impl")
def multiply_add_impl(x, y, z):
  """Concrete implementation of the primitive.

  This function does not need to be JAX traceable.
  Args:
    x, y, z: the concrete arguments of the primitive. Will only be called with 
      concrete values.
  Returns:
    the concrete result of the primitive.
  """
  # Note that we can use the original numpy, which is not JAX traceable
  return np.add(np.multiply(x, y), z)

# Now we register the primal implementation with JAX
multiply_add_p.def_impl(multiply_add_impl)
<function __main__.multiply_add_impl(x, y, z)>
assert square_add_prim(2., 10.) == 14.
call square_add_prim(2.0, 10.0)
  call multiply_add_prim(2.0, 2.0, 10.0)
    call multiply_add_impl(2.0, 2.0, 10.0)
    |<- multiply_add_impl = 14.0
  |<- multiply_add_prim = 14.0
|<- square_add_prim = 14.0
JIT#

If we now try to use jit we get a NotImplementedError:

with expectNotImplementedError():
  api.jit(square_add_prim)(2., 10.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)

Found expected exception:
Traceback (most recent call last):
  File "/tmp/ipykernel_4031/1813425700.py", line 2, in <module>
    api.jit(square_add_prim)(2., 10.)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py", line 304, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
NotImplementedError: Abstract evaluation for 'multiply_add' not implemented
Abstract evaluation rules#

In order to JIT the function, and for other transformations as well, JAX first evaluates it abstractly using only the shape and type of the arguments. This abstract evaluation serves multiple purposes:

  • Gets the sequence of JAX primitives that are used in the computation. This sequence will be compiled.

  • Computes the shape and type of all vectors and operations used in the computation.

For example, the abstraction of a vector with 3 elements may be ShapedArray(float32[3]), or ConcreteArray([1., 2., 3.]). In the latter case, JAX uses the actual concrete value wrapped as an abstract value.

from jax import core
@trace("multiply_add_abstract_eval")
def multiply_add_abstract_eval(xs, ys, zs):
  """Abstract evaluation of the primitive.

  This function does not need to be JAX traceable. It will be invoked with
  abstractions of the actual arguments. 
  Args:
    xs, ys, zs: abstractions of the arguments.
  Result:
    a ShapedArray for the result of the primitive.
  """
  assert xs.shape == ys.shape
  assert xs.shape == zs.shape
  return core.ShapedArray(xs.shape, xs.dtype)

# Now we register the abstract evaluation with JAX
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)
<function __main__.multiply_add_abstract_eval(xs, ys, zs)>

If we re-attempt to JIT, we see how the abstract evaluation proceeds, but we get another error, about missing the actual XLA compilation rule:

with expectNotImplementedError():
  api.jit(square_add_prim)(2., 10.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
    call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>

Found expected exception:
Traceback (most recent call last):
  File "/home/docs/.asdf/installs/python/3.10.13/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/docs/.asdf/installs/python/3.10.13/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform cpu

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/tmp/ipykernel_4031/1813425700.py", line 2, in <module>
    api.jit(square_add_prim)(2., 10.)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py", line 304, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform cpu
XLA Compilation rules#

JAX compilation works by compiling each primitive into a graph of XLA operations.

This is the biggest hurdle to adding new functionality to JAX, because the set of XLA operations is limited, and JAX already has pre-defined primitives for most of them. However, XLA includes a CustomCall operation that can be used to encapsulate arbitrary functionality defined using C++.

from jax._src.lib.mlir.dialects import hlo
@trace("multiply_add_lowering")
def multiply_add_lowering(ctx, xc, yc, zc):
  """The compilation to XLA of the primitive.

  Given an mlir.ir.Value for each argument, return the mlir.ir.Values for
  the results of the function.

  Does not need to be a JAX-traceable function.
  """
  return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result]

# Now we register the lowering rule with JAX
# For GPU see the [Custom operations for GPUs](https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html)
# TODO: TPU?
from jax.interpreters import mlir
mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu')
<function __main__.multiply_add_lowering(ctx, xc, yc, zc)>

Now we succeed to JIT. Notice below that JAX first evaluates the function abstractly, which triggers the multiply_add_abstract_eval function, and then compiles the set of primitives it has encountered, including multiply_add. At this point JAX invokes multiply_add_xla_translation.

assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
    call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7fbfd03fa890>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7fbfd0229430>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7fbfd02294b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7fbfd0229470>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7fbfd194aa40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7fbfd0461120>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x55fbb990a290>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_4031/1570919344.py":1:0) at callsite("<module>"("/tmp/ipykernel_4031/1570919344.py":1:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at "_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7fbfd0490870, file "/tmp/ipykernel_4031/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0)), (<code object func_wrapper at 0x7fc000582b80, file "/tmp/ipykernel_4031/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0)), (<code object square_add_prim at 0x7fbfd0490500, file "/tmp/ipykernel_4031/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0)), (<code object <lambda> at 0x7fbfd020a4a0, file "/tmp/ipykernel_4031/1570919344.py", line 1>, 6): loc("<lambda>"("/tmp/ipykernel_4031/1570919344.py":1:0)), (<code object <module> at 0x7fbfd0208920, file "/tmp/ipykernel_4031/1570919344.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_4031/1570919344.py":1:0)), (<code object run_code at 0x7fc00a010920, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)), (<code object run_ast_nodes at 0x7fc00a0107c0, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3418>, 500): loc("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0)), (<code object run_cell_async at 0x7fc00a010450, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3183>, 828): loc("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0)), (<code object _pseudo_sync_runner at 0x7fc009d76fa0, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 120>, 8): loc("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0))}, canonical_name_cache={'/tmp/ipykernel_4031/1308506715.py': '/tmp/ipykernel_4031/1308506715.py', '/tmp/ipykernel_4031/1393342955.py': '/tmp/ipykernel_4031/1393342955.py', '/tmp/ipykernel_4031/1570919344.py': '/tmp/ipykernel_4031/1570919344.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_4031/1308506715.py': True, '/tmp/ipykernel_4031/1393342955.py': True, '/tmp/ipykernel_4031/1570919344.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, platforms=None, global_constant_computation=False, replace_tokens_with_dummy=True)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7fbfd022c670>, tokens_out=None, axis_size_env=None, dim_var_values=[]), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7fbfd04557f0>]

Below is another use of jit where we compile only with respect to the first argument. Notice how the second argument to square_add_prim is concrete, which leads in the third argument to multiply_add_abstract_eval being ConcreteArray. We see that multiply_add_abstract_eval may be used with both ShapedArray and ConcreteArray.

assert api.jit(lambda x, y: square_add_prim(x, y), 
               static_argnums=1)(2., 10.) == 14.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 10.0)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 10.0)
    call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7fbfd03fb470>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7fbfd02311b0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7fbfd0231230>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7fbfd02311f0>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7fbfd194aa40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7fbfd022d690>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x55fbb9c975c0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_4031/4165789807.py":1:0) at callsite("<module>"("/tmp/ipykernel_4031/4165789807.py":1:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at "_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7fbfd0490870, file "/tmp/ipykernel_4031/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0)), (<code object func_wrapper at 0x7fc000582b80, file "/tmp/ipykernel_4031/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0)), (<code object square_add_prim at 0x7fbfd0490500, file "/tmp/ipykernel_4031/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0)), (<code object <lambda> at 0x7fbfd1967b50, file "/tmp/ipykernel_4031/4165789807.py", line 1>, 6): loc("<lambda>"("/tmp/ipykernel_4031/4165789807.py":1:0)), (<code object <module> at 0x7fbfd1965b00, file "/tmp/ipykernel_4031/4165789807.py", line 1>, 20): loc("<module>"("/tmp/ipykernel_4031/4165789807.py":1:0)), (<code object run_code at 0x7fc00a010920, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)), (<code object run_ast_nodes at 0x7fc00a0107c0, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3418>, 500): loc("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0)), (<code object run_cell_async at 0x7fc00a010450, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3183>, 828): loc("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0)), (<code object _pseudo_sync_runner at 0x7fc009d76fa0, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 120>, 8): loc("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0))}, canonical_name_cache={'/tmp/ipykernel_4031/1308506715.py': '/tmp/ipykernel_4031/1308506715.py', '/tmp/ipykernel_4031/1393342955.py': '/tmp/ipykernel_4031/1393342955.py', '/tmp/ipykernel_4031/4165789807.py': '/tmp/ipykernel_4031/4165789807.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_4031/1308506715.py': True, '/tmp/ipykernel_4031/1393342955.py': True, '/tmp/ipykernel_4031/4165789807.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, platforms=None, global_constant_computation=False, replace_tokens_with_dummy=True)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7fbfd022dc60>, tokens_out=None, axis_size_env=None, dim_var_values=[]), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%0 = "stablehlo.constant"() <{value = dense<1.000000e+01> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7fc00a0cd7f0>]
Forward differentiation#

JAX implements forward differentiation in the form of a Jacobian-vector product (see the JAX autodiff cookbook).

If we attempt now to compute the jvp function we get an error because we have not yet told JAX how to differentiate the multiply_add primitive.

# The second argument `(2., 10.)` are the argument values
# where we evaluate the Jacobian, and the third `(1., 1.)`
# are the values of the tangents for the arguments.
with expectNotImplementedError():
  api.jvp(square_add_prim, (2., 10.), (1., 1.))
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
  call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)

Found expected exception:
Traceback (most recent call last):
  File "/tmp/ipykernel_4031/800067577.py", line 5, in <module>
    api.jvp(square_add_prim, (2., 10.), (1., 1.))
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 1906, in jvp
    return _jvp(lu.wrap_init(fun), primals, tangents, has_aux=has_aux)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 1935, in _jvp
    out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat)
NotImplementedError: Differentiation rule for 'multiply_add' not implemented
from jax.interpreters import ad


@trace("multiply_add_value_and_jvp")
def multiply_add_value_and_jvp(arg_values, arg_tangents):
  """Evaluates the primal output and the tangents (Jacobian-vector product).

  Given values of the arguments and perturbation of the arguments (tangents), 
  compute the output of the primitive and the perturbation of the output.

  This method must be JAX-traceable. JAX may invoke it with abstract values 
  for the arguments and tangents.

  Args:
    arg_values: a tuple of arguments
    arg_tangents: a tuple with the tangents of the arguments. The tuple has 
      the same length as the arg_values. Some of the tangents may also be the 
      special value ad.Zero to specify a zero tangent.
  Returns:
     a pair of the primal output and the tangent.
  """
  x, y, z = arg_values
  xt, yt, zt = arg_tangents
  _trace("Primal evaluation:")
  # Now we have a JAX-traceable computation of the output. 
  # Normally, we can use the ma primitive itself to compute the primal output. 
  primal_out = multiply_add_prim(x, y, z)
  
  _trace("Tangent evaluation:")
  # We must use a JAX-traceable way to compute the tangent. It turns out that 
  # the output tangent can be computed as (xt * y + x * yt + zt),
  # which we can implement in a JAX-traceable way using the same "multiply_add_prim" primitive.
  
  # We do need to deal specially with Zero. Here we just turn it into a 
  # proper tensor of 0s (of the same shape as 'x'). 
  # An alternative would be to check for Zero and perform algebraic 
  # simplification of the output tangent computation.
  def make_zero(tan):
    return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan  
  
  output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))
  return (primal_out, output_tangent)

# Register the forward differentiation rule with JAX 
ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp
# Tangent is: xt*y + x*yt + zt = 1.*2. + 2.*1. + 1. = 5.
assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.)
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
  call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
    call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (1.0, 1.0, 1.0))
      Primal evaluation:
      call multiply_add_prim(2.0, 2.0, 10.0)
        call multiply_add_impl(2.0, 2.0, 10.0)
        |<- multiply_add_impl = 14.0
      |<- multiply_add_prim = 14.0
      Tangent evaluation:
      call multiply_add_prim(2.0, 1.0, 1.0)
        call multiply_add_impl(2.0, 1.0, 1.0)
        |<- multiply_add_impl = 3.0
      |<- multiply_add_prim = 3.0
      call multiply_add_prim(1.0, 2.0, 3.0)
        call multiply_add_impl(1.0, 2.0, 3.0)
        |<- multiply_add_impl = 5.0
      |<- multiply_add_prim = 5.0
    |<- multiply_add_value_and_jvp = (14.0, 5.0)
  |<- multiply_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
|<- square_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>

TO EXPLAIN:

  • Why is JAX using ConcreteArray in square_add_prim? There is no abstract evaluation going on here.

  • Not sure how to explain that multiply_add_prim is invoked with ConcreteValue, yet we do not call the multiply_add_abstract_eval.

  • I think it would be useful to show the jaxpr here

JIT of forward differentiation#

We can apply JIT to the forward differentiation function:

assert api.jit(lambda arg_values, arg_tangents: 
                   api.jvp(square_add_prim, arg_values, arg_tangents))(
         (2., 10.), (1., 1.)) == (14., 5.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
    call multiply_add_value_and_jvp((Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>), (Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>))
      Primal evaluation:
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
      Tangent evaluation:
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
    |<- multiply_add_value_and_jvp = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7fbfd027e520>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7fbfd0280230>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7fbfd0280730>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7fbfd02806f0>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7fbfd194aa40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7fbfd022f0d0>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x55fbb9de4690>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":27:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_4031/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_4031/2145028508.py":1:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7fbfd0490870, file "/tmp/ipykernel_4031/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0)), (<code object func_wrapper at 0x7fc000582b80, file "/tmp/ipykernel_4031/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7fbfd02090b0, file "/tmp/ipykernel_4031/3197095916.py", line 4>, 36): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":27:0)), (<code object square_add_prim at 0x7fbfd0490500, file "/tmp/ipykernel_4031/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0)), (<code object <lambda> at 0x7fbfd0209790, file "/tmp/ipykernel_4031/2145028508.py", line 1>, 10): loc("<lambda>"("/tmp/ipykernel_4031/2145028508.py":2:0)), (<code object <module> at 0x7fbfd0209370, file "/tmp/ipykernel_4031/2145028508.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_4031/2145028508.py":1:0))}, canonical_name_cache={'/tmp/ipykernel_4031/1308506715.py': '/tmp/ipykernel_4031/1308506715.py', '/tmp/ipykernel_4031/1393342955.py': '/tmp/ipykernel_4031/1393342955.py', '/tmp/ipykernel_4031/3197095916.py': '/tmp/ipykernel_4031/3197095916.py', '/tmp/ipykernel_4031/2145028508.py': '/tmp/ipykernel_4031/2145028508.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_4031/1308506715.py': True, '/tmp/ipykernel_4031/1393342955.py': True, '/tmp/ipykernel_4031/3197095916.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/tmp/ipykernel_4031/2145028508.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, platforms=None, global_constant_computation=False, replace_tokens_with_dummy=True)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7fbfd022eda0>, tokens_out=None, axis_size_env=None, dim_var_values=[]), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7fbfd1cce630>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7fbfd027e520>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7fbfd0280230>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7fbfd0280730>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7fbfd02806f0>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7fbfd194aa40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7fbfd022f0d0>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x55fbb9de4690>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":27:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_4031/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_4031/2145028508.py":1:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x55fbb9de48c0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_4031/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_4031/2145028508.py":1:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7fbfd0490870, file "/tmp/ipykernel_4031/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0)), (<code object func_wrapper at 0x7fc000582b80, file "/tmp/ipykernel_4031/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7fbfd02090b0, file "/tmp/ipykernel_4031/3197095916.py", line 4>, 36): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":27:0)), (<code object square_add_prim at 0x7fbfd0490500, file "/tmp/ipykernel_4031/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0)), (<code object <lambda> at 0x7fbfd0209790, file "/tmp/ipykernel_4031/2145028508.py", line 1>, 10): loc("<lambda>"("/tmp/ipykernel_4031/2145028508.py":2:0)), (<code object <module> at 0x7fbfd0209370, file "/tmp/ipykernel_4031/2145028508.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_4031/2145028508.py":1:0)), (<code object multiply_add_value_and_jvp at 0x7fbfd02090b0, file "/tmp/ipykernel_4031/3197095916.py", line 4>, 86): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":41:0))}, canonical_name_cache={'/tmp/ipykernel_4031/1308506715.py': '/tmp/ipykernel_4031/1308506715.py', '/tmp/ipykernel_4031/1393342955.py': '/tmp/ipykernel_4031/1393342955.py', '/tmp/ipykernel_4031/3197095916.py': '/tmp/ipykernel_4031/3197095916.py', '/tmp/ipykernel_4031/2145028508.py': '/tmp/ipykernel_4031/2145028508.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_4031/1308506715.py': True, '/tmp/ipykernel_4031/1393342955.py': True, '/tmp/ipykernel_4031/3197095916.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/tmp/ipykernel_4031/2145028508.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, platforms=None, global_constant_computation=False, replace_tokens_with_dummy=True)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7fbfd022dfc0>, tokens_out=None, axis_size_env=None, dim_var_values=[]), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 2), Value(<block argument> of type 'tensor<f32>' at index: 3))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7fbfd04435b0>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7fbfd027e520>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7fbfd0280230>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7fbfd0280730>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7fbfd02806f0>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7fbfd194aa40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7fbfd022f0d0>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x55fbb9de4690>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":27:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_4031/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_4031/2145028508.py":1:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x55fbb9de48c0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_4031/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_4031/2145028508.py":1:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x55fbb9ee9d20>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_4031/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_4031/2145028508.py":1:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7fbfd0490870, file "/tmp/ipykernel_4031/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0)), (<code object func_wrapper at 0x7fc000582b80, file "/tmp/ipykernel_4031/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7fbfd02090b0, file "/tmp/ipykernel_4031/3197095916.py", line 4>, 36): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":27:0)), (<code object square_add_prim at 0x7fbfd0490500, file "/tmp/ipykernel_4031/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0)), (<code object <lambda> at 0x7fbfd0209790, file "/tmp/ipykernel_4031/2145028508.py", line 1>, 10): loc("<lambda>"("/tmp/ipykernel_4031/2145028508.py":2:0)), (<code object <module> at 0x7fbfd0209370, file "/tmp/ipykernel_4031/2145028508.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_4031/2145028508.py":1:0)), (<code object multiply_add_value_and_jvp at 0x7fbfd02090b0, file "/tmp/ipykernel_4031/3197095916.py", line 4>, 86): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":41:0)), (<code object multiply_add_value_and_jvp at 0x7fbfd02090b0, file "/tmp/ipykernel_4031/3197095916.py", line 4>, 88): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":41:0))}, canonical_name_cache={'/tmp/ipykernel_4031/1308506715.py': '/tmp/ipykernel_4031/1308506715.py', '/tmp/ipykernel_4031/1393342955.py': '/tmp/ipykernel_4031/1393342955.py', '/tmp/ipykernel_4031/3197095916.py': '/tmp/ipykernel_4031/3197095916.py', '/tmp/ipykernel_4031/2145028508.py': '/tmp/ipykernel_4031/2145028508.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_4031/1308506715.py': True, '/tmp/ipykernel_4031/1393342955.py': True, '/tmp/ipykernel_4031/3197095916.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/tmp/ipykernel_4031/2145028508.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, platforms=None, global_constant_computation=False, replace_tokens_with_dummy=True)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[])], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7fbfd022ee60>, tokens_out=None, axis_size_env=None, dim_var_values=[]), Value(<block argument> of type 'tensor<f32>' at index: 2), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%3 = "stablehlo.add"(%2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7fbfd02323b0>]

Notice that first we evaluate multiply_add_value_and_jvp abstractly, which in turn evaluates abstractly both the primal and the tangent evaluation (a total of 3 invocations of the ma primitive). Then we compile the 3 occurrences of the primitive.

Reverse differentiation#

If we attempt now to use reverse differentiation we see that JAX starts by using the multiply_add_value_and_jvp to compute the forward differentiation for abstract values, but then runs into a NotImplementedError.

When computing the reverse differentiation JAX first does abstract evaluation of the forward differentiation code multiply_add_value_and_jvp to obtain a trace of primitives that compute the output tangent. Observe that JAX performs this abstract evaluation with concrete values for the differentiation point, and abstract values for the tangents. Observe also that JAX uses the special abstract tangent value Zero for the tangent corresponding to the 3rd argument of ma. This reflects the fact that we do not differentiate w.r.t. the 2nd argument to square_add_prim, which flows to the 3rd argument to multiply_add_prim.

Observe also that during the abstract evaluation of the tangent we pass the value 0.0 as the tangent for the 3rd argument. This is due to the use of the make_zero function in the definition of multiply_add_value_and_jvp.

# This is reverse differentiation w.r.t. the first argument of square_add_prim
with expectNotImplementedError():
  api.grad(square_add_prim)(2., 10.)
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
  call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
    call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
      Primal evaluation:
      call multiply_add_prim(2.0, 2.0, 10.0)
        call multiply_add_impl(2.0, 2.0, 10.0)
        |<- multiply_add_impl = 14.0
      |<- multiply_add_prim = 14.0
      Tangent evaluation:
      call multiply_add_prim(2.0, Traced<ShapedArray(float32[], weak_type=True)>, 0.0)
        call multiply_add_abstract_eval(ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[], weak_type=True), ConcreteArray(0.0, dtype=float32, weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, 2.0, Traced<ShapedArray(float32[])>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
    |<- multiply_add_value_and_jvp = (14.0, Traced<ShapedArray(float32[])>)
  |<- multiply_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
|<- square_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>

Found expected exception:
Traceback (most recent call last):
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py", line 283, in get_primitive_transpose
    return primitive_transposes[p]
KeyError: multiply_add

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/docs/.asdf/installs/python/3.10.13/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/docs/.asdf/installs/python/3.10.13/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'multiply_add' not implemented

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/tmp/ipykernel_4031/339076514.py", line 3, in <module>
    api.grad(square_add_prim)(2., 10.)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 621, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'multiply_add' not implemented

The above error is because there is a missing piece for JAX to be able to use the forward differentiation code to compute reverse differentiation.

Transposition#

As explained above, when computing reverse differentiation JAX obtains a trace of primitives that compute the tangent using forward differentiation. Then, JAX interprets this trace abstractly backwards and for each primitive it applies a transposition rule.

To understand what is going on, consider for now a simpler example of the function “f(x, y) = x * y + y”. Assume we need to differentiate at the point (2., 4.). JAX will produce the following JVP tangent calculation of ft from the tangents of the input xt and yt:

   a = xt * 4.
   b = 2. * yt
   c = a + b
   ft = c + yt

By construction, the tangent calculation is always linear in the input tangents. The only non-linear operator that may arise in the tangent calculation is multiplication, but then one of the operands is constant.

JAX will produce the reverse differentiation computation by processing the JVP computation backwards. For each operation in the tangent computation, it accumulates the cotangents of the variables used by the operation, using the cotangent of the result of the operation:

  # Initialize cotangents of inputs and intermediate vars
  xct = yct = act = bct = cct = 0.
  # Initialize cotangent of the output
  fct = 1.
  # Process "ft = c + yt"
  cct += fct
  yct += fct
  # Process "c = a + b"
  act += cct
  bct += cct
  # Process "b = 2. * yt"
  yct += 2. * bct
  # Process "a = xt * 4."
  xct += act * 4.

One can verify that this computation produces xct = 4. and yct = 3., which are the partial derivatives of the function f.

JAX knows for each primitive that may appear in a JVP calculation how to transpose it. Conceptually, if the primitive p(x, y, z) is linear in the arguments y and z for a constant value of x, e.g., p(x, y, z) = y*cy + z*cz, then the transposition of the primitive is:

p_transpose(out_ct, x, _, _) = (None, out_ct*cy, out_ct*cz)

Notice that p_transpose takes the cotangent of the output of the primitive and a value corresponding to each argument of the primitive. For the linear arguments, the transposition gets an undefined _ value, and for the other arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value None returned for the constant arguments.

In particular,

 add_transpose(out_ct, _, _) = (out_ct, out_ct)
 mult_transpose(out_ct, x, _) = (None, x * out_ct)
 mult_transpose(out_ct, _, y) = (out_ct * y, None)
@trace("multiply_add_transpose")
def multiply_add_transpose(ct, x, y, z):
  """Evaluates the transpose of a linear primitive.

  This method is only used when computing the backward gradient following 
  value_and_jvp, and is only needed for primitives that are used in the JVP 
  calculation for some other primitive. We need transposition for multiply_add_prim, 
  because we have used multiply_add_prim in the computation of the output_tangent in 
  multiply_add_value_and_jvp.

  In our case, multiply_add is not a linear primitive. However, it is used linearly 
  w.r.t. tangents in multiply_add_value_and_jvp:
       output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))
  
  Always one of the first two multiplicative arguments is a constant.

  Args:
      ct: the cotangent of the output of the primitive.
      x, y, z: values of the arguments. The arguments that are used linearly
        get an ad.UndefinedPrimal value. The other arguments get a constant
        value.
  Returns:
      a tuple with the cotangent of the inputs, with the value None
      corresponding to the constant arguments.
  """
  if not ad.is_undefined_primal(x):
    # This use of multiply_add is with a constant "x"
    assert ad.is_undefined_primal(y)
    ct_y = ad.Zero(y.aval) if type(ct) is ad.Zero else multiply_add_prim(x, ct, lax.zeros_like_array(x))
    res = None, ct_y, ct
  else:
    # This use of multiply_add is with a constant "y"
    assert ad.is_undefined_primal(x)
    ct_x = ad.Zero(x.aval) if type(ct) is ad.Zero else multiply_add_prim(ct, y, lax.zeros_like_array(y))
    res = ct_x, None, ct
  return res


ad.primitive_transposes[multiply_add_p] = multiply_add_transpose

Now we can complete the run of the grad:

assert api.grad(square_add_prim)(2., 10.) == 4.
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
  call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
    call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
      Primal evaluation:
      call multiply_add_prim(2.0, 2.0, 10.0)
        call multiply_add_impl(2.0, 2.0, 10.0)
        |<- multiply_add_impl = 14.0
      |<- multiply_add_prim = 14.0
      Tangent evaluation:
      call multiply_add_prim(2.0, Traced<ShapedArray(float32[], weak_type=True)>, 0.0)
        call multiply_add_abstract_eval(ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[], weak_type=True), ConcreteArray(0.0, dtype=float32, weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, 2.0, Traced<ShapedArray(float32[])>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
    |<- multiply_add_value_and_jvp = (14.0, Traced<ShapedArray(float32[])>)
  |<- multiply_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
|<- square_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
call multiply_add_transpose(1.0, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), 2.0, UndefinedPrimal(ShapedArray(float32[])))
  call multiply_add_prim(1.0, 2.0, 0.0)
    call multiply_add_impl(1.0, 2.0, 0.0)
    |<- multiply_add_impl = 2.0
  |<- multiply_add_prim = 2.0
|<- multiply_add_transpose = (2.0, None, 1.0)
call multiply_add_transpose(1.0, 2.0, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), 0.0)
  call multiply_add_prim(2.0, 1.0, 0.0)
    call multiply_add_impl(2.0, 1.0, 0.0)
    |<- multiply_add_impl = 2.0
  |<- multiply_add_prim = 2.0
|<- multiply_add_transpose = (None, 2.0, 1.0)

Notice the two calls to multiply_add_transpose. They correspond to the two uses of multiply_add_prim in the computation of the output_tangent in multiply_add_value_and_jvp. The first call to transpose corresponds to the last use of multiply_add_prim: multiply_add_prim(xt, y, ...) where y is the constant 2.0.

JIT of reverse differentiation#

Notice that the abstract evaluation of the multiply_add_value_and_jvp is using only abstract values, while in the absence of JIT we used ConcreteArray.

assert api.jit(api.grad(square_add_prim))(2., 10.) == 4.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
    call multiply_add_value_and_jvp((Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
      Primal evaluation:
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
      Tangent evaluation:
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
    |<- multiply_add_value_and_jvp = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_transpose(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, UndefinedPrimal(ShapedArray(float32[])))
  call multiply_add_prim(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
    call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- multiply_add_transpose = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, None, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_transpose(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
    call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[]), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- multiply_add_transpose = (None, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7fbfd027e250>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7fbfd02812b0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7fbfd02813b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7fbfd0282b30>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7fbfd194aa40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7fbfd022fd30>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x55fbb9ee9d20>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("<module>"("/tmp/ipykernel_4031/3085343041.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7fbfd0490870, file "/tmp/ipykernel_4031/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0)), (<code object func_wrapper at 0x7fc000582b80, file "/tmp/ipykernel_4031/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7fbfd02090b0, file "/tmp/ipykernel_4031/3197095916.py", line 4>, 88): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":41:0)), (<code object square_add_prim at 0x7fbfd0490500, file "/tmp/ipykernel_4031/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0)), (<code object <module> at 0x7fbfd02199a0, file "/tmp/ipykernel_4031/3085343041.py", line 1>, 18): loc("<module>"("/tmp/ipykernel_4031/3085343041.py":1:0)), (<code object run_code at 0x7fc00a010920, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0))}, canonical_name_cache={'/tmp/ipykernel_4031/1308506715.py': '/tmp/ipykernel_4031/1308506715.py', '/tmp/ipykernel_4031/1393342955.py': '/tmp/ipykernel_4031/1393342955.py', '/tmp/ipykernel_4031/3197095916.py': '/tmp/ipykernel_4031/3197095916.py', '/tmp/ipykernel_4031/3085343041.py': '/tmp/ipykernel_4031/3085343041.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_4031/1308506715.py': True, '/tmp/ipykernel_4031/1393342955.py': True, '/tmp/ipykernel_4031/3197095916.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/tmp/ipykernel_4031/3085343041.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, platforms=None, global_constant_computation=False, replace_tokens_with_dummy=True)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='transpose'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7fbfd0460ee0>, tokens_out=None, axis_size_env=None, dim_var_values=[]), Value(%0 = "stablehlo.constant"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%1 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7fbfd02812f0>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7fbfd027e250>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7fbfd02812b0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7fbfd02813b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7fbfd0282b30>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7fbfd194aa40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7fbfd022fd30>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x55fbb9ee9d20>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("<module>"("/tmp/ipykernel_4031/3085343041.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x55fbb9e7b400>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("<module>"("/tmp/ipykernel_4031/3085343041.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7fbfd0490870, file "/tmp/ipykernel_4031/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0)), (<code object func_wrapper at 0x7fc000582b80, file "/tmp/ipykernel_4031/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7fbfd02090b0, file "/tmp/ipykernel_4031/3197095916.py", line 4>, 88): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":41:0)), (<code object square_add_prim at 0x7fbfd0490500, file "/tmp/ipykernel_4031/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0)), (<code object <module> at 0x7fbfd02199a0, file "/tmp/ipykernel_4031/3085343041.py", line 1>, 18): loc("<module>"("/tmp/ipykernel_4031/3085343041.py":1:0)), (<code object run_code at 0x7fc00a010920, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)), (<code object multiply_add_value_and_jvp at 0x7fbfd02090b0, file "/tmp/ipykernel_4031/3197095916.py", line 4>, 86): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_4031/3197095916.py":41:0))}, canonical_name_cache={'/tmp/ipykernel_4031/1308506715.py': '/tmp/ipykernel_4031/1308506715.py', '/tmp/ipykernel_4031/1393342955.py': '/tmp/ipykernel_4031/1393342955.py', '/tmp/ipykernel_4031/3197095916.py': '/tmp/ipykernel_4031/3197095916.py', '/tmp/ipykernel_4031/3085343041.py': '/tmp/ipykernel_4031/3085343041.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_4031/1308506715.py': True, '/tmp/ipykernel_4031/1393342955.py': True, '/tmp/ipykernel_4031/3197095916.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/tmp/ipykernel_4031/3085343041.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, platforms=None, global_constant_computation=False, replace_tokens_with_dummy=True)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='transpose'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[]), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7fbfd02bc580>, tokens_out=None, axis_size_env=None, dim_var_values=[]), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%4 = "stablehlo.constant"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>), Value(%5 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7fbfd2119c30>]
Batching#

The batching transformation takes a point-wise computation and turns it into a computation on vectors. If we try it right now, we get a NotImplementedError:

# The arguments are two vectors instead of two scalars
with expectNotImplementedError():
  api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),
                                               np.array([10., 20.]))
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
  call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)

Found expected exception:
Traceback (most recent call last):
  File "/tmp/ipykernel_4031/2641678767.py", line 3, in <module>
    api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 1214, in vmap_f
    out_flat = batching.batch(
NotImplementedError: Batching rule for 'multiply_add' not implemented

We need to tell JAX how to evaluate the batched version of the primitive. In this particular case, the multiply_add_prim already operates pointwise for any dimension of input vectors. So the batched version can use the same multiply_add_prim implementation.

from jax.interpreters import batching


@trace("multiply_add_batch")
def multiply_add_batch(vector_arg_values, batch_axes):
  """Computes the batched version of the primitive.
  
  This must be a JAX-traceable function.
  
  Since the multiply_add primitive already operates pointwise on arbitrary
  dimension tensors, to batch it we can use the primitive itself. This works as
  long as both the inputs have the same dimensions and are batched along the
  same axes. The result is batched along the axis that the inputs are batched.
  
  Args:
    vector_arg_values: a tuple of two arguments, each being a tensor of matching
      shape.
    batch_axes: the axes that are being batched. See vmap documentation.
  Returns:
    a tuple of the result, and the result axis that was batched. 
  """
  assert batch_axes[0] == batch_axes[1]
  assert batch_axes[0] == batch_axes[2]
  _trace("Using multiply_add to compute the batch:")
  res = multiply_add_prim(*vector_arg_values)
  return res, batch_axes[0]


batching.primitive_batchers[multiply_add_p] = multiply_add_batch
assert np.allclose(api.vmap(square_add_prim, in_axes=0, out_axes=0)(
  np.array([2., 3.]),
  np.array([10., 20.])),
  [14., 29.])
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
  call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
    call multiply_add_batch(([2. 3.], [2. 3.], [10. 20.]), (0, 0, 0))
      Using multiply_add to compute the batch:
      call multiply_add_prim([2. 3.], [2. 3.], [10. 20.])
        call multiply_add_impl([2. 3.], [2. 3.], [10. 20.])
        |<- multiply_add_impl = [14. 29.]
      |<- multiply_add_prim = [14. 29.]
    |<- multiply_add_batch = ([14. 29.], 0)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
JIT of batching#
assert np.allclose(api.jit(api.vmap(square_add_prim, in_axes=0, out_axes=0))
                    (np.array([2., 3.]),
                     np.array([10., 20.])),
                    [14., 29.])
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
  call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
    call multiply_add_batch((Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>), (0, 0, 0))
      Using multiply_add to compute the batch:
      call multiply_add_prim(Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>)
        call multiply_add_abstract_eval(ShapedArray(float32[2]), ShapedArray(float32[2]), ShapedArray(float32[2]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[2])
      |<- multiply_add_prim = Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>
    |<- multiply_add_batch = (Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, 0)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7fbfd02a61b0>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7fbfd0477f70>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7fbfd02c00b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7fbfd02c0770>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7fbfd194aa40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7fbfd02bc3a0>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x55fbb9ceef90>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_batch"("/tmp/ipykernel_4031/184469370.py":25:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0) at callsite("<module>"("/tmp/ipykernel_4031/1392464762.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7fbfd0490870, file "/tmp/ipykernel_4031/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_4031/1308506715.py":11:0)), (<code object func_wrapper at 0x7fc000582b80, file "/tmp/ipykernel_4031/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_4031/1393342955.py":48:0)), (<code object multiply_add_batch at 0x7fbfd021b9f0, file "/tmp/ipykernel_4031/184469370.py", line 4>, 52): loc("multiply_add_batch"("/tmp/ipykernel_4031/184469370.py":25:0)), (<code object square_add_prim at 0x7fbfd0490500, file "/tmp/ipykernel_4031/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_4031/1308506715.py":16:0)), (<code object <module> at 0x7fbfd02192c0, file "/tmp/ipykernel_4031/1392464762.py", line 1>, 48): loc("<module>"("/tmp/ipykernel_4031/1392464762.py":1:0)), (<code object run_code at 0x7fc00a010920, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0))}, canonical_name_cache={'/tmp/ipykernel_4031/1308506715.py': '/tmp/ipykernel_4031/1308506715.py', '/tmp/ipykernel_4031/1393342955.py': '/tmp/ipykernel_4031/1393342955.py', '/tmp/ipykernel_4031/184469370.py': '/tmp/ipykernel_4031/184469370.py', '/tmp/ipykernel_4031/1392464762.py': '/tmp/ipykernel_4031/1392464762.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_4031/1308506715.py': True, '/tmp/ipykernel_4031/1393342955.py': True, '/tmp/ipykernel_4031/184469370.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/batching.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/tmp/ipykernel_4031/1392464762.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, platforms=None, global_constant_computation=False, replace_tokens_with_dummy=True)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='vmap'))), primitive=multiply_add, avals_in=[ShapedArray(float32[2]), ShapedArray(float32[2]), ShapedArray(float32[2])], avals_out=[ShapedArray(float32[2])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7fbfd02bd240>, tokens_out=None, axis_size_env=None, dim_var_values=[]), Value(<block argument> of type 'tensor<2xf32>' at index: 0), Value(<block argument> of type 'tensor<2xf32>' at index: 0), Value(<block argument> of type 'tensor<2xf32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7fbfd02970b0>]

Writing custom Jaxpr interpreters in JAX#

Open in Colab Open in Kaggle

JAX offers several composable function transformations (jit, grad, vmap, etc.) that enable writing concise, accelerated code.

Here we show how to add your own function transformations to the system, by writing a custom Jaxpr interpreter. And we’ll get composability with all the other transformations for free.

This example uses internal JAX APIs, which may break at any time. Anything not in the API Documentation should be assumed internal.

import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
from jax import random

What is JAX doing?#

JAX provides a NumPy-like API for numerical computing which can be used as is, but JAX’s true power comes from composable function transformations. Take the jit function transformation, which takes in a function and returns a semantically identical function but is lazily compiled by XLA for accelerators.

x = random.normal(random.key(0), (5000, 5000))
def f(w, b, x):
  return jnp.tanh(jnp.dot(x, w) + b)
fast_f = jit(f)

When we call fast_f, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax’s tracing machinery, you can refer to the “How it works” section in the README.

Jaxpr tracer#

A tracer of special importance in Jax is the Jaxpr tracer, which records ops into a Jaxpr (Jax expression). A Jaxpr is a data structure that can be evaluated like a mini functional programming language and thus Jaxprs are a useful intermediate representation for function transformation.

To get a first look at Jaxprs, consider the make_jaxpr transformation. make_jaxpr is essentially a “pretty-printing” transformation: it transforms a function into one that, given example arguments, produces a Jaxpr representation of its computation. make_jaxpr is useful for debugging and introspection. Let’s use it to look at how some example Jaxprs are structured.

def examine_jaxpr(closed_jaxpr):
  jaxpr = closed_jaxpr.jaxpr
  print("invars:", jaxpr.invars)
  print("outvars:", jaxpr.outvars)
  print("constvars:", jaxpr.constvars)
  for eqn in jaxpr.eqns:
    print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
  print()
  print("jaxpr:", jaxpr)

def foo(x):
  return x + 1
print("foo")
print("=====")
examine_jaxpr(jax.make_jaxpr(foo)(5))

print()

def bar(w, b, x):
  return jnp.dot(w, x) + b + jnp.ones(5), x
print("bar")
print("=====")
examine_jaxpr(jax.make_jaxpr(bar)(jnp.ones((5, 10)), jnp.ones(5), jnp.ones(10)))
foo
=====
invars: [Var(id=140015952089984):int32[]]
outvars: [Var(id=140015952089792):int32[]]
constvars: []
equation: [Var(id=140015952089984):int32[], 1] add [Var(id=140015952089792):int32[]] {}

jaxpr: { lambda ; a:i32[]. let b:i32[] = add a 1 in (b,) }

bar
=====
invars: [Var(id=140015952459328):float32[5,10], Var(id=140015952459392):float32[5], Var(id=140015952459456):float32[10]]
outvars: [Var(id=140015952459712):float32[5], Var(id=140015952459456):float32[10]]
constvars: []
equation: [Var(id=140015952459328):float32[5,10], Var(id=140015952459456):float32[10]] dot_general [Var(id=140015952459520):float32[5]] {'dimension_numbers': (((1,), (0,)), ((), ())), 'precision': None, 'preferred_element_type': dtype('float32')}
equation: [Var(id=140015952459520):float32[5], Var(id=140015952459392):float32[5]] add [Var(id=140015952459584):float32[5]] {}
equation: [1.0] broadcast_in_dim [Var(id=140015952459648):float32[5]] {'shape': (5,), 'broadcast_dimensions': ()}
equation: [Var(id=140015952459584):float32[5], Var(id=140015952459648):float32[5]] add [Var(id=140015952459712):float32[5]] {}

jaxpr: { lambda ; a:f32[5,10] b:f32[5] c:f32[10]. let
    d:f32[5] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] a c
    e:f32[5] = add d b
    f:f32[5] = broadcast_in_dim[broadcast_dimensions=() shape=(5,)] 1.0
    g:f32[5] = add e f
  in (g, c) }
  • jaxpr.invars - the invars of a Jaxpr are a list of the input variables to Jaxpr, analogous to arguments in Python functions.

  • jaxpr.outvars - the outvars of a Jaxpr are the variables that are returned by the Jaxpr. Every Jaxpr has multiple outputs.

  • jaxpr.constvars - the constvars are a list of variables that are also inputs to the Jaxpr, but correspond to constants from the trace (we’ll go over these in more detail later).

  • jaxpr.eqns - a list of equations, which are essentially let-bindings. Each equation is a list of input variables, a list of output variables, and a primitive, which is used to evaluate inputs to produce outputs. Each equation also has a params, a dictionary of parameters.

Altogether, a Jaxpr encapsulates a simple program that can be evaluated with inputs to produce an output. We’ll go over how exactly to do this later. The important thing to note now is that a Jaxpr is a data structure that can be manipulated and evaluated in whatever way we want.

Why are Jaxprs useful?#

Jaxprs are simple program representations that are easy to transform. And because Jax lets us stage out Jaxprs from Python functions, it gives us a way to transform numerical programs written in Python.

Your first interpreter: invert#

Let’s try to implement a simple function “inverter”, which takes in the output of the original function and returns the inputs that produced those outputs. For now, let’s focus on simple, unary functions which are composed of other invertible unary functions.

Goal:

def f(x):
  return jnp.exp(jnp.tanh(x))
f_inv = inverse(f)
assert jnp.allclose(f_inv(f(1.0)), 1.0)

The way we’ll implement this is by (1) tracing f into a Jaxpr, then (2) interpreting the Jaxpr backwards. While interpreting the Jaxpr backwards, for each equation we’ll look up the primitive’s inverse in a table and apply it.

1. Tracing a function#

Let’s use make_jaxpr to trace a function into a Jaxpr.

# Importing Jax functions useful for tracing/interpreting.
import numpy as np
from functools import wraps

from jax import core
from jax import lax
from jax._src.util import safe_map

jax.make_jaxpr returns a closed Jaxpr, which is a Jaxpr that has been bundled with the constants (literals) from the trace.

def f(x):
  return jnp.exp(jnp.tanh(x))

closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
print(closed_jaxpr.jaxpr)
print(closed_jaxpr.literals)
{ lambda ; a:f32[5]. let b:f32[5] = tanh a; c:f32[5] = exp b in (c,) }
[]
2. Evaluating a Jaxpr#

Before we write a custom Jaxpr interpreter, let’s first implement the “default” interpreter, eval_jaxpr, which evaluates the Jaxpr as-is, computing the same values that the original, un-transformed Python function would.

To do this, we first create an environment to store the values for each of the variables, and update the environment with each equation we evaluate in the Jaxpr.

def eval_jaxpr(jaxpr, consts, *args):
  # Mapping from variable -> value
  env = {}
  
  def read(var):
    # Literals are values baked into the Jaxpr
    if type(var) is core.Literal:
      return var.val
    return env[var]

  def write(var, val):
    env[var] = val

  # Bind args and consts to environment
  safe_map(write, jaxpr.invars, args)
  safe_map(write, jaxpr.constvars, consts)

  # Loop through equations and evaluate primitives using `bind`
  for eqn in jaxpr.eqns:
    # Read inputs to equation from environment
    invals = safe_map(read, eqn.invars)  
    # `bind` is how a primitive is called
    outvals = eqn.primitive.bind(*invals, **eqn.params)
    # Primitives may return multiple outputs or not
    if not eqn.primitive.multiple_results: 
      outvals = [outvals]
    # Write the results of the primitive into the environment
    safe_map(write, eqn.outvars, outvals) 
  # Read the final result of the Jaxpr from the environment
  return safe_map(read, jaxpr.outvars) 
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, jnp.ones(5))
[Array([2.1416876, 2.1416876, 2.1416876, 2.1416876, 2.1416876], dtype=float32)]

Notice that eval_jaxpr will always return a flat list even if the original function does not.

Furthermore, this interpreter does not handle higher-order primitives (like jit and pmap), which we will not cover in this guide. You can refer to core.eval_jaxpr (link) to see the edge cases that this interpreter does not cover.

Custom inverse Jaxpr interpreter#

An inverse interpreter doesn’t look too different from eval_jaxpr. We’ll first set up the registry which will map primitives to their inverses. We’ll then write a custom interpreter that looks up primitives in the registry.

It turns out that this interpreter will also look similar to the “transpose” interpreter used in reverse-mode autodifferentiation found here.

inverse_registry = {}

We’ll now register inverses for some of the primitives. By convention, primitives in Jax end in _p and a lot of the popular ones live in lax.

inverse_registry[lax.exp_p] = jnp.log
inverse_registry[lax.tanh_p] = jnp.arctanh

inverse will first trace the function, then custom-interpret the Jaxpr. Let’s set up a simple skeleton.

def inverse(fun):
  @wraps(fun)
  def wrapped(*args, **kwargs):
    # Since we assume unary functions, we won't worry about flattening and
    # unflattening arguments.
    closed_jaxpr = jax.make_jaxpr(fun)(*args, **kwargs)
    out = inverse_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
    return out[0]
  return wrapped

Now we just need to define inverse_jaxpr, which will walk through the Jaxpr backward and invert primitives when it can.

def inverse_jaxpr(jaxpr, consts, *args):
  env = {}
  
  def read(var):
    if type(var) is core.Literal:
      return var.val
    return env[var]

  def write(var, val):
    env[var] = val
  # Args now correspond to Jaxpr outvars
  safe_map(write, jaxpr.outvars, args)
  safe_map(write, jaxpr.constvars, consts)

  # Looping backward
  for eqn in jaxpr.eqns[::-1]:
    #  outvars are now invars 
    invals = safe_map(read, eqn.outvars)
    if eqn.primitive not in inverse_registry:
      raise NotImplementedError(
          f"{eqn.primitive} does not have registered inverse.")
    # Assuming a unary function 
    outval = inverse_registry[eqn.primitive](*invals)
    safe_map(write, eqn.invars, [outval])
  return safe_map(read, jaxpr.invars)

That’s it!

def f(x):
  return jnp.exp(jnp.tanh(x))

f_inv = inverse(f)
assert jnp.allclose(f_inv(f(1.0)), 1.0)

Importantly, you can trace through a Jaxpr interpreter.

jax.make_jaxpr(inverse(f))(f(1.))
{ lambda ; a:f32[]. let b:f32[] = log a; c:f32[] = atanh b in (c,) }

That’s all it takes to add a new transformation to a system, and you get composition with all the others for free! For example, we can use jit, vmap, and grad with inverse!

jit(vmap(grad(inverse(f))))((jnp.arange(5) + 1.) / 5.)
Array([-3.1440797, 15.584931 ,  2.2551253,  1.3155028,  1.       ],      dtype=float32, weak_type=True)

Exercises for the reader#

  • Handle primitives with multiple arguments where inputs are partially known, for example lax.add_p, lax.mul_p.

  • Handle xla_call and xla_pmap primitives, which will not work with both eval_jaxpr and inverse_jaxpr as written.

Custom operations for GPUs with C++ and CUDA#

JAX ships with a large number of built-in operations, but users occasionally run into a situation where they need a new operation that is not supported by JAX.

To accommodate such scenarios, JAX allows users to define custom operations and this tutorial is to explain how we can define one for GPUs and use it in single-GPU and multi-GPU environments.

This tutorial contains information from Extending JAX with custom C++ and CUDA code and supposes that you are familiar with JAX primitive.

RMS normalization#

For this tutorial, we are going to add the RMS normalization as a custom operation in JAX. Note that the RMS normalization can be expressed with jax.numpy directly. However, we are using it as an example to show the process of creating a custom operation for GPUs. The CUDA code in gpu_ops/rms_norm_kernels.cu for this operation has been borrowed from Apex and adapted to eliminate any dependency on PyTorch.

High-level steps#

This tutorial shows how to write both a custom operation and its gradient.

In C: You need to follow these steps in C for each new JAX primitive:

  • Have CUDA kernel(s).

  • Create a C function that dispatches the CUDA kernel that will be called by XLA.

  • Create a descriptor to convey information needed for the computation.

    • The types, the shapes and other attributes.

  • Bind C functions to Python

    • To create the descriptor and to call the primitive during execution.

In Python: You need to follow these steps in Python:

  • Define a new JAX primitive (instruction/operation)

  • Write Python functions to build the graph nodes with the primitive.

  • Define its abstract evaluation.

  • Define its lowering to MLIR.

  • [Optional] Define the gradient.

  • [Optional] Use custom_partitioning or shard_map functions for fast multi-GPU.

C code#

See gpu_ops code listing for a complete code listing of C++ and CUDA files. gpu_ops/rms_norm_kernels.cu defines the following functions, which are declared with the XLA custom function signature. These functions are responsible for launching RMS normalization kernels with the given buffers on the specified stream.

namespace gpu_ops {
    
void rms_forward_affine_mixed_dtypes(cudaStream_t stream, void **buffers,
                                     const char *opaque,
                                     std::size_t opaque_len);

void rms_backward_affine(cudaStream_t stream, void **buffers,
                         const char *opaque, std::size_t opaque_len);

} // namespace gpu_ops
  • stream is the CUDA stream to be used to execute any kernel on the GPU.

  • buffers has all pointers to input buffers followed by all pointers to output buffers.

  • opaque is a buffer for any extra information that is being passed to the custom functions and opaque_len is the length of opaque.

For this tutorial, an RMSNormDescriptor object will be passed to these functions as opaque.

namespace gpu_ops {

enum ElementType { BF16, F16, F32, F64 };

struct RMSNormDescriptor {
  int n1;
  int n2;
  double eps;
  ElementType x_type;
  ElementType w_type;
  int part_grad_size;
};

} // namespace gpu_ops

Now, we need to expose these functions as well as ElementType and RMSNormDescriptor as a Python module, gpu_ops, through pybind11.

pybind11::dict RMSNormRegistrations() {
  pybind11::dict dict;
  dict["rms_forward_affine_mixed_dtype"] =
      gpu_ops::EncapsulateFunction(gpu_ops::rms_forward_affine_mixed_dtypes);
  dict["rms_backward_affine"] =
      gpu_ops::EncapsulateFunction(gpu_ops::rms_backward_affine);
  return dict;
}

PYBIND11_MODULE(gpu_ops, m) {
  m.def("get_rms_norm_registrations", &RMSNormRegistrations);
  m.def("create_rms_norm_descriptor",
        [](int n1, int n2, double eps, gpu_ops::ElementType x_type,
           gpu_ops::ElementType w_type, int part_grad_size) {
          return gpu_ops::PackDescriptor(gpu_ops::RMSNormDescriptor{
              n1, n2, eps, x_type, w_type, part_grad_size});
        });

  pybind11::enum_<gpu_ops::ElementType>(m, "ElementType")
      .value("BF16", gpu_ops::ElementType::BF16)
      .value("F16", gpu_ops::ElementType::F16)
      .value("F32", gpu_ops::ElementType::F32)
      .value("F64", gpu_ops::ElementType::F64);

}

Build gpu_ops extension module#

We build the gpu_ops Python extension module with the aforementioned code. (See gpu_ops code listing for a complete code listing of C++ and CUDA files.)

python -m pip install pybind11==2.10.1
mkdir -p build
pybind_include_path=$(python -c "import pybind11; print(pybind11.get_include())")
python_executable=$(python -c 'import sys; print(sys.executable)')


nvcc --threads 4 -Xcompiler -Wall -ldl --expt-relaxed-constexpr -O3 -DNDEBUG -Xcompiler -O3 --generate-code=arch=compute_70,code=[compute_70,sm_70] --generate-code=arch=compute_75,code=[compute_75,sm_75] --generate-code=arch=compute_80,code=[compute_80,sm_80] --generate-code=arch=compute_86,code=[compute_86,sm_86] -Xcompiler=-fPIC -Xcompiler=-fvisibility=hidden -x cu -c gpu_ops/rms_norm_kernels.cu -o build/rms_norm_kernels.cu.o
c++ -I/usr/local/cuda/include -I$pybind_include_path $(${python_executable}-config --cflags) -O3 -DNDEBUG -O3 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o build/gpu_ops.cpp.o -c gpu_ops/gpu_ops.cpp
c++ -fPIC -O3 -DNDEBUG -O3 -flto -shared  -o build/gpu_ops$(${python_executable}-config --extension-suffix) build/gpu_ops.cpp.o build/rms_norm_kernels.cu.o -L/usr/local/cuda/lib64  -lcudadevrt -lcudart_static -lrt -lpthread -ldl
strip build/gpu_ops$(${python_executable}-config --extension-suffix)

Add RMS normalization to JAX as custom call#

gpu_ops is just a Python extension module and we need more work to plug it into JAX.

Create primitives#

We first create primitives, _rms_norm_fwd_p and _rms_norm_bwd_p, which the custom functions can be mapped to. We set the multiple_results attribute to True for these operations, which means that the operation produces multiple outputs as a tuple. When it is set to False, the operation produces a single output without a tuple. For more details, see How JAX primitives work.

from functools import partial

import jax
import jax.numpy as jnp
import jax._src.test_util as jtu
from build import gpu_ops
from jax import core, dtypes
from jax.interpreters import xla
from jax.lib import xla_client


# Create _rms_norm_fwd_p for forward operation.
_rms_norm_fwd_p = core.Primitive("rms_norm_fwd")
_rms_norm_fwd_p.multiple_results = True
_rms_norm_fwd_p.def_impl(partial(xla.apply_primitive, _rms_norm_fwd_p))


def rms_norm_fwd(x, weight, eps=1e-05):
    output, invvar = _rms_norm_fwd_p.bind(x, weight, eps=eps)
    return output


# Create _rms_norm_bwd_p for backward operation.
_rms_norm_bwd_p = core.Primitive("rms_norm_bwd")
_rms_norm_bwd_p.multiple_results = True
_rms_norm_bwd_p.def_impl(partial(xla.apply_primitive, _rms_norm_bwd_p))


def rms_norm_bwd(g, invvar, x, weight, eps):
    grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind(
        g, invvar, x, weight, eps=eps
    )
    return grad_input, grad_weight
Lowering to MLIR custom call#

To map the custom functions to the new primitives, _rms_norm_fwd_p and _rms_norm_bwd_p, we need to:

  • Register custom functions as custom call targets with xla_client.register_custom_call_target, and

  • Register lowering functions that lower the primitives to MLIR custom calls with the registered custom call targets.

The functions _rms_norm_fwd_cuda_lowering and _rms_norm_bwd_cuda_lowering below lower the primitives to MLIR custom call operations with the custom targets from gpu_ops. These functions are registered with jax.interpreters.mlir.register_lowering.

Note that an RMSNormDescriptor object is created in the lowering function, and passed to the custom call as opaque.

from functools import reduce

from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jaxlib.hlo_helpers import custom_call


# Register functions defined in gpu_ops as custom call target for GPUs
for _name, _value in gpu_ops.get_rms_norm_registrations().items():
    xla_client.register_custom_call_target(_name, _value, platform="gpu")


def element_type_to_descriptor_type_mapping(element_type):
    _element_type_to_descriptor_type_mapping = {
        ir.BF16Type.get(): gpu_ops.ElementType.BF16,
        ir.F16Type.get(): gpu_ops.ElementType.F16,
        ir.F32Type.get(): gpu_ops.ElementType.F32,
        ir.F64Type.get(): gpu_ops.ElementType.F64,
    }
    return _element_type_to_descriptor_type_mapping.get(element_type)


def default_layouts(*shapes):
    return [range(len(shape) - 1, -1, -1) for shape in shapes]


def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps):
    x_type = ir.RankedTensorType(x.type)
    x_shape = x_type.shape
    w_type = ir.RankedTensorType(weight.type)
    w_shape = w_type.shape
    iv_element_type = (
        ir.F32Type.get()
        if x_type.element_type in [ir.F16Type.get(), ir.BF16Type.get()]
        else x_type.element_type
    )

    n2 = reduce(lambda x, y: x * y, w_shape)
    n1 = reduce(lambda x, y: x * y, x_shape) // n2

    opaque = gpu_ops.create_rms_norm_descriptor(
        n1,
        n2,
        eps,
        element_type_to_descriptor_type_mapping(x_type.element_type),
        element_type_to_descriptor_type_mapping(w_type.element_type),
        0,  # unused
    )
    out = custom_call(
        b"rms_forward_affine_mixed_dtype",
        result_types=[
            ir.RankedTensorType.get(x_shape, w_type.element_type),
            ir.RankedTensorType.get((n1,), iv_element_type),
        ],
        operands=[x, weight],
        backend_config=opaque,
        operand_layouts=default_layouts(x_shape, w_shape),
        result_layouts=default_layouts(x_shape, (n1,)),
    ).results
    return out


mlir.register_lowering(
    _rms_norm_fwd_p,
    _rms_norm_fwd_cuda_lowering,
    platform="gpu",
)


def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps):
    x_type = ir.RankedTensorType(x.type)
    x_shape = x_type.shape
    w_type = ir.RankedTensorType(weight.type)
    w_shape = w_type.shape
    iv_type = ir.RankedTensorType(invvar.type)

    n2 = reduce(lambda x, y: x * y, w_shape)
    n1 = reduce(lambda x, y: x * y, x_shape) // n2

    part_grad_shape = ctx.avals_out[-1].shape

    opaque = gpu_ops.create_rms_norm_descriptor(
        n1,
        n2,
        eps,
        element_type_to_descriptor_type_mapping(x_type.element_type),
        element_type_to_descriptor_type_mapping(w_type.element_type),
        part_grad_shape[0],
    )
    out = custom_call(
        b"rms_backward_affine",
        result_types=[
            ir.RankedTensorType.get(x_shape, x_type.element_type),
            ir.RankedTensorType.get(w_shape, w_type.element_type),
            ir.RankedTensorType.get(part_grad_shape, iv_type.element_type),
        ],
        operands=[grad_output, invvar, x, weight],
        backend_config=opaque,
        operand_layouts=default_layouts(x_shape, (n1,), x_shape, w_shape),
        result_layouts=default_layouts(x_shape, w_shape, part_grad_shape),
    ).results
    return out


mlir.register_lowering(
    _rms_norm_bwd_p,
    _rms_norm_bwd_cuda_lowering,
    platform="gpu",
)

Let’s test it#

per_core_batch_size=4
seq_len=512
emb_dim=512
x = jax.random.normal(
    jax.random.PRNGKey(0),
    shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim),
    dtype=jnp.bfloat16,
)
norm_shape = x.shape[-2:]
weight = jnp.ones(norm_shape, dtype=jnp.bfloat16)
Test forward function#
out = rms_norm_fwd(x, weight)
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In [5], line 1
----> 1 out = rms_norm_fwd(x, weight)

...

NotImplementedError: Abstract evaluation for 'rms_norm_fwd' not implemented

Abstract evaluation#

The test above failed with NotImplementedError: Abstract evaluation for 'rms_norm_fwd' not implemented. Why did the test fail? What does it mean?

As part of the execution, JAX performs abstract evaluation. As JAX has no knowledge about the new primitives, it doesn’t know how to compute the output shapes and output data types, thus can’t evaluate these operations abstractly.

We need to provide a function for abstract evaluation of each primitive. These abstract evaluation functions compute the shape and the data type of the outputs, but don’t compute actual values for the operations.

These functions are passed to .def_abstract_eval method to be registered with the corresponding primitives.

See How JAX primitives work for more information on abstract evaluation.

from functools import reduce
from operator import mul

from jax.core import ShapedArray


def _rms_norm_fwd_abstract(x, weight, eps):
    w_dtype = dtypes.canonicalize_dtype(weight.dtype)
    iv_dtype = dtypes.canonicalize_dtype(x.dtype)
    if iv_dtype in [jnp.float16, jnp.bfloat16]:
        iv_dtype = jnp.float32
    n2 = reduce(mul, weight.shape)
    n1 = reduce(mul, x.shape) // n2
    return (
        ShapedArray(x.shape, w_dtype, named_shape=x.named_shape),  # output
        ShapedArray((n1,), iv_dtype, named_shape=x.named_shape),  # invvar
    )


_rms_norm_fwd_p.def_abstract_eval(_rms_norm_fwd_abstract)


def _rms_norm_bwd_abstract(grad_output, invvar, x, weight, eps):
    iv_dtype = dtypes.canonicalize_dtype(invvar.dtype)
    w_dtype = dtypes.canonicalize_dtype(weight.dtype)
    x_dtype = dtypes.canonicalize_dtype(x.dtype)
    n2 = reduce(lambda x, y: x * y, weight.shape)
    n1 = reduce(lambda x, y: x * y, x.shape) // n2
    part_grad_shape = (16, n2)
    assert dtypes.canonicalize_dtype(grad_output.dtype) == w_dtype
    assert grad_output.shape == x.shape
    assert invvar.shape == (n1,)
    assert (
        iv_dtype == jnp.float32 if x_dtype in [jnp.float16, jnp.bfloat16] else x_dtype
    )
    assert grad_output.named_shape == x.named_shape
    weight_named_shape = (
        weight_named_shape if weight.named_shape else x.named_shape
    )
    return (
        ShapedArray(
            x.shape, x_dtype, named_shape=x.named_shape
        ),  # grad input
        ShapedArray(
            weight.shape, w_dtype, named_shape=weight_named_shape
        ),  # grad weight
        ShapedArray(
            part_grad_shape, iv_dtype, named_shape=weight_named_shape
        ),  # part grad
    )


_rms_norm_bwd_p.def_abstract_eval(_rms_norm_bwd_abstract)

Let’s test it again#

Test the forward function#
out = rms_norm_fwd(x, weight)
Test the backward function#

Now let’s test the backward operation using jax.grad and jtu.check_grads.

def loss(x, weight):
    predictions = rms_norm_fwd(x, weight)
    return -jnp.mean(predictions**2)


loss_grad = jax.grad(loss)
out = loss_grad(x, weight)
jtu.check_grads(loss, (x, weight), modes=["rev"], order=1)
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In [8], line 7
      3     return -jnp.mean(predictions**2)
      6 loss_grad = jax.grad(loss)
----> 7 out = loss_grad(x, weight)

...

NotImplementedError: Differentiation rule for 'rms_norm_fwd' not implemented

Differentiation rule#

The backward operation failed with the error NotImplementedError: Differentiation rule for 'rms_norm_fwd' not implemented. It means that, although we have defined rms_norm_fwd and rms_norm_bwd, JAX doesn’t know the relationship between them.

We can teach JAX that rms_norm_bwd is the backward operation for rms_norm_fwd, using jax.custom_vjp and its convention. As the first step, we need to refine the definition of rms_norm_fwd and rms_norm_bwd.

# rms_norm_fwd was previously defined as
#
# def rms_norm_fwd(x, weight, eps=1e-05):
#     output, invvar = _rms_norm_fwd_p.bind(x, weight, eps=eps)
#     return output
#
def rms_norm_fwd(x, weight, eps=1e-05):
    output, invvar = _rms_norm_fwd_p.bind(x, weight, eps=eps)
    return output, (invvar, x, weight)


# rms_norm_bwd was previously defined as
#
# def rms_norm_bwd(g, invvar, x, weight, eps):
#     grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind(
#         g, invvar, x, weight, eps=eps
#     )
#     return grad_input, grad_weight
#
def rms_norm_bwd(eps, res, g):
    invvar, x, weight = res
    grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind(
        g, invvar, x, weight, eps=eps
    )
    return grad_input, grad_weight

rms_norm_fwd now returns an extra output (invvar, x, weight) for the residual data and rms_norm_bwd takes eps, res, and g as the parameters.

Once the relationship between rms_norm_fwd and rms_norm_bwd is established through jax.custom_vjp, JAX will ensure that the residual data from rms_norm_fwd is passed to rms_norm_bwd as res for backward operation. For non-differentiable parameters such as eps, JAX ensures that they are passed to the backward operation before the residual data. That’s why eps precedes res in the parameter list of rms_norm_bwd.

Now that rms_norm_fwd returns the residual data, which is not needed for simple RMS normalization operation, we define a wrapper around it, rms_norm. It simply calls rms_norm_fwd and returns only output. Note that rms_norm is annotated with @partial(jax.custom_vjp, nondiff_argnums=(2,)) and we are passing rms_norm_fwd and rms_norm_bwd to rms_norm.defvjp. It teaches JAX that, when rms_norm is differentiated, rms_norm_fwd is to be used for forward operation, and rms_norm_bwd is to be used for backward operation.

See Custom derivative rules for JAX-transformable Python functions for more information on jax.custom_vjp.

@partial(jax.custom_vjp, nondiff_argnums=(2,))
def rms_norm(x, weight, eps=1e-05):
    output, _ = rms_norm_fwd(x, weight, eps=eps)
    return output


rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd)

With the refinement we have made, the backward operation test works with a modification: loss now calls rms_norm instead of rms_norm_fwd.

def loss(x, weight):
    predictions = rms_norm(x, weight)
    return -jnp.mean(predictions**2)


loss_grad = jax.grad(loss)
out = loss_grad(x, weight)
jtu.check_grads(loss, (x, weight), modes=["rev"], order=1)

Let’s test it on multiple devices#

We are using jax.experimental.pjit.pjit for parallel execution on multiple devices, and we produce reference values with sequential execution on a single device.

Test the forward function#

Let’s first test the forward operation on multiple devices. We are creating a simple 1D mesh and sharding x on all devices.

from jax.sharding import Mesh, PartitionSpec
from jax.experimental.pjit import pjit


mesh = Mesh(jax.local_devices(), ("x",))
ref = rms_norm(x, weight)
pjitted = pjit(
    rms_norm,
    # Shard x by batch dimension and replicate weight on all devices.
    in_shardings=(PartitionSpec("x", None, None), PartitionSpec(None, None)),
    # Shard the output by batch dimension.
    out_shardings=PartitionSpec("x", None, None),
)

with mesh:
    print(pjitted.lower(x, weight).compile().runtime_executable().hlo_modules()[0].to_string())
    out = pjitted(x, weight)

jnp.allclose(ref, out, atol=1e-5, rtol=1e-5)
HloModule pjit_rms_norm, entry_computation_layout={(bf16[4,512,512]{2,1,0},bf16[512,512]{1,0})->bf16[4,512,512]{2,1,0}}

%fused_computation (param_1: bf16[32,512,512], param_1.3: u32[]) -> bf16[4,512,512] {
  %param_1 = bf16[32,512,512]{2,1,0} parameter(0)
  %param_1.3 = u32[] parameter(1)
  %convert.2 = s32[] convert(u32[] %param_1.3), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
  %constant_9 = s32[] constant(4), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
  %multiply.3 = s32[] multiply(s32[] %convert.2, s32[] %constant_9), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
  %constant_8 = s32[] constant(0), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
  ROOT %dynamic-slice.2 = bf16[4,512,512]{2,1,0} dynamic-slice(bf16[32,512,512]{2,1,0} %param_1, s32[] %multiply.3, s32[] %constant_8, s32[] %constant_8), dynamic_slice_sizes={4,512,512}, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
}

ENTRY %main.7_spmd (param: bf16[4,512,512], param.1: bf16[512,512]) -> bf16[4,512,512] {
  %param = bf16[4,512,512]{2,1,0} parameter(0), sharding={devices=[8,1,1]0,1,2,3,4,5,6,7}
  %all-gather = bf16[32,512,512]{2,1,0} all-gather(bf16[4,512,512]{2,1,0} %param), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={0}, use_global_device_ids=true, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
  %param.1 = bf16[512,512]{1,0} parameter(1), sharding={replicated}
  %custom-call.0 = (bf16[32,512,512]{2,1,0}, f32[32]{0}) custom-call(bf16[32,512,512]{2,1,0} %all-gather, bf16[512,512]{1,0} %param.1), custom_call_target="rms_forward_affine_mixed_dtype", operand_layout_constraints={bf16[32,512,512]{2,1,0}, bf16[512,512]{1,0}}, api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}, backend_config=" \000\000\000\000\000\004\000\361h\343\210\265\370\344>\000\000\000\000\000\000\000\000\000\000\000\000\255\177\000\000"
  %get-tuple-element = bf16[32,512,512]{2,1,0} get-tuple-element((bf16[32,512,512]{2,1,0}, f32[32]{0}) %custom-call.0), index=0, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
  %partition-id = u32[] partition-id(), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
  ROOT %fusion = bf16[4,512,512]{2,1,0} fusion(bf16[32,512,512]{2,1,0} %get-tuple-element, u32[] %partition-id), kind=kLoop, calls=%fused_computation, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
}
True

The values have been computed correctly for forward operation, however, the generated HLO modules show an all-gather operation to replicate x on all devices, incurring large communication overhead.

As XLA does not have enough knowledge about the custom functions to shard input tensors, it decides to replicate them to produce correct values before making the custom call.

To avoid this duplication, we can:

  • custom_partitioning: to make it behave like all native JAX operations (but more complicated)

  • Use manual sharding

This example demonstrates the use of custom_partitioning.

Shard the forward function with custom_partitioning#

We first create a helper function to help with all the JAX/XLA callback registration required.

def register_primitive(cls):
    """
    register jax primitive

    The order of calls. Each operation is composed of two primitives: Inner and Outer.

    Inner, only the basic to wrap the custom_call itself.
    - impl to XLA custom_call in C.
    - abstract to know the static shapes
    - lower to StableHLO XLA custom_call.
    Outer, mostly all the rest:
    - impl: Bind to the inner primitive. Not used for real computation, but only for tracing. So we only need to bind.
    - abstract: same
    - lower to StableHLO custom_p. (XLA will call the python callback from it)
    - custom_p
    - vmap: could be added here.
    VJP is based on Outer, but not handled in this function.
    """

    def name_of_wrapper_p():
        return cls.name + "_wrapper"

    inner_p = core.Primitive(cls.name)
    dispatch.prim_requires_devices_during_lowering.add(inner_p)
    inner_p.multiple_results = cls.multiple_results
    inner_p.def_impl(partial(xla.apply_primitive, inner_p))
    inner_p.def_abstract_eval(cls.abstract)
    mlir.register_lowering(inner_p, cls.lowering, platform='cuda')
    cls.inner_primitive = inner_p

    outer_p = core.Primitive(name_of_wrapper_p())
    dispatch.prim_requires_devices_during_lowering.add(outer_p)
    outer_p.multiple_results = cls.multiple_results
    outer_p.def_impl(cls.impl)
    outer_p.def_abstract_eval(cls.abstract)
    batching.primitive_batchers[outer_p] = cls.batcher
    outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)
    outer_p_lower.def_partition(infer_sharding_from_operands=cls.infer_sharding_from_operands,
                                partition=cls.partition)
    mlir.register_lowering(outer_p,
                           mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results))
    cls.outer_primitive = outer_p
...

We define 2 JAX primitives, one inner primitive that map to the real kernel we want to warp in JAX. And an outer primitive that will be used with the custom_partitioning registration and for the gradient. (And if you implement the interface to support vmat, it will also be on the outer primitive).

JAX custom_partitioning implementation are callbacks from XLA to Python during XLA sharding logic. XLA sharding goes in two phases: a sharding propagation phase and a partition phase. The propagation phase is when XLA plan the sharding to be created. It is the partition phase that create the sharded graph. For XLA to be able to shard our custom operations, it needs us to define 2 extra functions: infer_sharding_from_operands() and partition(). They are used in the first and second phase respectively.

The infer_sharding_from_operands() function must do what its name say: infer the output sharding from the input sharding.

The partition() function will do a few things:

  • tell which input sharding will be expected. XLA will reshad if needed.

  • tell the final version of the output sharding.

  • give a function that will create the new instruction from the sharded inputs.

See the code comments for more explanation:

class RmsNormFwdClass:
    name = "rms_forward_affine_mixed_dtype"
    multiple_results = True
    impl_static_args = (2,)    # eps
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh,
                                     arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
                                     result_infos : Tuple[jax._src.core.ShapedArray]):
        del eps, result_infos  # Not needed for this example.
        x_info, weight_info = arg_infos
        assert len(x_info.shape) == 3
        assert len(weight_info.shape) == 2
        # partition() will force all dims of all inputs to be replicated except the
        # first dim of x that will be kept as is.
        # This is because the implementaion can only be sharded on the batch dimensions.

        x_spec = arg_infos[0].sharding.spec
        # None mean that we replicate on that dimension.
        output_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0], None, None))
        invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0]))
        return (output_sharding, invvar_sharding)

    @staticmethod
    def partition(eps : float, mesh : jax.sharding.Mesh,
                  arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
                  result_infos : Tuple[jax._src.api.ShapeDtypeStruct]):
        del result_infos  # Not needed for this example.
        x_info, weight_info = arg_infos
        assert len(x_info.shape) == 3
        assert len(weight_info.shape) == 2
        x_spec = arg_infos[0].sharding.spec
        # We only support sharding on the batch dimensions.
        # Force sharding on all others dimensions with None.
        arg_shardings = (NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)),
                         NamedSharding(mesh, PartitionSpec(None, None)))
        invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0]))
        output_shardings = (arg_shardings[0], invvar_sharding)
        # Sharded_impl only accepts positional arugments
        # And they should be Jax traceable variables
        impl = partial(RmsNormFwdClass.impl, eps=eps)

        return mesh, impl, output_shardings, arg_shardings
register_primitive(RmsNormFwdClass)

Next we define the primitive for the backward pass of RMSNorm

Shard the backward function with custom_partitioning#
class RmsNormBwdClass:
    name = "rms_norm_bwd"
    multiple_results = True
    impl_static_args = (4,)    # eps
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh,
                                     arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
                                     result_infos : Tuple[jax._src.core.ShapedArray]):
        del eps, result_infos  # Not needed for this example.
        g_info, invvar_info, x_info, weight_info = arg_infos
        assert len(g_info.shape) == 3
        assert len(invvar_info.shape) == 1
        assert len(x_info.shape) == 3
        assert len(weight_info.shape) == 2
        # partition() will force all dims to be replicated except the batch dimension.
        x_spec = x_info.sharding.spec
        output_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0], None, None))
        invvar_sharding = NamedSharding(mesh, PartitionSpec(None, None))
        return (output_sharding, invvar_sharding, output_sharding, )

    @staticmethod
    def partition(eps : float, mesh : jax.sharding.Mesh,
                  arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
                  result_infos : Tuple[jax._src.api.ShapeDtypeStruct]):
        del result_infos  # Not needed for this example.
        g_info, invvar_info, x_info, weight_info = arg_infos
        assert len(g_info.shape) == 3
        assert len(invvar_info.shape) == 1
        assert len(x_info.shape) == 3
        assert len(weight_info.shape) == 2

        # We only support sharding on the batch dimensions.
        # Force sharding on all others dimensions with None.
        # Also force gx, x and invvar to have the same batch sharding/replication.
        x_spec = x_info.sharding.spec
        arg_shardings = (NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)),
                         NamedSharding(mesh, PartitionSpec(x_spec[0],)),
                         NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)),
                         NamedSharding(mesh, PartitionSpec(None, None)))

        output_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0], None, None))
        invvar_sharding = NamedSharding(mesh, PartitionSpec(None, None))
        output_shardings = (output_sharding, invvar_sharding, invvar_sharding)


        # Sharded_impl only accepts positional arugments
        # And they should be Jax traceable variables
        def impl(g, invvar, x, weight):
            grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind(
                g, invvar, x, weight, eps=eps
            )
            # We need to sum the weight gradient from all partition.
            global_weight = grad_weight
            if x_spec[0]:
                global_weight = jax.lax.psum(grad_weight, x_spec[0])
            return grad_input, global_weight, part_grad
        return mesh, impl, output_shardings, arg_shardings
register_primitive(RmsNormBwdClass)

Plumbing to establish the forward and backward primtives with a custom_vjp rule as before:

@partial(jax.custom_vjp, nondiff_argnums=(2,))
def custom_p_rms_norm(x, weight, eps=1e-05):
    output, _ = custom_p_rms_norm_fwd(x, weight, eps=eps)
    return output
  
def custom_p_rms_norm_fwd(x, weight, eps=1e-05):
    output, invvar = RmsNormFwdClass.outer_primitive.bind(x, weight, eps=eps)
    return output, (invvar, x, weight)

def custom_p_rms_norm_bwd(eps, res, g):
    invvar, x, weight = res
    grad_input, grad_weight, part_grad = RmsNormBwdClass.outer_primitive.bind(
        g, invvar, x, weight, eps=eps)
    return grad_input, grad_weight

custom_p_rms_norm.defvjp(custom_p_rms_norm_fwd, custom_p_rms_norm_bwd)

With that we have completely defined our custom RMS norm primitive with custom_partitioning. To check for correctness we define the following loss functions: ref_loss is the reference value to compare against, while custom_p_loss uses our new primitive that implements custom_partitioning.

def ref_loss(x, weight):
    predictions = rms_norm(x, weight)
    return -jnp.mean(predictions**2)


ref = jax.grad(ref_loss, argnums=(0, 1))(x, weight)

def custom_p_loss(x, weight):
    predictions = custom_p_rms_norm(x, weight)
    return -jnp.mean(predictions**2)

Check for correctness#

with Mesh(jax.local_devices(), ("x",)):
    def run_and_verify(loss):
        pjitted = pjit(
            jax.grad(loss, argnums=(0, 1)),
            # Shard x by batch dimension and replicate weight on all devices.
            in_shardings=(
                PartitionSpec("x", None, None),
                PartitionSpec(None, None),
            ),
            # Shard the output by batch dimension and replicate weight grad on all devices.
            out_shardings=(
                PartitionSpec("x", None, None),
                PartitionSpec(None, None),
            ),
        )
        hlo = pjitted.lower(x, weight).compile().as_text()
        out = pjitted(x, weight)
        print(hlo)
        assert "all-reduce-done" in hlo, "The gradient will produce wrong value!"
        if "all-gather-start" in hlo:
            print("NOT OPTIMIZED, ALL_GATHER in the graph!")
        return out

    custom_p_out = run_and_verify(custom_p_loss)


for r, o in zip(ref_out, custom_p_out):
    print(jnp.allclose(r, o, atol=1e-6, rtol=1e-6))
HloModule pjit_custom_p_loss, is_scheduled=true, entry_computation_layout={(f16[4,512,512]{2,1,0}, f16[512,512]{1,0})->(f16[4,512,512]{2,1,0}, f16[512,512]{1,0})}, allow_spmd_sharding_propagation_to_parameters={false,false}, allow_spmd_sharding_propagation_to_output={false,false}, num_partitions=4, frontend_attributes={fingerprint_before_lhs="d7b9bc40de002332dd665ff2ab537b76"}

%fused_multiply (param_0: f16[4,512,512]) -> f16[4,512,512] {
  %param_0 = f16[4,512,512]{2,1,0} parameter(0)
  %constant_4_1 = f16[] constant(-4.7684e-07)
  %broadcast.8.1 = f16[4,512,512]{2,1,0} broadcast(f16[] %constant_4_1), dimensions={}, metadata={op_name="pjit(custom_p_loss)/jit(main)/mul" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=484}
  ROOT %multiply.5.1 = f16[4,512,512]{2,1,0} multiply(f16[4,512,512]{2,1,0} %param_0, f16[4,512,512]{2,1,0} %broadcast.8.1), metadata={op_name="pjit(custom_p_loss)/jit(main)/mul" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=484}
}

%region_0.9._custom_call_lowering_rule (Arg_0.10.0: f16[], Arg_1.11.0: f16[]) -> f16[] {
  %Arg_1.11.0 = f16[] parameter(1)
  %Arg_0.10.0 = f16[] parameter(0)
  ROOT %add.2.0 = f16[] add(f16[] %Arg_0.10.0, f16[] %Arg_1.11.0), metadata={op_name="jit(main)/add" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=433}
}

ENTRY %main.23_spmd (param.2: f16[4,512,512], param.1.0: f16[512,512]) -> (f16[4,512,512], f16[512,512]) {
  %param.1.0 = f16[512,512]{1,0} parameter(1), sharding={replicated}
  %param.2 = f16[4,512,512]{2,1,0} parameter(0), sharding={devices=[4,1,1]<=[4]}
  %custom-call.3.0 = (f16[4,512,512]{2,1,0}, f32[4]{0}) custom-call(f16[4,512,512]{2,1,0} %param.2, f16[512,512]{1,0} %param.1.0), custom_call_target="rms_forward_affine_mixed_dtype", operand_layout_constraints={f16[4,512,512]{2,1,0}, f16[512,512]{1,0}}, api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormFwdClass.partition at 0x7ff99e3980d0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormFwdClass.infer_sharding_from_operands at 0x7ff99e398040> decode_shardings=True in_tree=PyTreeDef((*, *)) out_tree=PyTreeDef((*, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=440}, backend_config="\004\000\000\000\000\000\004\000\361h\343\210\265\370\344>\001\000\000\000\001\000\000\000\000\000\000\000$V\000\000"
  %get-tuple-element.14 = f16[4,512,512]{2,1,0} get-tuple-element((f16[4,512,512]{2,1,0}, f32[4]{0}) %custom-call.3.0), index=0, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormFwdClass.partition at 0x7ff99e3980d0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormFwdClass.infer_sharding_from_operands at 0x7ff99e398040> decode_shardings=True in_tree=PyTreeDef((*, *)) out_tree=PyTreeDef((*, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=440}
  %loop_multiply_fusion = f16[4,512,512]{2,1,0} fusion(f16[4,512,512]{2,1,0} %get-tuple-element.14), kind=kLoop, calls=%fused_multiply, metadata={op_name="pjit(custom_p_loss)/jit(main)/mul" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=484}
  %get-tuple-element.1.0 = f32[4]{0} get-tuple-element((f16[4,512,512]{2,1,0}, f32[4]{0}) %custom-call.3.0), index=1, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormFwdClass.partition at 0x7ff99e3980d0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormFwdClass.infer_sharding_from_operands at 0x7ff99e398040> decode_shardings=True in_tree=PyTreeDef((*, *)) out_tree=PyTreeDef((*, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=440}
  %custom-call.5.0 = (f16[4,512,512]{2,1,0}, f16[512,512]{1,0}, f32[16,262144]{1,0}) custom-call(f16[4,512,512]{2,1,0} %loop_multiply_fusion, f32[4]{0} %get-tuple-element.1.0, f16[4,512,512]{2,1,0} %param.2, f16[512,512]{1,0} %param.1.0), custom_call_target="rms_backward_affine", operand_layout_constraints={f16[4,512,512]{2,1,0}, f32[4]{0}, f16[4,512,512]{2,1,0}, f16[512,512]{1,0}}, api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}, backend_config="\004\000\000\000\000\000\004\000\361h\343\210\265\370\344>\001\000\000\000\001\000\000\000\020\000\000\000$V\000\000"
  %get-tuple-element.7.0 = f16[512,512]{1,0} get-tuple-element((f16[4,512,512]{2,1,0}, f16[512,512]{1,0}, f32[16,262144]{1,0}) %custom-call.5.0), index=1, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}
  %all-reduce-start = f16[512,512]{1,0} all-reduce-start(f16[512,512]{1,0} %get-tuple-element.7.0), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%region_0.9._custom_call_lowering_rule, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"collective_backend_config":{"is_sync":true,"no_parallel_custom_call":false}}
  %all-reduce-done = f16[512,512]{1,0} all-reduce-done(f16[512,512]{1,0} %all-reduce-start), metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}
  %get-tuple-element.12.0 = f16[4,512,512]{2,1,0} get-tuple-element((f16[4,512,512]{2,1,0}, f16[512,512]{1,0}, f32[16,262144]{1,0}) %custom-call.5.0), index=0, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}
  ROOT %tuple.1.0 = (f16[4,512,512]{2,1,0}, f16[512,512]{1,0}) tuple(f16[4,512,512]{2,1,0} %get-tuple-element.12.0, f16[512,512]{1,0} %all-reduce-done)
}
True
True

Now there are no all-gathers in the HLO, sharding is respected and only gradients are accumulated via an all-reduce.

Let’s put it together#

The complete definition of the primitives using custom_partitioning can be found in Custom_Operation_for_GPUs.py and the corresponding C++ code the defines python bindings in addition to the kernel implementations can be found below:

gpu_ops code listing#

gpu_ops/kernel_helpers.h
gpu_ops/kernels.h
gpu_ops/pybind11_kernel_helpers.h
gpu_ops/gpu_ops.cpp
gpu_ops/rms_norm_kernels.cu

Generalized Convolutions in JAX#

Open in Colab Open in Kaggle

JAX provides a number of interfaces to compute convolutions across data, including:

For basic convolution operations, the jax.numpy and jax.scipy operations are usually sufficient. If you want to do more general batched multi-dimensional convolution, the jax.lax function is where you should start.

Basic One-dimensional Convolution#

Basic one-dimensional convolution is implemented by jax.numpy.convolve(), which provides a JAX interface for numpy.convolve(). Here is a simple example of 1D smoothing implemented via a convolution:

import matplotlib.pyplot as plt

from jax import random
import jax.numpy as jnp
import numpy as np

key = random.key(1701)

x = jnp.linspace(0, 10, 500)
y = jnp.sin(x) + 0.2 * random.normal(key, shape=(500,))

window = jnp.ones(10) / 10
y_smooth = jnp.convolve(y, window, mode='same')

plt.plot(x, y, 'lightgray')
plt.plot(x, y_smooth, 'black');
_images/c83cd43deff4b4be1ddea7f887c187cc8b748d2d58de59e92ad65c730e730f35.png

The mode parameter controls how boundary conditions are treated; here we use mode='same' to ensure that the output is the same size as the input.

For more information, see the jax.numpy.convolve() documentation, or the documentation associated with the original numpy.convolve() function.

Basic N-dimensional Convolution#

For N-dimensional convolution, jax.scipy.signal.convolve() provides a similar interface to that of jax.numpy.convolve(), generalized to N dimensions.

For example, here is a simple approach to de-noising an image based on convolution with a Gaussian filter:

from scipy import misc
import jax.scipy as jsp

fig, ax = plt.subplots(1, 3, figsize=(12, 5))

# Load a sample image; compute mean() to convert from RGB to grayscale.
image = jnp.array(misc.face().mean(-1))
ax[0].imshow(image, cmap='binary_r')
ax[0].set_title('original')

# Create a noisy version by adding random Gaussian noise
key = random.key(1701)
noisy_image = image + 50 * random.normal(key, image.shape)
ax[1].imshow(noisy_image, cmap='binary_r')
ax[1].set_title('noisy')

# Smooth the noisy image with a 2D Gaussian smoothing kernel.
x = jnp.linspace(-3, 3, 7)
window = jsp.stats.norm.pdf(x) * jsp.stats.norm.pdf(x[:, None])
smooth_image = jsp.signal.convolve(noisy_image, window, mode='same')
ax[2].imshow(smooth_image, cmap='binary_r')
ax[2].set_title('smoothed');
/tmp/ipykernel_5100/4118182506.py:7: DeprecationWarning: scipy.misc.face has been deprecated in SciPy v1.10.0; and will be completely removed in SciPy v1.12.0. Dataset methods have moved into the scipy.datasets module. Use scipy.datasets.face instead.
  image = jnp.array(misc.face().mean(-1))
_images/4c4655c6ee43dbbc7d633706f1836d3eb84ed8f3710beb80c229131b7902d37e.png

Like in the one-dimensional case, we use mode='same' to specify how we would like edges to be handled. For more information on available options in N-dimensional convolutions, see the jax.scipy.signal.convolve() documentation.

General Convolutions#

For the more general types of batched convolutions often useful in the context of building deep neural networks, JAX and XLA offer the very general N-dimensional conv_general_dilated function, but it’s not very obvious how to use it. We’ll give some examples of the common use-cases.

A survey of the family of convolutional operators, a guide to convolutional arithmetic, is highly recommended reading!

Let’s define a simple diagonal edge kernel:

# 2D kernel - HWIO layout
kernel = jnp.zeros((3, 3, 3, 3), dtype=jnp.float32)
kernel += jnp.array([[1, 1, 0],
                     [1, 0,-1],
                     [0,-1,-1]])[:, :, jnp.newaxis, jnp.newaxis]

print("Edge Conv kernel:")
plt.imshow(kernel[:, :, 0, 0]);
Edge Conv kernel:
_images/883f9b8725214e9ec8d843249e2b4f3f7f223bd87dc3ccfc3e056d1ddb7b8e82.png

And we’ll make a simple synthetic image:

# NHWC layout
img = jnp.zeros((1, 200, 198, 3), dtype=jnp.float32)
for k in range(3):
  x = 30 + 60*k
  y = 20 + 60*k
  img = img.at[0, x:x+10, y:y+10, k].set(1.0)

print("Original Image:")
plt.imshow(img[0]);
Original Image:
_images/2b57ee7824559511470625e31186206b423c9d9b227662d7a54df68ec07636e1.png
lax.conv and lax.conv_with_general_padding#

These are the simple convenience functions for convolutions

️⚠️ The convenience lax.conv, lax.conv_with_general_padding helper function assume NCHW images and OIHW kernels.

from jax import lax
out = lax.conv(jnp.transpose(img,[0,3,1,2]),    # lhs = NCHW image tensor
               jnp.transpose(kernel,[3,2,0,1]), # rhs = OIHW conv kernel tensor
               (1, 1),  # window strides
               'SAME') # padding mode
print("out shape: ", out.shape)
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(np.array(out)[0,0,:,:]);
out shape:  (1, 3, 200, 198)
First output channel:
_images/8a57f395c84aae3dcdd66097dff88e50520df8bc92105afad8cd12ede4137609.png
out = lax.conv_with_general_padding(
  jnp.transpose(img,[0,3,1,2]),    # lhs = NCHW image tensor
  jnp.transpose(kernel,[2,3,0,1]), # rhs = IOHW conv kernel tensor
  (1, 1),  # window strides
  ((2,2),(2,2)), # general padding 2x2
  (1,1),  # lhs/image dilation
  (1,1))  # rhs/kernel dilation
print("out shape: ", out.shape)
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(np.array(out)[0,0,:,:]);
out shape:  (1, 3, 202, 200)
First output channel:
_images/d5cb51b8073c6aae881811ff9be4400a984b028f3b5cf30dae85819d34c8ea8a.png
Dimension Numbers define dimensional layout for conv_general_dilated#

The important argument is the 3-tuple of axis layout arguments: (Input Layout, Kernel Layout, Output Layout)

  • N - batch dimension

  • H - spatial height

  • W - spatial width

  • C - channel dimension

  • I - kernel input channel dimension

  • O - kernel output channel dimension

⚠️ To demonstrate the flexibility of dimension numbers we choose a NHWC image and HWIO kernel convention for lax.conv_general_dilated below.

dn = lax.conv_dimension_numbers(img.shape,     # only ndim matters, not shape
                                kernel.shape,  # only ndim matters, not shape 
                                ('NHWC', 'HWIO', 'NHWC'))  # the important bit
print(dn)
ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))
SAME padding, no stride, no dilation#
out = lax.conv_general_dilated(img,    # lhs = image tensor
                               kernel, # rhs = conv kernel tensor
                               (1,1),  # window strides
                               'SAME', # padding mode
                               (1,1),  # lhs/image dilation
                               (1,1),  # rhs/kernel dilation
                               dn)     # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape)
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(np.array(out)[0,:,:,0]);
out shape:  (1, 200, 198, 3)
First output channel:
_images/8a57f395c84aae3dcdd66097dff88e50520df8bc92105afad8cd12ede4137609.png
VALID padding, no stride, no dilation#
out = lax.conv_general_dilated(img,     # lhs = image tensor
                               kernel,  # rhs = conv kernel tensor
                               (1,1),   # window strides
                               'VALID', # padding mode
                               (1,1),   # lhs/image dilation
                               (1,1),   # rhs/kernel dilation
                               dn)      # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, "DIFFERENT from above!")
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(np.array(out)[0,:,:,0]);
out shape:  (1, 198, 196, 3) DIFFERENT from above!
First output channel:
_images/14fc3371350b4be8e6610fec38e52169fc557153cb0e03fa3eb13e36b1cfe07f.png
SAME padding, 2,2 stride, no dilation#
out = lax.conv_general_dilated(img,    # lhs = image tensor
                               kernel, # rhs = conv kernel tensor
                               (2,2),  # window strides
                               'SAME', # padding mode
                               (1,1),  # lhs/image dilation
                               (1,1),  # rhs/kernel dilation
                               dn)     # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, " <-- half the size of above")
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(np.array(out)[0,:,:,0]);
out shape:  (1, 100, 99, 3)  <-- half the size of above
First output channel:
_images/3b6a45fcc4cc5b8d674f01605eefc5f675da8b99973c2ad4179b715516596d39.png
VALID padding, no stride, rhs kernel dilation ~ Atrous convolution (excessive to illustrate)#
out = lax.conv_general_dilated(img,     # lhs = image tensor
                               kernel,  # rhs = conv kernel tensor
                               (1,1),   # window strides
                               'VALID', # padding mode
                               (1,1),   # lhs/image dilation
                               (12,12), # rhs/kernel dilation
                               dn)      # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape)
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(np.array(out)[0,:,:,0]);
out shape:  (1, 176, 174, 3)
First output channel:
_images/55c23d940e320ae69facae44494064addb1093e052482ed3e562e6dc2d663a48.png
VALID padding, no stride, lhs=input dilation ~ Transposed Convolution#
out = lax.conv_general_dilated(img,               # lhs = image tensor
                               kernel,            # rhs = conv kernel tensor
                               (1,1),             # window strides
                               ((0, 0), (0, 0)),  # padding mode
                               (2,2),             # lhs/image dilation
                               (1,1),             # rhs/kernel dilation
                               dn)                # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, "<-- larger than original!")
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(np.array(out)[0,:,:,0]);
out shape:  (1, 397, 393, 3) <-- larger than original!
First output channel:
_images/8ca17566e91d168ba563430d2932e403002d609fa2f6c7332b9ba7bce994f006.png

We can use the last to, for instance, implement transposed convolutions:

# The following is equivalent to tensorflow:
# N,H,W,C = img.shape
# out = tf.nn.conv2d_transpose(img, kernel, (N,2*H,2*W,C), (1,2,2,1))

# transposed conv = 180deg kernel rotation plus LHS dilation
# rotate kernel 180deg:
kernel_rot = jnp.rot90(jnp.rot90(kernel, axes=(0,1)), axes=(0,1))
# need a custom output padding:
padding = ((2, 1), (2, 1))
out = lax.conv_general_dilated(img,     # lhs = image tensor
                               kernel_rot,  # rhs = conv kernel tensor
                               (1,1),   # window strides
                               padding, # padding mode
                               (2,2),   # lhs/image dilation
                               (1,1),   # rhs/kernel dilation
                               dn)      # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, "<-- transposed_conv")
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(np.array(out)[0,:,:,0]);
out shape:  (1, 400, 396, 3) <-- transposed_conv
First output channel:
_images/cf25594cda8874a4482b7ff6591fea1ec31b9e73e8e5cb9f8a2ee91be88f6518.png
1D Convolutions#

You aren’t limited to 2D convolutions, a simple 1D demo is below:

# 1D kernel - WIO layout
kernel = jnp.array([[[1, 0, -1], [-1,  0,  1]], 
                    [[1, 1,  1], [-1, -1, -1]]], 
                    dtype=jnp.float32).transpose([2,1,0])
# 1D data - NWC layout
data = np.zeros((1, 200, 2), dtype=jnp.float32)
for i in range(2):
  for k in range(2):
      x = 35*i + 30 + 60*k
      data[0, x:x+30, k] = 1.0

print("in shapes:", data.shape, kernel.shape)

plt.figure(figsize=(10,5))
plt.plot(data[0]);
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
                                ('NWC', 'WIO', 'NWC'))
print(dn)

out = lax.conv_general_dilated(data,   # lhs = image tensor
                               kernel, # rhs = conv kernel tensor
                               (1,),   # window strides
                               'SAME', # padding mode
                               (1,),   # lhs/image dilation
                               (1,),   # rhs/kernel dilation
                               dn)     # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape)
plt.figure(figsize=(10,5))
plt.plot(out[0]);
in shapes: (1, 200, 2) (3, 2, 2)
ConvDimensionNumbers(lhs_spec=(0, 2, 1), rhs_spec=(2, 1, 0), out_spec=(0, 2, 1))
out shape:  (1, 200, 2)
_images/2ef5235cfbafd3b581a277495eed4e16fb0408ffc6e34d3ddc370ea37402829e.png _images/b48faa6272742a58b06f6fb5922b5a22aa9384ee26f69ce847d8762fffe85f39.png
3D Convolutions#
import matplotlib as mpl

# Random 3D kernel - HWDIO layout
kernel = jnp.array([
  [[0, 0,  0], [0,  1,  0], [0,  0,   0]],
  [[0, -1, 0], [-1, 0, -1], [0,  -1,  0]], 
  [[0, 0,  0], [0,  1,  0], [0,  0,   0]]], 
  dtype=jnp.float32)[:, :, :, jnp.newaxis, jnp.newaxis]

# 3D data - NHWDC layout
data = jnp.zeros((1, 30, 30, 30, 1), dtype=jnp.float32)
x, y, z = np.mgrid[0:1:30j, 0:1:30j, 0:1:30j]
data += (jnp.sin(2*x*jnp.pi)*jnp.cos(2*y*jnp.pi)*jnp.cos(2*z*jnp.pi))[None,:,:,:,None]

print("in shapes:", data.shape, kernel.shape)
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
                                ('NHWDC', 'HWDIO', 'NHWDC'))
print(dn)

out = lax.conv_general_dilated(data,    # lhs = image tensor
                               kernel,  # rhs = conv kernel tensor
                               (1,1,1), # window strides
                               'SAME',  # padding mode
                               (1,1,1), # lhs/image dilation
                               (1,1,1), # rhs/kernel dilation
                               dn)      # dimension_numbers
print("out shape: ", out.shape)

# Make some simple 3d density plots:
from mpl_toolkits.mplot3d import Axes3D
def make_alpha(cmap):
  my_cmap = cmap(jnp.arange(cmap.N))
  my_cmap[:,-1] = jnp.linspace(0, 1, cmap.N)**3
  return mpl.colors.ListedColormap(my_cmap)
my_cmap = make_alpha(plt.cm.viridis)
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=data.ravel(), cmap=my_cmap)
ax.axis('off')
ax.set_title('input')
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=out.ravel(), cmap=my_cmap)
ax.axis('off')
ax.set_title('3D conv output');
in shapes: (1, 30, 30, 30, 1) (3, 3, 3, 1, 1)
ConvDimensionNumbers(lhs_spec=(0, 4, 1, 2, 3), rhs_spec=(4, 3, 0, 1, 2), out_spec=(0, 4, 1, 2, 3))
out shape:  (1, 30, 30, 30, 1)
_images/7d3ec30e216f531cb7071c1c9198fcd664b418bece5360933ec6d837fbed601f.png _images/88b7c41fae67d0a9f8594815ac87e16c9bc89948fecfe04e6ac396a85e72d7de.png

Developer Documentation#

JAX welcomes contributions from the community. See below for various install guides to get setup as a developer as well as developer-focused resources such as Jax Enhancement Proposals.

Contributing to JAX#

Everyone can contribute to JAX, and we value everyone’s contributions. There are several ways to contribute, including:

The JAX project follows Google’s Open Source Community Guidelines.

Ways to contribute#

We welcome pull requests, in particular for those issues marked with contributions welcome or good first issue.

For other proposals, we ask that you first open a GitHub Issue or Discussion to seek feedback on your planned contribution.

Contributing code using pull requests#

We do all of our development using git, so basic knowledge is assumed.

Follow these steps to contribute code:

  1. Sign the Google Contributor License Agreement (CLA). For more information, see the Pull Request Checklist below.

  2. Fork the JAX repository by clicking the Fork button on the repository page. This creates a copy of the JAX repository in your own account.

  3. Install Python >= 3.9 locally in order to run tests.

  4. pip installing your fork from source. This allows you to modify the code and immediately test it out:

    git clone https://github.com/YOUR_USERNAME/jax
    cd jax
    pip install -r build/test-requirements.txt  # Installs all testing requirements.
    pip install -e ".[cpu]"  # Installs JAX from the current directory in editable mode.
    
  5. Add the JAX repo as an upstream remote, so you can use it to sync your changes.

    git remote add upstream https://www.github.com/google/jax
    
  6. Create a branch where you will develop from:

    git checkout -b name-of-change
    

    And implement your changes using your favorite editor (we recommend Visual Studio Code).

  7. Make sure your code passes JAX’s lint and type checks, by running the following from the top of the repository:

    pip install pre-commit
    pre-commit run --all
    

    See Linting and Type-checking for more details.

  8. Make sure the tests pass by running the following command from the top of the repository:

    pytest -n auto tests/
    

    JAX’s test suite is quite large, so if you know the specific test file that covers your changes, you can limit the tests to that; for example:

    pytest -n auto tests/lax_scipy_test.py
    

    You can narrow the tests further by using the pytest -k flag to match particular test names:

    pytest -n auto tests/lax_scipy_test.py -k testLogSumExp
    

    JAX also offers more fine-grained control over which particular tests are run; see Running the tests for more information.

  9. Once you are satisfied with your change, create a commit as follows ( how to write a commit message):

    git add file1.py file2.py ...
    git commit -m "Your commit message"
    

    Then sync your code with the main repo:

    git fetch upstream
    git rebase upstream/main
    

    Finally, push your commit on your development branch and create a remote branch in your fork that you can use to create a pull request from:

    git push --set-upstream origin name-of-change
    

    Please ensure your contribution is a single commit (see Single-change commits and pull requests)

  10. Create a pull request from the JAX repository and send it for review. Check the JAX pull request checklist for considerations when preparing your PR, and consult GitHub Help if you need more information on using pull requests.

JAX pull request checklist#

As you prepare a JAX pull request, here are a few things to keep in mind:

Google contributor license agreement#

Contributions to this project must be accompanied by a Google Contributor License Agreement (CLA). You (or your employer) retain the copyright to your contribution; this simply gives us permission to use and redistribute your contributions as part of the project. Head over to https://cla.developers.google.com/ to see your current agreements on file or to sign a new one.

You generally only need to submit a CLA once, so if you’ve already submitted one (even if it was for a different project), you probably don’t need to do it again. If you’re not certain whether you’ve signed a CLA, you can open your PR and our friendly CI bot will check for you.

Single-change commits and pull requests#

A git commit ought to be a self-contained, single change with a descriptive message. This helps with review and with identifying or reverting changes if issues are uncovered later on.

Pull requests typically comprise a single git commit. (In some cases, for instance for large refactors or internal rewrites, they may contain several.) In preparing a pull request for review, you may need to squash together multiple commits. We ask that you do this prior to sending the PR for review if possible. The git rebase -i command might be useful to this end.

Linting and Type-checking#

JAX uses mypy and ruff to statically test code quality; the easiest way to run these checks locally is via the pre-commit framework:

pip install pre-commit
pre-commit run --all

If your pull request touches documentation notebooks, this will also run some checks on those (See Update notebooks for more details).

Full GitHub test suite#

Your PR will automatically be run through a full test suite on GitHub CI, which covers a range of Python versions, dependency versions, and configuration options. It’s normal for these tests to turn up failures that you didn’t catch locally; to fix the issues you can push new commits to your branch.

Restricted test suite#

Once your PR has been reviewed, a JAX maintainer will mark it as Pull Ready. This will trigger a larger set of tests, including tests on GPU and TPU backends that are not available via standard GitHub CI. Detailed results of these tests are not publicly viewable, but the JAX maintainer assigned to your PR will communicate with you regarding any failures these might uncover; it’s not uncommon, for example, that numerical tests need different tolerances on TPU than on CPU.

Building from source#

First, obtain the JAX source code:

git clone https://github.com/google/jax
cd jax

Building JAX involves two steps:

  1. Building or installing jaxlib, the C++ support library for jax.

  2. Installing the jax Python package.

Building or installing jaxlib#

Installing jaxlib with pip#

If you’re only modifying Python portions of JAX, we recommend installing jaxlib from a prebuilt wheel using pip:

pip install jaxlib

See the JAX readme for full guidance on pip installation (e.g., for GPU and TPU support).

Building jaxlib from source#

To build jaxlib from source, you must also install some prerequisites:

  • a C++ compiler (g++, clang, or MSVC)

    On Ubuntu or Debian you can install the necessary prerequisites with:

    sudo apt install g++ python python3-dev
    

    If you are building on a Mac, make sure XCode and the XCode command line tools are installed.

    See below for Windows build instructions.

  • Python packages: numpy, wheel, build.

You can install the necessary Python dependencies using pip:

pip install numpy wheel build

To build jaxlib for CPU or TPU, you can run:

python build/build.py
pip install dist/*.whl  # installs jaxlib (includes XLA)

There are two ways to build jaxlib with CUDA support: (1) use python build/build.py --enable_cuda to generate a jaxlib wheel with cuda support, or (2) use python build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12 to generate three wheels (jaxlib without cuda, jax-cuda-plugin, and jax-cuda-pjrt). You can set gpu_plugin_cuda_version to 11 or 12.

See python build/build.py --help for configuration options, including ways to specify the paths to CUDA and CUDNN, which you must have installed. Here python should be the name of your Python 3 interpreter; on some systems, you may need to use python3 instead. By default, the wheel is written to the dist/ subdirectory of the current directory.

Building jaxlib from source with a modified XLA repository.#

JAX depends on XLA, whose source code is in the XLA GitHub repository. By default JAX uses a pinned copy of the XLA repository, but we often want to use a locally-modified copy of XLA when working on JAX. There are two ways to do this:

  • use Bazel’s override_repository feature, which you can pass as a command line flag to build.py as follows:

    python build/build.py --bazel_options=--override_repository=xla=/path/to/xla
    
  • modify the WORKSPACE file in the root of the JAX source tree to point to a different XLA tree.

To contribute changes back to XLA, send PRs to the XLA repository.

The version of XLA pinned by JAX is regularly updated, but is updated in particular before each jaxlib release.

Additional Notes for Building jaxlib from source on Windows#

On Windows, follow Install Visual Studio to set up a C++ toolchain. Visual Studio 2019 version 16.5 or newer is required. If you need to build with CUDA enabled, follow the CUDA Installation Guide to set up a CUDA environment.

JAX builds use symbolic links, which require that you activate Developer Mode.

You can either install Python using its Windows installer, or if you prefer, you can use Anaconda or Miniconda to set up a Python environment.

Some targets of Bazel use bash utilities to do scripting, so MSYS2 is needed. See Installing Bazel on Windows for more details. Install the following packages:

pacman -S patch coreutils

Once coreutils is installed, the realpath command should be present in your shell’s path.

Once everything is installed. Open PowerShell, and make sure MSYS2 is in the path of the current session. Ensure bazel, patch and realpath are accessible. Activate the conda environment. The following command builds with CUDA enabled, adjust it to whatever suitable for you:

python .\build\build.py `
  --enable_cuda `
  --cuda_path='C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1' `
  --cudnn_path='C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1' `
  --cuda_version='10.1' `
  --cudnn_version='7.6.5'

To build with debug information, add the flag --bazel_options='--copt=/Z7'.

Additional notes for building a ROCM jaxlib for AMD GPUs#

You need several ROCM/HIP libraries installed to build for ROCM. For example, on a Ubuntu machine with AMD’s apt repositories available, you need a number of packages installed:

sudo apt install miopen-hip hipfft-dev rocrand-dev hipsparse-dev hipsolver-dev \
    rccl-dev rccl hip-dev rocfft-dev roctracer-dev hipblas-dev rocm-device-libs

To build jaxlib with ROCM support, you can run the following build command, suitably adjusted for your paths and ROCM version.

python build/build.py --enable_rocm --rocm_path=/opt/rocm-5.7.0

AMD’s fork of the XLA repository may include fixes not present in the upstream XLA repository. If you experience problems with the upstream repository, you can try AMD’s fork, by cloning their repository:

git clone https://github.com/ROCmSoftwarePlatform/xla.git

and override the XLA repository with which JAX is built:

python build/build.py --enable_rocm --rocm_path=/opt/rocm-5.7.0 \
  --bazel_options=--override_repository=xla=/path/to/xla-rocm

Installing jax#

Once jaxlib has been installed, you can install jax by running:

pip install -e .  # installs jax

To upgrade to the latest version from GitHub, just run git pull from the JAX repository root, and rebuild by running build.py or upgrading jaxlib if necessary. You shouldn’t have to reinstall jax because pip install -e sets up symbolic links from site-packages into the repository.

Running the tests#

First, install the dependencies by running pip install -r build/test-requirements.txt.

There are two supported mechanisms for running the JAX tests, either using Bazel or using pytest.

Using Bazel#

First, configure the JAX build by running:

python build/build.py --configure_only

You may pass additional options to build.py to configure the build; see the jaxlib build documentation for details.

By default the Bazel build runs the JAX tests using jaxlib built from source. To run JAX tests, run:

bazel test //tests:cpu_tests //tests:backend_independent_tests

//tests:gpu_tests and //tests:tpu_tests are also available, if you have the necessary hardware.

To use a preinstalled jaxlib instead of building jaxlib from source, run

bazel test --//jax:build_jaxlib=false //tests:cpu_tests //tests:backend_independent_tests

A number of test behaviors can be controlled using environment variables (see below). Environment variables may be passed to JAX tests using the --test_env=FLAG=value flag to Bazel.

Some of JAX tests are for multiple accelerators (i.e. GPUs, TPUs). When JAX is already installed, you can run GPUs tests like this:

bazel test //tests:gpu_tests --local_test_jobs=4 --test_tag_filters=multiaccelerator --//jax:build_jaxlib=false --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform

You can speed up single accelerator tests by running them in parallel on multiple accelerators. This also triggers multiple concurrent tests per accelerator. For GPUs, you can do it like this:

NB_GPUS=2
JOBS_PER_ACC=4
J=$((NB_GPUS * JOBS_PER_ACC))
MULTI_GPU="--run_under $PWD/build/parallel_accelerator_execute.sh --test_env=JAX_ACCELERATOR_COUNT=${NB_GPUS} --test_env=JAX_TESTS_PER_ACCELERATOR=${JOBS_PER_ACC} --local_test_jobs=$J"
bazel test //tests:gpu_tests //tests:backend_independent_tests --test_env=XLA_PYTHON_CLIENT_PREALLOCATE=false --test_tag_filters=-multiaccelerator $MULTI_GPU

Some test targets, like a //tests:logpcg_tests optionally use matplotlib, so you may need to pip install matplotlib to run tests via bazel.

Using pytest#

To run all the JAX tests using pytest, we recommend using pytest-xdist, which can run tests in parallel. It is installed as a part of pip install -r build/test-requirements.txt command.

From the repository root directory run:

pytest -n auto tests
Controlling test behavior#

JAX generates test cases combinatorially, and you can control the number of cases that are generated and checked for each test (default is 10) using the JAX_NUM_GENERATED_CASES environment variable. The automated tests currently use 25 by default.

For example, one might write

# Bazel
bazel test //tests/... --test_env=JAX_NUM_GENERATED_CASES=25`

or

# pytest
JAX_NUM_GENERATED_CASES=25 pytest -n auto tests

The automated tests also run the tests with default 64-bit floats and ints (JAX_ENABLE_X64):

JAX_ENABLE_X64=1 JAX_NUM_GENERATED_CASES=25 pytest -n auto tests

You can run a more specific set of tests using pytest’s built-in selection mechanisms, or alternatively you can run a specific test file directly to see more detailed information about the cases being run:

JAX_NUM_GENERATED_CASES=5 python tests/lax_numpy_test.py

You can skip a few tests known to be slow, by passing environment variable JAX_SKIP_SLOW_TESTS=1.

To specify a particular set of tests to run from a test file, you can pass a string or regular expression via the --test_targets flag. For example, you can run all the tests of jax.numpy.pad using:

python tests/lax_numpy_test.py --test_targets="testPad"

The Colab notebooks are tested for errors as part of the documentation build.

Doctests#

JAX uses pytest in doctest mode to test the code examples within the documentation. You can run this using

pytest docs

Additionally, JAX runs pytest in doctest-modules mode to ensure code examples in function docstrings will run correctly. You can run this locally using, for example:

pytest --doctest-modules jax/_src/numpy/lax_numpy.py

Keep in mind that there are several files that are marked to be skipped when the doctest command is run on the full package; you can see the details in ci-build.yaml

Type checking#

We use mypy to check the type hints. To check types locally the same way as the CI checks them:

pip install mypy
mypy --config=pyproject.toml --show-error-codes jax

Alternatively, you can use the pre-commit framework to run this on all staged files in your git repository, automatically using the same mypy version as in the GitHub CI:

pre-commit run mypy

Linting#

JAX uses the ruff linter to ensure code quality. You can check your local changes by running:

pip install ruff
ruff jax

Alternatively, you can use the pre-commit framework to run this on all staged files in your git repository, automatically using the same ruff version as the GitHub tests:

pre-commit run ruff

Update documentation#

To rebuild the documentation, install several packages:

pip install -r docs/requirements.txt

And then run:

sphinx-build -b html docs docs/build/html -j auto

This can take a long time because it executes many of the notebooks in the documentation source; if you’d prefer to build the docs without executing the notebooks, you can run:

sphinx-build -b html -D nb_execution_mode=off docs docs/build/html -j auto

You can then see the generated documentation in docs/build/html/index.html.

The -j auto option controls the parallelism of the build. You can use a number in place of auto to control how many CPU cores to use.

Update notebooks#

We use jupytext to maintain two synced copies of the notebooks in docs/notebooks: one in ipynb format, and one in md format. The advantage of the former is that it can be opened and executed directly in Colab; the advantage of the latter is that it makes it much easier to track diffs within version control.

Editing ipynb#

For making large changes that substantially modify code and outputs, it is easiest to edit the notebooks in Jupyter or in Colab. To edit notebooks in the Colab interface, open http://colab.research.google.com and Upload from your local repo. Update it as needed, Run all cells then Download ipynb. You may want to test that it executes properly, using sphinx-build as explained above.

Editing md#

For making smaller changes to the text content of the notebooks, it is easiest to edit the .md versions using a text editor.

Syncing notebooks#

After editing either the ipynb or md versions of the notebooks, you can sync the two versions using jupytext by running jupytext --sync on the updated notebooks; for example:

pip install jupytext==1.16.0
jupytext --sync docs/notebooks/thinking_in_jax.ipynb

The jupytext version should match that specified in .pre-commit-config.yaml.

To check that the markdown and ipynb files are properly synced, you may use the pre-commit framework to perform the same check used by the github CI:

git add docs -u  # pre-commit runs on files in git staging.
pre-commit run jupytext
Creating new notebooks#

If you are adding a new notebook to the documentation and would like to use the jupytext --sync command discussed here, you can set up your notebook for jupytext by using the following command:

jupytext --set-formats ipynb,md:myst path/to/the/notebook.ipynb

This works by adding a "jupytext" metadata field to the notebook file which specifies the desired formats, and which the jupytext --sync command recognizes when invoked.

Notebooks within the Sphinx build#

Some of the notebooks are built automatically as part of the pre-submit checks and as part of the Read the docs build. The build will fail if cells raise errors. If the errors are intentional, you can either catch them, or tag the cell with raises-exceptions metadata (example PR). You have to add this metadata by hand in the .ipynb file. It will be preserved when somebody else re-saves the notebook.

We exclude some notebooks from the build, e.g., because they contain long computations. See exclude_patterns in conf.py.

Documentation building on readthedocs.io#

JAX’s auto-generated documentation is at https://jax.readthedocs.io/.

The documentation building is controlled for the entire project by the readthedocs JAX settings. The current settings trigger a documentation build as soon as code is pushed to the GitHub main branch. For each code version, the building process is driven by the .readthedocs.yml and the docs/conf.py configuration files.

For each automated documentation build you can see the documentation build logs.

If you want to test the documentation generation on Readthedocs, you can push code to the test-docs branch. That branch is also built automatically, and you can see the generated documentation here. If the documentation build fails you may want to wipe the build environment for test-docs.

For a local test, I was able to do it in a fresh directory by replaying the commands I saw in the Readthedocs logs:

mkvirtualenv jax-docs  # A new virtualenv
mkdir jax-docs  # A new directory
cd jax-docs
git clone --no-single-branch --depth 50 https://github.com/google/jax
cd jax
git checkout --force origin/test-docs
git clean -d -f -f
workon jax-docs

python -m pip install --upgrade --no-cache-dir pip
python -m pip install --upgrade --no-cache-dir -I Pygments==2.3.1 setuptools==41.0.1 docutils==0.14 mock==1.0.1 pillow==5.4.1 alabaster>=0.7,<0.8,!=0.7.5 commonmark==0.8.1 recommonmark==0.5.0 'sphinx<2' 'sphinx-rtd-theme<0.5' 'readthedocs-sphinx-ext<1.1'
python -m pip install --exists-action=w --no-cache-dir -r docs/requirements.txt
cd docs
python `which sphinx-build` -T -E -b html -d _build/doctrees-readthedocs -D language=en . _build/html

Internal APIs#

core#

Jaxpr(constvars, invars, outvars, eqns[, ...])

ClosedJaxpr(jaxpr, consts)

Open inColab

Autodidax: JAX core from scratch#

Ever want to learn how JAX works, but the implementation seemed impenetrable? Well, you’re in luck! By reading this tutorial, you’ll learn every big idea in JAX’s core system. You’ll even get clued into our weird jargon!

This is a work-in-progress draft. There are some important ingredients missing, still to come in parts 5 and 6 (and more?). There are also some simplifications here that we haven’t yet applied to the main system, but we will.

Part 1: Transformations as interpreters: standard evaluation, jvp, and vmap#

We want to transform functions that look like this:

def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z

Think of functions like sin and the arithmetic operations underlying the infix operators (mul, add, and neg) as primitive operations, meaning atomic units of processing rather than compositions.

“Transform” means “interpret differently.” Instead of standard interpretation where we apply primitive operations to numerical inputs to produce numerical outputs, we want to override primitive application and let different values flow through our program. For example, we might want to replace the application of every primitive with an application of its JVP rule, and let primal-tangent pairs flow through our program. Moreover, we want to be able to compose multiple transformations, leading to stacks of interpreters.

JAX core machinery#

We can implement stacks of interpreters and even have them all discharge on the fly as we execute the Python function to be transformed. To start, let’s define these primitives so that we can intercept their application:

from typing import NamedTuple

class Primitive(NamedTuple):
  name: str

add_p = Primitive('add')
mul_p = Primitive('mul')
neg_p = Primitive("neg")
sin_p = Primitive("sin")
cos_p = Primitive("cos")
reduce_sum_p = Primitive("reduce_sum")
greater_p = Primitive("greater")
less_p = Primitive("less")
transpose_p = Primitive("transpose")
broadcast_p = Primitive("broadcast")

def add(x, y): return bind1(add_p, x, y)
def mul(x, y): return bind1(mul_p, x, y)
def neg(x): return bind1(neg_p, x)
def sin(x): return bind1(sin_p, x)
def cos(x): return bind1(cos_p, x)
def greater(x, y): return bind1(greater_p, x, y)
def less(x, y): return bind1(less_p, x, y)
def transpose(x, perm): return bind1(transpose_p, x, perm=perm)
def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes)
def reduce_sum(x, axis=None):
  if axis is None:
    axis = tuple(range(np.ndim(x)))
  if type(axis) is int:
    axis = (axis,)
  return bind1(reduce_sum_p, x, axis=axis)

def bind1(prim, *args, **params):
  out, = bind(prim, *args, **params)
  return out

We’ll set up array data types and infix operator methods in a moment.

A Primitive is just an object with a name, to which we attach our interpretation rules (one for each transformation). The bind function is our interception point: it’ll figure out which transformation rule to apply, based on how the arguments are boxed in tracers and what interpreters are active.

The functions that user code calls, like add and sin, are just wrappers around calls to bind. These wrappers let us control how arguments are passed to bind, and in particular we follow a handy internal convention: when we call bind, we pass values representing array data as positional arguments, and we pass metadata like the axis argument to sum_p via keyword. This calling convention simplifies some core logic (since e.g. instances of the Tracer class to be defined below can only occur in positional arguments to bind). The wrappers can also provide docstrings!

We represent active interpreters as a stack. The stack is just a simple list, and each element is a container with an integer level (corresponding to the element’s height in the stack), an interpreter type (which we’ll call a trace_type), and an optional field for any global data the interpreter needs. We call each element a MainTrace, though maybe “Interpreter” would be more descriptive.

from collections.abc import Sequence
from contextlib import contextmanager
from typing import Optional, Any

class MainTrace(NamedTuple):
  level: int
  trace_type: type['Trace']
  global_data: Optional[Any]

trace_stack: list[MainTrace] = []
dynamic_trace: Optional[MainTrace] = None  # to be employed in Part 3

@contextmanager
def new_main(trace_type: type['Trace'], global_data=None):
  level = len(trace_stack)
  main = MainTrace(level, trace_type, global_data)
  trace_stack.append(main)

  try:
    yield main
  finally:
    trace_stack.pop()

When we’re about to apply a transformation, we’ll push another interpreter onto the stack using new_main. Then, as we apply primitives in the function, we can think of the bind first being interpreted by the trace at the top of the stack (i.e. with the highest level). If that first interpreter itself binds other primitives in its interpretation rule for the primitive, like how the JVP rule of sin_p might bind cos_p and mul_p, then those bind calls will be handled by the interpreter at the next level down.

What goes at the bottom of the interpreter stack? At the bottom, we know all the transformation interpreters are finished, and we just want to do standard evaluation. So at the bottom we’ll put an evaluation interpreter.

Let’s sketch out the interface for interpreters, which is based on the Trace and Tracer base classes. A Tracer represents a boxed-up value, perhaps carrying some extra context data used by the interpreter. A Trace handles boxing up values into Tracers and also handles primitive application.

class Trace:
  main: MainTrace

  def __init__(self, main: MainTrace) -> None:
    self.main = main

  def pure(self, val): assert False  # must override
  def lift(self, val): assert False  # must override

  def process_primitive(self, primitive, tracers, params):
    assert False  # must override

The first two methods are about boxing up values in Tracers, which are the objects that flow through the Python programs we transform. The last method is the callback we’ll use to interpret primitive application.

The Trace itself doesn’t contain any data, other than a reference to its corresponding MainTrace instance. In fact, multiple instances of a Trace might be created and discarded during an application of a transformation, whereas only a single MainTrace instance is created per application of a transformation.

As for Tracers themselves, each one carries an abstract value (and forwards infix operators to it), and the rest is up to the transformation. (The relationship between Tracers and AbstractValues is that there’s one Tracer per transformation, and at least one AbstractValue per base type, like arrays.)

import numpy as np

class Tracer:
  _trace: Trace

  __array_priority__ = 1000

  @property
  def aval(self):
    assert False  # must override

  def full_lower(self):
    return self  # default implementation

  def __neg__(self): return self.aval._neg(self)
  def __add__(self, other): return self.aval._add(self, other)
  def __radd__(self, other): return self.aval._radd(self, other)
  def __mul__(self, other): return self.aval._mul(self, other)
  def __rmul__(self, other): return self.aval._rmul(self, other)
  def __gt__(self, other): return self.aval._gt(self, other)
  def __lt__(self, other): return self.aval._lt(self, other)
  def __bool__(self): return self.aval._bool(self)
  def __nonzero__(self): return self.aval._nonzero(self)

  def __getattr__(self, name):
    try:
      return getattr(self.aval, name)
    except AttributeError:
      raise AttributeError(f"{self.__class__.__name__} has no attribute {name}")

def swap(f): return lambda x, y: f(y, x)
class ShapedArray:
  array_abstraction_level = 1
  shape: tuple[int, ...]
  dtype: np.dtype

  def __init__(self, shape, dtype):
    self.shape = shape
    self.dtype = dtype

  @property
  def ndim(self):
    return len(self.shape)

  _neg = staticmethod(neg)
  _add = staticmethod(add)
  _radd = staticmethod(swap(add))
  _mul = staticmethod(mul)
  _rmul = staticmethod(swap(mul))
  _gt = staticmethod(greater)
  _lt = staticmethod(less)

  @staticmethod
  def _bool(tracer):
    raise Exception("ShapedArray can't be unambiguously converted to bool")

  @staticmethod
  def _nonzero(tracer):
    raise Exception("ShapedArray can't be unambiguously converted to bool")

  def str_short(self):
    return f'{self.dtype.name}[{",".join(str(d) for d in self.shape)}]'

  def __hash__(self):
    return hash((self.shape, self.dtype))

  def __eq__(self, other):
    return (type(self) is type(other) and
            self.shape == other.shape and self.dtype == other.dtype)

  def __repr__(self):
    return f"ShapedArray(shape={self.shape}, dtype={self.dtype})"

class ConcreteArray(ShapedArray):
  array_abstraction_level = 2
  val: np.ndarray

  def __init__(self, val):
    self.val = val
    self.shape = val.shape
    self.dtype = val.dtype

  @staticmethod
  def _bool(tracer):
    return bool(tracer.aval.val)

  @staticmethod
  def _nonzero(tracer):
    return bool(tracer.aval.val)

def get_aval(x):
  if isinstance(x, Tracer):
    return x.aval
  elif type(x) in jax_types:
    return ConcreteArray(np.asarray(x))
  else:
    raise TypeError(x)

jax_types = {bool, int, float,
             np.bool_, np.int32, np.int64, np.float32, np.float64, np.ndarray}

Notice that we actually have two AbstractValues for arrays, representing different levels of abstraction. A ShapedArray represents the set of all possible arrays with a given shape and dtype. A ConcreteArray represents a singleton set consisting of a single array value.

Now that we’ve set up the interpreter stack, the Trace/Tracer API for interpreters, and abstract values, we can come back to implement bind:

def bind(prim, *args, **params):
  top_trace = find_top_trace(args)
  tracers = [full_raise(top_trace, arg) for arg in args]
  outs = top_trace.process_primitive(prim, tracers, params)
  return [full_lower(out) for out in outs]

The main action is that we call find_top_trace to figure out which interpreter should handle this primitive application. We then call that top trace’s process_primitive so that the trace can apply its interpretation rule. The calls to full_raise just ensure that the inputs are boxed in the top trace’s Tracer instances, and the call to full_lower is an optional optimization so that we unbox values out of Tracers as much as possible.

import operator as op

def find_top_trace(xs) -> Trace:
  top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),
                 default=trace_stack[0], key=op.attrgetter('level'))
  if dynamic_trace and dynamic_trace.level > top_main.level:
    top_main = dynamic_trace
  return top_main.trace_type(top_main)

In words, ignoring the dynamic_trace step until Part 3, find_top_trace returns the highest-level interpreter associated with the Tracers on its inputs, and otherwise returns the interpreter at the bottom of the stack (which is always an evaluation trace, at least for now). This is a deviation from the description above, where we always start by running the interpreter at the top of the stack and then work our way down, applying every interpreter in the stack. Instead, we’re only applying an interpreter when the input arguments to a primitive bind are boxed in a Tracer corresponding to that interpreter. This optimization lets us skip irrelevant transformations, but bakes in an assumption that transformations mostly follow data dependence (except for the special bottom-of-the-stack interpreter, which interprets everything).

An alternative would be to have every interpreter in the stack interpret every operation. That’s worth exploring! JAX is designed around data dependence in large part because that’s so natural for automatic differentiation, and JAX’s roots are in autodiff. But it may be over-fit.

def full_lower(val: Any):
  if isinstance(val, Tracer):
    return val.full_lower()
  else:
    return val

def full_raise(trace: Trace, val: Any) -> Tracer:
  if not isinstance(val, Tracer):
    assert type(val) in jax_types
    return trace.pure(val)
  level = trace.main.level
  if val._trace.main is trace.main:
    return val
  elif val._trace.main.level < level:
    return trace.lift(val)
  elif val._trace.main.level > level:
    raise Exception(f"Can't lift level {val._trace.main.level} to {level}.")
  else:  # val._trace.level == level
    raise Exception(f"Different traces at same level: {val._trace}, {trace}.")

The logic in full_raise serves to box values into Tracers for a particular Trace, calling different methods on the Trace based on context: Trace.pure is called on non-Tracer constants, and Trace.lift is called for values that are already Tracers from a lower-level interpreter. These two methods could share the same implementation, but by distinguishing them in the core logic we can provide more information to the Trace subclass.

That’s it for the JAX core! Now we can start adding interpreters.

Evaluation interpreter#

We’ll start with the simplest interpreter: the evaluation interpreter that will sit at the bottom of the interpreter stack.

class EvalTrace(Trace):
  pure = lift = lambda self, x: x  # no boxing in Tracers needed

  def process_primitive(self, primitive, tracers, params):
    return impl_rules[primitive](*tracers, **params)

trace_stack.append(MainTrace(0, EvalTrace, None))  # special bottom of the stack

# NB: in JAX, instead of a dict we attach impl rules to the Primitive instance
impl_rules = {}

impl_rules[add_p] = lambda x, y: [np.add(x, y)]
impl_rules[mul_p] = lambda x, y: [np.multiply(x, y)]
impl_rules[neg_p] = lambda x: [np.negative(x)]
impl_rules[sin_p] = lambda x: [np.sin(x)]
impl_rules[cos_p] = lambda x: [np.cos(x)]
impl_rules[reduce_sum_p] = lambda x, *, axis: [np.sum(x, axis)]
impl_rules[greater_p] = lambda x, y: [np.greater(x, y)]
impl_rules[less_p] = lambda x, y: [np.less(x, y)]
impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)]

def broadcast_impl(x, *, shape, axes):
  for axis in sorted(axes):
    x = np.expand_dims(x, axis)
  return [np.broadcast_to(x, shape)]
impl_rules[broadcast_p] = broadcast_impl

With this interpreter, we can evaluate user functions:

def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z

print(f(3.0))
2.7177599838802657

Woo! Like going around in a big circle. But the point of this indirection is that now we can add some real transformations.

Forward-mode autodiff with jvp#

First, a few helper functions:

import builtins

def zeros_like(val):
  aval = get_aval(val)
  return np.zeros(aval.shape, aval.dtype)

def unzip2(pairs):
  lst1, lst2 = [], []
  for x1, x2 in pairs:
    lst1.append(x1)
    lst2.append(x2)
  return lst1, lst2

def map(f, *xs):
  return list(builtins.map(f, *xs))

def zip(*args):
  fst, *rest = args = map(list, args)
  n = len(fst)
  for arg in rest:
    assert len(arg) == n
  return list(builtins.zip(*args))

The Tracer for forward-mode autodiff carries a primal-tangent pair. The Trace applies JVP rules.

class JVPTracer(Tracer):
  def __init__(self, trace, primal, tangent):
    self._trace = trace
    self.primal = primal
    self.tangent = tangent

  @property
  def aval(self):
    return get_aval(self.primal)

class JVPTrace(Trace):
  pure = lift = lambda self, val: JVPTracer(self, val, zeros_like(val))

  def process_primitive(self, primitive, tracers, params):
    primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
    jvp_rule = jvp_rules[primitive]
    primal_outs, tangent_outs = jvp_rule(primals_in, tangents_in, **params)
    return [JVPTracer(self, x, t) for x, t in zip(primal_outs, tangent_outs)]

jvp_rules = {}

Notice both pure and lift package a value into a JVPTracer with the minimal amount of context, which is a zero tangent value.

Let’s add some JVP rules for primitives:

def add_jvp(primals, tangents):
  (x, y), (x_dot, y_dot) = primals, tangents
  return [x + y], [x_dot + y_dot]
jvp_rules[add_p] = add_jvp

def mul_jvp(primals, tangents):
  (x, y), (x_dot, y_dot) = primals, tangents
  return [x * y], [x_dot * y + x * y_dot]
jvp_rules[mul_p] = mul_jvp

def sin_jvp(primals, tangents):
  (x,), (x_dot,) = primals, tangents
  return [sin(x)], [cos(x) * x_dot]
jvp_rules[sin_p] = sin_jvp

def cos_jvp(primals, tangents):
  (x,), (x_dot,) = primals, tangents
  return [cos(x)], [-sin(x) * x_dot]
jvp_rules[cos_p] = cos_jvp

def neg_jvp(primals, tangents):
  (x,), (x_dot,) = primals, tangents
  return [neg(x)], [neg(x_dot)]
jvp_rules[neg_p] = neg_jvp

def reduce_sum_jvp(primals, tangents, *, axis):
  (x,), (x_dot,) = primals, tangents
  return [reduce_sum(x, axis)], [reduce_sum(x_dot, axis)]
jvp_rules[reduce_sum_p] = reduce_sum_jvp

def greater_jvp(primals, tangents):
  (x, y), _ = primals, tangents
  out_primal = greater(x, y)
  return [out_primal], [zeros_like(out_primal)]
jvp_rules[greater_p] = greater_jvp

def less_jvp(primals, tangents):
  (x, y), _ = primals, tangents
  out_primal = less(x, y)
  return [out_primal], [zeros_like(out_primal)]
jvp_rules[less_p] = less_jvp

Finally, we add a transformation API to kick off the trace:

def jvp_v1(f, primals, tangents):
  with new_main(JVPTrace) as main:
    trace = JVPTrace(main)
    tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]
    out = f(*tracers_in)
    tracer_out = full_raise(trace, out)
    primal_out, tangent_out = tracer_out.primal, tracer_out.tangent
  return primal_out, tangent_out

And with that, we can differentiate!

x = 3.0
y, sin_deriv_at_3 = jvp_v1(sin, (x,), (1.0,))
print(sin_deriv_at_3)
print(cos(3.0))
-0.9899924966004454
-0.9899924966004454
def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z

x, xdot = 3., 1.
y, ydot = jvp_v1(f, (x,), (xdot,))
print(y)
print(ydot)
2.7177599838802657
2.979984993200891
def deriv(f):
  return lambda x: jvp_v1(f, (x,), (1.,))[1]

print(deriv(sin)(3.))
print(deriv(deriv(sin))(3.))
print(deriv(deriv(deriv(sin)))(3.))
print(deriv(deriv(deriv(deriv(sin))))(3.))
-0.9899924966004454
-0.1411200080598672
0.9899924966004454
0.1411200080598672
def f(x):
  if x > 0.:  # Python control flow
    return 2. * x
  else:
    return x

print(deriv(f)(3.))
print(deriv(f)(-3.))
2.0
1.0

Pytrees and flattening user functions’ inputs and outputs#

A limitation with jvp_v1 is that it assumes the user function accepts arrays as positional arguments and produces a single array as output. What if it produced a list as output? Or accepted nested containers as inputs? It would be a pain to deal with all the possible containers in inputs and outputs at every layer of the stack. Instead, we can wrap the user function so that the wrapped version accepts arrays as inputs and returns a flat list of arrays as output. The wrapper just needs to unflatten its input, call the user function, and flatten the output.

Here’s how we’d like to write jvp, assuming the user always gives us functions that take arrays as inputs and produces a flat list of arrays as outputs:

def jvp_flat(f, primals, tangents):
  with new_main(JVPTrace) as main:
    trace = JVPTrace(main)
    tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]
    outs = f(*tracers_in)
    tracers_out = [full_raise(trace, out) for out in outs]
    primals_out, tangents_out = unzip2((t.primal, t.tangent) for t in tracers_out)
  return primals_out, tangents_out

To support user functions that have arbitrary containers in the inputs and outputs, here’s how we’d write the user-facing jvp wrapper:

def jvp(f, primals, tangents):
  primals_flat, in_tree = tree_flatten(primals)
  tangents_flat, in_tree2 = tree_flatten(tangents)
  if in_tree != in_tree2: raise TypeError
  f, out_tree = flatten_fun(f, in_tree)
  primals_out_flat, tangents_out_flat = jvp_flat(f, primals_flat, tangents_flat)
  primals_out = tree_unflatten(out_tree(), primals_out_flat)
  tangents_out = tree_unflatten(out_tree(), tangents_out_flat)
  return primals_out, tangents_out

Notice that we had to plumb the tree structure of the user function output back to the caller of flatten_fun. That information isn’t available until we actually run the user function, so flatten_fun just returns a reference to a mutable cell, represented as a thunk. These side-effects are safe because we always run the user function exactly once. (This safe regime is the reason for the “linear” name in linear_util.py, in the sense of linear types.)

All that remains is to write tree_flatten, tree_unflatten, and flatten_fun.

Hide code cell source
def flatten_fun(f, in_tree):
  store = Store()

  def flat_fun(*args_flat):
    pytree_args = tree_unflatten(in_tree, args_flat)
    out = f(*pytree_args)
    out_flat, out_tree = tree_flatten(out)
    store.set_value(out_tree)
    return out_flat

  return flat_fun, store

class Empty: pass
empty = Empty()

class Store:
  val = empty

  def set_value(self, val):
    assert self.val is empty
    self.val = val

  def __call__(self):
    return self.val
Hide code cell source
from collections.abc import Hashable, Iterable, Iterator
import itertools as it
from typing import Callable

class NodeType(NamedTuple):
  name: str
  to_iterable: Callable
  from_iterable: Callable

def register_pytree_node(ty: type, to_iter: Callable, from_iter: Callable
                         ) -> None:
  node_types[ty] = NodeType(str(ty), to_iter, from_iter)

node_types: dict[type, NodeType] = {}
register_pytree_node(tuple, lambda t: (None, t), lambda _, xs: tuple(xs))
register_pytree_node(list,  lambda l: (None, l), lambda _, xs:  list(xs))
register_pytree_node(dict,
                     lambda d: map(tuple, unzip2(sorted(d.items()))),
                     lambda keys, vals: dict(zip(keys, vals)))

class PyTreeDef(NamedTuple):
  node_type: NodeType
  node_metadata: Hashable
  child_treedefs: tuple['PyTreeDef', ...]

class Leaf: pass
leaf = Leaf()

def tree_flatten(x: Any) -> tuple[list[Any], PyTreeDef]:
  children_iter, treedef = _tree_flatten(x)
  return list(children_iter), treedef

def _tree_flatten(x: Any) -> tuple[Iterable, PyTreeDef]:
  node_type = node_types.get(type(x))
  if node_type:
    node_metadata, children = node_type.to_iterable(x)
    children_flat, child_trees = unzip2(map(_tree_flatten, children))
    flattened = it.chain.from_iterable(children_flat)
    return flattened, PyTreeDef(node_type, node_metadata, tuple(child_trees))
  else:
    return [x], leaf

def tree_unflatten(treedef: PyTreeDef, xs: list[Any]) -> Any:
  return _tree_unflatten(treedef, iter(xs))

def _tree_unflatten(treedef: PyTreeDef, xs: Iterator) -> Any:
  if treedef is leaf:
    return next(xs)
  else:
    children = (_tree_unflatten(t, xs) for t in treedef.child_treedefs)
    return treedef.node_type.from_iterable(treedef.node_metadata, children)

With this pytree-handling jvp implementation, we can now handle arbitrary input and output containers. That’ll come in handy with future transformations too!

def f(x):
  y = sin(x) * 2.
  z = - y + x
  return {'hi': z, 'there': [x, y]}

x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
{'hi': 2.7177599838802657, 'there': [3.0, 0.2822400161197344]}
{'hi': 2.979984993200891, 'there': [1.0, -1.9799849932008908]}
Vectorized batching with vmap#

First, a couple helper functions, one for producing mapped abstract values from unmapped ones (by removing an axis), and one for moving batch dimensions around:

def mapped_aval(batch_dim, aval):
  shape = list(aval.shape)
  del shape[batch_dim]
  return ShapedArray(tuple(shape), aval.dtype)

def move_batch_axis(axis_size, src, dst, x):
  if src is not_mapped:
    target_shape = list(np.shape(x))
    target_shape.insert(dst, axis_size)
    return broadcast(x, target_shape, [dst])
  elif src == dst:
    return x
  else:
    return moveaxis(x, src, dst)

def moveaxis(x, src: int, dst: int):
  perm = [i for i in range(np.ndim(x)) if i != src]
  perm.insert(dst, src)
  return transpose(x, perm)

The Tracer for vectorized batching carries a batched value and an optional integer indicating which axis (if any) is the batch axis.

from typing import Union

class NotMapped: pass
not_mapped = NotMapped()

BatchAxis = Union[NotMapped, int]

class BatchTracer(Tracer):
  def __init__(self, trace, val, batch_dim: BatchAxis):
    self._trace = trace
    self.val = val
    self.batch_dim = batch_dim

  @property
  def aval(self):
    if self.batch_dim is not_mapped:
      return get_aval(self.val)
    else:
      return mapped_aval(self.batch_dim, get_aval(self.val))

  def full_lower(self):
    if self.batch_dim is not_mapped:
      return full_lower(self.val)
    else:
      return self

class BatchTrace(Trace):
  pure = lift = lambda self, val: BatchTracer(self, val, not_mapped)

  def process_primitive(self, primitive, tracers, params):
    vals_in, bdims_in = unzip2((t.val, t.batch_dim) for t in tracers)
    vmap_rule = vmap_rules[primitive]
    val_outs, bdim_outs = vmap_rule(self.axis_size, vals_in, bdims_in, **params)
    return [BatchTracer(self, x, bd) for x, bd in zip(val_outs, bdim_outs)]

  @property
  def axis_size(self):
    return self.main.global_data

vmap_rules = {}

Here we’ve implemented the optional Tracer.full_lower method, which lets us peel off a batching tracer if it’s not needed because it doesn’t represent a batched value.

For BatchTrace, analogous to JVPTrace, the methods pure and lift just box a value in a BatchTracer with the minimal amount of context, which in this case is a batch_dim taking the sentinel value not_mapped. Notice we use the MainTrace’s interpreter-global data field to store the batch axis size.

Next we can define batching interpreter rules for each primitive:

from functools import partial

def binop_batching_rule(op, axis_size, vals_in, dims_in):
  (x, y), (x_bdim, y_bdim) = vals_in, dims_in
  if x_bdim != y_bdim:
    if x_bdim is not_mapped:
      x = move_batch_axis(axis_size, x_bdim, y_bdim, x)
      x_bdim = y_bdim
    else:
      y = move_batch_axis(axis_size, y_bdim, x_bdim, y)
  return [op(x, y)], [x_bdim]
vmap_rules[add_p] = partial(binop_batching_rule, add)
vmap_rules[mul_p] = partial(binop_batching_rule, mul)

def vectorized_unop_batching_rule(op, axis_size, vals_in, dims_in):
  (x,), (x_bdim,) = vals_in, dims_in
  return [op(x)], [x_bdim]
vmap_rules[sin_p] = partial(vectorized_unop_batching_rule, sin)
vmap_rules[cos_p] = partial(vectorized_unop_batching_rule, cos)
vmap_rules[neg_p] = partial(vectorized_unop_batching_rule, neg)

def reduce_sum_batching_rule(axis_size, vals_in, dims_in, *, axis):
  (x,), (x_bdim,) = vals_in, dims_in
  new_axis = tuple(ax + (x_bdim <= ax) for ax in axis)
  out_bdim = x_bdim - sum(ax < x_bdim for ax in axis)
  return [reduce_sum(x, new_axis)], [out_bdim]
vmap_rules[reduce_sum_p] = reduce_sum_batching_rule

Finally, we add a transformation API to kick off the trace:

def vmap_flat(f, in_axes, *args):
  axis_size, = {x.shape[ax] for x, ax in zip(args, in_axes)
                if ax is not not_mapped}
  with new_main(BatchTrace, axis_size) as main:
    trace = BatchTrace(main)
    tracers_in = [BatchTracer(trace, x, ax) if ax is not None else x
                  for x, ax in zip(args, in_axes)]
    outs = f(*tracers_in)
    tracers_out = [full_raise(trace, out) for out in outs]
    vals_out, bdims_out = unzip2((t.val, t.batch_dim) for t in tracers_out)
  outs_transposed = [move_batch_axis(axis_size, bdim, 0, val_out)
                     for val_out, bdim in zip(vals_out, bdims_out)]
  return outs_transposed

def vmap(f, in_axes):
  def batched_f(*args):
    args_flat, in_tree = tree_flatten(args)
    in_axes_flat, in_tree2 = tree_flatten(in_axes)
    if in_tree != in_tree2: raise TypeError
    f_flat, out_tree = flatten_fun(f, in_tree)
    outs_flat = vmap_flat(f_flat, in_axes_flat, *args_flat)
    return tree_unflatten(out_tree(), outs_flat)
  return batched_f
def add_one_to_a_scalar(scalar):
  assert np.ndim(scalar) == 0
  return 1 + scalar

vector_in = np.arange(3.)
vector_out = vmap(add_one_to_a_scalar, (0,))(vector_in)

print(vector_in)
print(vector_out)
[0. 1. 2.]
[1. 2. 3.]
def jacfwd(f, x):
  pushfwd = lambda v: jvp(f, (x,), (v,))[1]
  vecs_in = np.eye(np.size(x)).reshape(np.shape(x) * 2)
  return vmap(pushfwd, (0,))(vecs_in)

def f(x):
  return sin(x)

jacfwd(f, np.arange(3.))
array([[ 1.        ,  0.        , -0.        ],
       [ 0.        ,  0.54030231, -0.        ],
       [ 0.        ,  0.        , -0.41614684]])

That’s it for jvp and vmap!

Part 2: Jaxprs#

The next transformations on the horizon are jit for just-in-time compilation and vjp for reverse-mode autodiff. (grad is just a small wrapper around vjp.) Whereas jvp and vmap only needed each Tracer to carry a little bit of extra context, for both jit and vjp we need much richer context: we need to represent programs. That is, we need jaxprs!

Jaxprs are JAX’s internal intermediate representation of programs. They are explicitly typed, functional, first-order, and in ANF form. We need a program representation for jit because the purpose of jit is to stage computation out of Python. For any computation we want to stage out, we need to be able to represent it as data, and build it up as we trace a Python function. Similarly, vjp needs a way to represent the computation for the backward pass of reverse-mode autodiff. We use the same jaxpr program representation for both needs.

(Building a program representation is the most free kind of trace-transformation, and so except for issues around handling native Python control flow, any transformation could be implemented by first tracing to a jaxpr and then interpreting the jaxpr.)

Jaxpr data structures#

The jaxpr term syntax is roughly:

jaxpr ::=
  { lambda <binder> , ... .
    let <eqn>
        ...
    in ( <atom> , ... ) }

binder ::= <var>:<array_type>
var ::= a | b | c | ...
atom ::= <var> | <literal>
literal ::= <int32> | <int64> | <float32> | <float64>

eqn ::= <binder> , ... = <primitive> [ <params> ] <atom> , ...

The syntax of types is:

jaxpr_type ::= [ <array_type> , ... ] -> [ <array_type> , ... ]
array_type ::= <dtype>[<shape>]
dtype ::= f32 | f64 | i32 | i64
shape ::= <int> , ...

How do we represent these as Python data structures? We reuse ShapedArrays to represent types, and we can represent the term syntax with a few Python structs:

class Var:
  aval: ShapedArray
  def __init__(self, aval): self.aval = aval

class Lit:
  val: Any
  aval: ShapedArray

  def __init__(self, val):
    self.aval = aval = raise_to_shaped(get_aval(val))
    self.val = np.array(val, aval.dtype)

Atom = Union[Var, Lit]

class JaxprEqn(NamedTuple):
  primitive: Primitive
  inputs: list[Atom]
  params: dict[str, Any]
  out_binders: list[Var]

class Jaxpr(NamedTuple):
  in_binders: list[Var]
  eqns: list[JaxprEqn]
  outs: list[Atom]

  def __hash__(self): return id(self)
  __eq__ = op.is_

def raise_to_shaped(aval):
  return ShapedArray(aval.shape, aval.dtype)

Type-checking a jaxpr involves checking that there are no unbound variables, that variables are only bound once, and that for each equation the type of the primitive application matches the type of the output binders.

class JaxprType(NamedTuple):
  in_types:  list[ShapedArray]
  out_types: list[ShapedArray]

  def __repr__(self):
    in_types = ', '.join(aval.str_short() for aval in self.in_types)
    out_types = ', '.join(aval.str_short() for aval in self.out_types)
    return f'({in_types}) -> ({out_types})'

def typecheck_jaxpr(jaxpr: Jaxpr) -> JaxprType:
  env: set[Var] = set()

  for v in jaxpr.in_binders:
    if v in env: raise TypeError
    env.add(v)

  for eqn in jaxpr.eqns:
    in_types = [typecheck_atom(env, x) for x in eqn.inputs]
    out_types = abstract_eval_rules[eqn.primitive](*in_types, **eqn.params)
    for out_binder, out_type in zip(eqn.out_binders, out_types):
      if not out_type == out_binder.aval: raise TypeError
    for out_binder in eqn.out_binders:
      if out_binder in env: raise TypeError
      env.add(out_binder)

  in_types = [v.aval for v in jaxpr.in_binders]
  out_types = [typecheck_atom(env, x) for x in jaxpr.outs]
  return JaxprType(in_types, out_types)

def typecheck_atom(env: set[Var], x: Atom) -> ShapedArray:
  if isinstance(x, Var):
    if x not in env: raise TypeError("unbound variable")
    return x.aval
  elif isinstance(x, Lit):
    return raise_to_shaped(get_aval(x.val))
  else:
    assert False

We can apply the function represented by a jaxpr to arguments with a simple interpreter.

def eval_jaxpr(jaxpr: Jaxpr, args: list[Any]) -> list[Any]:
  env: dict[Var, Any] = {}

  def read(x: Atom) -> Any:
    return env[x] if type(x) is Var else x.val

  def write(v: Var, val: Any) -> None:
    assert v not in env  # single-assignment
    env[v] = val

  map(write, jaxpr.in_binders, args)
  for eqn in jaxpr.eqns:
    in_vals = map(read, eqn.inputs)
    outs = bind(eqn.primitive, *in_vals, **eqn.params)
    map(write, eqn.out_binders, outs)
  return map(read, jaxpr.outs)

def jaxpr_as_fun(jaxpr: Jaxpr):
  return lambda *args: eval_jaxpr(jaxpr, args)

By using bind in the interpreter, this interpreter itself is traceable.

Building jaxprs with tracing#

Now that we have jaxprs as a data structure, we need ways to produce these from tracing Python code. In general there are two variants of how we trace to a jaxpr; jit uses one and vjp uses the other. We’ll start with the one used by jit, which is also used by control flow primitives like lax.cond, lax.while_loop, and lax.scan.

def split_list(lst: list[Any], n: int) -> tuple[list[Any], list[Any]]:
  assert 0 <= n <= len(lst)
  return lst[:n], lst[n:]

def partition_list(bs: list[bool], l: list[Any]) -> tuple[list[Any], list[Any]]:
  assert len(bs) == len(l)
  lists = lst1, lst2 = [], []
  for b, x in zip(bs, l):
    lists[b].append(x)
  return lst1, lst2
# NB: the analogous class in JAX is called 'DynamicJaxprTracer'
class JaxprTracer(Tracer):
  __slots__ = ['aval']
  aval: ShapedArray

  def __init__(self, trace, aval):
    self._trace = trace
    self.aval = aval

# NB: the analogous class in JAX is called 'DynamicJaxprTrace'
class JaxprTrace(Trace):
  def new_arg(self, aval: ShapedArray) -> JaxprTracer:
    aval = raise_to_shaped(aval)
    tracer = self.builder.new_tracer(self, aval)
    self.builder.tracer_to_var[id(tracer)] = Var(aval)
    return tracer

  def get_or_make_const_tracer(self, val: Any) -> JaxprTracer:
    tracer = self.builder.const_tracers.get(id(val))
    if tracer is None:
      tracer = self.builder.new_tracer(self, raise_to_shaped(get_aval(val)))
      self.builder.add_const(tracer, val)
    return tracer
  pure = lift = get_or_make_const_tracer

  def process_primitive(self, primitive, tracers, params):
    avals_in = [t.aval for t in tracers]
    avals_out = abstract_eval_rules[primitive](*avals_in, **params)
    out_tracers = [self.builder.new_tracer(self, a) for a in avals_out]
    inputs = [self.builder.getvar(t) for t in tracers]
    outvars = [self.builder.add_var(t) for t in out_tracers]
    self.builder.add_eqn(JaxprEqn(primitive, inputs, params, outvars))
    return out_tracers

  @property
  def builder(self):
    return self.main.global_data

# NB: in JAX, we instead attach abstract eval rules to Primitive instances
abstract_eval_rules = {}

Notice that we keep as interpreter-global data a builder object, which keeps track of variables, constants, and eqns as we build up the jaxpr.

class JaxprBuilder:
  eqns: list[JaxprEqn]
  tracer_to_var: dict[int, Var]
  const_tracers: dict[int, JaxprTracer]
  constvals: dict[Var, Any]
  tracers: list[JaxprTracer]

  def __init__(self):
    self.eqns = []
    self.tracer_to_var = {}
    self.const_tracers = {}
    self.constvals = {}
    self.tracers = []

  def new_tracer(self, trace: JaxprTrace, aval: ShapedArray) -> JaxprTracer:
    tracer = JaxprTracer(trace, aval)
    self.tracers.append(tracer)
    return tracer

  def add_eqn(self, eqn: JaxprEqn) -> None:
    self.eqns.append(eqn)

  def add_var(self, tracer: JaxprTracer) -> Var:
    assert id(tracer) not in self.tracer_to_var
    var = self.tracer_to_var[id(tracer)] = Var(tracer.aval)
    return var

  def getvar(self, tracer: JaxprTracer) -> Var:
    var = self.tracer_to_var.get(id(tracer))
    assert var is not None
    return var

  def add_const(self, tracer: JaxprTracer, val: Any) -> Var:
    var = self.add_var(tracer)
    self.const_tracers[id(val)] = tracer
    self.constvals[var] = val
    return var

  def build(self, in_tracers: list[JaxprTracer], out_tracers: list[JaxprTracer]
            ) -> tuple[Jaxpr, list[Any]]:
    constvars, constvals = unzip2(self.constvals.items())
    t2v = lambda t: self.tracer_to_var[id(t)]
    in_binders = constvars + [t2v(t) for t in in_tracers]
    out_vars = [t2v(t) for t in out_tracers]
    jaxpr = Jaxpr(in_binders, self.eqns, out_vars)
    typecheck_jaxpr(jaxpr)
    jaxpr, constvals = _inline_literals(jaxpr, constvals)
    return jaxpr, constvals
def _inline_literals(jaxpr: Jaxpr, consts: list[Any]) -> tuple[Jaxpr, list[Any]]:
  const_binders, other_binders = split_list(jaxpr.in_binders, len(consts))
  scalars = [type(x) in jax_types and not get_aval(x).shape for x in consts]
  new_const_binders, lit_binders = partition_list(scalars, const_binders)
  new_consts, lit_vals = partition_list(scalars, consts)
  literals = dict(zip(lit_binders, map(Lit, lit_vals)))
  new_eqns = [JaxprEqn(eqn.primitive, [literals.get(x, x) for x in eqn.inputs],
                       eqn.params, eqn.out_binders) for eqn in jaxpr.eqns]
  new_outs = [literals.get(x, x) for x in jaxpr.outs]
  new_jaxpr = Jaxpr(new_const_binders + other_binders, new_eqns, new_outs)
  typecheck_jaxpr(new_jaxpr)
  return new_jaxpr, new_consts

The rules we need for JaxprTrace.process_primitive are essentially typing rules for primitive applications: given the primitive, its parameters, and types for the inputs, the rule must produce a type for the output, which is then packaged with the output JaxprTracer. We can use abstract evaluation rules for this same purpose, even though they can be more general (since abstract evaluation rules must accept ConcreteArray inputs, and since they need only return an upper bound on the set of possible outputs, they can produce ConcreteArray outputs as well). We’ll reuse these abstract evaluation rules for the other jaxpr-producing trace machinery, where the potential extra generality is useful.

def binop_abstract_eval(x: ShapedArray, y: ShapedArray) -> list[ShapedArray]:
  if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
    raise TypeError
  if raise_to_shaped(x) != raise_to_shaped(y): raise TypeError
  return [ShapedArray(x.shape, x.dtype)]

abstract_eval_rules[add_p] = binop_abstract_eval
abstract_eval_rules[mul_p] = binop_abstract_eval

def compare_abstract_eval(x: ShapedArray, y: ShapedArray) -> list[ShapedArray]:
  if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
    raise TypeError
  if x.shape != y.shape: raise TypeError
  return [ShapedArray(x.shape, np.dtype('bool'))]
abstract_eval_rules[greater_p] = compare_abstract_eval
abstract_eval_rules[less_p] = compare_abstract_eval

def vectorized_unop_abstract_eval(x: ShapedArray) -> list[ShapedArray]:
  return [ShapedArray(x.shape, x.dtype)]

abstract_eval_rules[sin_p] = vectorized_unop_abstract_eval
abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval
abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval

def reduce_sum_abstract_eval(x: ShapedArray, *, axis: tuple[int, ...]
                             ) -> list[ShapedArray]:
  axis_ = set(axis)
  new_shape = [d for i, d in enumerate(x.shape) if i not in axis_]
  return [ShapedArray(tuple(new_shape), x.dtype)]
abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval

def broadcast_abstract_eval(x: ShapedArray, *, shape: Sequence[int],
                            axes: Sequence[int]) -> list[ShapedArray]:
  return [ShapedArray(tuple(shape), x.dtype)]
abstract_eval_rules[broadcast_p] = broadcast_abstract_eval

To check our implementation of jaxprs, we can add a make_jaxpr transformation and a pretty-printer:

from functools import lru_cache

@lru_cache()  # ShapedArrays are hashable
def make_jaxpr_v1(f, *avals_in):
  avals_in, in_tree = tree_flatten(avals_in)
  f, out_tree = flatten_fun(f, in_tree)

  builder = JaxprBuilder()
  with new_main(JaxprTrace, builder) as main:
    trace = JaxprTrace(main)
    tracers_in = [trace.new_arg(aval) for aval in avals_in]
    outs = f(*tracers_in)
    tracers_out = [full_raise(trace, out) for out in outs]
    jaxpr, consts = builder.build(tracers_in, tracers_out)
  return jaxpr, consts, out_tree()
Hide code cell source
from collections import defaultdict
import string

class PPrint:
  lines: list[tuple[int, str]]

  def __init__(self, lines):
    self.lines = lines

  def indent(self, indent: int) -> 'PPrint':
    return PPrint([(indent + orig_indent, s) for orig_indent, s in self.lines])

  def __add__(self, rhs: 'PPrint') -> 'PPrint':
    return PPrint(self.lines + rhs.lines)

  def __rshift__(self, rhs: 'PPrint') -> 'PPrint':
    if not rhs.lines: return self
    if not self.lines: return rhs
    indent, s = self.lines[-1]
    indented_block = rhs.indent(indent + len(s))
    common_line = s + ' ' * rhs.lines[0][0] + rhs.lines[0][1]
    return PPrint(self.lines[:-1]
                  + [(indent, common_line)]
                  + indented_block.lines[1:])

  def __str__(self) -> str:
    return '\n'.join(' ' * indent + s for indent, s in self.lines)

def pp(s: Any) -> PPrint:
  return PPrint([(0, line) for line in str(s).splitlines()])

def vcat(ps: list[PPrint]) -> PPrint:
  return sum(ps, pp(''))

def pp_jaxpr(jaxpr: Jaxpr) -> PPrint:
  namegen = (''.join(s) for r in it.count(1)
             for s in it.permutations(string.ascii_lowercase, r))
  names = defaultdict(lambda: next(namegen))
  in_binders = ', '.join(var_str(names, x) for x in jaxpr.in_binders)
  eqns = vcat([pp_eqn(names, e) for e in jaxpr.eqns])
  outs = ', '.join(names[v] if isinstance(v, Var) else str(v.val)
                   for v in jaxpr.outs)
  return (pp(f'{{ lambda {in_binders} .') +
          ((pp('let ') >> eqns) + pp(f'in ( {outs} ) }}')).indent(2))

def var_str(names: defaultdict[Var, str], v: Var) -> str:
  return f'{names[v]}:{v.aval.str_short()}'

def pp_eqn(names: defaultdict[Var, str], eqn: JaxprEqn) -> PPrint:
  rule = pp_rules.get(eqn.primitive)
  if rule:
    return rule(names, eqn)
  else:
    lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
    rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >>
           pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
                       for x in eqn.inputs)))
    return lhs >> pp(' = ') >> rhs

def pp_params(params: dict[str, Any]) -> PPrint:
  items = sorted(params.items())
  if items:
    return pp(' [ ') >> vcat([pp(f'{k}={v}') for k, v in items]) >> pp(' ] ')
  else:
    return pp(' ')

Jaxpr.__repr__ = lambda self: str(pp_jaxpr(self))
pp_rules: dict[Primitive, Callable[..., PPrint]] = {}
jaxpr, consts, _ = make_jaxpr_v1(lambda x: 2. * x, raise_to_shaped(get_aval(3.)))
print(jaxpr)
print(typecheck_jaxpr(jaxpr))
{ lambda a:float64[] .
  let b:float64[] = mul 2.0 a
  in ( b ) }
(float64[]) -> (float64[])

But there’s a limitation here: because of how find_top_trace operates by data dependence, make_jaxpr_v1 can’t stage out all the primitive operations performed by the Python callable it’s given. For example:

jaxpr, consts, _ = make_jaxpr_v1(lambda: mul(2., 2.))
print(jaxpr)
{ lambda  .
  let 
  in ( 4.0 ) }

This is precisely the issue that omnistaging fixed. We want to ensure that the JaxprTrace started by make_jaxpr is always applied, regardless of whether any inputs to bind are boxed in corresponding JaxprTracer instances. We can achieve this by employing the dynamic_trace global defined in Part 1:

@contextmanager
def new_dynamic(main: MainTrace):
  global dynamic_trace
  prev_dynamic_trace, dynamic_trace = dynamic_trace, main
  try:
    yield
  finally:
    dynamic_trace = prev_dynamic_trace

@lru_cache()
def make_jaxpr(f: Callable, *avals_in: ShapedArray,
               ) -> tuple[Jaxpr, list[Any], PyTreeDef]:
  avals_in, in_tree = tree_flatten(avals_in)
  f, out_tree = flatten_fun(f, in_tree)

  builder = JaxprBuilder()
  with new_main(JaxprTrace, builder) as main:
    with new_dynamic(main):
      trace = JaxprTrace(main)
      tracers_in = [trace.new_arg(aval) for aval in avals_in]
      outs = f(*tracers_in)
      tracers_out = [full_raise(trace, out) for out in outs]
      jaxpr, consts = builder.build(tracers_in, tracers_out)
  return jaxpr, consts, out_tree()

jaxpr, consts, _ = make_jaxpr(lambda: mul(2., 2.))
print(jaxpr)
{ lambda  .
  let a:float64[] = mul 2.0 2.0
  in ( a ) }

Using dynamic_trace this way is conceptually the same as stashing the current interpreter stack and starting a new one with the JaxprTrace at the bottom. That is, no interpreters lower in the stack than the dynamic_trace are applied (since JaxprTrace.process_primitive doesn’t call bind), though if the Python callable being traced to a jaxpr itself uses transformations then those can be pushed onto the interpreter stack above the JaxprTrace. But temporarily stashing the interpreter stack would break up the system state. The dynamic_trace tag achieves the same goals while keeping the system state simpler.

That’s it for jaxprs! With jaxprs in hand, we can implement the remaining major JAX features.

Part 3: jit, simplified#

While jit has a transformation-like API in that it accepts a Python callable as an argument, under the hood it’s really a higher-order primitive rather than a transformation. A primitive is higher-order when it’s parameterized by a function.

On-the-fly (“final style”) and staged (“initial style”) processing#

There are two options for how to handle higher-order primitives. Each requires a different approach to tracing and engenders different tradeoffs:

  1. On-the-fly processing, where bind takes a Python callable as an argument. We defer forming a jaxpr until as late as possible, namely until we’re running the final interpreter at the bottom of the interpreter stack. That way we can swap a JaxprTrace in at the bottom of the interpreter stack and thus stage out rather than execute all primitive operations. With this approach, transformations in the stack get applied as we execute the Python callable as usual. This approach can be very tricky to implement, but it’s as general as possible because it allows higher-order primitives not to raise the abstraction level of their arguments and thus allows data-dependent Python control flow. We refer to this approach as using a “final-style higher-order primitive” employing the discharge-at-tracing-time “final-style transformations” we’ve used so far.

  2. Staged processing, where bind takes a jaxpr as an argument. Before we call bind, in the primitive wrapper we can just use make_jaxpr to form a jaxpr up-front and be done with the Python callable entirely. In this case, make_jaxpr puts its JaxprTrace at the top of the interpreter stack, and no transformations lower in the stack, which might enter via closed-over Tracers, are applied to the Python callable as we trace it. (Transformations applied within the Python callable are applied as usual, being added to the stack above the JaxprTrace.) Instead, the transformations lower in the stack are later applied to the call primitive, and the call primitive’s rules must then transform the jaxpr itself. Because we trace to a jaxpr up-front, this approach can’t support data-dependent Python control flow, but it is more straightforward to implement. We refer to this kind of higher-order primitive as an “initial-style higher-order primitive”, and say that its jaxpr-processing transformation rules are “initial-style transformation rules.”

The latter approach fits for jit because we don’t need to support data-dependent Python control flow in the user-provided Python callable, as the whole purpose of jit is to stage computation out of Python to be executed by XLA. (In contrast, custom_jvp is a higher-order primitive in which we want to support data-dependent Python control flow.)

Historically, we started using the “initial-style” and “final-style” terminology after reading the typed tagless final interpreters paper, and jokingly referring to JAX as an implementation of “untyped tagful final interpreters.” We don’t claim to carry over (or understand) any deep meaning behind these terms; we loosely use “initial style” to mean “build an AST and then transform it”, and we use “final style” to mean “transform as we trace.” But it’s just imprecise yet sticky jargon.

With the initial-style approach, here’s the user-facing jit wrapper:

def jit(f):
  def f_jitted(*args):
    avals_in = [raise_to_shaped(get_aval(x)) for x in args]
    jaxpr, consts, out_tree = make_jaxpr(f, *avals_in)
    outs = bind(xla_call_p, *consts, *args, jaxpr=jaxpr, num_consts=len(consts))
    return tree_unflatten(out_tree, outs)
  return f_jitted

xla_call_p = Primitive('xla_call')

With any new primitive, we need to give it transformation rules, starting with its evaluation rule. When we evaluate an application of the xla_call primitive, we want to stage out the computation to XLA. That involves translating the jaxpr to an XLA HLO program, transferring the argument values to the XLA device, executing the XLA program, and transferring back the results. We’ll cache the XLA HLO compilation so that for each jitted function it only needs to be performed once per argument shape and dtype signature.

First, some utilities.

class IDHashable:
  val: Any

  def __init__(self, val):
    self.val = val

  def __hash__(self) -> int:
    return id(self.val)

  def __eq__(self, other):
    return type(other) is IDHashable and id(self.val) == id(other.val)

Next, we’ll define the evaluation rule for xla_call:

from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
xe = xc._xla
xops = xc._xla.ops

def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int):
  consts, args = args[:num_consts], args[num_consts:]
  hashable_consts = tuple(map(IDHashable, consts))
  execute = xla_callable(IDHashable(jaxpr), hashable_consts)
  return execute(*args)
impl_rules[xla_call_p] = xla_call_impl

@lru_cache()
def xla_callable(hashable_jaxpr: IDHashable,
                 hashable_consts: tuple[IDHashable, ...]):
  jaxpr: Jaxpr = hashable_jaxpr.val
  typecheck_jaxpr(jaxpr)
  consts = [x.val for x in hashable_consts]
  in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]]
  c = xc.XlaBuilder('xla_call')
  xla_consts = _xla_consts(c, consts)
  xla_params = _xla_params(c, in_avals)
  outs = jaxpr_subcomp(c, jaxpr, xla_consts + xla_params)
  out = xops.Tuple(c, outs)
  compiled = xb.get_backend(None).compile(
    xc._xla.mlir.xla_computation_to_mlir_module(c.build(out)))
  return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs])

def _xla_consts(c: xe.XlaBuilder, consts: list[Any]) -> list[xe.XlaOp]:
  unique_consts = {id(cnst): cnst for cnst in consts}
  xla_consts = {
      id_: xops.ConstantLiteral(c, cnst) for id_, cnst in unique_consts.items()}
  return [xla_consts[id(cnst)] for cnst in consts]

def _xla_params(c: xe.XlaBuilder, avals_in: list[ShapedArray]) -> list[xe.XlaOp]:
  return [xops.Parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)]

def _xla_shape(aval: ShapedArray) -> xe.Shape:
  return xc.Shape.array_shape(xc.dtype_to_etype(aval.dtype), aval.shape)

The main action is in xla_callable, which compiles a jaxpr into an XLA HLO program using jaxpr_subcomp, then returns a callable which executes the compiled program:

def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: list[xe.XlaOp]
                  ) -> list[xe.XlaOp]:
  env: dict[Var, xe.XlaOp] = {}

  def read(x: Atom) -> xe.XlaOp:
    return env[x] if type(x) is Var else xops.Constant(c, np.asarray(x.val))

  def write(v: Var, val: xe.XlaOp) -> None:
    env[v] = val

  map(write, jaxpr.in_binders, args)
  for eqn in jaxpr.eqns:
    in_avals = [x.aval for x in eqn.inputs]
    in_vals = map(read, eqn.inputs)
    rule = xla_translations[eqn.primitive]
    out_vals = rule(c, in_avals, in_vals, **eqn.params)
    map(write, eqn.out_binders, out_vals)
  return map(read, jaxpr.outs)

def execute_compiled(compiled, out_avals, *args):
  input_bufs = [input_handlers[type(x)](x) for x in args]
  out_bufs = compiled.execute(input_bufs)
  return [handle_result(aval, buf) for aval, buf in zip(out_avals, out_bufs)]

default_input_handler = xb.get_backend(None).buffer_from_pyval
input_handlers = {ty: default_input_handler for ty in
                  [bool, int, float, np.ndarray, np.float64, np.float32]}

def handle_result(aval: ShapedArray, buf):
  del aval  # Unused for now
  return np.asarray(buf)

xla_translations = {}

Notice that jaxpr_subcomp has the structure of a simple interpreter. That’s a common pattern: the way we process jaxprs is usually with an interpreter. And as with any interpreter, we need an interpretation rule for each primitive:

def direct_translation(op, c, in_avals, in_vals):
  del c, in_avals
  return [op(*in_vals)]

xla_translations[add_p] = partial(direct_translation, xops.Add)
xla_translations[mul_p] = partial(direct_translation, xops.Mul)
xla_translations[neg_p] = partial(direct_translation, xops.Neg)
xla_translations[sin_p] = partial(direct_translation, xops.Sin)
xla_translations[cos_p] = partial(direct_translation, xops.Cos)
xla_translations[greater_p] = partial(direct_translation, xops.Gt)
xla_translations[less_p] = partial(direct_translation, xops.Lt)

def reduce_sum_translation(c, in_avals, in_vals, *, axis):
  (x_aval,), (x,) = in_avals, in_vals
  zero = xops.ConstantLiteral(c, np.array(0, x_aval.dtype))
  subc = xc.XlaBuilder('add')
  shape = _xla_shape(ShapedArray((), x_aval.dtype))
  xops.Add(xops.Parameter(subc, 0, shape), xops.Parameter(subc, 1, shape))
  return [xops.Reduce(c, [x], [zero], subc.build(), axis)]
xla_translations[reduce_sum_p] = reduce_sum_translation

def broadcast_translation(c, in_avals, in_vals, *, shape, axes):
  x, = in_vals
  dims_complement = [i for i in range(len(shape)) if i not in axes]
  return [xops.BroadcastInDim(x, shape, dims_complement)]
xla_translations[broadcast_p] = broadcast_translation

With that, we can now use jit to stage out, compile, and execute programs with XLA!

@jit
def f(x, y):
  print('tracing!')
  return sin(x) * cos(y)
z = f(3., 4.)  # 'tracing!' prints the first time
print(z)
tracing!
-0.09224219304455371
z = f(4., 5.)  # 'tracing!' doesn't print, compilation cache hit!
print(z)
-0.21467624978306993
@jit
def f(x):
  return reduce_sum(x, axis=0)

print(f(np.array([1., 2., 3.])))
6.0
def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z

def deriv(f):
  return lambda x: jvp(f, (x,), (1.,))[1]

print(    deriv(deriv(f))(3.))
print(jit(deriv(deriv(f)))(3.))
0.2822400161197344
0.2822400161197344

Instead of implementing jit to first trace to a jaxpr and then to lower the jaxpr to XLA HLO, it might appear that we could have skipped the jaxpr step and just lowered to HLO while tracing. That is, perhaps we could have instead implemented jit with a Trace and Tracer that appended to the XLA HLO graph incrementally on each primitive bind. That’s correct for now, but won’t be possible when we introduce compiled SPMD computations because there we must know the number of replicas needed before compiling the program.

We haven’t yet defined any transformation rules for xla_call_p other than its evaluation rule. That is, we can’t yet do vmap-of-jit or jvp-of-jit or even jit-of-jit. Instead jit has to be at the “top level.” Let’s fix that!

def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):
  del num_consts  # Unused
  new_jaxpr, new_consts = jvp_jaxpr(jaxpr)
  outs = bind(xla_call_p, *new_consts, *primals, *tangents, jaxpr=new_jaxpr,
              num_consts=len(new_consts))
  n = len(outs) // 2
  primals_out, tangents_out = outs[:n], outs[n:]
  return primals_out, tangents_out
jvp_rules[xla_call_p] = xla_call_jvp_rule

@lru_cache()
def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]:
  def jvp_traceable(*primals_and_tangents):
    n = len(primals_and_tangents) // 2
    primals, tangents = primals_and_tangents[:n], primals_and_tangents[n:]
    return jvp(jaxpr_as_fun(jaxpr), primals, tangents)

  in_avals = [v.aval for v in jaxpr.in_binders]
  new_jaxpr, new_consts, _ = make_jaxpr(jvp_traceable, *in_avals, *in_avals)
  return new_jaxpr, new_consts
def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):
  del num_consts  # Unused
  new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, tuple(dims_in))
  outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr,
              num_consts=len(new_consts))
  return outs, [0] * len(outs)
vmap_rules[xla_call_p] = xla_call_vmap_rule

@lru_cache()
def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...]
               ) -> tuple[Jaxpr, list[Any]]:
  vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))
  in_avals = [unmapped_aval(axis_size, d, v.aval)
              for v, d in zip(jaxpr.in_binders, bdims_in)]
  new_jaxpr, new_consts, _ = make_jaxpr(vmap_traceable, *in_avals)
  return new_jaxpr, new_consts

def unmapped_aval(axis_size: int, batch_dim: BatchAxis, aval: ShapedArray
                  ) -> ShapedArray:
  if batch_dim is not_mapped:
    return aval
  else:
    shape = list(aval.shape)
    shape.insert(batch_dim, axis_size)
    return ShapedArray(tuple(shape), aval.dtype)
def xla_call_abstract_eval_rule(*in_types, jaxpr, num_consts):
  del num_consts  # Unused
  jaxpr_type = typecheck_jaxpr(jaxpr)
  if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):
    raise TypeError
  return jaxpr_type.out_types
abstract_eval_rules[xla_call_p] = xla_call_abstract_eval_rule

def xla_call_translation(c, in_avals, in_vals, *, jaxpr, num_consts):
  del num_consts  # Only used at top-level.
  # Calling jaxpr_subcomp directly would inline. We generate a Call HLO instead.
  subc = xc.XlaBuilder('inner xla_call')
  xla_params = _xla_params(subc, in_avals)
  outs = jaxpr_subcomp(subc, jaxpr, xla_params)
  subc = subc.build(xops.Tuple(subc, outs))
  return destructure_tuple(c, xops.Call(c, subc, in_vals))
xla_translations[xla_call_p] = xla_call_translation

def destructure_tuple(c, tup):
  num_elements = len(c.get_shape(tup).tuple_shapes())
  return [xops.GetTupleElement(tup, i) for i in range(num_elements)]
@jit
def f(x):
  print('tracing!')
  y = sin(x) * 2.
  z = - y + x
  return z

x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
tracing!
2.7177599838802657
2.979984993200891
y, ydot = jvp(f, (x,), (xdot,))  # 'tracing!' not printed
ys = vmap(f, (0,))(np.arange(3.))
print(ys)
[ 0.         -0.68294197  0.18140515]

One piece missing is device memory persistence for arrays. That is, we’ve defined handle_result to transfer results back to CPU memory as NumPy arrays, but it’s often preferable to avoid transferring results just to transfer them back for the next operation. We can do that by introducing an Array class, which can wrap XLA buffers and otherwise duck-type numpy.ndarrays:

def handle_result(aval: ShapedArray, buf):  # noqa: F811
  return Array(aval, buf)

class Array:
  buf: Any
  aval: ShapedArray

  def __init__(self, aval, buf):
    self.aval = aval
    self.buf = buf

  dtype = property(lambda self: self.aval.dtype)
  shape = property(lambda self: self.aval.shape)
  ndim  = property(lambda self: self.aval.ndim)

  def __array__(self): return np.asarray(self.buf)
  def __repr__(self):  return repr(np.asarray(self.buf))
  def __str__(self):   return str(np.asarray(self.buf))

  _neg = staticmethod(neg)
  _add = staticmethod(add)
  _radd = staticmethod(add)
  _mul = staticmethod(mul)
  _rmul = staticmethod(mul)
  _gt = staticmethod(greater)
  _lt = staticmethod(less)
input_handlers[Array] = lambda x: x.buf

jax_types.add(Array)
@jit
def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z

x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
2.7177599838802657
2.979984993200891
Hide code cell source
def pprint_xla_call(names: defaultdict[Var, str], eqn: JaxprEqn) -> PPrint:
  lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
  params_without_jaxpr = {k:v for k, v in eqn.params.items() if k != 'jaxpr'}
  rhs = (pp(eqn.primitive.name) >> pp_params(params_without_jaxpr) >>
         pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
                     for x in eqn.inputs)))
  return vcat([lhs >> pp(' = ') >> rhs,
               pp_jaxpr(eqn.params['jaxpr']).indent(2)])
pp_rules[xla_call_p] = pprint_xla_call

Part 4: linearize and vjp (and grad!)#

The linearize and vjp autodiff functions are built on jvp, but involve jaxprs as well. That’s because both involve staging out, or delaying, computation.

linearize#

In the case of linearize, we want to stage out the linear part of a jvp computation. That is, in terms of Haskell-like type signatures, if we have jvp : (a -> b) -> (a, T a) -> (b, T b), then we write linearize : (a -> b) -> a -> (b, T a -o T b), using T a to mean “the tangent type of a” and using the “lollipop” -o rather than the arrow -> to indicate a linear function. We define the semantics of linearize in terms of jvp too:

y, f_lin = linearize(f, x)
y_dot = f_lin(x_dot)

gives the same result for (y, y_dot) as

y, y_dot = jvp(f, (x,), (x_dot,))

where the application of f_lin does not redo any of the linearization work. We’ll represent the delayed linear part f_lin : T a -o T b as a jaxpr.

Tangentially, now that we have linear arrows -o, we can provide a slightly more informative type for jvp:

jvp : (a -> b) -> (UnrestrictedUse a, T a) -o (UnrestrictedUse b, T b)

Here we’re writing UnrestrictedUse just to indicate that we have a special pair where the first element can be used in an unrestricted (nonlinear) way. In conjunction with the linear arrow, this notation is just meant to express that the function jvp f uses its first input in a nonlinear way but its second input in a linear way, producing a corresponding nonlinear output (which can be used in a nonlinear way) paired with a linear output. This more refined type signature encodes the data dependencies in jvp f, which are useful for partial evaluation.

To build the f_lin jaxpr from a JVP, we need to perform partial evaluation: we evaluate all the primal values as we trace, but stage the tangent computations into a jaxpr. This is our second way to build jaxprs. But where make_jaxpr and its underlying JaxprTrace/JaxprTracer interpreters aim to stage out every primitive bind, this second approach stages out only those primitive binds with a data dependence on tangent inputs.

First, some utilities:

def split_half(lst: list[Any]) -> tuple[list[Any], list[Any]]:
  assert not len(lst) % 2
  return split_list(lst, len(lst) // 2)

def merge_lists(which: list[bool], l1: list[Any], l2: list[Any]) -> list[Any]:
  l1, l2 = iter(l1), iter(l2)
  out = [next(l2) if b else next(l1) for b in which]
  assert next(l1, None) is next(l2, None) is None
  return out

Next, we’ll write linearize by combining jvp together with a general partial evaluation transformation, to be added next:

def linearize_flat(f, *primals_in):
  pvals_in = ([PartialVal.known(x) for x in primals_in] +
              [PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])
  def f_jvp(*primals_tangents_in):
    primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))
    return [*primals_out, *tangents_out]
  jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in)
  primal_pvals, _ = split_half(pvals_out)
  assert all(pval.is_known for pval in primal_pvals)
  primals_out = [pval.const for pval in primal_pvals]
  f_lin = lambda *tangents: eval_jaxpr(jaxpr, [*consts, *tangents])
  return primals_out, f_lin

def linearize(f, *primals_in):
  primals_in_flat, in_tree = tree_flatten(primals_in)
  f, out_tree = flatten_fun(f, in_tree)
  primals_out_flat, f_lin_flat = linearize_flat(f, *primals_in_flat)
  primals_out = tree_unflatten(out_tree(), primals_out_flat)

  def f_lin(*tangents_in):
    tangents_in_flat, in_tree2 = tree_flatten(tangents_in)
    if in_tree != in_tree2: raise TypeError
    tangents_out_flat = f_lin_flat(*tangents_in_flat)
    return tree_unflatten(out_tree(), tangents_out_flat)

  return primals_out, f_lin

def vspace(aval: ShapedArray) -> ShapedArray:
  return raise_to_shaped(aval)  # TODO handle integers?

Now we turn to the general partial evaluation transformation. The goal is to accept a Python callable and a list of inputs, some known and some unknown, and to produce (1) all the outputs which can be computed from the known inputs, together with (2) a jaxpr representing the part of the Python callable’s computation which can only be performed after the remaining inputs are known.

This transformation is tricky to summarize in a type signature. If we assume the input function’s type signature is (a1, a2) -> (b1, b2), where a1 and a2 represent the known and unknown inputs, respectively, and where b1 only has a data dependency on a1 while b2 has some data dependency on a2, then we might write

partial_eval : ((a1, a2) -> (b1, b2)) -> a1 -> exists r. (b1, r, (r, a2) -> b2)

In words, given values for the inputs of type a1, partial_eval produces the outputs of type b1 along with “residual” values of existentially-quantified type r representing the intermediates required to complete the computation in the second stage. It also produces a function of type (r, a2) -> b2 which accepts the residual values as well as the remaining inputs and produces the remaining outputs.

We like to think of partial evaluation as “unzipping” one computation into two. For example, consider this jaxpr:

{ lambda a:float64[] .
  let b:float64[] = sin a
      c:float64[] = neg b
  in ( c ) }

A jaxpr for the JVP would look like:

{ lambda a:float64[] b:float64[] .
  let c:float64[] = sin a
      d:float64[] = cos a
      e:float64[] = mul d b
      f:float64[] = neg c
      g:float64[] = neg e
  in ( f, g ) }

If we imagine applying partial evaluation to this jaxpr with the first input known and the second unknown, we end up ‘unzipping’ the JVP jaxpr into primal and tangent jaxprs:

{ lambda a:float64[] .
  let c:float64[] = sin a
      d:float64[] = cos a
      f:float64[] = neg c
  in ( f, d ) }
{ lambda d:float64[] b:float64[] .
  let e:float64[] = mul d b
      g:float64[] = neg e
  in ( g ) }

This second jaxpr represents the linear computation that we want from linearize.

However, unlike in this jaxpr example, we want the computation on known values to occur while evaluating the input Python callable. That is, rather than forming a jaxpr for the entire function (a1, a2) -> (b1, b2), staging all operations out of Python first before sorting out what can be evaluated now and what must be delayed, we want only to form a jaxpr for those operations that must be delayed due to a dependence on unknown inputs. In the context of automatic differentiation, this is the feature that ultimately enables us to handle functions like grad(lambda x: x**2 if x > 0 else 0.). Python control flow works because partial evaluation keeps the primal computation in Python. As a consequence, our Trace and Tracer subclasses must on the fly sort out what can be evaluated and what must be staged out into a jaxpr.

First, we start with a PartialVal class, which represents a value that can be either known or unknown:

class PartialVal(NamedTuple):
  aval: ShapedArray
  const: Optional[Any]

  @classmethod
  def known(cls, val: Any):
    return PartialVal(get_aval(val), val)

  @classmethod
  def unknown(cls, aval: ShapedArray):
    return PartialVal(aval, None)

  is_known   = property(lambda self: self.const is not None)
  is_unknown = property(lambda self: self.const is     None)

Partial evaluation will take a list of PartialVals representing inputs, and return a list of PartialVal outputs along with a jaxpr representing the delayed computation:

def partial_eval_flat(f: Callable, pvals_in: list[PartialVal]
                      ) -> tuple[Jaxpr, list[PartialVal], list[Any]]:
  with new_main(PartialEvalTrace) as main:
    trace = PartialEvalTrace(main)
    tracers_in = [trace.new_arg(pval) for pval in pvals_in]
    outs = f(*tracers_in)
    tracers_out = [full_raise(trace, out) for out in outs]
    pvals_out = [t.pval for t in tracers_out]
    unk_tracers_in  = [t for t in tracers_in  if t.pval.is_unknown]
    unk_tracers_out = [t for t in tracers_out if t.pval.is_unknown]
    jaxpr, consts = tracers_to_jaxpr(unk_tracers_in, unk_tracers_out)
  return jaxpr, pvals_out, consts

Next we need to implement PartialEvalTrace and its PartialEvalTracer. This interpreter will build a jaxpr on the fly while tracking data dependencies. To do so, it builds a bipartite directed acyclic graph (DAG) between PartialEvalTracer nodes, representing staged-out values, and JaxprRecipe nodes, representing formulas for how to compute some values from others. One kind of recipe is a JaxprEqnRecipe, corresponding to a JaxprEqn’s primitive application, but we also have recipe types for constants and lambda binders:

from weakref import ref, ReferenceType

class LambdaBindingRecipe(NamedTuple):
  pass

class ConstRecipe(NamedTuple):
  val: Any

class JaxprEqnRecipe(NamedTuple):
  prim: Primitive
  tracers_in: list['PartialEvalTracer']
  params: dict[str, Any]
  avals_out: list[ShapedArray]
  tracer_refs_out: list['ReferenceType[PartialEvalTracer]']

JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe]
class PartialEvalTracer(Tracer):
  pval: PartialVal
  recipe: Optional[JaxprRecipe]

  def __init__(self, trace, pval, recipe):
    self._trace = trace
    self.pval = pval
    self.recipe = recipe

  aval = property(lambda self: self.pval.aval)

  def full_lower(self):
    if self.pval.is_known:
      return full_lower(self.pval.const)
    return self

The PartialEvalTrace contains the logic for constructing the graph of JaxprRecipes and PartialEvalTracers. Each argument corresponds to a LambdaBindingRecipe leaf node, and each constant is a ConstRecipe leaf node holding a reference to the constant. All other tracers and recipes come from process_primitive, which forms tracers with JaxprEqnRecipes.

For most primitives, the process_primitive logic is straightforward: if all inputs are known then we can bind the primitive on the known values (evaluating it in Python) and avoid forming tracers corresponding to the output. If instead any input is unknown then we instead stage out into a JaxprEqnRecipe representing the primitive application. To build the tracers representing unknown outputs, we need avals, which we get from the abstract eval rules. (Notice that tracers reference JaxprEqnRecipes, and JaxprEqnRecipes reference tracers; we avoid circular garbage by using weakrefs.)

That process_primitive logic applies to most primitives, but xla_call_p requires recursive treatment. So we special-case its rule in a partial_eval_rules dict.

class PartialEvalTrace(Trace):
  def new_arg(self, pval: PartialVal) -> Any:
    return PartialEvalTracer(self, pval, LambdaBindingRecipe())

  def lift(self, val: Any) -> PartialEvalTracer:
    return PartialEvalTracer(self, PartialVal.known(val), None)
  pure = lift

  def instantiate_const(self, tracer: PartialEvalTracer) -> PartialEvalTracer:
    if tracer.pval.is_unknown:
      return tracer
    else:
      pval = PartialVal.unknown(raise_to_shaped(tracer.aval))
      return PartialEvalTracer(self, pval, ConstRecipe(tracer.pval.const))

  def process_primitive(self, primitive, tracers, params):
    if all(t.pval.is_known for t in tracers):
      return bind(primitive, *map(full_lower, tracers), **params)
    rule = partial_eval_rules.get(primitive)
    if rule: return rule(self, tracers, **params)
    tracers_in = [self.instantiate_const(t) for t in tracers]
    avals_in = [t.aval for t in tracers_in]
    avals_out = abstract_eval_rules[primitive](*avals_in, **params)
    tracers_out = [PartialEvalTracer(self, PartialVal.unknown(aval), None)
                   for aval in avals_out]
    eqn = JaxprEqnRecipe(primitive, tracers_in, params, avals_out,
                         map(ref, tracers_out))
    for t in tracers_out: t.recipe = eqn
    return tracers_out

partial_eval_rules = {}

Now that we can build graph representations of jaxprs with PartialEvalTrace, we need a mechanism to convert the graph representation to a standard jaxpr. The jaxpr corresponds to a topological sort of the graph.

def tracers_to_jaxpr(tracers_in: list[PartialEvalTracer],
                     tracers_out: list[PartialEvalTracer]):
  tracer_to_var: dict[int, Var] = {id(t): Var(raise_to_shaped(t.aval))
                                   for t in tracers_in}
  constvar_to_val: dict[int, Any] = {}
  constid_to_var: dict[int, Var] = {}
  processed_eqns: set[int] = set()
  eqns: list[JaxprEqn] = []
  for t in toposort(tracers_out, tracer_parents):
    if isinstance(t.recipe, LambdaBindingRecipe):
      assert id(t) in set(map(id, tracers_in))
    elif isinstance(t.recipe, ConstRecipe):
      val = t.recipe.val
      var = constid_to_var.get(id(val))
      if var is None:
        aval = raise_to_shaped(get_aval(val))
        var = constid_to_var[id(val)] = Var(aval)
        constvar_to_val[var] = val
      tracer_to_var[id(t)] = var
    elif isinstance(t.recipe, JaxprEqnRecipe):
      if id(t.recipe) not in processed_eqns:
        eqns.append(recipe_to_eqn(tracer_to_var, t.recipe))
        processed_eqns.add(id(t.recipe))
    else:
      raise TypeError(t.recipe)

  constvars, constvals = unzip2(constvar_to_val.items())
  in_binders = constvars + [tracer_to_var[id(t)] for t in tracers_in]
  out_vars = [tracer_to_var[id(t)] for t in tracers_out]
  jaxpr = Jaxpr(in_binders, eqns, out_vars)
  typecheck_jaxpr(jaxpr)
  return jaxpr, constvals

def recipe_to_eqn(tracer_to_var: dict[int, Var], recipe: JaxprEqnRecipe
                  ) -> JaxprEqn:
  inputs = [tracer_to_var[id(t)] for t in recipe.tracers_in]
  out_binders = [Var(aval) for aval in recipe.avals_out]
  for t_ref, var in zip(recipe.tracer_refs_out, out_binders):
    if t_ref() is not None: tracer_to_var[id(t_ref())] = var
  return JaxprEqn(recipe.prim, inputs, recipe.params, out_binders)

def tracer_parents(t: PartialEvalTracer) -> list[PartialEvalTracer]:
  return t.recipe.tracers_in if isinstance(t.recipe, JaxprEqnRecipe) else []
Hide code cell source
def toposort(out_nodes: list[Any], parents: Callable[[Any], list[Any]]):
  if not out_nodes: return []
  out_nodes = remove_duplicates(out_nodes)

  child_counts = {}
  stack = list(out_nodes)
  while stack:
    node = stack.pop()
    if id(node) in child_counts:
      child_counts[id(node)] += 1
    else:
      child_counts[id(node)] = 1
      stack.extend(parents(node))
  for node in out_nodes:
    child_counts[id(node)] -= 1

  sorted_nodes = []
  childless_nodes = [node for node in out_nodes if not child_counts[id(node)]]
  while childless_nodes:
    node = childless_nodes.pop()
    sorted_nodes.append(node)
    for parent in parents(node):
      if child_counts[id(parent)] == 1:
        childless_nodes.append(parent)
      else:
        child_counts[id(parent)] -= 1

  sorted_nodes = sorted_nodes[::-1]
  check_toposort(sorted_nodes, parents)
  return sorted_nodes

def remove_duplicates(lst):
  seen = set()
  return [x for x in lst if id(x) not in seen and not seen.add(id(x))]

def check_toposort(nodes: list[Any], parents: Callable[[Any], list[Any]]):
  seen = set()
  for node in nodes:
    assert all(id(parent) in seen for parent in parents(node))
    seen.add(id(node))

Now we can linearize!

y, sin_lin = linearize(sin, 3.)
print(y, sin(3.))
print(sin_lin(1.), cos(3.))
0.1411200080598672 0.1411200080598672
-0.9899924966004454 -0.9899924966004454

To handle linearize-of-jit, we still need to write a partial evaluation rule for xla_call_p. Other than tracer bookkeeping, the main task is to perform partial evaluation of a jaxpr, ‘unzipping’ it into two jaxprs.

There are actually two rules to write: one for trace-time partial evaluation, which we’ll call xla_call_partial_eval, and one for partial evaluation of jaxprs, which we’ll call xla_call_peval_eqn.

def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
  del num_consts  # Unused
  in_unknowns = [not t.pval.is_known for t in tracers]
  jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns)
  known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)
  known_vals = [t.pval.const for t in known_tracers]
  outs1_res = bind(xla_call_p, *known_vals, jaxpr=jaxpr1, num_consts=0)
  outs1, res = split_list(outs1_res, len(jaxpr1.outs) - num_res)
  res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]
  outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None)
           for v in jaxpr2.outs]
  eqn = JaxprEqnRecipe(xla_call_p, res_tracers + unknown_tracers,
                       dict(jaxpr=jaxpr2, num_consts=0),
                       [v.aval for v in jaxpr2.outs], map(ref, outs2))
  for t in outs2: t.recipe = eqn
  return merge_lists(out_unknowns, outs1, outs2)
partial_eval_rules[xla_call_p] = xla_call_partial_eval

def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool],
                       instantiate: Optional[list[bool]] = None,
                       ) -> tuple[Jaxpr, Jaxpr, list[bool], int]:
  env: dict[Var, bool] = {}
  residuals: set[Var] = set()

  def read(x: Atom) -> bool:
    return type(x) is Var and env[x]

  def write(unk: bool, v: Var) -> None:
    env[v] = unk

  def new_res(x: Atom) -> Atom:
    if type(x) is Var: residuals.add(x)
    return x

  eqns1, eqns2 = [], []
  map(write, in_unknowns, jaxpr.in_binders)
  for eqn in jaxpr.eqns:
    unks_in = map(read, eqn.inputs)
    rule = partial_eval_jaxpr_rules.get(eqn.primitive)
    if rule:
      eqn1, eqn2, unks_out, res = rule(unks_in, eqn)
      eqns1.append(eqn1); eqns2.append(eqn2); residuals.update(res)
      map(write, unks_out, eqn.out_binders)
    elif any(unks_in):
      inputs = [v if unk else new_res(v) for unk, v in zip(unks_in, eqn.inputs)]
      eqns2.append(JaxprEqn(eqn.primitive, inputs, eqn.params, eqn.out_binders))
      map(partial(write, True), eqn.out_binders)
    else:
      eqns1.append(eqn)
      map(partial(write, False), eqn.out_binders)
  out_unknowns = map(read, jaxpr.outs)
  if instantiate is not None:
    for v, uk, inst in zip(jaxpr.outs, out_unknowns, instantiate):
      if inst and not uk: new_res(v)
    out_unknowns = map(op.or_, out_unknowns, instantiate)

  residuals, num_res = list(residuals), len(residuals)
  assert all(type(v) is Var for v in residuals), residuals

  ins1, ins2 = partition_list(in_unknowns, jaxpr.in_binders)
  outs1, outs2 = partition_list(out_unknowns, jaxpr.outs)

  jaxpr1 = Jaxpr(ins1, eqns1, outs1 + residuals)
  jaxpr2 = Jaxpr(residuals + ins2, eqns2, outs2)
  typecheck_partial_eval_jaxpr(jaxpr, in_unknowns, out_unknowns, jaxpr1, jaxpr2)

  return jaxpr1, jaxpr2, out_unknowns, num_res

def typecheck_partial_eval_jaxpr(jaxpr, unks_in, unks_out, jaxpr1, jaxpr2):
  jaxprty = typecheck_jaxpr(jaxpr)    # (a1,  a2) -> (b1, b2 )
  jaxpr1ty = typecheck_jaxpr(jaxpr1)  #  a1       -> (b1, res)
  jaxpr2ty = typecheck_jaxpr(jaxpr2)  # (res, a2) -> b2

  a1, a2 = partition_list(unks_in, jaxprty.in_types)
  b1, b2 = partition_list(unks_out, jaxprty.out_types)
  b1_, res = split_list(jaxpr1ty.out_types, len(b1))
  res_, a2_ = split_list(jaxpr2ty.in_types, len(res))
  b2_ = jaxpr2ty.out_types

  if jaxpr1ty.in_types != a1: raise TypeError
  if jaxpr2ty.out_types != b2: raise TypeError
  if b1 != b1_: raise TypeError
  if res != res_: raise TypeError
  if a2 != a2_: raise TypeError
  if b2 != b2_: raise TypeError

partial_eval_jaxpr_rules = {}

def xla_call_peval_eqn(unks_in: list[bool], eqn: JaxprEqn,
                       ) -> tuple[JaxprEqn, JaxprEqn, list[bool], list[Var]]:
  jaxpr = eqn.params['jaxpr']
  jaxpr1, jaxpr2, unks_out, num_res = partial_eval_jaxpr(jaxpr, unks_in)
  ins1, ins2 = partition_list(unks_in, eqn.inputs)
  out_binders1, out_binders2 = partition_list(unks_out, eqn.out_binders)
  residuals = [Var(v.aval) for v in jaxpr2.in_binders[:num_res]]
  eqn1 = JaxprEqn(xla_call_p, ins1, dict(jaxpr=jaxpr1, num_consts=0),
                  out_binders1 + residuals)
  eqn2 = JaxprEqn(xla_call_p, residuals + ins2,
                  dict(jaxpr=jaxpr2, num_consts=0), out_binders2)
  return eqn1, eqn2, unks_out, residuals
partial_eval_jaxpr_rules[xla_call_p] = xla_call_peval_eqn

With that, we can compose linearize and jit however we like:

@jit
def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z

y, f_lin = linearize(f, 3.)
y_dot = f_lin(1.)
print(y, y_dot)
2.7177599838802657 2.979984993200891
@jit
def f(x):
  y = sin(x) * 2.
  z = g(x, y)
  return z

@jit
def g(x, y):
  return cos(x) + y

y, f_lin = linearize(f, 3.)
y_dot = f_lin(1.)
print(y, y_dot)
-0.7077524804807109 -2.121105001260758
vjp and grad#

The vjp transformation works a lot like linearize. Its type signature is analogous:

linearize : (a -> b) -> a -> (b, T a -o T b)
vjp       : (a -> b) -> a -> (b, T b -o T a)

The only difference is that we transpose the linear part of the computation before returning it, so that it goes from type T a -o T b to type T b -o T a. That is, we’ll implement vjp as, essentially,

def vjp(f, x):
  y, f_lin = linearize(f, x)
  f_vjp = lambda y_bar: transpose(f_lin)(y_bar)
  return y, f_vjp

Since we have the linear computation as a jaxpr, not just a Python callable, we can implement the transpose transformation as a jaxpr interpreter.

def vjp_flat(f, *primals_in):
  pvals_in = ([PartialVal.known(x) for x in primals_in] +
              [PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])
  primal_pvals_in, tangent_pvals_in = split_half(pvals_in)
  def f_jvp(*primals_tangents_in):
    primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))
    return [*primals_out, *tangents_out]
  jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in)  # linearize
  primal_pvals, _ = split_half(pvals_out)
  assert all(pval.is_known for pval in primal_pvals)
  primals_out = [pval.const for pval in primal_pvals]
  transpose_inputs = consts + [UndefPrimal(p.aval) for p in tangent_pvals_in]
  f_vjp = lambda *cts: eval_jaxpr_transposed(jaxpr, transpose_inputs, cts)
  return primals_out, f_vjp

def vjp(f, *primals_in):
  primals_in_flat, in_tree = tree_flatten(primals_in)
  f, out_tree = flatten_fun(f, in_tree)
  primals_out_flat, f_vjp_flat = vjp_flat(f, *primals_in_flat)
  primals_out = tree_unflatten(out_tree(), primals_out_flat)

  def f_vjp(*cotangents_out):
    cotangents_out_flat, _ = tree_flatten(cotangents_out)
    cotangents_in_flat = f_vjp_flat(*cotangents_out_flat)
    return tree_unflatten(in_tree, cotangents_in_flat)

  return primals_out, f_vjp

class UndefPrimal(NamedTuple):
  aval: ShapedArray

register_pytree_node(UndefPrimal,
                     lambda u: (u.aval, ()),
                     lambda aval, _: UndefPrimal(aval))

We use UndefPrimal instances to indicate which arguments with respect to which we want to transpose. These arise because in general, being explicit about closed-over values, we want to transpose functions of type a -> b -o c to functions of type a -> c -o b. Even more generally, the inputs with respect to which the function is linear could be scattered through the argument list. So we indicate the linear positions using UndefPrimal. We register UndefPrimal as a pytree node because the pytree mechanism gives a handy way to prune these placeholders out of argument lists.

Next, we can write eval_jaxpr_transposed, along with transpose rules for all primitives which can be linear in at least one argument:

# NB: the analogous function in JAX is called 'backward_pass'
def eval_jaxpr_transposed(jaxpr: Jaxpr, args: list[Any], cotangents: list[Any]
                          ) -> list[Any]:
  primal_env: dict[Var, Any] = {}
  ct_env: dict[Var, Any] = {}

  def read_primal(x: Atom) -> Any:
    return primal_env.get(x, UndefPrimal(x.aval)) if type(x) is Var else x.val

  def write_primal(v: Var, val: Any) -> None:
    if type(val) is not UndefPrimal:
      primal_env[v] = val

  def read_cotangent(v: Var) -> Any:
    return ct_env.pop(v, np.zeros(v.aval.shape, v.aval.dtype))

  def write_cotangent(x: Atom, val: Any):
    if type(x) is Var and val is not None:
      ct_env[x] = add(ct_env[x], val) if x in ct_env else val

  map(write_primal, jaxpr.in_binders, args)
  map(write_cotangent, jaxpr.outs, cotangents)
  for eqn in jaxpr.eqns[::-1]:
    primals_in = map(read_primal, eqn.inputs)
    cts_in = map(read_cotangent, eqn.out_binders)
    rule = transpose_rules[eqn.primitive]
    cts_out = rule(cts_in, *primals_in, **eqn.params)
    map(write_cotangent, eqn.inputs, cts_out)

  return [read_cotangent(v) for v, x in zip(jaxpr.in_binders, args)
          if type(x) is UndefPrimal]

transpose_rules = {}
def mul_transpose_rule(cts, x, y):
  z_bar, = cts
  assert (type(x) is UndefPrimal) ^ (type(y) is UndefPrimal)
  return [mul(z_bar, y), None] if type(x) is UndefPrimal else [None, mul(x, z_bar)]
transpose_rules[mul_p] = mul_transpose_rule

def neg_transpose_rule(cts, x):
  ybar, = cts
  assert type(x) is UndefPrimal
  return [neg(ybar)]
transpose_rules[neg_p] = neg_transpose_rule

def add_transpose_rule(cts, x, y):
  z_bar, = cts
  return [z_bar, z_bar]
transpose_rules[add_p] = add_transpose_rule

def reduce_sum_transpose_rule(cts, x, *, axis):
  y_bar, = cts
  return [broadcast(y_bar, x.aval.shape, axis)]
transpose_rules[reduce_sum_p] = reduce_sum_transpose_rule

def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):
  del num_consts  # Unused
  undef_primals = [type(x) is UndefPrimal for x in invals]
  transposed_jaxpr, new_consts = transpose_jaxpr(jaxpr, tuple(undef_primals))
  residuals, _ = partition_list(undef_primals, invals)
  outs = bind(xla_call_p, *new_consts, *residuals, *cts,
              jaxpr=transposed_jaxpr, num_consts=len(new_consts))
  outs = iter(outs)
  return [next(outs) if undef else None for undef in undef_primals]
transpose_rules[xla_call_p] = xla_call_transpose_rule

@lru_cache()
def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...]
                    ) -> tuple[Jaxpr, list[Any]]:
  avals_in, avals_out = typecheck_jaxpr(jaxpr)
  traceable = partial(eval_jaxpr_transposed, jaxpr)
  args = [UndefPrimal(a) if u else a for a, u in zip(avals_in, undef_primals)]
  trans_jaxpr, consts, _ = make_jaxpr(traceable, tuple(args), tuple(avals_out))
  typecheck_jaxpr(trans_jaxpr)
  return trans_jaxpr, consts

Now that we can linearize and transpose, we can finally write grad:

def grad(f):
  def gradfun(x, *xs):
    y, f_vjp = vjp(f, x, *xs)
    if np.shape(y) != (): raise TypeError
    x_bar, *_ = f_vjp(np.ones(np.shape(y), np.result_type(y)))
    return x_bar
  return gradfun
y, f_vjp = vjp(sin, 3.)
print(f_vjp(1.), cos(3.))
(-0.9899924966004454,) -0.9899924966004454
def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z

print(grad(f)(3.))
2.979984993200891
@jit
def f(x):
  y = x * 2.
  z = g(y)
  return z

@jit
def g(x):
  return cos(x) * 2.

print(grad(f)(3.))
1.1176619927957034

Here’s something of a compositionality stress test:

# from core_test.py fun_with_nested_calls_2
def foo(x):
  @jit
  def bar(y):
    def baz(w):
      q = jit(lambda x: y)(x)
      q = q + jit(lambda: y)()
      q = q + jit(lambda y: w + y)(y)
      q = jit(lambda w: jit(sin)(x) * y)(1.0) + q
      return q
    p, t = jvp(baz, (x + 1.0,), (y,))
    return t + (x * p)
  return bar(x)

def assert_allclose(*vals):
  for v1, v2 in zip(vals[:-1], vals[1:]):
    np.testing.assert_allclose(v1, v2)

ans1 = f(3.)
ans2 = jit(f)(3.)
ans3, _ = jvp(f, (3.,), (5.,))
ans4, _ = jvp(jit(f), (3.,), (5.,))
assert_allclose(ans1, ans2, ans3, ans4)

deriv1 = grad(f)(3.)
deriv2 = grad(jit(f))(3.)
deriv3 = jit(grad(jit(f)))(3.)
_, deriv4 = jvp(f, (3.,), (1.,))
_, deriv5 = jvp(jit(f), (3.,), (1.,))
assert_allclose(deriv1, deriv2, deriv3, deriv4, deriv5)

hess1 = grad(grad(f))(3.)
hess2 = grad(grad(jit(f)))(3.)
hess3 = grad(jit(grad(f)))(3.)
hess4 = jit(grad(grad(f)))(3.)
_, hess5 = jvp(grad(f), (3.,), (1.,))
_, hess6 = jvp(jit(grad(f)), (3.,), (1.,))
_, hess7 = jvp(jit(grad(f)), (3.,), (1.,))
assert_allclose(hess1, hess2, hess3, hess4, hess5, hess6, hess7)

Part 5: the control flow primitives cond#

Next we’ll add higher-order primitives for staged-out control flow. These resemble jit from Part 3, another higher-order primitive, but differ in that they are parameterized by multiple callables rather than just one.

Adding cond#

We introduce a cond primitive to represent conditional application of one function or another inside a jaxpr. We write the type of cond as Bool -> (a -> b) -> (a -> b) -> a -> b. In words, cond takes a boolean representing the predicate and two functions of equal types. Depending on the value of the predicate, it applies one function or the other to its final argument.

In Python, we represent it as a function which itself takes two functions as arguments. As with jit, the first step is to call make_jaxpr on its callable arguments to turn them into jaxprs:

def cond(pred, true_fn, false_fn, *operands):
  avals_in = [raise_to_shaped(get_aval(x)) for x in operands]
  true_jaxpr, true_consts, out_tree = make_jaxpr(true_fn, *avals_in)
  false_jaxpr, false_consts, out_tree_ = make_jaxpr(false_fn, *avals_in)
  if out_tree != out_tree_: raise TypeError
  true_jaxpr, false_jaxpr = _join_jaxpr_consts(
      true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
  if typecheck_jaxpr(true_jaxpr) != typecheck_jaxpr(false_jaxpr):
    raise TypeError
  outs = bind_cond(pred, *true_consts, *false_consts, *operands,
                   true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
  return tree_unflatten(out_tree, outs)
cond_p = Primitive('cond')

def _join_jaxpr_consts(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int
                       ) -> tuple[Jaxpr, Jaxpr]:
  jaxpr1_type, jaxpr2_type = typecheck_jaxpr(jaxpr1), typecheck_jaxpr(jaxpr2)
  assert jaxpr1_type.in_types[n1:] == jaxpr2_type.in_types[n2:]
  consts1, rest1 = split_list(jaxpr1.in_binders, n1)
  consts2, rest2 = split_list(jaxpr2.in_binders, n2)
  new_jaxpr1 = Jaxpr(consts1 + consts2 + rest1, jaxpr1.eqns, jaxpr1.outs)
  new_jaxpr2 = Jaxpr(consts1 + consts2 + rest2, jaxpr2.eqns, jaxpr2.outs)
  return new_jaxpr1, new_jaxpr2

def bind_cond(pred, *args, true_jaxpr, false_jaxpr):
  assert len(args) == len(true_jaxpr.in_binders) == len(false_jaxpr.in_binders)
  return bind(cond_p, pred, *args, true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)

We require true_jaxpr and false_jaxpr to have the same type, but because they might close over different constants (and because jaxprs can only represent closed terms, i.e. can’t have free variables and are instead closure-converted) we need to use the helper _join_jaxpr_consts to make consistent the input binder lists of the two jaxprs. (To be more economical we could try to identify pairs of constants with the same shapes, but instead we just concatenate the lists of constants.)

Next we can turn to adding interpreter rules for cond. Its evaluation rule is simple:

def cond_impl(pred, *operands, true_jaxpr, false_jaxpr):
  if pred:
    return eval_jaxpr(true_jaxpr, operands)
  else:
    return eval_jaxpr(false_jaxpr, operands)
impl_rules[cond_p] = cond_impl
out = cond(True, lambda: 3, lambda: 4)
print(out)
3

For its JVP and vmap rules, we only need to call the same jvp_jaxpr and vmap_jaxpr utilities we created for jit, followed by another pass of _join_jaxpr_consts:

def cond_jvp_rule(primals, tangents, *, true_jaxpr, false_jaxpr):
  pred, *primals = primals
  _   , *tangents = tangents
  true_jaxpr , true_consts  = jvp_jaxpr(true_jaxpr)
  false_jaxpr, false_consts = jvp_jaxpr(false_jaxpr)
  true_jaxpr, false_jaxpr = _join_jaxpr_consts(
      true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
  assert typecheck_jaxpr(true_jaxpr) == typecheck_jaxpr(false_jaxpr)
  outs = bind_cond(pred, *true_consts, *false_consts, *primals, *tangents,
                   true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
  primals_out, tangents_out = split_half(outs)
  return primals_out, tangents_out
jvp_rules[cond_p] = cond_jvp_rule
out, out_tan = jvp(lambda x: cond(True, lambda: x * x, lambda: 0.), (1.,), (1.,))
print(out_tan)
2.0
def cond_vmap_rule(axis_size, vals_in, dims_in, *, true_jaxpr, false_jaxpr):
  pred    , *vals_in = vals_in
  pred_dim, *dims_in = dims_in
  if pred_dim is not not_mapped: raise NotImplementedError  # TODO
  true_jaxpr, true_consts = vmap_jaxpr(true_jaxpr, axis_size, tuple(dims_in))
  false_jaxpr, false_consts = vmap_jaxpr(false_jaxpr, axis_size, tuple(dims_in))
  true_jaxpr, false_jaxpr = _join_jaxpr_consts(
      true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
  assert typecheck_jaxpr(true_jaxpr) == typecheck_jaxpr(false_jaxpr)
  outs = bind_cond(pred, *true_consts, *false_consts, *vals_in,
                   true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
  return outs, [0] * len(outs)
vmap_rules[cond_p] = cond_vmap_rule
xs = np.array([1., 2., 3])
out = vmap(lambda x: cond(True, lambda: x + 1., lambda: 0.), (0,))(xs)
print(out)
[2. 3. 4.]

Notice that we’re not currently supporting the case where the predicate value itself is batched. In mainline JAX, we handle this case by transforming the conditional to a select primitive. That transformation is semantically correct so long as true_fun and false_fun do not involve any side-effecting primitives.

Another thing not represented here, but present in the mainline JAX, is that applying transformations to two jaxprs of equal type might result in jaxprs of different types. For example, applying the mainline JAX version of vmap_jaxpr to the identity-function jaxpr

{ lambda a:float32[] .
  let
  in ( a ) }

would result in a jaxpr with a batched output, of type [float32[10]] -> [float32[10]] if the batch size were 10, while applying it to the zero-function jaxpr

{ lambda a:float32[] .
  let
  in ( 0. ) }

would result in a jaxpr with an unbatched output, of type [float32[10]] -> [float32[]]. This is an optimization, aimed at not batching values unnecessarily. But it means that in cond we’d need an extra step of joining the two transformed jaxprs to have consistent output types. We don’t need this step here because we chose vmap_jaxpr always to batch all outputs over the leading axis.

Next we can turn to abstract evaluation and XLA lowering rules:

def cond_abstract_eval(pred_type, *in_types, true_jaxpr, false_jaxpr):
  if pred_type != ShapedArray((), np.dtype('bool')): raise TypeError
  jaxpr_type = typecheck_jaxpr(true_jaxpr)
  if jaxpr_type != typecheck_jaxpr(false_jaxpr):
    raise TypeError
  if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):
    raise TypeError
  return jaxpr_type.out_types
abstract_eval_rules[cond_p] = cond_abstract_eval

def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr):
  del in_avals  # Unused
  pred, *in_vals = in_vals
  flat_vals, in_tree = tree_flatten(in_vals)
  operand = xops.Tuple(c, flat_vals)
  operand_shape = c.get_shape(operand)

  def make_comp(name: str, jaxpr: Jaxpr) -> xe.XlaComputation:
    c = xc.XlaBuilder(name)
    operand = xops.Parameter(c, 0, operand_shape)
    operands = tree_unflatten(in_tree, destructure_tuple(c, operand))
    outs = jaxpr_subcomp(c, jaxpr, operands)
    return c.build(xops.Tuple(c, outs))

  true_comp = make_comp('true_fn', true_jaxpr)
  false_comp = make_comp('false_fn', false_jaxpr)

  int_etype = xc.dtype_to_etype(np.dtype('int32'))
  out = xops.Conditional(xops.ConvertElementType(pred, int_etype),
                         [false_comp, true_comp], [operand] * 2)
  return destructure_tuple(c, out)
xla_translations[cond_p] = cond_translation
out = jit(lambda: cond(False, lambda: 1, lambda: 2))()
print(out)
2

Finally, to support reverse-mode automatic differentiation, we need partial evaluation and transposition rules. For partial evaluation, we need to introduce another jaxpr-munging utility, _join_jaxpr_res, to handle the fact that applying partial evaluation to true_fun and false_fun will in general result in distinct residuals. We use _join_jaxpr_res to make the output types of the transformed jaxprs consistent (while _join_jaxpr_consts dealt with input types).

def cond_partial_eval(trace, tracers, *, true_jaxpr, false_jaxpr):
  pred_tracer, *tracers = tracers
  assert pred_tracer.pval.is_known
  pred = pred_tracer.pval.const
  in_uks = [not t.pval.is_known for t in tracers]

  *jaxprs, out_uks, num_res = _cond_partial_eval(true_jaxpr, false_jaxpr, in_uks)
  t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2 = jaxprs

  known_tracers, unknown_tracers = partition_list(in_uks, tracers)
  known_vals = [t.pval.const for t in known_tracers]
  outs1_res = bind_cond(pred, *known_vals,
                        true_jaxpr=t_jaxpr1, false_jaxpr=f_jaxpr1)
  outs1, res = split_list(outs1_res, len(outs1_res) - num_res)
  pred_tracer_ = trace.instantiate_const(full_raise(trace, pred_tracer))
  res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]
  outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None)
           for v in t_jaxpr2.outs]
  eqn = JaxprEqnRecipe(cond_p, [pred_tracer_, *res_tracers, *unknown_tracers],
                       dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),
                       [v.aval for v in t_jaxpr2.outs], map(ref, outs2))
  for t in outs2: t.recipe = eqn
  return merge_lists(out_uks, outs1, outs2)
partial_eval_rules[cond_p] = cond_partial_eval

def _cond_partial_eval(true_jaxpr: Jaxpr, false_jaxpr: Jaxpr, in_uks: list[bool]
                       ) -> tuple[Jaxpr, Jaxpr, Jaxpr, Jaxpr, list[bool], int]:
  _, _, t_out_uks, _ = partial_eval_jaxpr(true_jaxpr , in_uks)
  _, _, f_out_uks, _ = partial_eval_jaxpr(false_jaxpr, in_uks)
  out_uks = map(op.or_, t_out_uks, f_out_uks)

  t_jaxpr1, t_jaxpr2, _, t_nres = partial_eval_jaxpr(true_jaxpr , in_uks, out_uks)
  f_jaxpr1, f_jaxpr2, _, f_nres = partial_eval_jaxpr(false_jaxpr, in_uks, out_uks)

  t_jaxpr1, f_jaxpr1 = _join_jaxpr_res(t_jaxpr1, f_jaxpr1, t_nres, f_nres)
  t_jaxpr2, f_jaxpr2 = _join_jaxpr_consts(t_jaxpr2, f_jaxpr2, t_nres, f_nres)
  assert typecheck_jaxpr(t_jaxpr1) == typecheck_jaxpr(f_jaxpr1)
  assert typecheck_jaxpr(t_jaxpr2) == typecheck_jaxpr(f_jaxpr2)
  num_res = t_nres + f_nres

  return t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2, out_uks, num_res

def _join_jaxpr_res(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int
                    ) -> tuple[Jaxpr, Jaxpr]:
  jaxpr1_type, jaxpr2_type = typecheck_jaxpr(jaxpr1), typecheck_jaxpr(jaxpr2)
  out_types1, _ = split_list(jaxpr1_type.out_types, len(jaxpr1.outs) - n1)
  out_types2, _ = split_list(jaxpr2_type.out_types, len(jaxpr2.outs) - n2)
  assert out_types1 == out_types2
  outs1, res1 = split_list(jaxpr1.outs, len(jaxpr1.outs) - n1)
  outs2, res2 = split_list(jaxpr2.outs, len(jaxpr2.outs) - n2)
  zeros_like1 = [Lit(np.zeros(v.aval.shape, v.aval.dtype)) for v in res1]
  zeros_like2 = [Lit(np.zeros(v.aval.shape, v.aval.dtype)) for v in res2]
  new_jaxpr1 = Jaxpr(jaxpr1.in_binders, jaxpr1.eqns, outs1 + res1 + zeros_like2)
  new_jaxpr2 = Jaxpr(jaxpr2.in_binders, jaxpr2.eqns, outs2 + zeros_like1 + res2)
  return new_jaxpr1, new_jaxpr2
_, f_lin = linearize(lambda x: cond(True, lambda: x, lambda: 0.), 1.)
out = f_lin(3.14)
print(out)
3.14
def cond_peval_eqn(unks_in: list[bool], eqn: JaxprEqn,
                   ) -> tuple[JaxprEqn, JaxprEqn, list[bool], list[Atom]]:
  pred_unk, *unks_in = unks_in
  assert not pred_unk
  true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']
  *jaxprs, unks_out, num_res = _cond_partial_eval(true_jaxpr, false_jaxpr, unks_in)
  t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2 = jaxprs
  ins1, ins2 = partition_list(unks_in, eqn.inputs[1:])
  outs1, outs2 = partition_list(unks_out, eqn.out_binders)
  residuals, _ = split_list(t_jaxpr2.in_binders, num_res)
  eqn1 = JaxprEqn(cond_p, [eqn.inputs[0], *ins1],
                  dict(true_jaxpr=t_jaxpr1, false_jaxpr=f_jaxpr1),
                  outs1 + residuals)
  eqn2 = JaxprEqn(cond_p, [eqn.inputs[0], *residuals, *ins2],
                  dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),
                  outs2)
  res = [eqn.inputs[0], *residuals] if type(eqn.inputs[0]) is Var else residuals
  return eqn1, eqn2, unks_out, res
partial_eval_jaxpr_rules[cond_p] = cond_peval_eqn
_, f_lin = linearize(jit(lambda x: cond(True, lambda: x, lambda: 0.)), 1.)
out = f_lin(3.14)
print(out)
3.14

Transposition is a fairly straightforward application of transpose_jaxpr:

def cond_transpose_rule(cts, pred, *invals, true_jaxpr, false_jaxpr):
  undef_primals = tuple(type(x) is UndefPrimal for x in invals)
  true_jaxpr, true_consts = transpose_jaxpr(true_jaxpr, undef_primals)
  false_jaxpr, false_consts = transpose_jaxpr(false_jaxpr, undef_primals)
  true_jaxpr, false_jaxpr = _join_jaxpr_consts(
      true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
  res = [x for x in invals if type(x) is not UndefPrimal]
  outs = bind_cond(pred, *true_consts, *false_consts, *res, *cts,
                   true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
  outs = iter(outs)
  return [None] + [next(outs) if type(x) is UndefPrimal else None for x in invals]
transpose_rules[cond_p] = cond_transpose_rule
out = grad(lambda x: cond(True, lambda: x * x, lambda: 0.))(1.)
print(out)
2.0
Hide code cell source
def pprint_cond(names: defaultdict[Var, str], eqn: JaxprEqn) -> PPrint:
  true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']
  new_params = {k:v for k, v in eqn.params.items() if not k.endswith('jaxpr')}
  lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
  rhs = (pp(eqn.primitive.name) >> pp_params(new_params) >>
         pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
                     for x in eqn.inputs)))
  return vcat([lhs >> pp(' = ') >> rhs,
               pp_jaxpr(true_jaxpr).indent(2),
               pp_jaxpr(false_jaxpr).indent(2)])
pp_rules[cond_p] = pprint_cond

JAX Enhancement Proposals (JEPs)#

Most changes can be discussed with simple issues/discussions and pull requests.

Some changes though are a bit larger in scope or require more discussion, and these should be implemented as JEP. This allows for writing longer documents that can be discussed in a pull request themselves.

The structure of JEPs is kept as lightweight as possible to start and might be extended later on.

When you should use a JEP#

  • When your change requires a design doc. We prefer collecting the designs as JEPs for better discoverability and further reference.

  • When your change requires extensive discussion. It’s fine to have relatively short discussions on issues or pull requests, but when the discussion gets longer this becomes unpractical for later digestion. JEPs allow to update the main document with a summary of the discussion and these updates can be discussed themselves in the pull request adding the JEP.

How to start a JEP#

First, create an issue with the JEP label. All pull requests that relate to the JEP (i.e. adding the JEP itself as well as any implementing pull requests) should be linked to this issue.

Then create a pull request that adds a file named %d-{short-title}.md - with the number being the issue number.

JAX PRNG Design#

We want a PRNG design that

  1. is expressive in that it is convenient to use and it doesn’t constrain the user’s ability to write numerical programs with exactly the behavior that they want,

  2. enables reproducible program execution in a backend-independent way,

  3. has semantics that are invariant to @jit compilation boundaries and device backends,

  4. enables vectorization for generating array values using SIMD hardware,

  5. is parallelizable in that it doesn’t add sequencing constraints between random function calls that otherwise would have no data dependence,

  6. scales to multi-replica, multi-core, and distributed computation,

  7. fits with JAX and XLA semantics and design philosophies (which are ultimately motivated by other practical concerns).

As a corollary of these we believe the design should be functional. Another corollary is that, at least given current hardware constraints, we’re going to do the PRNG in software.

TLDR JAX PRNG = Threefry counter PRNG + a functional array-oriented splitting model

Contents#
Three programming models and toy example programs#

Here’s a toy example of a stateful global PRNG like the one often used in Numpy programs:

def foo(): return bar() + baz()
def bar(): return rand(RNG, (3, 4))
def baz(): return rand(RNG, (3, 4))
def main():
  global RNG
  RNG = RandomState(0)
  return foo()

To achieve reproducibility here we would need to control the order of evaluation for bar() and baz() even though there is no explicit data dependence from one to the other. This kind of sequencing requirement stemming from reproducibility (#2) violates parallelizability (#5) and doesn’t fit with JAX or XLA’s functional semantics (#6) in which subexpressions can be evaluated in any order. Even if we didn’t require reproducibility and thus allowed any evaluation order, parallelization across calls (#5) would still be made difficult by the need to update shared state. Moreover, because the same PRNG state would need to be accessed and maintained in both Python and any compiled code, this model would likely lead to engineering challenges to achieve compilation invariance (#3) and scaling to multiple replicas (#6). Finally, the expressiveness is limited (#1) because there is no way for foo() to call bar() or baz() without affecting its own (implicit) PRNG state.

Whether the model supports vectorization (#4) depends on some additional details. In Numpy, PRNG vectorization is limited by a sequential-equivalent guarantee:

In [1]: rng = np.random.RandomState(0)

In [2]: rng.randn(2)
Out[2]: array([1.76405235, 0.40015721])

In [3]: rng = np.random.RandomState(0)

In [4]: np.stack([rng.randn() for _ in range(2)])
Out[4]: array([1.76405235, 0.40015721])

To allow for vectorization (#4) within primitive PRNG function calls that generate arrays (e.g. to rand() with a shape argument), we drop this sequential-equivalent guarantee. This vectorization can be supported by any of the three programming models discussed in this section, though it motivates the implementation in terms of a counter-based PRNG as described in the next section.

The stateful PRNG user programming model is not promising. Here’s an example of a functional model but lacking a key ingredient that we call splitting:

def foo(rng_1):
   y, rng_2 = baz(rng_1)
   z, rng_3 = bar(rng_2)
   return y + z, rng_3

def bar(x, rng):
  val, new_rng = rand(rng, (3, 4))
  return val, new_rng

def baz(x, rng):
  val, new_rng = rand(rng, (3, 4))
  return val, new_rng

def main():
  foo(RandomState(0))

This model explicitly threads the PRNG state through all functions (primitive or non-primitive) that generate random values: that is, every random function must both accept and return the state. Now there is an explicit data dependence between the call to baz() and the call to bar() in foo(), so the data flow (and hence sequencing) is made explicit and fits with JAX’s existing semantics (#7), unlike in the previous model. This explicit threading can also make the semantics invariant to compilation boundaries (#3).

Explicit threading is inconvenient for the programmer. But worse, it hasn’t actually improved the expressiveness (#1): there is still no way for foo() to call into bar() or baz() while maintaining its own PRNG state. Without knowledge of their callers or the subroutines they call, functions must defensively pass in and return the rng state everywhere. Moreover, it also doesn’t improve the prospects for parallelization (#5) or scaling to multiple replicas (#6) because everything is still sequential, even if the sequencing is made explicit in the functional programming sense.

In short, making the code functional by explicitly threading state isn’t enough to achieve our expressiveness (#1) and performance (#5, #6) goals.

The key problem in both the previous models is that there’s too much sequencing. To reduce the amount of sequential dependence we use functional splittable PRNGs. Splitting is a mechanism to ‘fork’ a new PRNG state into two PRNG states while maintaining the usual desirable PRNG properties (the two new streams are computationally parallelizable and produce independent random values, i.e. they behave like multistreams).

def foo(rng_1):
   rng_2, rng_3 = split(rng_1, 2)
   return bar(rng_2) + baz(rng_3)

def bar(x, rng):
  return rand(rng, (3, 4))

def baz(x, rng):
  return rand(rng, (3, 4))

def main():
  foo(RandomState(0))

Some points to notice:

  1. there is no sequential dependence between the calls to bar() and baz() and they can be evaluated in either order without affecting the value of the result, which solves the remaining performance goals (#5, #6),

  2. functions do not need to return updated versions of PRNGs and it is straightforward to call a random subroutine without affecting existing PRNG states, improving the expressiveness (#1) from the other functional model.

The example doesn’t show it, but as a consequence of the choice (2) the only way to advance the PRNG state is to call split(). That is, we have two ways to achieve (1), and they differ in whether they burden the user program with explicit calls to split(), as in the above example, or instead burden the user program with explicit threading. We prefer the former, i.e. the version with explicit splitting, because we can easily implement the explicit-threading version in terms of it.

Design#

We can use the counter-based PRNG design, and in particular the Threefry hash function, as described in Parallel random numbers: as easy as 1, 2, 3. We use the counter to achieve efficient vectorization: for a given key we can generate an array of values in a vectorized fashion by mapping the hash function over a range of integers [k + 1, …, k + sample_size]. We use the key together with the hash function to implement splittable PRNGs: that is, splitting is a way to generate two new keys from an existing one.

type Sample = Int256
type Key = Sample  -- important identification for splitting
type Count = Int32

hash :: Key -> Count -> Int256  -- output type equal to Key and Sample

split :: Key -> (Key, Key)
split key = (hash key 0, hash key 1)

draw_samples :: Key -> Int -> [Sample]
draw_samples key n = map (hash key) [1..n]

Surprisingly, drawing a sample is very similar to splitting! The key is the difference in the type of the output (even though the types are identified): in one case the value is to be used in forming random samples of interest (e.g. turning random bits into a Float representing a random normal) while in the other case the value is to be used as a key for further hashing.

The asymmetry in the hash function arguments, of type Key and Count, is that the latter is trivial and computationally cheap to advance by an arbitrary amount, since we just need to increase the integer value, while the former is only advanced by hashing. That’s why we use the count argument for vectorization.

More realistic example user programs#

Here’s what a training loop on the host might look like when the step requires a PRNG (maybe for dropout or for VAE training):

rng = lax.rng.new_rng()
for i in xrange(num_steps):
  rng, rng_input = lax.rng.split(rng)
  params = compiled_update(rng_input, params, next(batches))

Notice that we’re burdening the user with explicit splitting of the rng, but the rng does not need to be returned from the code at all.

Here’s how we can use this PRNG model with the stax neural net builder library to implement dropout:

def Dropout(rate, mode='train'):
  def init_fun(input_shape):
    return input_shape, ()
  def apply_fun(rng, params, inputs):
    if mode == 'train':
      keep = lax.random.bernoulli(rng, rate, inputs.shape)
      return np.where(keep, inputs / rate, 0)
    else:
      return inputs
  return init_fun, apply_fun

The rng value here is just the key used for the hash, not a special object. The rng argument is passed to every apply_fun, and so it needs to be handled in the serial and parallel combinators with splitting:

def serial(*layers):
  init_funs, apply_funs = zip(*layers)
  def init_fun(input_shape):
    ...
  def apply_fun(rng, params, inputs):
    rngs = split(rng, len(layers))
    for rng, param, apply_fun in zip(rngs, params, apply_funs):
      inputs = apply_fun(rng, param, inputs)
    return inputs
  return init_fun, apply_fun

def parallel(*layers):
  init_funs, apply_funs = zip(*layers)
  def init_fun(input_shape):
    ...
  def apply_fun(rng, params, inputs):
    rngs = split(rng, len(layers))
    return [f(r, p, x) for f, r, p, x in zip(apply_funs, rngs, params, inputs)]
  return init_fun, apply_fun

Here we’re using a simple extended version of split that can produce multiple copies.

Tradeoffs and alternatives#
  1. We’re not exploiting any device hardware PRNG

    • We don’t currently have enough control over the hardware PRNG’s state for all backends.

    • Even if we did, it would be backend-dependent and we might have to introduce sequential dependencies between random calls to ensure deterministic ordering and hence reproducibility.

    • We don’t know of any workloads for which the software PRNG should become a bottleneck.

    • We could consider providing an additional API that allows access to a hardware PRNG for users who want to give up other desiderata (like strict reproducibility).

  2. We give up the sequential equivalent guarantee, in which creating a random array in one call produces the same values as creating the flattened array one random element at a time.

    • This property is likely incompatible with vectorization (a high priority).

    • We don’t know of any users or examples for which this property is important.

    • Users could write a layer on top of this API to provide this guarantee.

  3. We can’t follow the numpy.random API exactly.

Custom JVP/VJP rules for JAX-transformable functions#

This is a design document, explaining some of the thinking behind the design and implementation of jax.custom_jvp and jax.custom_vjp. For user-oriented documentation, see the tutorial notebook.

There are two ways to define differentiation rules in JAX:

  1. using jax.custom_jvp and jax.custom_vjp to define custom differentiation rules for Python functions that are already JAX-transformable; and

  2. defining new core.Primitive instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems.

This document is about #1 only.

Contents#
Goals#

We want users to customize the forward- and/or reverse-mode differentiation behavior of their code. This customization

  1. should have a clear and consistent semantics in how it works and how it composes with other JAX transformations; and

  2. should be flexible in supporting use cases and workflows like in Autograd and PyTorch, including cases involving differentiation of Python control flow and workflows for NaN debugging.

As JAX developers we want to write library functions, like logit and expit, that are defined in terms of other primitives, but for the purposes of differentiation have primitive-like behavior in the sense that we want to define custom differentiation rules for them, which may be more numerically stable or performant. In particular, we don’t want to have to specify vmap or jit rules for functions like logit and expit.

As a stretch goal, we’d like to make JAX a great environment for power users looking to add custom differentiation rules for higher-order functions like fixed_point, odeint, etc.; this design doc won’t solve that problem, but we want to be confident we’re not going to preclude good solutions to that problem.

That is, our primary goals are

  1. solve the vmap-removes-custom-jvp semantics problem (#1249), and

  2. allow Python in custom VJPs, e.g. to debug NaNs (#1275).

Secondary goals are 3. clean up and simplify user experience (symbolic zeros, kwargs, etc) 4. make progress towards a world where users can easily add fixed_point, odeint, root, etc.

Overall, we want to close #116, #1097, #1249, #1275, #1366, #1723, #1670, #1875, #1938, and replace the custom_transforms machinery (from #636, #818, and others).

Non-goals#

Here are objectives we’re not aiming to achieve:

  1. The custom_transforms machinery aimed to provide a transformation-generic mechanism for customizing behavior, in principle (though never really used in practice) allowing users to customize rules for any transformation while somehow inheriting the “transparent” behavior for others. We are instead only going to solve the customization problem for differentiation (JVP and VJP, separately). Differentiation is the only case actually requested, and by specializing to differentiation we can reduce complexity and improve flexibility. To control all rules one can just write a primitive.

  2. We’re not going to prioritize mathematical aesthetics over flexibility and clarity on the user side, and simplicity on the implementation side. In particular, while the custom VJP signature a -> (b, CT b --o CT a) is mathematically pleasing, if it’s hard to implement in a Python mechanism because of the closure in the return type, we’re fine doing something that handles residuals more explicitly.

  3. Serialization support, of the form where the staged-out serialized program representation can be loaded and further JAX-transformed as opposed to just evaluated, is currently out of scope for these custom JVP/VJP transformation rules. Serialization may be useful not only for researchers who want to save some representation of their computation (and transform it after loading it), but also for future considerations like having jaxpr transformations implemented outside Python, or having jaxprs as an MLIR dialect. By defining this as a non-goal for the purpose of this design, we have fewer constraints on where we can stash Python callables.

Main problem descriptions#
The vmap-removes-custom-jvp semantics problem#

The vmap-removes-custom-jvp semantics problem is that vmap does not compose properly with differentiation of functions with custom_transforms rules:

# old custom_transforms api to be replaced
@jax.custom_transforms
def f(x):
  return 2. * x

# f_vjp :: a -> (b, CT b --o CT a)
def f_vjp(x):
  return f(x), lambda g: 3. * x  # 3 instead of 2

jax.defvjp_all(f, f_vjp)

grad(f)(1.)  # 3.
vmap(grad(f))(np.ones(4))  # [3., 3., 3., 3.]
grad(lambda x: vmap(f)(x).sum())(np.ones(4))  # [2., 2., 2., 2.]

The last grad-of-vmap line has an unexpected result! In general, applying vmap, or really any non-differentiation transformation, has the effect of removing the custom differentiation rule. (Applying jvp causes a failure when a custom VJP rule is defined.)

The problem exists because transformations are like rewrites, and the vmap transformation effectively rewrites the function to no longer call the newly-introduced primitive for which there is a custom rule (and hence grad then doesn’t produce the custom rule’s result). In more detail, the custom_transforms machinery sets things up so that evaluating f(x) applies the function

{ lambda  ; ; a.
  let b = f_primitive a
  in [b] }

where f_primitive is a new primitive (introduced for every custom_transforms function and in fact for every call of the function) to which the custom VJP rule is associated. When we evaluate grad(f)(x), the differentiation machinery encounters f_primitive and processes it with the custom rule.

However, because f_primitive is transparent to vmap, in the sense that vmap operates on (effectively by inlining) the definition of f_primitive, the function vmap(f) is effectively

{ lambda  ; ; a.
  let b = mul 2. a
  in [b] }

In words, vmap rewrites the function in terms of its underlying primitives and their transformation rules, removing f_primitive entirely.

More generally, because vmap(f) has semantics defined in terms of calls to f, it is semantically inconsistent to remove the custom derivative rule. That is, since we define

vmap(f)(xs) == np.stack([f(x) for x in xs])

we must have

jvp(vmap(f))(xs) == jvp(lambda xs: np.stack([f(x) for x in xs]))

yet this property is not observed when f has a custom derivative rule defined, as the custom derivative rule is used in the right-hand version but not the left-hand one.

This issue isn’t specific to vmap; it applies to all transformations for which the semantics of transforming a function f are defined in terms of calls to the function f, rather than rewriting it into another function. The mask transformation also falls into this class. Differentiation transforms and the hypothetical all-unary-functions-become-cosine transform are not in this class.

(The interaction between additional custom rules, like custom vmap rules, is likely to get even more complex, suggesting the problem framing of custom_transforms is too broad.)

The Python flexibility problem#

In JAX, as in Autograd and PyTorch but not TF1, differentiation of a Python function is performed while the function is being executed and traced. This behavior delights users for a few reasons.

First and most importantly, it enables pdb-based workflows, e.g. for inspecting numerics or catching NaNs. That is, users can employ the standard Python debugger and other Python-native tools to debug their code, even being able to inspect runtime values to understand numerical behavior on examples and to catch fundamentally runtime errors like NaNs. In fact, just while working on the PR corresponding to this design, especially on the odeint primitive, I used runtime value inspection to debug issues many times, increasing my confidence that this is a key user workflow in Python. One especially handy trick, which I’ve used in both JAX and Autograd many times, is the ability to insert a debugger breakpoint in a custom VJP rule to enter a debugger at a specific point in the backward pass.

Second, it allows differentiation of Python native control flow. We’re not sure how often this is used in practice in finalized software artifacts, but when users first poke around JAX or Autograd they’re often impressed by this freedom. There’s a reason we include it at the top of our JAX and Autograd READMEs, slide decks, and demos. Ceding this capability would be a step backward from Autograd. We want JAX to have the best automatic differentiation.

However, the custom_transforms machinery does not provide this Python-support flexibility. That is, because it’s implemented in terms of up-front jaxpr formation from the Python code for both the user function and custom differentiation rules, code like this leads to an abstract value tracing error:

# old custom_transforms api to be replaced
@jax.custom_transforms
def f(x):
  if x > 0:
    return x
  else:
    return 0.

def f_vjp(x):
  return ...

jax.defvjp_all(f, f_vjp)

grad(f)(1.)  # Error!
Solution idea#

The main idea is that dougalm@ already solved these problems with core.call. That is, we can frame the task of specifying a custom JVP rule for a user function in terms of a new Python-level call primitive (not to be added to the jaxpr language; see below). This new call primitive has a user Python function associated with it just like core.call, but additionally has a second Python callable representing the JVP rule. Let’s refer to this new call primitive as custom_jvp_call.

Transformations like vmap interact with custom_jvp_call as with core.call: they effectively pass right through it and are applied to the underlying Python callables. Schematically, writing in terms of curried versions of the primitives for convenience, analogously to how vmap interacts with core.call by applying to the function to be called:

vmap(call(f)) == call(vmap(f))

for the new primitive custom_jvp_call we simply apply vmap to the two functions it entails:

vmap(custom_jvp_call(f, f_jvp)) == custom_jvp_call(vmap(f), vmap(f_jvp))

This behavior means we’ve solved the vmap-removes-custom-jvp semantics problem.

The jvp transformation interacts as one might expect: it just calls f_jvp,

jvp(call(f)) == call(jvp(f))

jvp(custom_jvp_call(f, f_jvp)) == f_jvp

Because custom_jvp_call acts like core.call (and not like xla.xla_call) in that it doesn’t raise the abstraction level of its inputs (because it’s not delaying anything or staging anything out), it means we’ve solved the Python flexibility problem: there are no constraints on the user Python function (above the usual functional programming constraints required by jvp or vjp).

What about evaluation and compilation? These are two ways to “exit” the JAX system, in the sense that no additional transformations can be applied after these steps. As a result, their rules are trivial:

eval(call(f)) == eval(f)
jit(call(f)) == hlo_call(jit(f))

eval(custom_jvp_call(f, f_jvp)) == eval(f)
jit(custom_jvp_call(f, f_jvp)) == hlo_call(jit(f))

In words, if a JVP rule hasn’t already rewritten custom_jvp_call(f, f_jvp) into f_jvp, when we get to the point of evaluation with eval or staging out to XLA with jit, differentiation is never going to be applied, so we just ignore f_jvp and behave just like core.call. However, due to the wrinkle discussed next, the partial eval rule for custom_jvp_call must be a bit more complex, since partial evaluation isn’t just used to stage out to XLA with jit.

The only remaining wrinkle has to do with “initial-style” jaxpr-forming primitives, like lax.scan, and their transformation rules. These represent a different kind of “staging out to a jaxpr” than that for compilation because we can perform additional transformations on the staged-out jaxpr. That is, when lax.scan forms a jaxpr, it does not exit the transformation system, since when we apply a jvp or vmap to a lax.scan we need to apply it to the function represented by the jaxpr.

Another way to state the wrinkle is that initial-style primitives like lax.scan rely on the ability to round-trip to a jaxpr and back to a Python callable while preserving semantics. That must mean preserving custom differentiation rule semantics too.

The solution is to use a bit of dynamic scoping: when we’re staging out to a jaxpr for an initial-style primitive, like those in lax_control_flow.py, we set a bit on the global trace state. When that bit is set, instead of using the final-style custom_jvp_call primitive, we use an initial-style custom_jvp_call_jaxpr primitive, and trace the functions f and f_jvp to jaxprs up-front to make initial-style processing easier. The custom_jvp_call_jaxpr primitive is otherwise similar to the final-style version.

(Footnote: while morally we form jaxprs for both f and f_jvp before binding custom_jvp_call_jaxpr, we need to delay the formation of the jaxpr of f_jvp because it may call the custom-JVP function and thus eager processing would lead to an infinite recursion. We delay that jaxpr formation in a thunk.)

If we gave up on the Python flexibility problem, we could get away with only having custom_jvp_call_jaxpr and not having the separate Python-level primitive custom_jvp_call.

API#

The custom JVP for an a -> b function is specified with an (a, Ta) -> (b, T b) function:

# f :: a -> b
@jax.custom_jvp
def f(x):
  return np.sin(x)

# f_jvp :: (a, T a) -> (b, T b)
def f_jvp(primals, tangents):
  x, = primals
  t, = tangents
  return f(x), np.cos(x) * t

f.defjvp(f_jvp)

(Interesting autodiff aside: for the rule to apply to higher-order differentiation, one must call f in the body of f_jvp; that precludes some kinds of work sharing between the internals of f and the tangent calculation.)

The custom VJP for an a -> b function is specified with an a -> (b, c) forward pass function paired with a (c, CT b) -> CT a backward pass function:

# f :: a -> b
@jax.custom_vjp
def f(x):
  return np.sin(x)

# f_fwd :: a -> (b, c)
def f_fwd(x):
  return f(x), np.cos(x)

# f_bwd :: (c, CT b) -> CT a
def f_bwd(cos_x, g):
  return (cos_x * g,)

f.defvjp(f_fwd, f_bwd)

The signature a -> (b, CT b --o CT a) is more aesthetically pleasing, but supporting it would make the implementation more complex and might require compromising expressibility desiderata. The basic reason that Python callables are opaque (unless we trace them to a jaxpr eagerly, which places expressiveness constraints), and in this case we may be returning a callable with vmap tracers inside its closure that we need to know about during the forward pass.

We could add convenience wrappers, for example to define the JVP rule for a single argument at a time (like we do internally for primitives). But because this proposal is complicated enough as it is, I decided against convenience layers; let’s keep things minimal for now.

There are some other bells and whistles to the API:

  • Inputs and output types a, b, and c can be arbitrary pytrees of jaxtypes.

  • Passing arguments by name (keyword arguments) is supported when they can be resolved to positions using the inspect module. This is a bit of an experiment with Python 3’s improved ability to programmatically inspect argument signatures. I believe it is sound but not complete, which is a fine place to be. (See also #2069.)

  • Arguments can be marked non-differentiable using nondiff_argnums, and as with jit’s static_argnums these arguments don’t have to be JAX types. We need to set a convention for how these arguments are passed to the rules. For a primal function with type signature (d, a) -> b where d represents the non-differentiable type, the JVP rule’s signature is (a, T a, d) -> T b and the VJP rule’s reverse component signature is (d, c, CT b) -> CT a. That is, the non-differentiable arguments are passed in order after primals and tangents for a custom JVP rule, and passed in order preceding the residuals in a custom VJP rule’s reverse function.

Implementation notes#
  • Updated jax.experimental.odeint

    • Since odeint is a pretty complex user of a custom VJP rule, in addition to just updating it to work at all, I wanted to revise it to be a canonical user of the new custom VJP API as a way to test that the API was a good one.

    • Along the way I made other improvements to the odeint implementation:

      • remove raveling/unraveling boilerplate

      • make use of lax.scan to remove the index-update logic

      • speed up by 20+% on the simple pendulum benchmark

  • Added a custom bind method on each transform for the custom derivative call primitives, custom_jvp_call and custom_vjp_call. It’s like core.call_bind, except we don’t process env traces: those are just errors.

  • Added custom_lin primitive, which gets staged out into linear jaxprs to be transposed when using a custom VJP rule.

    • Because our reverse-mode autodiff is decomposed into linearization, partial evaluation, and transposition, our custom VJP rules are processed in two separate steps: one during linearization and one during transposition.

    • The linearization step, i.e. the JVP rule for custom_vjp_call, applies custom_lin to the tangent values; custom_lin carries with it the user’s custom backward-pass function, and as a primitive it only has a transpose rule.

    • This mechanism is described more in #636.

  • To prevent

custom_vjp and nondiff_argnums update guide#

mattjj@ Oct 14 2020

This doc assumes familiarity with jax.custom_vjp, as described in the Custom derivative rules for JAX-transformable Python functions notebook.

What to update#

After JAX PR #4008, the arguments passed into a custom_vjp function’s nondiff_argnums can’t be Tracers (or containers of Tracers), which basically means to allow for arbitrarily-transformable code nondiff_argnums shouldn’t be used for array-valued arguments. Instead, nondiff_argnums should be used only for non-array values, like Python callables or shape tuples or strings.

Wherever we used to use nondiff_argnums for array values, we should just pass those as regular arguments. In the bwd rule, we need to produce values for them, but we can just produce None values to indicate there’s no corresponding gradient value.

For example, here’s the old way to write clip_gradient, which won’t work when hi and/or lo are Tracers from some JAX transformation.

from functools import partial
import jax

@partial(jax.custom_vjp, nondiff_argnums=(0, 1))
def clip_gradient(lo, hi, x):
  return x  # identity function

def clip_gradient_fwd(lo, hi, x):
  return x, None  # no residual values to save

def clip_gradient_bwd(lo, hi, _, g):
  return (jnp.clip(g, lo, hi),)

clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)

Here’s the new, awesome way, which supports arbitrary transformations:

import jax

@jax.custom_vjp  # no nondiff_argnums!
def clip_gradient(lo, hi, x):
  return x  # identity function

def clip_gradient_fwd(lo, hi, x):
  return x, (lo, hi)  # save lo and hi values as residuals

def clip_gradient_bwd(res, g):
  lo, hi = res
  return (None, None, jnp.clip(g, lo, hi))  # return None for lo and hi

clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)

If you use the old way instead of the new way, you’ll get a loud error in any case where something might go wrong (namely when there’s a Tracer passed into a nondiff_argnums argument).

Here’s a case where we actually need nondiff_argnums with custom_vjp:

from functools import partial
import jax

@partial(jax.custom_vjp, nondiff_argnums=(0,))
def skip_app(f, x):
  return f(x)

def skip_app_fwd(f, x):
  return skip_app(f, x), None

def skip_app_bwd(f, _, g):
  return (g,)

skip_app.defvjp(skip_app_fwd, skip_app_bwd)
Explanation#

Passing Tracers into nondiff_argnums arguments was always buggy. While there were some cases that worked correctly, others would lead to complex and confusing error messages.

The essence of the bug was that nondiff_argnums was implemented in a way that acted very much like lexical closure. But lexical closure over Tracers wasn’t at the time intended to work with custom_jvp/custom_vjp. Implementing nondiff_argnums that way was a mistake!

PR #4008 fixes all lexical closure issues with custom_jvp and custom_vjp. Woohoo! That is, now custom_jvp and custom_vjp functions and rules can close over Tracers to our hearts’ content. For all non-autodiff transformations, things will Just Work. For autodiff transformations, we’ll get a clear error message about why we can’t differentiate with respect to values over which a custom_jvp or custom_vjp closes:

Detected differentiation of a custom_jvp function with respect to a closed-over value. That isn’t supported because the custom JVP rule only specifies how to differentiate the custom_jvp function with respect to explicit input parameters.

Try passing the closed-over value into the custom_jvp function as an argument, and adapting the custom_jvp rule.

In tightening up and robustifying custom_jvp and custom_vjp in this way, we found that allowing custom_vjp to accept Tracers in its nondiff_argnums would take a significant amount of bookkeeping: we’d need to rewrite the user’s fwd function to return the values as residuals, and rewrite the user’s bwd function to accept them as normal residuals (rather than accepting them as special leading arguments, as happens with nondiff_argnums). This seems maybe manageable, until you think through how we have to handle arbitrary pytrees! Moreover, that complexity isn’t necessary: if user code treats array-like non-differentiable arguments just like regular arguments and residuals, everything already works. (Before #4039 JAX might’ve complained about involving integer-valued inputs and outputs in autodiff, but after #4039 those will just work!)

Unlike custom_vjp, it was easy to make custom_jvp work with nondiff_argnums arguments that were Tracers. So these updates only need to happen with custom_vjp.

Omnistaging#

mattjj@ Sept 25 2020

This is more of an upgrade guide than a design doc.

Contents#
tl;dr#
What’s going on?#

A change to JAX’s tracing infrastructure called “omnistaging” (google/jax#3370) was switched on in jax==0.2.0. This change improves memory performance, trace execution time, and simplifies jax internals, but may cause some existing code to break. Breakage is usually a result of buggy code, so long-term it’s best to fix the bugs, but omnistaging can also be disabled as a temporary workaround. And we’re happy to help you with fixes!

How do I know if omnistaging broke my code?#

The easiest way to tell if omnistaging is responsible is to disable omnistaging and see if the issues go away. See the What issues can arise when omnistaging is switched on? section below.

How can I disable omnistaging for now?#

Note: this applies to JAX versions 0.2.0 through 0.2.11; omnistaging cannot be disabled in JAX versions 0.2.12 and higher

It is temporarily possible to disable omnistaging by

  1. setting the shell environment variable JAX_OMNISTAGING to something falsey;

  2. setting the boolean flag jax_omnistaging to something falsey if your code parses flags with absl;

  3. using this statement near the top of your main file:

jax.config.disable_omnistaging()
How do I fix bugs exposed by omnistaging?#

By far the most common issue with omnistaging is using jax.numpy to compute shape values or other trace-time constants. See the code block below for a quick example, and for full details along with other issues see the section What issues can arise when omnistaging is switched on?.

Instead of this:

@jit
def f(x):
  input_size = jnp.prod(x.shape)
  if input_size > 100:
    ...

do this:

import numpy as np

@jit
def f(x):
  input_size = np.prod(x.shape)
  if input_size > 100:
    ...

Instead of thinking of jax.numpy as a drop-in replacement for numpy, it’s now better to think of using jax.numpy operations only when you want to perform a computation on an accelerator (like your GPU).

What is “omnistaging” and why is it useful?#

Omnistaging is the name for a JAX core upgrade aimed at staging out more computation from op-by-op Python to XLA, and avoiding any “trace-time constant folding” in jit, pmap, and control flow primitives. As a result, omnistaging improves JAX’s memory performance (sometimes dramatically) both by reducing fragmentation during tracing and by producing fewer large compile-time constants for XLA. It can also improve tracing performance by eliminating op-by-op execution at tracing time. Further, omnistaging simplifies JAX core internals, fixing many outstanding bugs and setting the stage for important upcoming features.

The name “omnistaging” means staging out everything possible.

Toy example#

JAX transformations like jit and pmap stage out computations to XLA. That is, we apply them to functions comprising multiple primitive operations so that rather being executed one at a time from Python the operations are all part of one end-to-end optimized XLA computation.

But exactly which operations get staged out? Until omnistaging, JAX staged out computation based on data dependence only. Here’s an example function, followed by the XLA HLO program it stages out before the omnistaging change:

from jax import jit
import jax.numpy as jnp

@jit
def f(x):
  y = jnp.add(1, 1)
  return x * y

f(3)
ENTRY jit_f.6 {
  constant.2 = pred[] constant(false)
  parameter.1 = s32[] parameter(0)
  constant.3 = s32[] constant(2)
  multiply.4 = s32[] multiply(parameter.1, constant.3)
  ROOT tuple.5 = (s32[]) tuple(multiply.4)
}

Notice that the add operation is not staged out. Instead, we only see a multiply.

Here’s the HLO generated from this function after the omnistaging change:

ENTRY jit_f.8 {
  constant.2 = pred[] constant(false)
  parameter.1 = s32[] parameter(0)
  constant.3 = s32[] constant(1)
  constant.4 = s32[] constant(1)
  add.5 = s32[] add(constant.3, constant.4)
  multiply.6 = s32[] multiply(parameter.1, add.5)
  ROOT tuple.7 = (s32[]) tuple(multiply.6)
}
Slightly less toy example#

Here’s a less toy example which can arise in practice when we want to create boolean masks:

import jax.numpy as jnp
from jax import lax

@jit
def select_tril(x):
  mask = jnp.arange(x.shape[0])[:, None] > jnp.arange(x.shape[1])
  return lax.select(mask, x, jnp.zeros_like(x))  # lax.select is like jnp.where

x = np.arange(12).reshape((3, 4))
select_tril(x)

Before omnistaging:

ENTRY jit_select_tril.8 {
  constant.3 = pred[] constant(false)
  constant.1 = pred[3,4]{1,0} constant({...})
  parameter.2 = s32[3,4]{1,0} parameter(0)
  constant.4 = s32[] constant(0)
  broadcast.5 = s32[3,4]{1,0} broadcast(constant.4), dimensions={}
  select.6 = s32[3,4]{1,0} select(constant.1, parameter.2, broadcast.5)
  ROOT tuple.7 = (s32[3,4]{1,0}) tuple(select.6)
}

The select operation is staged out, but the operations for constructing the constant mask are not. Rather than being staged out, the operations that construct mask are executed op-by-op at Python tracing time, and XLA only sees a compile time constant constant.1 representing the value of mask. That’s unfortunate, because if we had staged out the operations for constructing mask, XLA could have fused them into the select and avoided materializing the result at all. As a result we end up wasting memory with a potentially-large constant, wasting time dispatching multiple un-fused op-by-op XLA computations, and potentially even fragmenting memory.

(The broadcast that corresponds to the construction of the zeros array for jnp.zeros_like(x) is staged out because JAX is lazy about very simple expressions from google/jax#1668. After omnistaging, we can remove that lazy sublanguage and simplify JAX internals.)

The reason the creation of mask is not staged out is that, before omnistaging, jit operates based on data dependence. That is, jit stages out only those operations in a function that have a data dependence on an argument. Control flow primitives and pmap behave similarly. In the case of select_tril, the operations to construct the constant mask do not have a data dependence on the argument x, so they are not staged out; only the lax.select call has a data dependence.

With omnistaging all jax.numpy calls in the dynamic context of a jit-transformed function are staged out to XLA. That is, after omnistaging the computation XLA sees for select_tril is

ENTRY jit_select_tril.16 {
  constant.4 = pred[] constant(false)
  iota.1 = s32[3]{0} iota(), iota_dimension=0
  broadcast.5 = s32[3,1]{1,0} broadcast(iota.1), dimensions={0}
  reshape.7 = s32[3]{0} reshape(broadcast.5)
  broadcast.8 = s32[3,4]{1,0} broadcast(reshape.7), dimensions={0}
  iota.2 = s32[4]{0} iota(), iota_dimension=0
  broadcast.6 = s32[1,4]{1,0} broadcast(iota.2), dimensions={1}
  reshape.9 = s32[4]{0} reshape(broadcast.6)
  broadcast.10 = s32[3,4]{1,0} broadcast(reshape.9), dimensions={1}
  compare.11 = pred[3,4]{1,0} compare(broadcast.8, broadcast.10), direction=GT
  parameter.3 = s32[3,4]{1,0} parameter(0)
  constant.12 = s32[] constant(0)
  broadcast.13 = s32[3,4]{1,0} broadcast(constant.12), dimensions={}
  select.14 = s32[3,4]{1,0} select(compare.11, parameter.3, broadcast.13)
  ROOT tuple.15 = (s32[3,4]{1,0}) tuple(select.14)
}
What issues can arise when omnistaging is switched on?#

As a consequence of staging out all jax.numpy operations from Python to XLA when in the dynamic context of a jit or pmap, some code that worked previously can start raising loud errors. As explained below, these behaviors were already buggy before omnistaging, but omnistaging makes them into hard errors.

Using jax.numpy for shape computations#
Example#
from jax import jit
import jax.numpy as jnp

@jit
def ex1(x):
  size = jnp.prod(jnp.array(x.shape))
  return x.reshape((size,))

ex1(jnp.ones((3, 4)))
Error message#
[... full traceback ...]
  File "/home/mattjj/packages/jax/jax/core.py", line 862, in raise_concretization_error
    raise ConcretizationTypeError(msg)
jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

The error arose in jax.numpy.reshape.

While tracing the function ex1 at ex1.py:4, this value became a tracer due to JAX operations on these lines:

  operation c:int32[] = reduce_prod[ axes=(0,) ] b:int32[2]
    from line ex1.py:6 (ex1)

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.

See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>
Explanation#

With omnistaging, we can’t use jax.numpy for shape computations as in the use of jnp.prod above because in the dynamic context of a jit function those operations will be staged out of Python as values to be computed at execution time, yet we need them to be compile-time (and hence trace-time) constants.

Before omnistaging, this code wouldn’t have raised an error, but it was a common performance bug: the jnp.prod computation would have been executed on the device at tracing time, meaning extra compilation, transfers, synchronization, allocations, and potentially memory fragmentation.

Solution#

The solution is simply to use the original numpy for shape calculations like these. Not only do we avoid the error, but also we keep the computations on the host (and with lower overheads).

This issue was common enough in code that we tried to make the error message especially good. In addition to the stack trace showing where an abstract tracer value caused a problem (the jnp.reshape line in the full stack trace, on omni.py:10), we also explain why this value became a tracer in the first place by pointing to the upstream primitive operation that caused it to become an abstract tracer (the reduce_prod from jnp.prod on omni.py:9) and to which jit-decorated function the tracer belongs (ex1 on omni.py:6).

Side-effects#
Example#
from jax import jit
from jax import random

key = random.PRNGKey(0)

def init():
  global key
  key, subkey = random.split(key)
  return random.normal(subkey, ())

print(init())  # -1.2515389
print(init())  # -0.58665067

init = jit(init)
print(init())  # 0.48648298
print(init())  # 0.48648298  !!

That last call has repeated randomness but no hard error, because we aren’t re-executing the Python. But if we look at key, we see an escaped tracer when omnistaging is on:

print(key) # Traced<ShapedArray(uint32[2])>with<DynamicJaxprTrace(level=0/1)>

Before omnistaging, the random.split call would not be staged out and so we wouldn’t get an escaped tracer. The code would still be buggy in that the jitted function wouldn’t be reproducing the semantics of the original function (because of the repeated use of the same PRNG key), ultimately due to the side effect.

With omnistaging on, if we touch key again, we’ll get an escaped tracer error:

random.normal(key, ())
Error message#
[... full stack trace …]
  File "/home/mattjj/packages/jax/jax/interpreters/partial_eval.py", line 836, in _assert_live
    raise core.escaped_tracer_error(msg)
jax.core.UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state. Detail: tracer created on line example.py:8 (init).
Explanation#

The second largest category of omnistaging issues we found had to do with side-effecting code. This code already voided the JAX warranty by transforming effectful functions, but due to pre-omnistaging “trace-time constant folding” behavior, some side effecting functions could nevertheless behave correctly. Omnistaging catches more of these errors.

Solution#

The solution is to identify JAX-transformed functions that rely on side effects, and to rewrite them not to be effectful.

Small numerical differences based on XLA optimizations#

Because with omnistaging more computations are being staged out to XLA, rather than some being executed at trace time, that can have the effect of reordering floating point operations. As a result, we’ve seen numerical behaviors change in a way that causes tests with overly tight tolerances to fail when omnistaging is switched on.

Dependence on JAX internal APIs that changed#

Omnistaging involved some big revisions to JAX’s core code, including removing or changing internal functions. Any code that relies on such internal JAX APIs can break when omnistaging is switched on, either with build errors (from pytype) or runtime errors.

Triggering XLA compile time bugs#

Because omnistaging involves staging out more code to XLA, we’ve seen it trigger pre-existing XLA compile-time bugs on some backends. The best thing to do with these is to report them so we can work with the XLA teams on fixes.

JEP 9263: Typed keys & pluggable RNGs#

Jake VanderPlas, Roy Frostig

August 2023

Overview#

Going forward, RNG keys in JAX will be more type-safe and customizable. Rather than representing a single PRNG key by a length-2 uint32 array, it will be represented as a scalar array with a special RNG dtype that satisfies jnp.issubdtype(key.dtype, jax.dtypes.prng_key).

For now, old-style RNG keys can still be created with jax.random.PRNGKey():

>>> key = jax.random.PRNGKey(0)
>>> key
Array([0, 0], dtype=uint32)
>>> key.shape
(2,)
>>> key.dtype
dtype('uint32')

Starting now, new-style RNG keys can be created with jax.random.key():

>>> key = jax.random.key(0)
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]
>>> key.shape
()
>>> key.dtype
key<fry>

This (scalar-shaped) array behaves the same as any other JAX array, except that its element type is a key (and associated metadata). We can make non-scalar key arrays as well, for example by applying jax.vmap() to jax.random.key():

>>> key_arr = jax.vmap(jax.random.key)(jnp.arange(4))
>>> key_arr
Array((4,), dtype=key<fry>) overlaying:
[[0 0]
 [0 1]
 [0 2]
 [0 3]]
>>> key_arr.shape
(4,)

Aside from switching to a new constructor, most PRNG-related code should continue to work as expected. You can continue to use keys in jax.random APIs as before; for example:

# split
new_key, subkey = jax.random.split(key)

# random number generation
data = jax.random.uniform(key, shape=(5,))

However, not all numerical operations work on key arrays. They now intentionally raise errors:

>>> key = key + 1
ValueError: dtype=key<fry> is not a valid dtype for JAX type promotion.

If for some reason you need to recover the underlying buffer (the old-style key), you can do so with jax.random.key_data():

>>> jax.random.key_data(key)
Array([0, 0], dtype=uint32)

For old-style keys, key_data() is an identity operation.

What does this mean for users?#

For JAX users, this change does not require any code changes now, but we hope that you will find the upgrade worthwhile and switch to using typed keys. To try this out, replace uses of jax.random.PRNGKey() with jax.random.key(). This may introduce breakages in your code that fall into one of a few categories:

  • If your code performs unsafe/unsupported operations on keys (such as indexing, arithmetic, transposition, etc; see Type Safety section below), this change will catch it. You can update your code to avoid such unsupported operations, or use jax.random.key_data() and jax.random.wrap_key_data() to manipulate raw key buffers in an unsafe way.

  • If your code includes explicit logic about key.shape, you may need to update this logic to account for the fact that the trailing key buffer dimension is no longer an explicit part of the shape.

  • If your code includes explicit logic about key.dtype, you will need to upgrade it to use the new public APIs for reasoning about RNG dtypes, such as dtypes.issubdtype(dtype, dtypes.prng_key).

  • If you call a JAX-based library which does not yet handle typed PRNG keys, you can use raw_key = jax.random.key_data(key) for now to recover the raw buffer, but please keep a TODO to remove this once the downstream library supports typed RNG keys.

At some point in the future, we plan to deprecate jax.random.PRNGKey() and require the use of jax.random.key().

Detecting new-style typed keys#

To check whether an object is a new-style typed PRNG key, you can use jax.dtypes.issubdtype or jax.numpy.issubdtype:

>>> typed_key = jax.random.key(0)
>>> jax.dtypes.issubdtype(typed_key.dtype, jax.dtypes.prng_key)
True
>>> raw_key = jax.random.PRNGKey(0)
>>> jax.dtypes.issubdtype(raw_key.dtype, jax.dtypes.prng_key)
False
Type annotations for PRNG Keys#

The recommended type annotation for both old and new-style PRNG keys is jax.Array. A PRNG key is distinguished from other arrays based on its dtype, and it is not currently possible to specify dtypes of JAX arrays within a type annotation. Previously it was possible to use jax.random.KeyArray or jax.random.PRNGKeyArray as type annotations, but these have always been aliased to Any under type checking, and so jax.Array has much more specificity.

Note: jax.random.KeyArray and jax.random.PRNGKeyArray were deprecated in JAX version 0.4.16, and removed in JAX version 0.4.24.

Notes for JAX library authors#

If you maintain a JAX-based library, your users are also JAX users. Know that JAX will continue to support “raw” old-style keys in jax.random for now, so callers may expect them to remain accepted everywhere. If you prefer to require new-style typed keys in your library, then you may want to enforce them with a check along the following lines:

from jax import dtypes

def ensure_typed_key_array(key: Array) -> Array:
  if dtypes.issubdtype(key.dtype, dtypes.prng_key):
    return key
  else:
    raise TypeError("New-style typed JAX PRNG keys required")
Motivation#

Two major motivating factors for this change are customizability and safety.

Customizing PRNG implementations#

JAX currently operates with a single, globally configured PRNG algorithm. A PRNG key is a vector of unsigned 32-bit integers, which jax.random APIs consume to produce pseudorandom streams. Any higher-rank uint32 array is interpreted as an array of such key buffers, where the trailing dimension represents keys.

The drawbacks of this design became clearer as we introduced alternative PRNG implementations, which must be selected by setting a global or local configuration flag. Different PRNG implementations have different size key buffers, and different algorithms for generating random bits. Determining this behavior with a global flag is error-prone, especially when there is more than one key implementation in use process-wide.

Our new approach is to carry the implementation as part of the PRNG key type, i.e. with the element type of the key array. Using the new key API, here is an example of generating pseudorandom values under the default threefry2x32 implementation (which is implemented in pure Python and compiled with JAX), and under the non-default rbg implementation (which corresponds to a single XLA random-bit generation operation):

>>> key = jax.random.key(0, impl='threefry2x32')  # this is the default impl
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]
>>> jax.random.uniform(key, shape=(3,))
Array([0.9653214 , 0.31468165, 0.63302994], dtype=float32)

>>> key = jax.random.key(0, impl='rbg')
>>> key
Array((), dtype=key<rbg>) overlaying:
[0 0 0 0]
>>> jax.random.uniform(key, shape=(3,))
Array([0.39904642, 0.8805201 , 0.73571277], dtype=float32)
Safe PRNG key use#

PRNG keys are really only meant to support a few operations in principle, namely key derivation (e.g. splitting) and random number generation. The PRNG is designed to generate independent pseudorandom numbers, provided keys are properly split and that every key is consumed once.

Code that manipulates or consumes key data in other ways often indicates an accidental bug, and representing key arrays as raw uint32 buffers has allowed for easy misuse along these lines. Here are a few example misuses that we’ve encountered in the wild:

Key buffer indexing#

Access to the underlying integer buffers makes it easy to try and derive keys in non-standard ways, sometimes with unexpectedly bad consequences:

# Incorrect
key = random.PRNGKey(999)
new_key = random.PRNGKey(key[1])  # identical to the original key!
# Correct
key = random.PRNGKey(999)
key, new_key = random.split(key)

If this key were a new-style typed key made with random.key(999), indexing into the key buffer would error instead.

Key arithmetic#

Key arithmetic is a similarly treacherous way to derive keys from other keys. Deriving keys in a way that avoids jax.random.split() or jax.random.fold_in() by manipulating key data directly produces a batch of keys that—depending on the PRNG implementation—might then generate correlated random numbers within the batch:

# Incorrect
key = random.PRNGKey(0)
batched_keys = key + jnp.arange(10, dtype=key.dtype)[:, None]
# Correct
key = random.PRNGKey(0)
batched_keys = random.split(key, 10)

New-style typed keys created with random.key(0) address this by disallowing arithmetic operations on keys.

Inadvertent transposing of key buffers#

With “raw” old-style key arrays, it’s easy to accidentally swap batch (leading) dimensions and key buffer (trailing) dimensions. Again this possibly results in keys that produce correlated pseudorandomness. A pattern that we’ve seen over time boils down to this:

# Incorrect
keys = random.split(random.PRNGKey(0))
data = jax.vmap(random.uniform, in_axes=1)(keys)
# Correct
keys = random.split(random.PRNGKey(0))
data = jax.vmap(random.uniform, in_axes=0)(keys)

The bug here is subtle. By mapping over in_axes=1, this code makes new keys by combining a single element from each key buffer in the batch. The resulting keys are different from one another, but are effectively “derived” in a non-standard way. Again, the PRNG is not designed or tested to produce independent random streams from such a key batch.

New-style typed keys created with random.key(0) address this by hiding the buffer representation of individual keys, instead treating keys as opaque elements of a key array. Key arrays have no trailing “buffer” dimension to index, transpose, or map over.

Key reuse#

Unlike state-based PRNG APIs like numpy.random, JAX’s functional PRNG does not implicitly update a key when it has been used.

# Incorrect
key = random.PRNGKey(0)
x = random.uniform(key, (100,))
y = random.uniform(key, (100,))  # Identical values!
# Correct
key = random.PRNGKey(0)
key1, key2 = random.split(random.key(0))
x = random.uniform(key1, (100,))
y = random.uniform(key2, (100,))

We’re actively working on tools to detect and prevent unintended key reuse. This is still work in progress, but it relies on typed key arrays. Upgrading to typed keys now sets us up to introduce these safety features as we build them out.

Design of typed PRNG keys#

Typed PRNG keys are implemented as an instance of extended dtypes within JAX, of which the new PRNG dtypes are a sub-dtype.

Extended dtypes#

From the user perspective, an extended dtype dt has the following user-visible properties:

  • jax.dtypes.issubdtype(dt, jax.dtypes.extended) returns True: this is the public API that should be used to detect whether a dtype is an extended dtype.

  • It has a class-level attribute dt.type, which returns a typeclass in the hierarchy of numpy.generic. This is analogous to how np.dtype('int32').type returns numpy.int32, which is not a dtype but rather a scalar type, and a subclass of numpy.generic.

  • Unlike numpy scalar types, we do not allow instantiation of dt.type scalar objects: this is in accordance with JAX’s decision to represent scalar values as zero-dimensional arrays.

From a non-public implementation perspective, an extended dtype has the following properties:

  • Its type is a subclass of the private base class jax._src.dtypes.ExtendedDtype, the non-public base class used for extended dtypes. An instance of ExtendedDtype is analogous to an instance of np.dtype, like np.dtype('int32').

  • It has a private _rules attribute which allows the dtype to define how it behaves under particular operations. For example, jax.lax.full(shape, fill_value, dtype) will delegate to dtype._rules.full(shape, fill_value, dtype) when dtype is an extended dtype.

Why introduce extended dtypes in generality, beyond PRNGs? We reuse this same extended dtype mechanism elsewhere internally. For example, the jax._src.core.bint object, a bounded integer type used for experimental work on dynamic shapes, is another extended dtype. In recent JAX versions it satisfies the properties above (See jax/_src/core.py#L1789-L1802).

PRNG dtypes#

PRNG dtypes are defined as a particular case of extended dtypes. Specifically, this change introduces a new public scalar type class jax.dtypes.prng_key, which has the following property:

>>> jax.dtypes.issubdtype(jax.dtypes.prng_key, jax.dtypes.extended)
True

PRNG key arrays then have a dtype with the following properties:

>>> key = jax.random.key(0)
>>> jax.dtypes.issubdtype(key.dtype, jax.dtypes.extended)
True
>>> jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key)
True

And in addition to key.dtype._rules as outlined for extended dtypes in general, PRNG dtypes define key.dtype._impl, which contains the metadata that defines the PRNG implementation. The PRNG implementation is currently defined by the non-public jax._src.prng.PRNGImpl class. For now, PRNGImpl isn’t meant to be a public API, but we might revisit this soon to allow for fully custom PRNG implementations.

Progress#

Following is a non-comprehensive list of key Pull Requests implementing the above design. The main tracking issue is #9263.

  • Implement pluggable PRNG via PRNGImpl: #6899

  • Implement PRNGKeyArray, without dtype: #11952

  • Add a “custom element” dtype property to PRNGKeyArray with _rules attribute: #12167

  • Rename “custom element type” to “opaque dtype”: #12170

  • Refactor bint to use the opaque dtype infrastructure: #12707

  • Add jax.random.key to create typed keys directly: #16086

  • Add impl argument to key and PRNGKey: #16589

  • Rename “opaque dtype” to “extended dtype” & define jax.dtypes.extended: #16824

  • Introduce jax.dtypes.prng_key and unify PRNG dtype with Extended dtype: #16781

  • Add a jax_legacy_prng_key flag to support warning or erroring when using legacy (raw) PRNG keys: #17225

Design of Type Promotion Semantics for JAX#

Open in Colab Open in Kaggle

Jake VanderPlas, December 2021

One of the challenges faced in the design of any numerical computing library is the choice of how to handle operations between values of different types. This document outlines the thought process behind the promotion semantics used by JAX, summarized in JAX Type Promotion Semantics.

Goals of JAX Type Promotion#

JAX’s numerical computing API is modeled after that of NumPy, with a few enhancements including the ability to target accelerators like GPU and TPU. This makes adoption of NumPy’s type promotion system disadvantageous for JAX users: NumPy’s type promotion rules heavily favor 64-bit outputs, which is problematic for computation on accelerators. Devices such as GPUs and TPUs often pay a significant performance penalty to use 64-bit floating point types, and in some cases do not support native 64-bit floating point types at all.

A simple example of this problematic type promotion semantics can be seen in binary operations between 32-bit integers and floats:

import numpy as np
np.dtype(np.int32(1) + np.float32(1))
dtype('float64')

NumPy’s tendency to produce 64-bit values is a long-standing issue with using NumPy’s API for accelerator computations, for which there isn’t yet a good solution. For this reason, JAX has sought to re-think NumPy-style type promotion with accelerators in mind.

Stepping Back: Tables and Lattices#

Before we dive into the details, let’s take a moment to step back and think about how to think about the problem of type promotion. Consider arithmetic operations between built-in numerical types in Python, namely those of type int, float, and complex. With a few lines of code we can generate the type promotion table used by Python for addition between values of these types:

import pandas as pd
types = [int, float, complex]
name = lambda t: t.__name__
pd.DataFrame([[name(type(t1(1) + t2(1))) for t1 in types] for t2 in types],
             index=[name(t) for t in types], columns=[name(t) for t in types])
int float complex
int int float complex
float float float complex
complex complex complex complex

This table enumerates Python’s numerical type promotion behavior, but it turns out there is a complementary representation that is much more compact: a Lattice representation, where the supremum between any two nodes is the type that they promote to. The lattice representation of Python’s promotion table is much simpler:

Hide code cell source
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {'int': ['float'], 'float': ['complex']}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {'int': [0, 0], 'float': [1, 0], 'complex': [2, 0]}
fig, ax = plt.subplots(figsize=(8, 2))
nx.draw(graph, with_labels=True, node_size=4000, node_color='lightgray', pos=pos, ax=ax, arrowsize=20)
_images/818a3cf499d15c3be1d4c116db142da0418c174873f21e1ffcde679c6058f918.png

This lattice is a compact encoding of the information in the promotion table above. You can find the result of a type promotion for two inputs by tracing the graph to the first common child of the two nodes (including the nodes themselves); mathematically, this common child is known as the supremum, or least upper bound, or join of the pair on the lattice; here we will refer to this operation as the join.

Conceptually, an arrow means that implicit type promotion is allowed between the source and the destination: for example, implicit promotion from integer to float is allowed, but implicit promotion from float to integer is not.

Keep in mind that in general not every directed acyclic graph (DAG) will satisfy the properties of a lattice. A lattice requires the existence of a unique least upper bound for every pair of nodes; so, for example the following two DAGs are not lattices:

Hide code cell source
#@title
import networkx as nx
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 2, figsize=(10, 2))

lattice = {'A': ['B', 'C']}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {'A': [0, 0], 'B': [1, 0.5], 'C': [1, -0.5]}
nx.draw(graph, with_labels=True, node_size=2000, node_color='lightgray', pos=pos, ax=ax[0], arrowsize=20)
ax[0].set(xlim=[-0.5, 1.5], ylim=[-1, 1])

lattice = {'A': ['C', 'D'], 'B': ['C', 'D']}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {'A': [0, 0.5], 'B': [0, -0.5], 'C': [1, 0.5], 'D': [1, -0.5]}
nx.draw(graph, with_labels=True, node_size=2000, node_color='lightgray', pos=pos, ax=ax[1], arrowsize=20)
ax[1].set(xlim=[-0.5, 1.5], ylim=[-1, 1]);
_images/a0acbd07f9486d95c10a36c11301d528fb7e65d671d622226151c431b3e36c62.png

The left DAG is not a lattice because there exists no upper bound for nodes B and C; the right DAG fails on two counts: first, there exists no upper bound for nodes C and D, and for nodes A and B the least upper bound cannot be uniquely determined: both C and D are candidates, but they are unorderable.

Properties of a Type Promotion Lattice#

Specifying type promotions in terms of a lattice ensures a number of useful properties. Denoting the join on the lattice with the \(\vee\) operator, we have:

Existence: A lattice by definition requires that a unique lattice join exists for every pair of elements: \(\forall (a, b): \exists !(a \vee b)\)

Commutativity: The lattice join is commutative: \(\forall (a, b): a\vee b = b \vee a\).

Associativity: The lattice join is associative: \(\forall (a, b, c): a \vee (b \vee c) = (a \vee b) \vee c\).

On the other hand, these properties imply restrictions on the type promotion systems they can represent; in particular not every type promotion table can be represented by a lattice. A ready example of this is NumPy’s full type promotion table; this can be shown quickly by counterexample: here are three scalar types whose promotion behavior in NumPy is non-associative:

import numpy as np
a, b, c = np.int8(1), np.uint8(1), np.float16(1)
print(np.dtype((a + b) + c))
print(np.dtype(a + (b + c)))
float32
float16

Such a result may come as a surprise to users: we generally expect mathematical expressions to map to mathematical concepts, so, for example, a + b + c should be equivalent to c + b + a; x * (y + z) should be equivalent to x * y + x * z. If type promotion is non-associative or non-commutative, these properties no longer apply.

Further, a lattice-based type promotion system is simpler to conceptualize and understand when compared to a table-based system. For example, JAX recognizes 18 distinct types: a promotion lattice consisting of 18 nodes and sparse, well-motivated connections between them is far easier to hold in one’s mind than a table of 324 entries.

For this reason, we opt to use a lattice-based type promotion system for JAX.

Type Promotion within Categories#

Numerical computing libraries generally provide more than just int, float, and complex; within each of these categories there are a variety of possible precisions, denoted by the number of bits used in the numerical representation. The categories we will consider here are:

  • unsigned integers which include uint8, uint16, uint32 & uint64 (we’ll use u8, u16, u32, u64 for short)

  • signed integers which include int8, int16, int32 & int64 (we’ll use i8, i16, i32, i64 for short)

  • floating point, which include float16, float32 & float64 (we’ll use f16, f32, f64 for short)

  • complex floating point, which include complex64 & complex128 (we’ll use c64, c128 for short)

Numpy’s type promotion semantics within each of these four categories is relatively straightforward: the ordered hierarchy of types translates directly to four separate lattices representing in-category type promotion rules:

Hide code cell source
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'u8': ['u16'], 'u16': ['u32'], 'u32': ['u64'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
  'f16': ['f32'], 'f32': ['f64'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'u8': [0, 0], 'u16': [1, 0], 'u32': [2, 0], 'u64': [3, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [1, 2], 'f32': [2, 2], 'f64': [3, 2],
  'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 4))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
_images/2d8495bcb006c34b42eeb4f3e0c6530fdef0bd7364c56184993925f0cf157abc.png

In terms of promotion of values to 64-bit that JAX seeks to avoid, these same-kind promotion semantics within each type category are unproblematic: the only way to produce a 64-bit output is to have a 64-bit input.

Enter Python Scalars#

Let’s now think about where Python scalars fit into the mix.

In NumPy, promotion behavior differs depending on whether the inputs are arrays or scalars. For example, when operating on two scalars, normal promotion rules apply:

x = np.int8(0)  # int8 scalar
y = 1  # Python int = int64 scalar
(x + y).dtype
dtype('int64')

Here the Python value 1 is treated as an int64, and straightforward within-category rules lead to an int64 result.

In operations between Python scalars and NumPy arrays, however, scalars defer to the dtype of the array. For example:

x = np.zeros(1, dtype='int8')  # int8 array
y = 1  # Python int = int64 scalar
(x + y).dtype
dtype('int8')

Here the bit width of the int64 scalar is ignored, deferring to the bit width of the array.

There is another detail here: when NumPy type promotion involves a scalar, the output dtype is value-dependent: if the Python scalar is too large for the given dtype, it is promoted to a compatible type:

x = np.zeros(1, dtype='int8')  # int8 array
y = 1000  # int64 scalar
(x + y).dtype
dtype('int16')

For the purposes of JAX, value-dependent promotion is a non-starter because of the nature of JIT compilation and other transformations, which act on abstract representations of data without reference to their value.

Ignoring value-dependent effects, the signed integer branch of NumPy’s type promotion can be represented in the following lattice, where we’ll use * to mark scalar dtypes:

Hide code cell source
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i8*': ['i16*'], 'i16*': ['i32*'], 'i32*': ['i64*'], 'i64*': ['i8'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i8*': [0, 1], 'i16*': [2, 1], 'i32*': [4, 1], 'i64*': [6, 1],
  'i8': [9, 1], 'i16': [11, 1], 'i32': [13, 1], 'i64': [15, 1],
}
fig, ax = plt.subplots(figsize=(12, 4))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
ax.text(3, 1.6, "Scalar Types", ha='center', fontsize=14)
ax.text(12, 1.6, "Array Types", ha='center', fontsize=14)
ax.set_ylim(-1, 3);
_images/7e8c3295e403209560d8e142c5c830d79456a4e6d207dd1a7e4d15b55c56006b.png

A similar pattern holds within the uint, float, and complex lattices.

For the sake of simplicity, let’s collapse each category of scalar types into a single node, denoted by u*, i*, f*, and c* respectively. Our set of in-category lattices can now be represented like this:

Hide code cell source
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'u*': ['u8'], 'u8': ['u16'], 'u16': ['u32'], 'u32': ['u64'],
  'i*': ['i8'], 'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
  'f*': ['f16'], 'f16': ['f32'], 'f32': ['f64'],
  'c*': ['c64'], 'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'u*': [0, 0], 'u8': [3, 0], 'u16': [5, 0], 'u32': [7, 0], 'u64': [9, 0],
  'i*': [0, 1], 'i8': [3, 1], 'i16': [5, 1], 'i32': [7, 1], 'i64': [9, 1],
  'f*': [0, 2], 'f16': [5, 2], 'f32': [7, 2], 'f64': [9, 2],
  'c*': [0, 3], 'c64': [7, 3], 'c128': [9, 3],
}
fig, ax = plt.subplots(figsize=(6, 4))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
_images/0fbe0c20cd350821e64f3742aa7864ec729565572b136950042095881672fdb9.png

In some senses, putting scalars at the left is a strange choice: the scalar types may contain values of any width, but when interacting with an array of a given type, the promotion result defers to the array type. The benefit of this is that when you perform an operation like x + 2 for an array x, the type of x will carry to the result no matter its width:

for dtype in [np.int8, np.int16, np.int32, np.int64]:
  x = np.arange(10, dtype=dtype)
  assert (x + 2).dtype == dtype

This behavior gives motivation to our * notation for scalar values: the * is reminiscent of a wildcard that can take on any desired value.

The benefit of these semantics is that you can readily express sequences of operations with clean Python code, without having to explicitly cast scalars to the appropriate type. Imagine if rather than writing this:

3 * (x + 1) ** 2

you had to write this:

np.int32(3) * (x + np.int32(1)) ** np.int32(2)

Although it is explicit, numerical code would become tedious to read or write. With the scalar promotion semantics described above, given an array x of type int32, the types in the second statement are implicit within the first.

Combining Lattices#

Recall that we began our discussion by introducing the lattice representing type promotion within Python: int -> float -> complex. Let’s rewrite this as i* -> f* -> c*, and let’s further allow i* to subsume u* (after all, there is no unsigned integer scalar type in Python).

Putting these all together, we get the following partial lattice representing type promotion between Python scalars and numpy arrays:

Hide code cell source
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16'], 'u16': ['u32'], 'u32': ['u64'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
  'f16': ['f32'], 'f32': ['f64'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],
  'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
_images/796586be87180b0de3171d39763f2d33a80a641b72d82c00f0c0e352f754f201.png

Notice that this is not (yet) a true lattice: there are many pairs of nodes for which a join does not exist. However, we can think of this as a partial lattice, in which some pairs of nodes do not have a defined promotion behavior, and the defined portion of this partial lattice does correctly describe NumPy’s array promotion behavior (leaving aside value-dependent semantics mentioned above).

This sets up a nice framework by which we can think about filling-out these undefined promotion rules, by adding connections on this graph. But which connections to add? Broadly speaking, we want any additional connections to satisfy a few properties:

  1. Promotion should satisfy the commutative and associative properties: in other words, the graph should remain a (partial) lattice.

  2. Promotion should never allow for dropping entire components of data: for example, we should never promote complex to float, as it would discard any imaginary parts.

  3. Promotion should never lead to an unhandled overflow. For example, the maximum possible uint32 is twice as large as the maximum possible int32, so we should not implicitly promote uint32 to int32.

  4. Wherever possible, promotion should avoid loss of precision. For example, an int64 value may have 64 bits of mantissa, so promoting int64 to float64 represents a possible loss of precision. However, the maximum representable float64 is larger than the maximum representable int64, so in this case criterion #3 is still satisfied.

  5. Wherever possible, binary promotion should avoid resulting in types that are wider than the inputs. This is to ensure that JAX’s implicit promotions remain friendly to accelerator-based workflows, in which users often want to restrict types to 32-bit (or in some cases 16-bit) values.

Each new connection on the lattice introduces some level of convenience to the user (a new set of types that can interact without explicit casting), but the convenience may become too costly if any of the above criteria are violated. Developing a full promotion lattice involves striking a balance between this convenience and this cost.

Mixed Promotion: Float and Complex#

Let’s begin with what is perhaps the easiest case, that of promotion between float and complex values.

Complex numbers are made up of pairs of floating point numbers, and so we have a natural path of promotion between them: cast float to complex while maintaining the width of the real part. In terms of our partial lattice representation, it would look like this:

Hide code cell source
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16'], 'u16': ['u32'], 'u32': ['u64'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
  'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],
  'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
_images/bf87909b2344aed80590d1c6d91585a02b25898ac217526cb49948d91205318f.png

This turns out to represent exactly the semantics used by Numpy in mixed float/complex type promotion.

Mixed Promotion: Signed & Unsigned Integers#

For the next case, let’s consider something a bit more difficult: promotion between signed and unsigned integers. For example, when promoting uint8 to a signed integer, how many bits do we need?

At first glance, you might think it natural to promote uint8 to int8; but the largest uint8 numbers are not representable in int8. For this reason, it makes more sense to promote unsigned integers to integers with twice the number of bits; this promotion behavior can be represented by adding the following connections to the promotion lattice:

Hide code cell source
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
  'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],
  'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
_images/3be7e17889458ac823bb5dacf31525c0d96578c6854962f45dcc60ec987a30bd.png

Again, the connections added here are precisely the promotion semantics implemented by Numpy for mixed-integer promotion.

How to handle uint64?#

The approach to mixed signed/unsigned integer promotion leaves out one type: uint64. Following the pattern above, the output of a mixed-integer operation involving uint64 should result in int128, but this is not a standard available dtype.

Numpy’s choice here is to promote to float64:

(np.uint64(1) + np.int64(1)).dtype
dtype('float64')

However, this may be a surprising convention: it’s the only case in which promotion of integer types does not result in an integer. For now, we will leave uint64 promotion undefined, and return to it later.

Mixed Promotion: Integer and Floating#

When promoting integers to floating point, we might start with the same thought process as mixed promotion between signed and unsigned integers. A 16-bit signed or unsigned integer cannot be represented at full precision by a 16-bit float, which has only 10 bits of mantissa. Therefore, it might make sense to promote integers to floats represented by twice the number of bits:

Hide code cell source
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16', 'i16', 'f16'], 'u16': ['u32', 'i32', 'f32'], 'u32': ['u64', 'i64', 'f64'],
  'i8': ['i16', 'f16'], 'i16': ['i32', 'f32'], 'i32': ['i64', 'f64'],
  'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],
  'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
_images/8b3247e8189fbfad46a7e5583b636866fc45576e07c9bfd904457926306299d1.png

This is effectively what Numpy type promotion does, but in doing so it breaks the lattice property of the graph: for example, the pair {i8, u8} no longer has a unique least upper bound: the possibilities are i16 and f16, which are unorderable on the graph. This turns out to be the source of NumPy’s non-associative type promotion highlighted above.

Can we come up with a modification of NumPy’s promotion rules, such that it will satisfy the lattice property, while also giving sensible results for mixed type promotion? There are a few approaches we could take here.

Option 0: Leave integer/floating mixed precision undefined#

To make behavior utterly predictable (at some cost to user convenience), a defensible choice would be to leave as undefined any mixed integer/float promotion beyond Python scalars, stopping with the partial lattice from the previous section. The downside would be the requirement for users to explicitly type-cast when operating between integer and floating-point quantities.

Option 1: Avoiding All Precision Loss#

If our focus is on avoiding precision loss at all costs, we can restore the lattice property by promoting unsigned integers to float via their existing signed integer paths:

Hide code cell source
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],
  'i8': ['i16', 'f16'], 'i16': ['i32', 'f32'], 'i32': ['i64', 'f64'],
  'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],
  'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
_images/1eda89d008a8c6dadf926229bf9f2245722006c5bc1c42961c555a2595c95117.png

A disadvantage of this approach is that it still leaves int64 and uint64 promotion undefined, because there is no standard floating point type with enough bits of mantissa to represent their full range of values. We could relax the precision constraint and complete the lattice by drawing connections from i64->f64 and u64->f64, but those links would run counter to the motivation for this promotion scheme.

A second disadvantage is that this lattice makes it difficult to find a sensible place to insert bfloat16 (see below) while maintaining the lattice property.

A third disadvantage of this approach, more important for JAX’s accelerator backends, is that some operations result in types that are much wider than necessary; for example mixed operations between uint16 and float16 would promote all the way to float64, which is not ideal.

Option 2: Avoid most wider-than-necessary promotions#

To address the unnecessary promotions to wider types, we could accept the possibility of some precision loss in integer/float promotion, promoting signed integers to floats of the same width:

Hide code cell source
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],
  'i8': ['i16'], 'i16': ['f16', 'i32'], 'i32': ['f32', 'i64'], 'i64': ['f64'],
  'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [1.5, 2], 'f32': [2.5, 2], 'f64': [3.5, 2],
  'c64': [3, 3], 'c128': [4, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
_images/f41cee38a476bf636be901e7f64a5dc3687002f9d12532ab706b9077d602b175.png

While this does allow for precision-losing promotions between integers and floats, these promotions will not mis-represent the magnitude of the result: though the floating point mantissa is not wide enough to represent all values, the exponent is wide enough to approximate them.

This approach also allows a natural promotion path from int64 to float64, though uint64 remains unpromotable in this scheme. That said, a connection from u64 to f64 could be justified more readily here than before.

This promotion scheme still results in some wider than necessary promotion paths; for example operations between float32 and uint32 result in float64. Additionally, this lattice makes it difficult to find a sensible place to insert bfloat16 (see below) while maintaining the lattice property.

Option 3: Avoid all wider-than-necessary promotions#

We can avoid all non-ideal 64-bit promotions if we’re willing to fundamentally change our thinking around integer and float promotions. Just as scalars always defer to the widths of array types, we can make integers always defer to the width of float types:

Hide code cell source
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'], 'i64': ['f*'],
  'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [1.5, 2], 'f32': [2.5, 2], 'f64': [3.5, 2],
  'c64': [3, 3], 'c128': [4, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
_images/d3f5e5be4354238a60698cb4f228d4e1f75a665577343c36b2c1ade1207783a0.png

This involves a small sleight of hand: previously we had used f* to refer to a scalar type. In this lattice, f* might be applied to the array output of a mixed computation. Instead of thinking of f* as a scalar, we could think of it as a special kind of float value with distinct promotion rules: in JAX we refer to this as a weak float; see below.

The advantage of this approach is that, outside unsigned ints, it avoids all wider-than-necessary promotions: you can never get an f64 output without a 64-bit input, and you can never get an f32 output without a 32-bit input: this results in convenient semantics for working on accelerators while avoiding inadvertent 64-bit values.

This feature of giving primacy to floating point types resembles the type promotion behavior of PyTorch. This lattice also happens to generate a promotion table that very closely resembles JAX’s original ad hoc type promotion scheme, which was not based on a lattice but had the property of giving primacy to floating point types.

This lattice additionally offers a natural location to insert bfloat16, without the need to impose an ordering between bf16 and f16:

Hide code cell source
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['u8', 'i8'], 'f*': ['c*', 'f16', 'bf16'], 'c*': ['c64'],
  'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'], 'i64': ['f*'],
  'f16': ['f32'], 'bf16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [1.8, 1.7], 'bf16': [1.8, 2.3], 'f32': [3.0, 2], 'f64': [4.0, 2],
  'c64': [3.5, 3], 'c128': [4.5, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
_images/aa73688b580b02776fce218d6efe58792ae3b0976160a4b0c130b797780578af.png

This is important because f16 and bf16 are not comparable because they utilize their bits differently: bf16 represents a larger range at lower precision, while f16 represents a smaller range at higher precision.

However, these advantages comes with a few tradeoffs:

  • mixed float/integer promotion is very prone to precision loss: for example, int64 (with a maximum value of \(9.2 \times 10^{18}\)) can be promoted to float16 (with a maximum value of \(6.5 \times 10^4\)), meaning most representable values will become inf.

  • as mentioned above, f* can no longer be thought of as a “scalar type”, but as a different flavor of float64. In JAX’s parlance, this is referred to as a weak type, in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values.

Note that also, this approach still leaves the uint64 promotion question unanswered, although it is perhaps reasonable to close the lattice by connecting u64 to f*.

Type Promotion in JAX#

In designing the type promotion semantics of JAX, we kept in mind many of these ideas, and leaned heavily on a few things:

  1. We chose to constrain JAX’s type promotion semantics to graphs that satisfy the lattice property: this is to ensure associativity and commutativity, but also to allow the semantics to be compactly described in a DAG, rather than requiring a large table.

  2. We leaned toward semantics that avoid inadvertent promotion to wider types, particularly when it comes to float values, in order to benefit computation on accelerators.

  3. We were fine accepting potential loss of precision (but not loss of magnitude) in mixed type promotion if it were required to maintain (1) and (2)

With this in mind, JAX has adopted Option 3. Or rather, a slightly modified version of Option 3 that draws the connection between u64 and f*, in order to create a true lattice. Rearranging the nodes for clarity, JAX’s type promotion lattice then looks like this:

Hide code cell source
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['u8', 'i8'], 'f*': ['c*', 'f16', 'bf16'], 'c*': ['c64'],
  'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'], 'u64': ['f*'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'], 'i64': ['f*'],
  'f16': ['f32'], 'bf16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [4.5, 0.5], 'c*': [5, 1.5],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [5.75, 0.8], 'bf16': [5.75, 0.2], 'f32': [7, 0.5], 'f64': [8, 0.5],
  'c64': [7.5, 1.5], 'c128': [8.5, 1.5],
}
fig, ax = plt.subplots(figsize=(10, 4))
ax.set_ylim(-0.5, 2)
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
# ax.patches[12].set_linestyle((0, (2, 4)))
_images/d261add493a579484d9772634ce146f1240af3966d0845839c354417a3de2e53.png

The behavior resulting from this choice is summarized in JAX Type Promotion Semantics. Notably, aside from the inclusion of larger unsigned types (u16, u32, u64) and some details about the behavior of scalar/weak types (i*, f*, c*), this type promotion scheme turns out to be very close to that chosen by PyTorch.

For those interested, the appendix below prints the full promotion tables used by NumPy, Tensorflow, PyTorch, and JAX.

Appendix: Example Type Promotion Tables#

The following are some examples of implicit type promotion tables implemented by various Python array computing libraries.

NumPy Type Promotion#

Note that NumPy does not include the bfloat16 dtype, and that the table below ignores value-dependent effects.

Hide code cell source
# @title

import numpy as np
import pandas as pd
from IPython import display

np_dtypes = {
  'b': np.bool_,
  'u8': np.uint8, 'u16': np.uint16, 'u32': np.uint32, 'u64': np.uint64,
  'i8': np.int8, 'i16': np.int16, 'i32': np.int32, 'i64': np.int64,
  'bf16': 'invalid', 'f16': np.float16, 'f32': np.float32, 'f64': np.float64,
  'c64': np.complex64, 'c128': np.complex128,
  'i*': int, 'f*': float, 'c*': complex}

np_dtype_to_code = {val: key for key, val in np_dtypes.items()}

def make_np_zero(dtype):
  if dtype in {int, float, complex}:
    return dtype(0)
  else:
    return np.zeros(1, dtype=dtype)

def np_result_code(dtype1, dtype2):
  try:
    out = np.add(make_np_zero(dtype1), make_np_zero(dtype2))
  except TypeError:
    return '-'
  else:
    if type(out) in {int, float, complex}:
      return np_dtype_to_code[type(out)]
    else:
      return np_dtype_to_code[out.dtype.type]


grid = [[np_result_code(dtype1, dtype2)
         for dtype2 in np_dtypes.values()]
        for dtype1 in np_dtypes.values()]
table = pd.DataFrame(grid, index=np_dtypes.keys(), columns=np_dtypes.keys())
display.HTML(table.to_html())
b u8 u16 u32 u64 i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i* f* c*
b b u8 u16 u32 u64 i8 i16 i32 i64 - f16 f32 f64 c64 c128 i64 f64 c128
u8 u8 u8 u16 u32 u64 i16 i16 i32 i64 - f16 f32 f64 c64 c128 u8 f64 c128
u16 u16 u16 u16 u32 u64 i32 i32 i32 i64 - f32 f32 f64 c64 c128 u16 f64 c128
u32 u32 u32 u32 u32 u64 i64 i64 i64 i64 - f64 f64 f64 c128 c128 u32 f64 c128
u64 u64 u64 u64 u64 u64 f64 f64 f64 f64 - f64 f64 f64 c128 c128 u64 f64 c128
i8 i8 i16 i32 i64 f64 i8 i16 i32 i64 - f16 f32 f64 c64 c128 i8 f64 c128
i16 i16 i16 i32 i64 f64 i16 i16 i32 i64 - f32 f32 f64 c64 c128 i16 f64 c128
i32 i32 i32 i32 i64 f64 i32 i32 i32 i64 - f64 f64 f64 c128 c128 i32 f64 c128
i64 i64 i64 i64 i64 f64 i64 i64 i64 i64 - f64 f64 f64 c128 c128 i64 f64 c128
bf16 - - - - - - - - - - - - - - - - - -
f16 f16 f16 f32 f64 f64 f16 f32 f64 f64 - f16 f32 f64 c64 c128 f16 f16 c64
f32 f32 f32 f32 f64 f64 f32 f32 f64 f64 - f32 f32 f64 c64 c128 f32 f32 c64
f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 - f64 f64 f64 c128 c128 f64 f64 c128
c64 c64 c64 c64 c128 c128 c64 c64 c128 c128 - c64 c64 c128 c64 c128 c64 c64 c64
c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 - c128 c128 c128 c128 c128 c128 c128 c128
i* i64 u8 u16 u32 u64 i8 i16 i32 i64 - f16 f32 f64 c64 c128 i64 f64 c128
f* f64 f64 f64 f64 f64 f64 f64 f64 f64 - f16 f32 f64 c64 c128 f64 f64 c128
c* c128 c128 c128 c128 c128 c128 c128 c128 c128 - c64 c64 c128 c64 c128 c128 c128 c128
Tensorflow Type Promotion#

Tensorflow avoids defining implicit type promotion, except for Python scalars in limited cases. The table is asymmetric because in tf.add(x, y), the type of y must be coercible to the type of x.

Hide code cell source
# @title

import tensorflow as tf
import pandas as pd
from IPython import display

tf_dtypes = {
  'b': tf.bool,
  'u8': tf.uint8, 'u16': tf.uint16, 'u32': tf.uint32, 'u64': tf.uint64,
  'i8': tf.int8, 'i16': tf.int16, 'i32': tf.int32, 'i64': tf.int64,
  'bf16': tf.bfloat16, 'f16': tf.float16, 'f32': tf.float32, 'f64': tf.float64,
  'c64': tf.complex64, 'c128': tf.complex128,
  'i*': int, 'f*': float, 'c*': complex}

tf_dtype_to_code = {val: key for key, val in tf_dtypes.items()}

def make_tf_zero(dtype):
  if dtype in {int, float, complex}:
    return dtype(0)
  else:
    return tf.zeros(1, dtype=dtype)

def result_code(dtype1, dtype2):
  try:
    out = tf.add(make_tf_zero(dtype1), make_tf_zero(dtype2))
  except (TypeError, tf.errors.InvalidArgumentError):
    return '-'
  else:
    if type(out) in {int, float, complex}:
      return tf_dtype_to_code[type(out)]
    else:
      return tf_dtype_to_code[out.dtype]


grid = [[result_code(dtype1, dtype2)
         for dtype2 in tf_dtypes.values()]
        for dtype1 in tf_dtypes.values()]
table = pd.DataFrame(grid, index=tf_dtypes.keys(), columns=tf_dtypes.keys())
display.HTML(table.to_html())
b u8 u16 u32 u64 i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i* f* c*
b - - - - - - - - - - - - - - - - - -
u8 - u8 - - - - - - - - - - - - - u8 - -
u16 - - u16 - - - - - - - - - - - - u16 - -
u32 - - - u32 - - - - - - - - - - - u32 - -
u64 - - - - u64 - - - - - - - - - - u64 - -
i8 - - - - - i8 - - - - - - - - - i8 - -
i16 - - - - - - i16 - - - - - - - - i16 - -
i32 - - - - - - - i32 - - - - - - - i32 - -
i64 - - - - - - - - i64 - - - - - - i64 - -
bf16 - - - - - - - - - bf16 - - - - - bf16 bf16 -
f16 - - - - - - - - - - f16 - - - - f16 f16 -
f32 - - - - - - - - - - - f32 - - - f32 f32 -
f64 - - - - - - - - - - - - f64 - - f64 f64 -
c64 - - - - - - - - - - - - - c64 - c64 c64 c64
c128 - - - - - - - - - - - - - - c128 c128 c128 c128
i* - - - - - - - i32 - - - - - - - i32 - -
f* - - - - - - - - - - - f32 - - - f32 f32 -
c* - - - - - - - - - - - - - - c128 c128 c128 c128
PyTorch Type Promotion#

Notice that torch does not include unsigned integer types larger than uint8. Aside from this and some details about promotion with scalar/weak types, the table is close to that used by jax.numpy.

Hide code cell source
# @title
import torch
import pandas as pd
from IPython import display

torch_dtypes = {
  'b': torch.bool,
  'u8': torch.uint8, 'u16': 'invalid', 'u32': 'invalid', 'u64': 'invalid',
  'i8': torch.int8, 'i16': torch.int16, 'i32': torch.int32, 'i64': torch.int64,
  'bf16': torch.bfloat16, 'f16': torch.float16, 'f32': torch.float32, 'f64': torch.float64,
  'c64': torch.complex64, 'c128': torch.complex128,
  'i*': int, 'f*': float, 'c*': complex}

torch_dtype_to_code = {val: key for key, val in torch_dtypes.items()}

def make_torch_zero(dtype):
  if dtype in {int, float, complex}:
    return dtype(0)
  else:
    return torch.zeros(1, dtype=dtype)

def torch_result_code(dtype1, dtype2):
  try:
    out = torch.add(make_torch_zero(dtype1), make_torch_zero(dtype2))
  except TypeError:
    return '-'
  else:
    if type(out) in {int, float, complex}:
      return torch_dtype_to_code[type(out)]
    else:
      return torch_dtype_to_code[out.dtype]


grid = [[torch_result_code(dtype1, dtype2)
         for dtype2 in torch_dtypes.values()]
        for dtype1 in torch_dtypes.values()]
table = pd.DataFrame(grid, index=torch_dtypes.keys(), columns=torch_dtypes.keys())
display.HTML(table.to_html())
b u8 u16 u32 u64 i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i* f* c*
b b u8 - - - i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i64 f32 c64
u8 u8 u8 - - - i16 i16 i32 i64 bf16 f16 f32 f64 c64 c128 u8 f32 c64
u16 - - - - - - - - - - - - - - - - - -
u32 - - - - - - - - - - - - - - - - - -
u64 - - - - - - - - - - - - - - - - - -
i8 i8 i16 - - - i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i8 f32 c64
i16 i16 i16 - - - i16 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i16 f32 c64
i32 i32 i32 - - - i32 i32 i32 i64 bf16 f16 f32 f64 c64 c128 i32 f32 c64
i64 i64 i64 - - - i64 i64 i64 i64 bf16 f16 f32 f64 c64 c128 i64 f32 c64
bf16 bf16 bf16 - - - bf16 bf16 bf16 bf16 bf16 f32 f32 f64 c64 c128 bf16 bf16 c64
f16 f16 f16 - - - f16 f16 f16 f16 f32 f16 f32 f64 c64 c128 f16 f16 c64
f32 f32 f32 - - - f32 f32 f32 f32 f32 f32 f32 f64 c64 c128 f32 f32 c64
f64 f64 f64 - - - f64 f64 f64 f64 f64 f64 f64 f64 c128 c128 f64 f64 c128
c64 c64 c64 - - - c64 c64 c64 c64 c64 c64 c64 c128 c64 c128 c64 c64 c64
c128 c128 c128 - - - c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128
i* i64 u8 - - - i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i64 f32 c64
f* f32 f32 - - - f32 f32 f32 f32 bf16 f16 f32 f64 c64 c128 f32 f64 c64
c* c64 c64 - - - c64 c64 c64 c64 c64 c64 c64 c128 c64 c128 c64 c64 c128
JAX Type Promotion: jax.numpy#

jax.numpy follows type promotion rules laid out at https://jax.readthedocs.io/en/latest/type_promotion.html. Here we use i*, f*, c* to indicate both Python scalars and weakly-typed arrays.

Hide code cell source
# @title
from jax import dtypes
import jax
import jax.numpy as jnp
import pandas as pd
from IPython import display
jax.config.update('jax_enable_x64', True)

jnp_dtypes = {
  'b': jnp.bool_.dtype,
  'u8': jnp.uint8.dtype, 'u16': jnp.uint16.dtype, 'u32': jnp.uint32.dtype, 'u64': jnp.uint64.dtype,
  'i8': jnp.int8.dtype, 'i16': jnp.int16.dtype, 'i32': jnp.int32.dtype, 'i64': jnp.int64.dtype,
  'bf16': jnp.bfloat16.dtype, 'f16': jnp.float16.dtype, 'f32': jnp.float32.dtype, 'f64': jnp.float64.dtype,
  'c64': jnp.complex64.dtype, 'c128': jnp.complex128.dtype,
  'i*': int, 'f*': float, 'c*': complex}


jnp_dtype_to_code = {val: key for key, val in jnp_dtypes.items()}

def make_jnp_zero(dtype):
  if dtype in {int, float, complex}:
    return dtype(0)
  else:
    return jnp.zeros((), dtype=dtype)

def jnp_result_code(dtype1, dtype2):
  try:
    out = jnp.add(make_jnp_zero(dtype1), make_jnp_zero(dtype2))
  except TypeError:
    return '-'
  else:
    if hasattr(out, 'aval') and out.aval.weak_type:
      return out.dtype.kind + '*'
    elif type(out) in {int, float, complex}:
      return jnp_dtype_to_code[type(out)]
    else:
      return jnp_dtype_to_code[out.dtype]

grid = [[jnp_result_code(dtype1, dtype2)
         for dtype2 in jnp_dtypes.values()]
        for dtype1 in jnp_dtypes.values()]
table = pd.DataFrame(grid, index=jnp_dtypes.keys(), columns=jnp_dtypes.keys())
display.HTML(table.to_html())
b u8 u16 u32 u64 i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i* f* c*
b b u8 u16 u32 u64 i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i* f* c*
u8 u8 u8 u16 u32 u64 i16 i16 i32 i64 bf16 f16 f32 f64 c64 c128 u8 f* c*
u16 u16 u16 u16 u32 u64 i32 i32 i32 i64 bf16 f16 f32 f64 c64 c128 u16 f* c*
u32 u32 u32 u32 u32 u64 i64 i64 i64 i64 bf16 f16 f32 f64 c64 c128 u32 f* c*
u64 u64 u64 u64 u64 u64 f* f* f* f* bf16 f16 f32 f64 c64 c128 u64 f* c*
i8 i8 i16 i32 i64 f* i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i8 f* c*
i16 i16 i16 i32 i64 f* i16 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i16 f* c*
i32 i32 i32 i32 i64 f* i32 i32 i32 i64 bf16 f16 f32 f64 c64 c128 i32 f* c*
i64 i64 i64 i64 i64 f* i64 i64 i64 i64 bf16 f16 f32 f64 c64 c128 i64 f* c*
bf16 bf16 bf16 bf16 bf16 bf16 bf16 bf16 bf16 bf16 bf16 f32 f32 f64 c64 c128 bf16 bf16 c64
f16 f16 f16 f16 f16 f16 f16 f16 f16 f16 f32 f16 f32 f64 c64 c128 f16 f16 c64
f32 f32 f32 f32 f32 f32 f32 f32 f32 f32 f32 f32 f32 f64 c64 c128 f32 f32 c64
f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 c128 c128 f64 f64 c128
c64 c64 c64 c64 c64 c64 c64 c64 c64 c64 c64 c64 c64 c128 c64 c128 c64 c64 c64
c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128
i* i* u8 u16 u32 u64 i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i* f* c*
f* f* f* f* f* f* f* f* f* f* bf16 f16 f32 f64 c64 c128 f* f* c*
c* c* c* c* c* c* c* c* c* c* c64 c64 c64 c128 c64 c128 c* c* c*
JAX Type Promotion: jax.lax#

jax.lax is lower-level, and does not do any implicit promotion. Here we use i*, f*, c* to indicate both Python scalars and weakly-typed arrays.

Hide code cell source
# @title
from jax import dtypes
import jax
import jax.numpy as jnp
import pandas as pd
from IPython import display
jax.config.update('jax_enable_x64', True)

jnp_dtypes = {
  'b': jnp.bool_.dtype,
  'u8': jnp.uint8.dtype, 'u16': jnp.uint16.dtype, 'u32': jnp.uint32.dtype, 'u64': jnp.uint64.dtype,
  'i8': jnp.int8.dtype, 'i16': jnp.int16.dtype, 'i32': jnp.int32.dtype, 'i64': jnp.int64.dtype,
  'bf16': jnp.bfloat16.dtype, 'f16': jnp.float16.dtype, 'f32': jnp.float32.dtype, 'f64': jnp.float64.dtype,
  'c64': jnp.complex64.dtype, 'c128': jnp.complex128.dtype,
  'i*': int, 'f*': float, 'c*': complex}


jnp_dtype_to_code = {val: key for key, val in jnp_dtypes.items()}

def make_jnp_zero(dtype):
  if dtype in {int, float, complex}:
    return dtype(0)
  else:
    return jnp.zeros((), dtype=dtype)

def jnp_result_code(dtype1, dtype2):
  try:
    out = jax.lax.add(make_jnp_zero(dtype1), make_jnp_zero(dtype2))
  except TypeError:
    return '-'
  else:
    if hasattr(out, 'aval') and out.aval.weak_type:
      return out.dtype.kind + '*'
    elif type(out) in {int, float, complex}:
      return jnp_dtype_to_code[type(out)]
    else:
      return jnp_dtype_to_code[out.dtype]

grid = [[jnp_result_code(dtype1, dtype2)
         for dtype2 in jnp_dtypes.values()]
        for dtype1 in jnp_dtypes.values()]
table = pd.DataFrame(grid, index=jnp_dtypes.keys(), columns=jnp_dtypes.keys())
display.HTML(table.to_html())
b u8 u16 u32 u64 i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i* f* c*
b - - - - - - - - - - - - - - - - - -
u8 - u8 - - - - - - - - - - - - - - - -
u16 - - u16 - - - - - - - - - - - - - - -
u32 - - - u32 - - - - - - - - - - - - - -
u64 - - - - u64 - - - - - - - - - - - - -
i8 - - - - - i8 - - - - - - - - - - - -
i16 - - - - - - i16 - - - - - - - - - - -
i32 - - - - - - - i32 - - - - - - - - - -
i64 - - - - - - - - i64 - - - - - - i64 - -
bf16 - - - - - - - - - bf16 - - - - - - - -
f16 - - - - - - - - - - f16 - - - - - - -
f32 - - - - - - - - - - - f32 - - - - - -
f64 - - - - - - - - - - - - f64 - - - f64 -
c64 - - - - - - - - - - - - - c64 - - - -
c128 - - - - - - - - - - - - - - c128 - - c128
i* - - - - - - - - i64 - - - - - - i* - -
f* - - - - - - - - - - - - f64 - - - f* -
c* - - - - - - - - - - - - - - c128 - - c*
Jax and Jaxlib versioning#
Why are jax and jaxlib separate packages?#

We publish JAX as two separate Python wheels, namely jax, which is a pure Python wheel, and jaxlib, which is a mostly-C++ wheel that contains libraries such as:

  • XLA,

  • pieces of LLVM used by XLA,

  • MLIR infrastructure, such as the StableHLO Python bindings.

  • JAX-specific C++ libraries for fast JIT and PyTree manipulation.

We distribute separate jax and jaxlib packages because it makes it easy to work on the Python parts of JAX without having to build C++ code or even having a C++ toolchain installed. jaxlib is a large library that is not easy for many users to build, but most changes to JAX only touch Python code. By allowing the Python pieces to be updated independently of the C++ pieces, we improve the development velocity for Python changes.

In addition jaxlib is not cheap to build, but we want to be able to iterate on and run the JAX tests in environments without a lot of CPU, for example in Github Actions or on a laptop. Many of our CI builds simply use a prebuilt jaxlib, rather than needing to rebuild the C++ pieces of JAX on each PR.

As we will see, distributing jax and jaxlib separately comes with a cost, in that it requires that changes to jaxlib maintain a backward compatible API. However, we believe that on balance it is preferable to make Python changes easy, even if at the cost of making C++ changes slightly harder.

How are jax and jaxlib versioned?#

Summary: jax and jaxlib share the same version number in the JAX source tree, but are released as separate Python packages. When installed, the jax package version must be greater than or equal to jaxlib’s version, and jaxlib’s version must be greater than or equal to the minimum jaxlib version specified by jax.

Both jax and jaxlib releases are numbered x.y.z, where x is the major version, and y is the minor version, and z is an optional patch release. Version numbers must follow PEP 440. Version number comparisons are lexicographic comparisons on tuples of integers.

Each jax release has an associated minimum jaxlib version mx.my.mz. The minimum jaxlib version for jax version x.y.z must be no greater than x.y.z.

For jax version x.y.z and jaxlib version lx.ly.lz to be compatible, the following must hold:

  • The jaxlib version (lx.ly.lz) must be greater than or equal to the minimum jaxlib version (mx.my.mz).

  • The jax version (x.y.z) must be greater than or equal to the jaxlib version (lx.ly.lz).

These constraints imply the following rules for releases:

  • jax may be released on its own at any time, without updating jaxlib.

  • If a new jaxlib is released, a jax release must be made at the same time.

These version constraints are currently checked by jax at import time, instead of being expressed as Python package version constraints. jax checks the jaxlib version at runtime rather than using a pip package version constraint because we provide separate jaxlib wheels for a variety of hardware and software versions (e.g, GPU, TPU, etc.). Since we do not know which is the right choice for any given user, we do not want pip to install a jaxlib package for us automatically.

In the future, we hope to separate out the hardware-specific pieces of jaxlib into separate plugins, at which point the minimum version could be expressed as a Python package dependency. For now, we do provide platform-specific extra requirements that install a compatible jaxlib version, e.g., jax[cuda].

How can I safely make changes to the API of jaxlib?#
  • jax may drop compatibility with older jaxlib releases at any time, so long as the minimum jaxlib version is increased to a compatible version. However, note that the minimum jaxlib, even for unreleased versions of jax, must be a released version! This allows us to use released jaxlib wheels in our CI builds, and allows Python developers to work on jax at HEAD without ever needing to build jaxlib.

    For example, to remove an old backwards compatibility path in the jax Python code, it is sufficient to bump the minimum jaxlib version and then delete the compatibility path.

  • jaxlib may drop compatibility with older jax releases lower than its own release version number. The version constraints enforced by jax would forbid the use of an incompatible jaxlib.

    For example, for jaxlib to drop a Python binding API used by an older jax version, the jaxlib minor or major version number must be incremented.

  • If possible, changes to the jaxlib should be made in a backwards-compatible way.

    In general jaxlib may freely change its API, so long as the rules about jax being compatible with all jaxlibs at least as new as the minimum version are followed. This implies that jax must always be compatible with at least two versions of jaxlib, namely, the last release, and the tip-of-tree version, effectively the next release. This is easier to do if compatibility is maintained, although incompatible changes can be made using version tests from jax; see below.

    For example, it is usually safe to add a new function to jaxlib, but unsafe to remove an existing function or to change its signature if current jax is still using it. Changes to jax must work or degrade gracefully for all jaxlib releases greater than the minimum up to HEAD.

Note that the compatibility rules here only apply to released versions of jax and jaxlib. They do not apply to unreleased versions; that is, it is ok to introduce and then remove an API from jaxlib if it is never released, or if no released jax version uses that API.

How is the source to jaxlib laid out?#

jaxlib is split across two main repositories, namely the jaxlib/ subdirectory in the main JAX repository and in the XLA source tree, which lives inside the XLA repository. The JAX-specific pieces inside XLA are primarily in the xla/python subdirectory.

The reason that C++ pieces of JAX, such as Python bindings and runtime components, are inside the XLA tree is partially historical and partially technical.

The historical reason is that originally the xla/python bindings were envisaged as general purpose Python bindings that might be shared with other frameworks. In practice this is increasingly less true, and xla/python incorporates a number of JAX-specific pieces and is likely to incorporate more. So it is probably best to simply think of xla/python as part of JAX.

The technical reason is that the XLA C++ API is not stable. By keeping the XLA:Python bindings in the XLA tree, their C++ implementation can be updated atomically with the C++ API of XLA. It is easier to maintain backward and forward compatibility of Python APIs than C++ ones, so xla/python exposes Python APIs and is responsible for maintaining backward compatibility at the Python level.

jaxlib is built using Bazel out of the jax repository. The pieces of jaxlib from the XLA repository are incorporated into the build as a Bazel submodule. To update the version of XLA used during the build, one must update the pinned version in the Bazel WORKSPACE. This is done manually on an as-needed basis, but can be overridden on a build-by-build basis.

How do we make changes across the jax and jaxlib boundary between releases?#

The jaxlib version is a coarse instrument: it only lets us reason about releases.

However, since the jax and jaxlib code is split across repositories that cannot be updated atomically in a single change, we need to manage compatibility at a finer granularity than our release cycle. To manage fine-grained compatibility, we have additional versioning that is independent of the jaxlib release version numbers.

We maintain an additional version number (_version) in xla_client.py in the XLA repository. The idea is that this version number, is defined in xla/python together with the C++ parts of JAX, is also accessible to JAX Python as jax._src.lib.xla_extension_version, and must be incremented every time that a change is made to the XLA/Python code that has backwards compatibility implications for jax. The JAX Python code can then use this version number to maintain backwards compatibility, e.g.:

from jax._src.lib import xla_extension_version

# 123 is the new version number for _version in xla_client.py
if xla_extension_version >= 123:
  # Use new code path
  ...
else:
  # Use old code path.

Note that this version number is in addition to the constraints on the released version numbers, that is, this version number exists to help manage compatibility during development for unreleased code. Releases must also follow the compatibility rules given above.

Sequencing side-effects in JAX#

sharadmv@ May 9 2022

Overview#

When we write JAX code, we can usually pretend we’re writing single-threaded, eagerly-executed Python even though underneath the hood, JAX and its runtime may execute it asynchronously in the background. As long as we write pure (side-effect-free) code, these performance optimizations are usually invisible to us and don’t interfere with our single-threaded mental model. Asynchronous execution is great – we get performant, parallel code without having to think about it at all!

However, in the presence of side-effects, the illusion begins to break down and the cracks in our mental model start to show. Specifically, these differences show up when we think about the order in which side-effects happen.

In this design note, we explore the interaction between JAX’s execution model, and the ordering of side-effects. We also provide a way of enforcing a “single-threaded” ordering of effects.

Background#

When we write the following Python code

def f():
  print("hello")
  return 2
def g():
  print("world")
  return 3
f()
g()

we expect "hello" to be printed before "world". This might seem obvious but consider the following JAX code:

@partial(jax.jit, device=<device 0>)
def f():
  return 2

@partial(jax.jit, device=<device 1>)
def g():
  return 3
f()
g()

In many cases, JAX will execute f and g in parallel, dispatching the computations onto different threads – g might actually be executed before f. Parallel execution is a nice performance optimization, especially if copying to and from a device is expensive (see the asynchronous dispatch note for more details). In practice, however, we often don’t need to think about asynchronous dispatch because we’re writing pure functions and only care about the inputs and outputs of functions – we’ll naturally block on future values.

However, now imagine that we have a jax.print function that works inside of JIT-ted JAX functions (host_callback.id_print is an example of this). Let’s return to the previous example except with prints in the mix.

@partial(jax.jit, device=<device 0>)
def f():
  jax.print("hello")
  return 2

@partial(jax.jit, device=<device 1>)
def g():
  jax.print("world")
  return 3
f()
g()

Thanks to asynchronous dispatch, we could actually see "world" being printed before "hello". The reordering of the print side-effects breaks the illusion of a single-threaded execution model.

Another example of where side-effects can “reveal” out-of-order execution is when we compile JAX programs. Consider the following JAX code:

@jax.jit
def f(x):
  jax.print("hello")
  jax.print("world")
  return x

Even though in Python, we’ve written the "hello" print before the "world" print, a compiler like XLA is free to reorder them because there’s no explicit data-dependence between the prints.

Motivation#

We’d like to support “ordered” effects. When we say ordered, we mean that the effects occur in the same order as we would if we were executing a single-threaded Python program. This is our main desideratum. In the presence of explicit parallelism like pmap or user threads, we don’t need to maintain this behavior but at least if the user is not explicitly requesting parallelism, we’d like to preserve a single-threaded ordering.

Before we dive in more, let’s first step back and ask ourselves if it is okay if we reorder effects in the name of performance, and conversely, do we need to enforce an ordering on effects at all? In some cases, we don’t need ordering. Maybe some side-effects shouldn’t adversely affect the performance of a JAX program. However, for other side-effects, we may want to enforce a single-threaded program order so users don’t get counterintuitive behavior. Consider a logging effect.

@jax.jit
def f(x, y):
  log_value(x)
  log_value(y)
f(1, 2)

If log is mutating a global list, we might expect that we add x before adding y. For a more strict effect, we may want the option to order the effects.

Enforcing ordered effects#

The main tool we have to enforce the ordering of computations is data-dependence. Simply put, if a function g has an input that is the output of a function f, f must be executed before g.

However, we may have side effects like prints that have no inputs at all so naively we couldn’t sequence them. We thus use tokens as a means of injecting artificial data-dependence into a computation.

What is a token? A token is just a dummy value that can be threaded in and out of a computation. By threading the same token in and out and several computations, we enforce that they have to happen in a certain order. Let’s take the previous print example and see what it would look like with tokens in the mix:

@jax.jit
def f(token, x):
  token = jax.print(token, "hello")
  token = jax.print(token, "world")
  return token, x

If we rewrite jax.print to take in and return a token, we have now sequenced the two prints since the input to the second print depends on the output of the first print. The actual value of token can be anything really, but we’ll see in practice that the tokens are invisible to users.

Runtime tokens vs. compiler tokens#

Here we will actually start talking about implementation details. In practice, we’ll need two separate types of tokens to sequence effects: one for each of the aforementioned sources of reordering. We’ll need runtime tokens to sequence asynchronously dispatched side-effecting computations and we’ll need compiler tokens to sequence effects within computations.

In practice, our computation will be rewritten to look like this:

@jax.jit
def f(runtime_token, x):
  compiler_token = new_compiler_token()
  compiler_token = jax.print(compiler_token, "hello")
  compiler_token = jax.print(compiler_token, "world")
  return runtime_token, x

Notice how the runtime tokens are only used at the JIT boundary and the compiler tokens are only within the compiled code. Compiler tokens are created during “lowering” (we convert Python code to a lower level representation like HLO or StableHLO) but runtime tokens need to be managed in Python since they’re being threaded in and out of JIT-ted functions.

Furthermore, notice that the runtime tokens are “disconnected” from the compiler tokens meaning there’s no data dependency between them. This could potentially be dangerous as if we will lose the data dependence between the bodies of two dispatched function calls. However, if we assume “strict execution” – i.e. a dispatched function will only start execution when all of its inputs are ready and all of it outputs will become ready at the same time – we are safe to create a fresh compiler token and return a non-output-dependent runtime token.

Managing runtime tokens#

To manage runtime tokens on behalf of the user, we’ll need to hook into JAX’s dispatch machinery. Whenever we call a JIT-ted function, we eventually bottom out in a function that looks like this:

def _execute(compiled_computation, *args):
  outputs = compiled_computation.execute(*args)
  return outputs

At this point we need to “inject” the runtime tokens into the computation and “extract” them from the computation’s outputs:

def _execute(compiled_computation, *args):
  runtime_token = get_runtime_token() # Grab global token
  runtime_token, *outputs = compiled_computation.execute(runtime_token, *args)
  update_runtime_token(runtime_token) # Update global token
  return outputs

What is runtime_token exactly? Well we need to be able to pass it into a compiled_computation, which means it needs to be some sort of array (for now, since there’s no shared token representation inside and outside compiled JAX code). In practice we can use a (0,)-shaped array to minimize overheads.

We also need to think about the multiple device use case, e.g. the first example where we first call a JIT-ted function on device 0 and then one on device 1. In that case, we need to also copy the runtime token returned from the first computation (which lives on device 0) to device 1 so we can pass it into the second computation. If two subsequent computations share the same device, this copy is not necessary.

Adding compiler tokens#

When we lower Python code to HLO or StableHLO we need to create a token at the start of the computation and ensure they are available when we have side-effecting computations that need to be ordered. The side-effecting computations will take the token as input and return it as an output.

The implementation of this token threading involves upgrading the JAX lowering machinery to do this bookkeeping automatically. The main challenges involve dealing with higher-order primitives like call primitives and control-flow primitives. We won’t go into details on how to handle those in this design note.

Blocking on output tokens#

Adding support for runtime and compiler tokens for side-effecting computations is important for sequencing but there’s also another subtle use-case for tokens, which is blocking on side-effecting computations. Even if we don’t want a side-effecting computation to be ordered we may still want to wait on its completion. Currently we have jax.block_until_ready, which waits until a future value has its result ready. However, with side-effecting computations, we may have functions that don’t have a return value but are still executing a side-effect. Take the simple example here:

@jax.jit
def f():
  jax.print("hello world")
  return
f() # Executed asynchronously

This compiled computation takes no explicit inputs and has no explicit outputs. If it was an ordered print effect, we could block on the returned runtime token, However, when this is an unordered computation we don’t do any token threading. How do we wait for f() to finish executing when we have no output value to call block_until_ready on? Well, we could apply our same token strategy except we only return runtime tokens and don’t take them as inputs. This will give us a value to block on that will only be ready once f() is done being executed. We’ll call these tokens output tokens. We end up with a function that looks like this:

@jax.jit
def f():
  jax.print("hello world")
  return new_runtime_token()
f() # Executed asynchronously

Underneath the hood, we’ll manage the output tokens in the same way we manage the runtime tokens but provide a method for users to block on the current set of output tokens. Unlike runtime tokens, output tokens need to be device-specific. Consider a single device use-case:

@jax.jit
def f():
  jax.print("hello")

@jax.jit
def g():
  jax.print("world")

f()
g()

Since f() and g() are executed on the same device, blocking on g()’s output token effectively blocks on f() since (as of now!), the JAX runtime does not interleave computations executed on the same device. We’ll have to revise this entire design if that changes, of course.

However, consider the two device use-case:

@partial(jax.jit, device=<device 0>)
def f():
  jax.print("hello")

@partial(jax.jit, device=<device 1>)
def g():
  jax.print("world")

f()
g()

Here we don’t want to explicitly sequence f() and g() but want to wait for both of them to finish. We’ll need one output token for f() and one for g() and we’ll block on both of those tokens:

@partial(jax.jit, device=<device 0>)
def f():
  jax.print("hello")
  return new_runtime_token()

@partial(jax.jit, device=<device 1>)
def g():
  jax.print("world")
  return new_runtime_token()

t0 = f()
t1 = g()
block_until_ready((t0, t1))

We’ll thus need a per-device output token so we can avoid sequencing computations on different devices while offering the ability to block on side-effecting computations. We end up with the following (approximate) change to the JAX dispatch machinery:

def _execute(compiled_computation, *args):
  output_token, *outputs = compiled_computation.execute(runtime_token, *args)
  update_output_token(output_token, compiled_computation.device)
  return outputs

We’ll also need to expose a function to that blocks on the output token:

def effects_barrier():
  output_token.block_until_ready()

Note that blocking on output tokens may not be fairly common since most JAX computations will return a value to block on. However, output tokens are helpful for testing and profiling, and are good to support so that we have a consistent and cohesive effect system.

Some more details#
  • All of the aforementioned token management infrastructure will be thread-local. This means that each user thread will have their own independent stream of runtime tokens. Sequencing is only promised at a user thread level.

  • In practice, we have one runtime token per effect. Different instances of that effect will be sequenced. This is to avoid sequencing effectul computations that may not have any relation to each other. Technically this goes against our original goal though of enforcing a single-threaded Python program ordering, but this is a tradeoff that could be modulated by having both “effect”-specific tokens and “global” tokens.

jax.remat / jax.checkpoint changes: what you need to know#
Contents#
What’s going on?#

As of #11830 we’re switching on a new implementation of jax.checkpoint(), aka jax.remat() (the two names are aliases of one another). For most code, there will be no changes. But there may be some observable differences in edge cases; see What are the possible issues after the upgrade?

How can I disable the change, and go back to the old behavior for now?#

In case you have a problem with this change, through version jax==0.3.16 it is possible to switch off the new implementation by setting the jax_new_checkpoint config option to be False, in any one of these ways:

  1. set the shell environment variable JAX_NEW_CHECKPOINT=0;

  2. execute jax.config.update('jax_new_checkpoint', False);

  3. if you parse flags with absl, pass the --jax_new_checkpoint=False option.

If you need to revert to the old implementation, please reach out on a GitHub issue so that we can make the new implementation work for you.

As of jax==0.3.17 the jax_new_checkpoint config option is no longer available. If you have an issue, please reach out on the issue tracker so we can help fix it!

Why are we doing this?#

At the time of writing, JAX has two parallel implementations of jax.checkpoint. The new one has been used for months (e.g. by Pax and Flaxformer/T5X) on an opt-in basis. But it hasn’t been on-by-default.

We want to switch the new implementation to be on-by-default, and then delete the old implementation. Using the new implementation, and removing the old implementation, gives users several benefits.

User-customizable rematerialization policies#

The main upside of the new implementation is a new feature corresponding to the policy argument. The idea is to give precise user control over what intermediates get saved (versus rematerialized) during the forward pass of automatic differentiation. By exercising this control over the memory-usage vs recomputation tradeoff, users can get significant performance wins, especially in large models and in our LLM MLPerf submission!

The full documentation for this feature is still forthcoming, but here’s a quick example:

from functools import partial
import jax

def apply_layer(W, x):
  return jnp.sin(jnp.dot(W, x))

@partial(jax.checkpoint, policy=jax.checkpoint_policies.checkpoint_dots)
def predict(params, x):
  for W in params[:-1]:
    x = apply_layer(W, x)
  return jnp.dot(params[-1], x)

By applying jax.checkpoint with policy=jax.checkpoint_policies.checkpoint_dots here, we ensure that only the results of matrix multiplies are allowed to be saved during the forward pass. The Jacobian coefficient values from cos applications, and the values of sin applications needed to compute them, are not saved from the forward pass and are instead recomputed during the backward pass. (Policies like this one can be effective on TPUs, where elementwise computations are effectively free but results from the matrix unit are worth saving.)

Ability to rematerialize constants, not just operations with data dependence on arguments#

The old jax.checkpoint implementation couldn’t actually rematerialize computations without a data dependence on arguments to the decorated function. Consider this toy example:

@jax.checkpoint
def f(x):
  a = some_function(jnp.arange(10_000_000))  # `a` does not depend on `x`
  return a * x

The old jax.checkpoint implementation was forced to save the value of a, which could require a lot of memory. The new jax.checkpoint implementation can rematerialize rather than save the value of a.

Significantly less Python overhead in some cases#

The new jax.checkpoint incurs significantly less Python overhead in some cases. Simple overhead benchmarks got 10x faster. These overheads only arise in eager op-by-op execution, so in the common case of using a jax.checkpoint under a jax.jit or similar the speedups aren’t relevant. But still, nice!

Enabling new JAX features by simplifying internals#

This change unlocks big future user benefits too, like custom batching rules (the vmap analogue of custom_vjp) and a forward-differentiable upgrade to custom_vjp. It also significantly reduces complexity in parts of the JAX codebase, which will be good for maintainability and bug-fixing in general.

What are the possible issues after the upgrade?#
Innocuous numerical changes#

Because the new implementation can rematerialize more computations, including those of potentially large constants, some code may see small numerical changes. The magnitude of any numerical changes should be within the range we expect from changing compiler optimizations, like reordering of floating point operations. But some overly tight test tolerances may need to be slightly relaxed.

The concrete=True option is removed.#

The old jax.checkpoint implementation had a boolean concrete option, which allowed tracing on concrete Python values (rather than delaying all computations and only tracing on abstracted values). That option was seldom used, and in the cases where it was used there were much simpler alternatives. So we removed the option in the new jax.checkpoint.

For example, the overwhelmingly common use of concrete=True in Google code was to support passing an argument like is_training:

@partial(jax.checkpoint, concrete=True)  # OLD jax.checkpoint API
def foo(x, is_training):
  if is_training:
    return g(x)
  else:
    return h(x)

With the new jax.checkpoint implementation, we can accomplish the same using the static\_argnums option:

@partial(jax.checkpoint, static_argnums=(1,))  # NEW jax.checkpoint API
def foo(x, is_training):
  if is_training:
    ...

If jax.numpy operations need to be performed on static arguments, with their numerical results computed during Python tracing rather than delayed, we can use static_argnums with jax.ensure_compile_time_eval(). But it seems unlikely that you’d need this!

Type Annotation Roadmap for JAX#
  • Author: jakevdp

  • Date: August 2022

Background#

Python 3.0 introduced optional function annotations (PEP 3107), which were later codified for use in static type checking around the release of Python 3.5 (PEP 484). To some degree, type annotations and static type checking have become an integral part of many Python development workflows, and to this end we have added annotations in a number of places throughout the JAX API. The current state of type annotations in JAX is a bit patchwork, and efforts to add more have been hampered by more fundamental design questions. This doc attempts to summarize those issues and generate a roadmap for the goals and non-goals of type annotations in JAX.

Why do we need such a roadmap? Better/more comprehensive type annotations are a frequent request from users, both internally and externally. In addition, we frequently receive pull requests from external users (for example, PR #9917, PR #10322) seeking to improve JAX’s type annotations: it’s not always clear to the JAX team member reviewing the code whether such contributions are beneficial, particularly when they introduce complex Protocols to address the challenges inherent to full-fledged annotation of JAX’s use of Python. This document details JAX’s goals and recommendations for type annotations within the package.

Why type annotations?#

There are a number of reasons that a Python project might wish to annotate their code-base; we’ll summarize them in this document as Level 1, Level 2, and Level 3.

Level 1: Annotations as documentation#

When originally introduced in PEP 3107, type annotations were motivated partly by the ability to use them as concise, inline documentation of function parameter types and return types. JAX has long utilized annotations in this manner; an example is the common pattern of creating type names aliased to Any. An example can be found in lax/slicing.py [source]:

Array = Any
Shape = core.Shape

def slice(operand: Array, start_indices: Sequence[int],
          limit_indices: Sequence[int],
          strides: Optional[Sequence[int]] = None) -> Array:
  ...

For the purposes of static type checking, this use of Array = Any for array type annotations puts no constraint on the argument values (Any is equivalent to no annotation at all), but it does serve as a form of useful in-code documentation for the developer.

For the sake of generated documentation, the name of the alias gets lost (the HTML docs for jax.lax.slice report operand as type Any), so the documentation benefit does not go beyond the source code (though we could enable some sphinx-autodoc options to improve this: See autodoc_type_aliases).

A benefit of this level of type annotation is that it is never wrong to annotate a value with Any, so it will provide a concrete benefit to developers and users in the form of documentation, without added complexity of satisfying the stricter needs of any particular static type checker.

Level 2: Annotations for intelligent autocomplete#

Many modern IDEs take advantage of type annotations as inputs to intelligent code completion systems. One example of this is the Pylance extension for VSCode, which uses Microsoft’s pyright static type checker as a source of information for VSCode’s IntelliSense completions.

This use of type checking requires going further than the simple aliases used above; for example, knowing that the slice function returns an alias of Any named Array does not add any useful information to the code completion engine. However, were we to annotate the function with a DeviceArray return type, the autocomplete would know how to populate the namespace of the result, and thus be able to suggest more relevant autocompletions during the course of development.

JAX has begun to add this level of type annotation in a few places; one example is the jnp.ndarray return type within the jax.random package [source]:

def shuffle(key: KeyArray, x: Array, axis: int = 0) -> jnp.ndarray:
  ...

In this case jnp.ndarray is an abstract base class that forward-declares the attributes and methods of JAX arrays (see source), and so Pylance in VSCode can offer the full set of autocompletions on results from this function. Here is a screenshot showing the result:

VSCode Intellisense Screenshot

Listed in the autocomplete field are all methods and attributes declared by the abstract ndarray class. We’ll discuss further below why it was necessary to create this abstract class rather than annotating with DeviceArray directly.

Level 3: Annotations for static type-checking#

These days, static type-checking often is the first thing people think of when considering the purpose of type annotations in Python code. While Python does not do any runtime checking of types, several mature static type checking tools exist that can do this as part of a CI test suite. The most important ones for JAX are the following:

  • python/mypy is more or less the standard in the open Python world. JAX currently runs mypy on a subset of source files within the Github Actions CI checks.

  • google/pytype is Google’s static type checker, and projects which depend on JAX within Google frequently use this.

  • microsoft/pyright is important as the static type checker used within VSCode for the Pylance completions mentioned previously.

Full static type checking is the strictest of all the type annotation applications, because it will surface an error any time your type annotations are not precisely correct. On the one hand, this is nice because your static type analysis may catch faulty type annotations (for example, a case where a DeviceArray method is missing from the jnp.ndarray abstract class).

On the other hand, this strictness can make the type checking process very brittle in packages that often rely on duck-typing rather than strict type-safe APIs. You’ll currently find code comments like #type: ignore (for mypy) or #pytype: disable (for pytype) peppered throughout the JAX codebase in several hundred places. These typically represent cases where typing problems have arisen; they may be inaccuracies in JAX type annotations, or inaccuracies in the static type checker’s ability to correctly follow the control flow in the code. On occasion, they are due to real & subtle bugs in the behavior of pytype or mypy. In rare cases, they may be due to the fact that JAX uses Python patterns that are difficult or even impossible to express in terms of Python’s static type annotation syntax.

Type annotation challenges for JAX#

JAX currently has type annotations that are a mixture of different styles, and aimed at all three levels of type annotation discussed above. Partly, this comes from the fact that JAX’s source code poses a number of unique challenges for Python’s type annotation system. We’ll outline them here.

Challenge 1: pytype, mypy and developer friction#

One challenge JAX currently faces is that package development must satisfy the constraints of two different static type checking systems, pytype (used by internal CI and internal Google projects) and mypy (used by external CI and external dependencies). Although the two type checkers have broad overlap in their behavior, each presents its own unique corner cases, as evidenced by the numerous #type: ignore and #pytype: disable statements throughout the JAX codebase.

This creates friction in development: internal contributors may iterate until tests pass, only to find that on export their pytype-approved code falls afoul of mypy. For external contributors, it’s often the opposite: a recent example is #9596 which had to be rolled-back after it failed internal Google pytype checks. Each time we move a type annotation from Level 1 (Any everywhere) to Level 2 or 3 (stricter annotations), it introduces more potential for such frustrating developer experiences.

Challenge 2: array duck-typing#

One particular challenge for annotating JAX code is its heavy use of duck-typing. An input to a function marked Array in general could be one of many different types: a JAX DeviceArray, a NumPy np.ndarray, a NumPy scalar, a Python scalar, a Python sequence, an object with an __array__ attribute, an object with a __jax_array__ attribute, or any flavor of jax.Tracer. For this reason, simple annotations like def func(x: DeviceArray) will not be sufficient, and will lead to false positives for many valid uses. This means that type annotations for JAX functions will not be short or trivial, but we would have to effectively develop a set of JAX-specific typing extensions similar to those in the numpy.typing package.

Challenge 3: transformations and decorators#

JAX’s Python API relies heavily on function transformations (jit(), vmap(), grad(), etc.), and this type of API poses a particular challenge for static type analysis. Flexible annotation for decorators has been a long-standing issue in the mypy package, which was only recently resolved by the introduction of ParamSpec, discussed in PEP 612 and added in Python 3.10. Because JAX follows NEP 29, it cannot rely on Python 3.10 features until sometime after mid-2024. In the meantime, Protocols can be used as a partial solution to this (JAX added this for jit and other methods in #9950) and ParamSpec is possible to use via the typing_extensions package (a prototype is in #9999) though this currently reveals fundamental bugs in mypy (see python/mypy#12593). All that to say: it’s not yet clear that the API of JAX’s function transforms can be suitably annotated within the current constraints of Python type annotation tools.

Challenge 4: array annotation lack of granularity#

Another challenge here is common to all array-oriented APIs in Python, and has been part of the JAX discussion for several years (see #943). Type annotations have to do with the Python class or type of an object, whereas in array-based languages often the attributes of the class are more important. In the case of NumPy, JAX, and similar packages, often we would wish to annotate particular array shapes and data types.

For example, the arguments to the jnp.linspace function must be scalar values, but in JAX scalars are represented by zero-dimensional arrays. So in order for annotations to not raise false positives, we must allow these arguments to be arbitrary arrays. Another example is the second argument to jax.random.choice, which must have dtype=int when shape=(). Python has a plan to enable type annotations with this level of granularity via Variadic Type Generics (see PEP 646, slated for Python 3.11) but like ParamSpec, support for this feature will take a while to stabilize.

There are some third-party projects that may help in the meantime, in particular google/jaxtyping, but this uses non-standard annotations and may not be suitable for annotating the core JAX library itself. All told, the array-type-granularity challenge is less of an issue than the other challenges, because the main effect is that array-like annotations will be less specific than they otherwise could be.

Challenge 5: imprecise APIs inherited from NumPy#

A large part of JAX’s user-facing API is inherited from NumPy within the jax.numpy submodule. NumPy’s API was developed years before static type checking became part of the Python language, and follows Python’s historic recommendations to use a duck-typing/EAFP coding style, in which strict type-checking at runtime is discouraged. As a concrete example of this, consider the numpy.tile() function, which is defined like this:

def tile(A, reps):
  try:
    tup = tuple(reps)
  except TypeError:
    tup = (reps,)
  d = len(tup)
  ...

Here the intent is that reps would contain either an int or a sequence of int values, but the implementation allows tup to be any iterable. When adding annotations to this kind of duck-typed code, we could take one of two routes:

  1. We may choose to annotate the intent of the function’s API, which here might be something like reps: Union[int, Sequence[int]].

  2. Conversely, we may choose to annotate the implementation of the function, which here might look something like reps: Union[ConvertibleToInt, Iterable[ConvertibleToInt]] where ConvertibleToInt is a special protocol that covers the exact mechanism by which our function converts the inputs to integers (i.e. via __int__, via __index__, via __array__, etc.). Note also here that in a strict sense, Iterable is not sufficient here because there are objects in Python that duck-type as iterables but do not satisfy a static type check against Iterable (namely, an object that is iterable via __getitem__ rather than __iter__.)

The advantage of #1, annotating intent, is that the annotations are more useful to the user in communicating the API contract; while for the developer the flexibility leaves room for refactoring when necessary. The down-side (particularly for gradually-typed APIs like JAX’s) is that it’s quite likely that user code exists which runs correctly, but would be flagged as incorrect by a type checker. Gradual typing of an existing duck-typed API means that the current annotation is implicitly Any, so changing this to a stricter type may present to users as a breaking change.

Broadly speaking, annotating intent better serves Level 1 type checking, while annotating implementation better serves Level 3, while Level 2 is more of a mixed bag (both intent and implementation are important when it comes to annotations in IDEs).

JAX type annotation roadmap#

With this framing (Level 1/2/3) and JAX-specific challenges in mind, we can begin to develop our roadmap for implementing consistent type annotations across the JAX project.

Guiding Principles#

For JAX type annotation, we will be guided by the following principles:

Purpose of type annotations#

We would like to support full, Level 1, 2, and 3 type annotation as far as possible. In particular, this means that we should have restrictive type annotations on both inputs and outputs to public API functions.

Annotate for intent#

JAX type annotations should in general indicate the intent of APIs, rather than the implementation, so that the annotations become useful to communicate the contract of the API. This means that at times inputs that are valid at runtime may not be recognized as valid by the static type checker (one example might be an arbitrary iterator passed in place of a shape that is annotated as Shape = Sequence[int]).

Inputs should be permissively-typed#

Inputs to JAX functions and methods should be typed as permissively as is reasonable: for example, while shapes are typically tuples, functions that accept a shape should accept arbitrary sequences. Similarly, functions that accept a dtype need not require an instance of class np.dtype, but rather any dtype-convertible object. This might include strings, built-in scalar types, or scalar object constructors such as np.float64 and jnp.float64. In order to make this as uniform as possible across the package, we will add a jax.typing module with common type specifications, starting with broad categories such as:

  • ArrayLike would be a union of anything that can be implicitly converted into an array: for example, jax arrays, numpy arrays, JAX tracers, and python or numpy scalars

  • DTypeLike would be a union of anything that can be implicitly converted into a dtype: for example, numpy dtypes, numpy dtype objects, jax dtype objects, strings, and built-in types.

  • ShapeLike would be a union of anything that could be converted into a shape: for example, sequences of integer or integer-like objects.

  • etc.

Note that these will in general be simpler than the equivalent protocols used in numpy.typing. For example, in the case of DTypeLike, JAX does not support structured dtypes, so JAX can use a simpler implementation. Similarly, in ArrayLike, JAX generally does not support list or tuple inputs in place of arrays, so the type definition will be simpler than the NumPy analog.

Outputs should be strictly-typed#

Conversely, outputs of functions and methods should be typed as strictly as possible: for example, for a JAX function that returns an array, the output should be annotated with something similar to jnp.ndarray rather than ArrayLike. Functions returning a dtype should always be annotated np.dtype, and functions returning a shape should always be Tuple[int] or a strictly-typed NamedShape equivalent. For this purpose, we will implement in jax.typing several strictly-typed analogs of the permissive types mentioned above, namely:

  • Array or NDArray (see below) for type annotation purposes is effectively equivalent to Union[Tracer, jnp.ndarray] and should be used to annotate array outputs.

  • DType is an alias of np.dtype, perhaps with the ability to also represent key types and other generalizations used within JAX.

  • Shape is essentially Tuple[int, ...], perhaps with some additional flexibility to account for dynamic shapes.

  • NamedShape is an extension of Shape that allows for named shapes as used internally in JAX.

  • etc.

We will also explore whether the current implementation of jax.numpy.ndarray can be dropped in favor of making ndarray an alias of Array or similar.

Err toward simplicity#

Aside from common typing protocols gathered in jax.typing, we should err on the side of simplicity. We should avoid constructing overly-complex protocols for arguments passed to API functions, and instead use simple unions such as Union[simple_type, Any] in the case that the full type specification of the API cannot be succinctly specified. This is a compromise that achieves the goals of Level 1 and 2 annotations, while punting on Level 3 in favor of avoiding unnecessary complexity.

Avoid unstable typing mechanisms#

In order to not add undue development friction (due to the internal/external CI differences), we would like to be conservative in the type annotation constructs we use: in particular, when it comes to recently-introduced mechanisms such as ParamSpec (PEP 612) and Variadic Type Generics (PEP 646), we would like to wait until support in mypy and other tools matures and stabilizes before relying on them.

One impact of this is that for the time being, when functions are decorated by JAX transformations like jit, vmap, grad, etc. JAX will effectively strip all annotations from the decorated function. While this is unfortunate, at the time of this writing mypy has a laundry-list of incompatibilities with the potential solution offered by ParamSpec (see ParamSpec mypy bug tracker), and we therefore judge it as not ready for full adoption in JAX at this time. We will revisit this question in the future once support for such features stabilizes.

Similarly, for the time being we will avoid adding the more complex & granular array type annotations offered by the jaxtyping project. This is a decision we could revisit at a future date.

Array Type Design Considerations#

As mentioned above, type annotation of arrays in JAX poses a unique challenge because of JAX’s extensive use of duck-typing, i.e. passing and returning Tracer objects in place actual arrays within jax transformations. This becomes increasingly confusing because objects used for type annotation often overlap with objects used for runtime instance checking, and may or may not correspond to the actual type hierarchy of the objects in question. For JAX, we need to provide duck-typed objects for use in two contexts: static type annotations and runtime instance checks.

The following discussion will assume that jax.Array is the runtime type on-device arrays, which is not yet the case but will be once the work in #12016 is complete.

Static type annotations#

We need to provide an object that can be used for duck-typed type annotations. Assuming for the moment that we call this object ArrayAnnotation, we need a solution which satisfies mypy and pytype for a case like the following:

@jit
def f(x: ArrayAnnotation) -> ArrayAnnotation:
  assert isinstance(x, core.Tracer)
  return x

This could be accomplished via a number of approaches, for example:

  • Use a type union: ArrayAnnotation = Union[Array, Tracer]

  • Create an interface file that declares Tracer and Array should be treated as subclasses of ArrayAnnotation.

  • Restructure Array and Tracer so that ArrayAnnotation is a true base class of both.

Runtime instance checks#

We also must provide an object that can be used for duck-typed runtime isinstance checks. Assuming for the moment that we call this object ArrayInstance, we need a solution that passes the following runtime check:

def f(x):
  return isinstance(x, ArrayInstance)
x = jnp.array([1, 2, 3])
assert f(x)       # x will be an array
assert jit(f)(x)  # x will be a tracer

Again, there are a couple mechanisms that could be used for this:

  • override type(ArrayInstance).__instancecheck__ to return True for both Array and Tracer objects; this is how jnp.ndarray is currently implemented (source).

  • define ArrayInstance as an abstract base class and dynamically register it to Array and Tracer

  • restructure Array and Tracer so that ArrayInstance is a true base class of both Array and Tracer

A decision we need to make is whether ArrayAnnotation and ArrayInstance should be the same or different objects. There is some precedent here; for example in the core Python language spec, typing.Dict and typing.List exist for the sake of annotation, while the built-in dict and list serve the purposes of instance checks. However, Dict and List are deprecated in newer Python versions in favor of using dict and list for both annotation and instance checks.

Following NumPy’s lead#

In NumPy’s case, np.typing.NDArray serves the purpose of type annotations, while np.ndarray serves the purpose of instance checks (as well as array type identity). Given this, it may be reasonable to conform to NumPy’s precedent and implement the following:

  • jax.Array is the actual type of on-device arrays.

  • jax.typing.NDArray is the object used for duck-typed array annotations.

  • jax.numpy.ndarray is the object used for duck-typed array instance checks.

This might feel somewhat natural to NumPy power-users, however this trifurcation would likely be a source of confusion: the choice of which to use for instance checks and annotations is not immediately clear.

Unifying instance checks and annotation#

Another approach would be to unify type checking and annotation via override mechanisms mentioned above.

Option 1: Partial unification#

A partial unification might look like this:

  • jax.Array is the actual type of on-device arrays.

  • jax.typing.Array is the object used for duck-typed array annotations (via .pyi interfaces on Array and Tracer).

  • jax.typing.Array is also the object used duck-typed instance checks (via an __isinstance__ override in its metaclass)

In this approach, jax.numpy.ndarray would become a simple alias jax.typing.Array for backward compatibility.

Option 2: Full unification via overrides#

Alternatively, we could opt for full unification via overrides:

  • jax.Array is the actual type of on-device arrays.

  • jax.Array is also the object used for duck-typed array annotations (via a .pyi interface on Tracer)

  • jax.Array is also the object used for duck-typed instance checks (via an __isinstance__ override in its metaclass)

Here, jax.numpy.ndarray would become a simple alias jax.Array for backward compatibility.

Option 3: Full unification via class hierarchy#

Finally, we could opt for full unification via restructuring of the class hierarchy and replacing duck-typing with OOP object hierarchies:

  • jax.Array is the actual type of on-device arrays

  • jax.Array is also the object used for array type annotations, by ensuring that Tracer inherits from jax.Array

  • jax.Array is also the object used for instance checks, via the same mechanism

Here jnp.ndarray could be an alias for jax.Array. This final approach is in some senses the most pure, but it is somewhat forced from an OOP design standpoint (Tracer is an Array?).

Option 4: Partial unification via class hierarchy#

We could make the class hierarchy more sensible by making Tracer and the class for on-device arrays inherit from a common base class. So, for example:

  • jax.Array is a base class for Tracer as well as the actual type of on-device arrays, which might be jax._src.ArrayImpl or similar.

  • jax.Array is the object used for array type annotations

  • jax.Array is also the object used for instance checks

Here jnp.ndarray would be an alias for Array. This may be purer from an OOP perspective, but compared to Options 2 and 3 it drops the notion that type(x) is jax.Array will evaluate to True.

Evaluation#

Considering the overall strengths and weaknesses of each potential approach:

  • From a user perspective, the unified approaches (options 2 and 3) are arguably best, because they remove the cognitive overhead involved in remembering which objects to use for instance checks or annotations: jax.Array is all you need to know.

  • However, both options 2 and 3 introduce some strange and/or confusing behavior. Option 2 depends on potentially confusing overrides of instance checks, which are not well supported for classes defined in pybind11. Option 3 requires Tracer to be a subclass array. This breaks the inheritance model, because it would require Tracer objects to carry all the baggage of Array objects (data buffers, sharding, devices, etc.)

  • Option 4 is purer in an OOP sense, and avoids the need for any overrides of typical instance check or type annotation behavior. The tradeoff is that the actual type of on-device arrays becomes something separate (here jax._src.ArrayImpl). But the vast majority of users would never have to touch this private implementation directly.

There are different tradeoffs here, but after discussion we’ve landed on Option 4 as our way forward.

Implementation Plan#

To move forward with type annotations, we will do the following:

  1. Iterate on this JEP doc until developers and stakeholders are bought-in.

  2. Create a private jax._src.typing (not providing any public APIs for now) and put in it the first level of simple types mentioned above:

    • Alias Array = Any for the time being, as this will take a bit more thought.

    • ArrayLike: a Union of types valid as inputs to normal jax.numpy functions

    • DType / DTypeLike (Note: numpy uses camel-cased DType; we should follow this convention for ease of use)

    • Shape / NamedShape / ShapeLike

    The beginnings of this are done in #12300.

  3. Begin work on a jax.Array base class that follows Option 4 from the previous section. Initially this will be defined in Python, and use the dynamic registration mechanism currently found in the jnp.ndarray implementation to ensure correct behavior of isinstance checks. A pyi override for each tracer and array-like class would ensure correct behavior for type annotations. jnp.ndarray could then be make into an alias of jax.Array

  4. As a test, use these new typing definitions to comprehensively annotate functions within jax.lax according to the guidelines above.

  5. Continue adding additional annotations one module at a time, focusing on public API functions.

  6. In parallel, begin re-implementing a jax.Array base class in pybind11, so that ArrayImpl and Tracer can inherit from it. Use a pyi definition to ensure static type checkers recognize the appropriate attributes of the class.

  7. Once jax.Array and jax._src.ArrayImpl have fully landed, remove these temporary Python implementations.

  8. When all is finalized, create a public jax.typing module that makes the above types available to users, along with documentation of annotation best practices for code using JAX.

We will track this work in #12049, from which this JEP gets its number.

shmap (shard_map) for simple per-device code#

sholto@, sharadmv@, jekbradbury@, zhangqiaorjc@, mattjj@

January 2023

Motivation#

JAX supports two schools of thought for multi-device programming:

  1. Compiler, take the wheel! Let the compiler automatically partition bulk array functions over devices.

  2. Just let me write what I mean, damnit! Give me per-device code and explicit communication collectives.

We need great APIs for both, and rather than being mutually exclusive alternatives, they need to compose with each other.

With pjit (now just jit) we have a next-gen API for the first school. But we haven’t quite leveled-up the second school. pmap follows the second school, but over time we found it has fatal flaws. xmap solved those flaws, but it doesn’t quite give us per-device shapes, and it includes several other big ideas too. Meanwhile, new demands for per-device explicit-collectives programming have emerged, like in Efficiently Scaling Transformer Inference.

We can level-up the second school with shmap. shmap is:

  • a simple multi-device parallelism API which lets us write per-device code with explicit collectives, where logical shapes match per-device physical buffer shapes and collectives correspond exactly to cross-device communication;

  • a specialization of xmap with scaled-back features and a few tweaks;

  • a fairly direct surfacing of the XLA SPMD Partitioner’s ‘manual’ mode;

  • a fun-to-say Seussian name which could stand for shard_map, shpecialized_xmap, sholto_map, or sharad_map.

For pjit users, shmap is a complementary tool. It can be used inside a pjit computation to drop temporarily into a “manual collectives” mode, like an escape hatch from the compiler’s automatic partitioning. That way, users get the convenience and familiar just-NumPy programming model of pjit for most of their code, along with the ability to hand-optimize collective communication with shmap wherever it’s needed. It’s the best of both worlds!

For pmap users, shmap is a strict upgrade. It’s more expressive, performant, and composable with other JAX APIs, without making basic batch data parallelism any harder.

For more on practical use, you can jump to When should you use shmap and when should you use pjit?. If you’re wondering why we need a new thing at all, or what the problems with pmap are, jump to Why don’t pmap or xmap already solve this?. Or keep reading the next section to see some shmap examples and the API spec.

So, let’s see shmap!#
TL;DR example (with a more detailed explanation to follow)#

Sho shick:

from functools import partial

import numpy as np

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map

devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('i', 'j'))

a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 * 32.).reshape(16, 32)

@partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P('j', None)),
         out_specs=P('i', None))
def matmul_basic(a_block, b_block):
  # a_block: f32[2, 8]
  # b_block: f32[8, 32]
  z_partialsum = jnp.dot(a_block, b_block)
  z_block = jax.lax.psum(z_partialsum, 'j')
  return z_block

c = matmul_basic(a, b)  # c: f32[8, 32]

Notice:

  • no nesting needed (or axis_index_groups) for multiple axes of parallelism, unlike pmap;

  • no reshapes in the caller, unlike pmap and hard-xmap, and logical shapes correspond to per-device physical shapes, unlike (non-hard) xmap;

  • precise device placement control by using mesh, unlike pmap;

  • there’s only one set of axis names for logical and physical, unlike xmap;

  • the result is a jax.Array which could be efficiently passed to a pjit, unlike pmap;

  • this same code works efficiently inside a pjit/jit, unlike pmap;

  • this code works eagerly, so we can pdb in the middle and print values, unlike xmap’s current implementation (though by design xmap without the sequential schedule can in principle work eagerly too).

Here’s another matmul variant with a fully sharded result:

@partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P('j', None)),
         out_specs=P('i', 'j'))
def matmul_reduce_scatter(a_block, b_block):
  # c_partialsum: f32[8/X, 32]
  c_partialsum = jnp.matmul(a_block, b_block)
  # c_block: f32[8/X, 32/Y]
  c_block = jax.lax.psum_scatter(c_partialsum, 'j', scatter_dimension=1, tiled=True)
  return c_block

c = matmul_reduce_scatter(a, b)
Slow down, start with the basics!#
Rank-reducing vs rank-preserving maps over array axes#

We can think of pmap (and vmap and xmap) as unstacking each array input along an axis (e.g. unpacking a 2D matrix into its 1D rows), applying its body function to each piece, and stacking the results back together, at least when collectives aren’t involved:

pmap(f, in_axes=[0], out_axes=0)(xs) == jnp.stack([f(x) for x in xs])

For example, if xs had shape f32[8,5] then each x has shape f32[5], and if each f(x) has shape f32[3,7] then the final stacked result pmap(f)(xs) has shape f32[8,3,7]. That is, each application of the body function f takes as argument inputs with one fewer axis than the corresponding argument to pmap(f). We can say these are rank-reducing maps with unstacking/stacking of inputs/outputs.

The number of logical applications of f is determined by the size of the input axis being mapped over: for example, if we map over an input axis of size 8, semantically we get 8 logical applications of the function, which for pmap always correspond to 8 devices physically computing them.

In contrast, shmap does not have this rank-reducing behavior. Instead, we can think of it as slicing (or “unconcatenating”) along input axes into blocks, applying the body function, and concatenating the results back together (again when collectives aren’t involved):

devices = np.array(jax.devices()[:4])
m = Mesh(devices, ('i',))  # mesh.shape['i'] = 4

shard_map(f, m, in_specs=P('i'), out_specs=P('i'))(y)
==
jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, 4)])

Recall that jnp.split slices its input into equally-sized blocks with the same rank, so that if in the above example y has shape f32[8,5] then each y_blk has shape f32[2,5], and if each f(y_blk) has shape f32[3,7] then the final concatenated result shard_map(f, ...)(y) has shape f32[12,7]. So shmap (shard_map) maps over shards, or blocks, of its inputs. We can say it’s a rank-preserving map with unconcatenating/concatenating of its inputs/outputs.

The number of logical applications of f is determined by the mesh size, not by any input axis size: for example, if we have a mesh of total size 4 (i.e. over 4 devices) then semantically we get 4 logical applications of the function, corresponding to the 4 devices physically computing them.

Controlling how each input is split (unconcatenated) and tiled with in_specs#

Each of the in_specs identifies some of the corresponding input array’s axes with mesh axes by name using PartitionSpecs, representing how to split (or unconcatenate) that input into the blocks to which the body function is applied. That identification determines the shard sizes; when an input axis is identified with a mesh axis, the input is split (unconcatenated) along that logical axis into a number of pieces equal to the corresponding mesh axis size. (It’s an error if the corresponding mesh axis size does not evenly divide the input array axis size.) If an input’s pspec does not mention a mesh axis name, then there’s no splitting over that mesh axis. For example:

devices = np.array(jax.devices())
m = Mesh(devices.reshape(4, 2), ('i', 'j'))

@partial(shard_map, mesh=m, in_specs=P('i', None), out_specs=P('i', 'j'))
def f1(x_block):
  print(x_block.shape)
  return x_block

x1 = np.arange(12 * 12).reshape(12, 12)
y = f1(x1)  # prints (3,12)

Here, because the input pspec did not mention the mesh axis name 'j', no input array axis is split over that mesh axis; similarly, because the second axis of the input array is not identified with (and hence split over) any mesh axis, application of f1 gets a full view of the input along that axis.

When a mesh axis is not mentioned in an input pspec, we can always rewrite to a less efficient program where all mesh axes are mentioned but the caller performs a jnp.tile, for example:

@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P('i', 'j'))
def f2(x_block):
  print(x_block.shape)
  return x_block

x = np.arange(12 * 12).reshape(12, 12)
x_ = jnp.tile(x, (1, mesh.axis_size['j']))  # x_ has shape (12, 24)
y = f2(x_)  # prints (3,12), and f1(x) == f2(x_)

In other words, because each input pspec can mention each mesh axis name zero or one times, rather than having to mention each name exactly once, we can say that in addition to the jnp.split built into its input, shard_map also has a jnp.tile built into its input, at least logically (though the tiling may not need to be carried out physically, depending on the arguments’ physical sharding layout). The tiling to use is not unique; we could also have tiled along the first axis, and used the pspec P(('j', 'i'), None).

Physical data movement is possible on inputs, as each device needs to have a copy of the appropriate data.

Controlling how each output assembled by concatenation, block transposition, and untiling using out_specs#

Analogously to the input side, each of the out_specs identifies some of the corresponding output array’s axes with mesh axes by name, representing how the output blocks (one for each application of the body function, or equivalently one for each physical device) should be assembled back together to form the final output value. For example, in both the f1 and f2 examples above the out_specs indicate we should form the final output by concatenating together the block results along both axes, resulting in both cases an array y of shape (12,24). (It’s an error if an output shape of the body function, i.e. an output block shape, has a rank too small for the concatenation described by the corresponding output pspec.)

When a mesh axis name is not mentioned in an output pspec, it represents an un-tiling: when the user writes an output pspec which does not mention one of the mesh axis names, they promise that the output blocks are equal along that mesh axis, and so only one block along that axis is used in the output (rather than concatenating all the blocks together along that mesh axis). For example, using the same mesh as above:

x = jnp.array([[3.]])

z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P('i', 'j'))()
print(z)  # prints the same as jnp.tile(x, (4, 2))

z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P('i', None))()
print(z)  # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))

z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P(None, None))()
print(z)  # prints the same as jnp.tile(x, (1, 1)), or just x

Notice that the body function closing over an array value is equivalent to passing it as an augment with a corresponding input pspec of P(None, None). As another example, following more closely to the other examples above:

@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P('i', None))
def f3(x_block):
  return jax.lax.psum(x_block, 'j')

x = np.arange(12 * 12).reshape(12, 12)
y3 = f3(x)
print(y3.shape)  # (12,6)

Notice that the result has a second axis size of 6, half the size of the input’s second axis. In this case, the un-tile expressed by not mentioning the mesh axis name 'j' in the output pspec was safe because of the collective psum, which ensures each output block is equal along the corresponding mesh axis. Here are two more examples where we vary which mesh axes are mentioned in the output pspec:

@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f4(x_block):
  return jax.lax.psum(x_block, 'i')

x = np.arange(12 * 12).reshape(12, 12)
y4 = f4(x)
print(y4.shape)  # (3,12)


@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P(None, None))
def f5(x_block):
  return jax.lax.psum(x_block, ('i', 'j'))

y5 = f5(x)
print(y5.shape)  # (3,6)

On the physical side, not mentioning a mesh axis name in an output pspec assembles an Array from the output device buffers with replicated layout along that mesh axis.

There is no runtime check that the output blocks are actually equal along a mesh axis to be un-tiled along, or equivalently that the corresponding physical buffers have equal values and thus can be interpreted as a replicated layout for a single logical array. But we can provide a static check mechanism which raises an error on all potentially-incorrect programs.

Because the out_specs can mention mesh axis names zero or one times, and because they can be mentioned in any order, we can say that in addition to the jnp.concatenate built into its output, shard_map also has both an untile and a block transpose built into its output.

Physical data movement is not possible on outputs, no matter the output pspec. Instead, out_specs just encodes how to assemble the block outputs into Arrays, or physically how to interpret the buffers across devices as the physical layout of a single logical Array.

API Specification#
from jax.sharding import Mesh
Specs = PyTree[PartitionSpec]

def shard_map(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs
          ) -> Callable:
  ...

where:

  • mesh encodes devices arranged in an array and with associated axis names, just like it does for xmap and for sharding.NamedSharding;

  • in_specs and out_specs are PartitionSpecs which can affinely mention axis names from mesh (not separate logical names as in xmap) to express slicing/unconcatenation and concatenation of inputs and outputs, respectively (not unstacking and stacking like pmap and xmap do), with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively;

  • the shapes of the arguments passed to f have the same ranks as the arguments passed to shard_map-of-f (unlike pmap and xmap where the ranks are reduced), and the shape of an argument to f is computed from the shape shape of the corresponding argument to shard_map-of-f and the corresponding PartitionSpec spec as roughly tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec));

  • the body of f can apply collectives using names from mesh.

shmap is eager by default, meaning that we dispatch computations primitive-by-primitive, so that the user can employ Python control flow on fully replicated values and interactive pdb debugging to print any values. To stage out and end-to-end compile a shmapped function, just put a jit around it. A consequence is that shmap doesn’t have its own dispatch and compilation paths like xmap and pmap currently do; it’s just the jit path.

When it’s staged out by e.g. an enclosing jit, the lowering of shmap to StableHLO is trivial: it just involves switching into ‘manual SPMD mode’ on the inputs, and switching back on the outputs. (We don’t currently plan to support partially-manual-partially-automatic modes.)

The interaction with effects is the same as with pmap.

The interaction with autodiff is also just like pmap (rather than attempting the new semantics that xmap did, corresponding to having unmapped intermediates and hence grad’s reduce_axes as well as making psum transpose to pbroadcast rather than psum). But it thus inherits an unsolved problem from pmap: in some cases, instead of transposing psum to psum, and thus performing a backward pass psum corresponding to the forward pass psum, it can be beneficial to move the backward pass psum to elsewhere in the backward pass, exploiting linearity. Many advanced pmap users addressed this challenge by using custom_vjp to implement psum_idrev and id_psumrev functions, but since it’s easy to accidentally leave those imbalanced, that technique is a foot-cannon. We have some ideas on how to provide this functionality in a safer way.

When should you use shmap and when should you use pjit?#

One philosophy is: it is almost always simpler to write a program in jit==pjit — but if a given part of the program is less optimized by the compiler than it could be, drop into shmap!

A realistic transformer example#

In fact, we can implement a simple version of the “collective matmul” algorithm recently introduced in XLA to overlap communication and computation using shmap and 30 lines of Python. The basic idea of the algorithm can be grasped with a simple example.

Suppose we want to compute C = A @ B where A is sharded by a 1D mesh on the 0-th dimension while B and C are replicated.

M, K, N = 4096, 2048, 1024
A = jnp.arange(np.prod((M, K))).reshape((M, K))
B = jnp.arange(np.prod((K, N))).reshape((K, N))

mesh = Mesh(np.array(jax.devices()), axis_names=('i'))
A_x = jax.device_put(A, NamedSharding(mesh, P('i', None)))

@jax.jit
def f(lhs, rhs):
  return lhs @ rhs

C = f(A_x, B)

A profile shows the blocking all-gather across 8 devices before the matmul can start. This is suboptimal because A is sharded on a non-contracting dimension, and each shard of A can be matmul’ed with B independently and this chunked computation can be overlapped with fetching of the next shard of A from another device.

image

This overlap can be implemented using shmap and explicit collectives.

def collective_matmul_allgather_lhs_non_contracting(lhs, rhs):
  # lhs is the looped operand; rhs is the local operand
  axis_size = jax.lax.psum(1, axis_name='i')
  axis_index = jax.lax.axis_index(axis_name='i')
  chunk_size = lhs.shape[0]

  def f(i, carrys):
    accum, lhs = carrys
    # matmul for a chunk
    update = lhs @ rhs
    # circular shift to the left
    lhs = jax.lax.ppermute(
        lhs,
        axis_name='i',
        perm=[(j, (j - 1) % axis_size) for j in range(axis_size)]
    )
    # device 0 computes chunks 0, 1, ...
    # device 1 computes chunks 1, 2, ...
    update_index = (((axis_index + i) % axis_size) * chunk_size, 0)
    accum = jax.lax.dynamic_update_slice(accum, update, update_index)
    return accum, lhs

  accum = jnp.zeros((lhs.shape[0] * axis_size, rhs.shape[1]), dtype=lhs.dtype)
  # fori_loop cause a crash: hlo_sharding.cc:817 Check failed: !IsManual()
  # accum, lhs = jax.lax.fori_loop(0, axis_size - 1, f, (accum, lhs))
  for i in range(0, axis_size - 1):
    accum, lhs = f(i, (accum, lhs))

  # compute the last chunk, without the ppermute
  update = lhs @ rhs
  i = axis_size - 1
  update_index = (((axis_index + i) % axis_size) * chunk_size, 0)
  accum = jax.lax.dynamic_update_slice(accum, update, update_index)

  return accum
jit_sharded_f = jax.jit(shard_map(
  collective_matmul_allgather_lhs_non_contracting, mesh,
  in_specs=(P('i', None), P()), out_specs=P()))
C = jit_sharded_f(A_x, B)

A profile shows that the all-gather is gone, and replaced with overlapped matmul with async collective permute. This profile matches very closely with the collective matmul paper result.

image

This collective matmul technique can be used to speed up feedforward blocks in transformer layers. This typically consists of two matrix multiplications followed by a ReduceScatter (to resolve partial sums from a parallelized matrix multiplication) and preceded by an AllGather (to collect the sharded dimensions along some axes and allow partial sum computation). Together, the ReduceScatter from one layer and the AllGather for the next amount to an AllReduce.

In a typical profile, the two matmuls will be followed by an AllReduce, and they will not be overlapped. Collective matmul can be used to achieve the overlap, but is difficult to trigger, has a minimum slice size and does not yet cover all topologies, tensor shapes and variants of collective matmul (i.e latency and throughput optimized variants). In a recent paper, we found a ~40% gain in many circumstances from manually implementing collective matmul variants in shmap style.

But it isn’t always more complex! We expect this to be a much more natural way to think about pipelined computation, and plan to do some demos of that soon!

Another realistic example#

Here’s how shmap might look in a transformer layer pass with a 2D weight gathered pattern (paper, Sec 3.2.3 on p. 5):

def matmul_2D_wg_manual(xnorm, q_wi, layer):
  '''Calls a custom manual implementation of matmul_reducescatter'''
  # [batch, maxlen, embed.X] @ [heads.YZ, embed.X, q_wi_per_head]
  # -> (matmul)
  # -> [batch, maxlen, heads.YZ, q_wi_per_head]{x unreduced}
  # -> (reducescatter over x into X heads, B batches)
  # -> [batch, maxlen, heads.YZX, q_wi_per_head]
  with jax.named_scope('q_wi'):
    xnorm = intermediate_dtype(xnorm)
    q_wi = matmul_reducescatter(
        'bte,hed->bthd',
        xnorm,
        params.q_wi,
        scatter_dimension=(0, 2),
        axis_name='i',
        layer=layer)
   return q_wi


import partitioning.logical_to_physical as l2phys

def pjit_transformer_layer(
    hparams: HParams, layer: int, params: weights.Layer, sin: jnp.ndarray,
    cos: jnp.ndarray, kv_caches: Sequence[attention.KVCache],
    x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
  """Forward pass through a single layer, returning output, K, V."""

  def my_layer(t, axis=0):
    """Gets the parameters corresponding to a given layer."""
    return lax.dynamic_index_in_dim(t, layer, axis=axis, keepdims=False)

  # 2D: [batch.Z, time, embed.XY]
  x = _with_sharding_constraint(
      x, ('residual_batch', 'residual_time', 'residual_embed'))
  xnorm = _layernorm(x)
  # 2D: [batch, time, embed.X]
  xnorm = _with_sharding_constraint(
      xnorm, ('post_norm_batch', 'time', 'post_norm_embed'))
  # jump into manual mode where you want to optimise
  if manual:
    q_wi = shard_map(matmul_2D_wg_manual, mesh
                in_specs=(l2phys('post_norm_batch', 'time', 'post_norm_embed'),
                          l2phys('layers', 'heads', 'embed', 'q_wi_per_head')),
                out_specs=l2phys('post_norm_batch', 'time', 'heads', 'q_wi_per_head'))(xnorm, q_wi, layer)
  else:
    q_wi = jnp.einsum('bte,hed->bthd', xnorm, my_layer(params.q_wi))
    # 2D: [batch, time, heads.YZX, None]
    q_wi = _with_sharding_constraint(q_wi,
                                   ('post_norm_batch', 'time', 'heads', 'qkv'))
  q = q_wi[:, :, :, :hparams.qkv]
  q = _rope(sin, cos, q)
  # unlike in https://arxiv.org/pdf/2002.05202.pdf, PaLM implements
  # swiGLU with full d_ff dimension, rather than 2/3 scaled
  wi0 = q_wi[:, :, :, hparams.qkv:hparams.qkv + (hparams.ff // hparams.heads)]
  wi1 = q_wi[:, :, :, hparams.qkv + (hparams.ff // hparams.heads):]
  kv = jnp.einsum('bte,ezd->btzd', xnorm, my_layer(params.kv))
  k = kv[:, :, 0, :hparams.qkv]
  v = kv[:, :, 0, hparams.qkv:]
  k = _rope(sin, cos, k)

  y_att = jnp.bfloat16(attention.attend(q, k, v, kv_caches, layer))

  y_mlp = special2.swish2(wi0) * wi1
  # 2D: [batch, time, heads.YZX, None]
  y_mlp = _with_sharding_constraint(y_mlp,
                                    ('post_norm_batch', 'time', 'heads', None))

  y_fused = jnp.concatenate([y_att, y_mlp], axis=-1)
  # do the second half of the mlp and the self-attn projection in parallel
  y_out = jnp.einsum('bthd,hde->bte', y_fused, my_layer(params.o_wo))
  # 2D: [batch.Z, time, embed.XY]
  y_out = _with_sharding_constraint(
      y_out, ('residual_batch', 'residual_time', 'residual_embed'))
  z = y_out + x
  z = _with_sharding_constraint(
      z, ('residual_batch', 'residual_time', 'residual_embed'))
  return z, k, v

In the profile below, both the first and second matmul were replaced by manually lowered versions, where the compute (fusions) are fully overlapped with the communication (ppermute)! One fun hint that we are using a latency optimised variant is that the ppmerute pixels are jittered — because there are two overlapping ppermutes using opposite ICI axes at the same time!

All-to-all is much harder to overlap, so was left on the table.

image
Why don’t pmap or xmap already solve this?#

pmap was our first multi-device parallelism API. It follows the per-device-code-and-explicit-collectives school. But it had major shortcomings which make it unsuitable for today’s programs:

  • Mapping multiple axes required nested pmaps. Not only are nested pmaps cumbersome to write, but also they make it difficult to control (or even predict) the device placement of data and computation, and difficult to preserve data sharding (see the next two bullets). Today’s programs require multiple axes of parallelism.

  • Controlling device placement was impossible. Especially with multiple axes of parallelism, programmers need to control how those axes are aligned with hardware resources and their communication topologies. But (nested) pmap doesn’t offer control over how mapped program instances are placed on hardware; there’s just an automatic device order which the user can’t control. (Gopher’s use of axis_index_groups and a single un-nested pmap was essentially a hack to get around this by flattening multiple axes of parallelism down to one.)

  • jit/pjit composability. jit-of-pmap is a performance footgun, as is nesting pmaps, as is e.g. scan-of-pmap, because sharding is not preserved when returning from an inner pmap. To preserve sharding we would need pattern matching on jaxprs to ensure we’re working with perfectly nested pmaps, or a pmap just inside a jit. Moreover, pjit was no help here because pmap targets XLA replicas while pjit targets the XLA SPMD Partitioner, and composing those two is hard.

  • jax.Array compatibility (and hence pjit compatibility). Because the sharding of pmap outputs can’t be expressed as Shardings / OpShardings, due to pmap’s stacking rather than concatenative semantics, the output of a pmap computation can’t currently be passed to a pjit computation without bouncing to host (or dispatching a reshaping computation).

  • Multi-controller semantics (and hence pjit compatibility). Multi-controller pmap concatenates values across controllers, which works well but differs from single-controller pmap’s stacking semantics. More practically, it precludes the use of non-fully-addressable jax.Array inputs and outputs as we use with multi-controller pjit.

  • Eager mode. We didn’t make pmap eager-first, and though we eventually (after 4+ years!) added eager operation with disable_jit(), the fact that pmap has jit fused into it means it has its own compilation and dispatch path (actually two dispatch paths: in Python for handling Tracers, and in C++ for performance on raw Array inputs!), a heavy implementation burden.

  • Reshapes needed in the caller. A typical use case with pmap on 8 devices might look like starting with a batch axis of size 128, reshaping it to split into two axes with sizes (8, 16), and then pmapping over the first. These reshapes are awkward and the compiler often interprets them as copies instead of view — increasing memory and time usage.

These shortcomings aren’t so bad when only doing batch data parallelism. But when more parallelism is involved, pmap just can’t cut it!

xmap paved the way as a next-gen evolution of pmap and solved (almost) all these issues. shmap follows in xmap’s footsteps and solves these problems in essentially the same ways; indeed, shmap is like a specialized subset of xmap (what some call the “hard xmap” subset), with a few tweaks.

For the initial prototype, we chose to implement shmap as a separate primitive from xmap, because limiting the set of features it supports makes it easier to focus on the core functionality. For example, shmap doesn’t allow unmapped intermediates, making it easier not to worry about the interactions between named axes and autodiff. Furthermore, not having to reason about interactions of all pairs of features makes it easier to add capabilities beyond what’s implemented in xmap today, such as support for eager mode.

Both shmap and xmap share significant portions of the lowering code. We could consider merging both in the future, or even focusing solely on shmap, depending on how the usage will evolve.

jax.extend: a module for extensions#

@froystig, @sharadmv, @jakevdp, @yashk2810

May 2023

import jax.extend as jex

Several projects depend on JAX’s codebase internals, often to use its core machinery (e.g. to write a transformation over its IR) or to extend it (e.g. to define new primitives). Two challenges for these dependencies are (a) that our internals aren’t all solidly designed for external use, and (b) that circumventing JAX’s public API is unsupported. In other words, our internals are often used like a library, but are neither structured nor updated like one.

This proposal considers introducing a jax.extend module that defines a library view of some of JAX’s internal components. We would treat this as a second-tier API, still guaranteeing essentially no compatibility policy, but hopefully making it easier to spot changes when they happen.

The audience for jax.extend includes JAX-adjacent Python libraries like Oryx, jax-triton, and many others, as well as projects experimenting with function transformations, autodiff systems, compiler frontends for numerical programming, etc.

This note gives an overview of how jax.extend might look, now and eventually. It doesn’t lay things out in great detail, instead proposing that we begin iteratively developing the module.

Note that jax.extend differs from jax.experimental, which is a staging ground for new features and ideas in progress. Typically, work in jax.experimental eventually makes into another JAX module or is removed altogether.

No compatibility policy#

To keep development overhead low, jax.extend would not follow the public API compatibility policy. It would promise no deprecation windows nor backwards compatibility between releases. Every release may break existing callers without simple recourse (e.g. without a flag reintroducing prior behavior). We would rely on the changelog to call out such changes.

Callers of jax.extend that need to upgrade their code regularly alongside JAX releases might find it useful to pin JAX versions as an intermediate step between releases. This is a common habit among projects that rely on JAX’s internals today. The difference is that it would now come with the help of changelog announcements and better intentions regarding library design and naming.

Iterative development#

Having no compatibility policy makes it easier to get started on implementation: on day one, we can move a handful of symbols over from internal packages such as jax._src and today’s jax.core and jax.interpreters. Then we can iterate to improve things from there.

Possible module overview#

We can imagine that eventually jax.extend would include the following modules:

  • core – primitives, the Jaxpr IR, etc.

  • interpreters – core transformations (e.g. autodiff, batching) and lowerings.

  • random – random bit generation, key splitting and folding, key arrays.

  • sharding – extra functionality around distributed arrays.

We might also have other symbols in the module at first, such as jex.api_util, as we work to remove or replace them. Others will be decided in time. For instance, jex.lib could offer an entry point to jaxlib (and would do so in the immediate term), but it’s not clear whether we want to keep it for long.

Some preliminary thoughts on what each of these might comprise follow.

jax.extend.core#

This should enable callers at least to define new JAX primitives and to process the Jaxpr IR (the output of jax.make_jaxpr(...)). Supporting this might involve providing:

  • Access to existing core system primitives, such as today’s jax._src.lax.add_p.

  • Access to IR types, such as the current jax._src.core.ShapedArray.

  • Functions for checking and pretty-printing jaxprs.

  • Functions for building jaxprs explicitly, rather than by staging Python functions via jax.make_jaxpr (or not!).

At initialization, this module will contain many more symbols than what’s needed to define primitives and rules, including various names used in setting up “final-style transformations”, such as the current jax._src.core.Trace and Tracer classes. We can revisit whether jex.core should also support final-style extensions alongside initial style approaches, and whether it can do so by a more narrow API than exposing Trace and Tracer entirely. Oryx might help guide these decisions.

We can also consider relocating make_jaxpr itself to jex.core.

jax.extend.interpreters#

This module would provide a means of registering various transformation rules for primitives—defining their behavior under AD, batching, lowering, etc.

It would initially reflect jax._src.interpreters in providing the modules ad, batching, partial_eval (for staging Python to Jaxpr, and for linearization in AD), mlir, pxla, and xla. The first three might be replaceable by a single primitive extension API in jex.core. The latter three, used for lowering, could be simplified into one module, maybe.

Today, to write transformation rules, e.g. for AD and batching, callers may need symbols relating to tracers, e.g. JVPTracer and BatchTracer. This may be avoidable later on, and allow us to remove tracer types from jex.

This module plus jex.core ought to suffice for replicating today’s custom primitive tutorials (e.g. ours and dfm’s). For instance, defining a primitive and its behavior under jax.jit would be possible as follows (in the immediate term):

from jax.extend import core	         # Previously: from jax import core
from jax.extend.interpreters import mlir        # ... and similarly

mul_add_p = core.Primitive('mul_add')
mul_add_p.def_impl(lambda x, y, z: x * y + z)

@mul_add_p.def_abstract_eval
def mul_add_abstract(x_sa, y_sa, z_sa):
  return core.ShapedArray(x_sa.shape, x_sa.dtype)

def mul_add_mlir(ctx, xc, yc, zc):
  add = mlir.hlo.AddOp
  mul = mlir.hlo.MulOp
  return add(mul(xc, yc), zc).results

mlir.register_lowering(mul_add_p, mul_add_mlir)

import jax
print(mul_add_p.bind(2, 3, 4))            # -> 10
print(jax.jit(mul_add_p.bind)(2, 3, 4))   # -> Array(10, dtype=int32)
jax.extend.random#

This module could expose our mechanism for defining new RNG implementations, and functions for working with PRNG key internals (see issue #9263), such as the current jax._src.prng.random_wrap and random_unwrap.

It could also expose the keyed hash functions that underlie the built-in RNG implementations, such as jax._src.prng.threefry_2x32.

jax.extend.sharding#

This module could expose low-level utilities for sharding distributed arrays.

We have only one item in mind for now. The XLA compiler’s array sharding format is more expressive than those provided by JAX. We could provide this as jex.sharding.XlaOpShardingProto, corresponding to today’s jax._src.lib.xla_client.OpSharding internally.

Efficient transposition of replication-inducing collectives#

mattjj@, dougalm@

August 2023

Motivation#

We have an efficiency problem in automatically transposing shmaps containing certain collectives. The issue arises with psum and all_gather, specifically when the output of the collective is returned to the caller as an unmapped output. And it’s not an edge case: for example, it arises when applying grad to a shmap-based batch data parallel neural network loss function which uses psum to compute the total loss.

We’ve known about this problem for some time. An analogous issue exists with pmap, though it’s been worked around by keeping grad inside pmap rather than outside. A primary goal of the incomplete avals-with-names work was to address a version of this transpose efficiency problem. This doc draws on those ideas, while extending and revising them to handle more cases and to be much easier to land. Indeed the solution proposed here only affects the shmap implementation. The rest of the system need not be changed (yet).

The main purpose of this doc is to define this transpose efficiency problem and propose an easy-to-land solution.

This doc is not about:

  • logical axis names on arrays (the only axis names here are just like in shmap and OG pmap);

  • changing autodiff semantics (all the numbers and (non)errors are staying the same, we’re just making things more efficient);

  • allowing user code to reflect on any new information, or really affecting user code at all.

Problem: efficient transpose of psum or all_gather depends on whether cotangents are invariant across devices#

Consider this semi-realistic example, meant to resemble a replicated-parameter batch data parallel loss function:

devices = jax.devices()  # 8 devices

@partial(shmap, mesh=Mesh(devices, ('batch',)),
         in_specs=(P(None, None), P('batch', None)),
         out_specs=P())
def loss(params, batch):
  inputs, targets = batch
  predictions = predict(params, inputs)
  local_loss = jnp.mean(jnp.sum(predictions - targets, -1))
  global_loss = lax.pmean(local_loss, 'batch'))
  return global_loss

Notice the out_specs=P(), which indicates an unmapped output. If you’re not familiar with the notion of unmapped outputs, see the appendix at the bottom of this document.

Most of the details in the loss example aren’t important. All that matters for our purposes is that we’re applying psum (or rather pmean = lambda x, name: psum(x, name) / psum(1, name)) at the end. So a distilled version looks like this:

# Example 1: shmap involving psum and unmapped output with inefficient transpose
f1 = shmap(lambda x: psum(g(x), 'i'),
           in_specs=P('i'), out_specs=P())

We even simplified notation by suppressing the mesh argument. In the examples to follow it can be inferred from context.

What does the transpose look like? Writing t to mean function transpose, we could evaluate t(f1)(ybar) for any ybar efficiently by applying the function ¿f1_transpose? below:

# An efficient "transpose" of Example 1 (but don't transpose this again!)
¿f1_transpose? = shmap(t(g), in_specs=P(), out_specs=P('i'))

But that’s not the transpose we currently get as t(f1).

Instead, the current recipe for transposition is roughly that we switch in_specs and out_specs, do some division rescaling for unmapped outputs, and transpose the body. Because psum is its own transpose (as an all-reduce sum), we end up producing this transpose:

# The transpose we currently get for Example 1 (which is fine to transpose again)
t(f1) = shmap(lambda ybar: t(g)(psum(ybar / 8, 'i')),
              in_specs=P(), out_specs=P('i'))

This transpose gets the numbers right, but it’s wasteful. We know statically from the transpose’s in_specs=P() that ybar has the same value for each function instance, i.e. that its value is device-invariant for devices along the mesh axis named i, and yet we apply a psum to it! That uses expensive communication just to multiply the value on each device by 8. (Here 8 refers to the size of axis i. The division by 8 comes from the original function’s out_specs=P(); it and the trivial psum basically cancel each other out.)

What are we doing wrong? We’re not exploiting the fact that cotangents ybar corresponding to f1’s unmapped outputs are guaranteed to be device-invariant; instead, we’re defensively psumming them as if they weren’t because psum’s transpose can’t be sure given the local information it has. Sometimes the psum is necessary, as in transposing f2 with respect to its first argument:

# Example 2: shmap involving psum and *mapped* output with efficient transpose
f2 = shmap(lambda x, y: psum(g(x), 'i') * y,
          in_specs=(P('i'), P('i')), out_specs=P('i'))

# The transpose we currently get for Example 2 is efficient
t(f2, 0) = shmap(lambda y, zbar: t(g)(psum(zbar * y, 'i')),
                in_specs=(P('i'), P('i')), out_specs=P('i'))

Intuitively, if our transpose machinery could tell the difference between Example 1 and Example 2, we could do better by avoiding the psum and division where possible.

The inefficient examples can be even smaller. Consider transposing this cursed identity function:

# Example 3: cursed identity
cursed_identity = shmap(lambda x: x, P(), P())

# Currently we get these inefficient transposes
t(cursed_identity) = shmap(lambda x: psum(x / 8, 'i'), P(), P())
t(t(cursed_identity)) = shmap(lambda x: psum(psum(x / 8 / 8, 'i'), 'i')), P(), P())
...

It keeps getting bigger the more we transpose. How embarrassing!

And psum isn’t the only culprit. Something analogous holds true for all_gather:

# Example 4: all_gather to an unmapped output
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())

# Currently we get this inefficient transpose
t(f4) = shmap(lambda ybar: psum_scatter(ybar / 8, 'i'), P(), P('i'))

This program is a bit artificial. Why do an all_gather and feed the result into an unmapped output, rather than skipping the all_gather in the body and just using out_specs=P('i') to collect the results? But even though it’s cooked-up, this example nevertheless exhibits a transpose which unnecessarily performs communication (we could have just performed a non-communicating slice), analogous to Example 1 for psum.

Also analogously to the psum examples, the defensive psum_scatter is necessary in some cases:

# Example 5: all_gather to a mapped output
f5 = shmap(lambda x, y: all_gather(x, 'i') * y,
           in_specs=(P('i'), P('i')), out_specs=P('i'))

# Currently we get this efficient transpose
t(f5, 0) = shmap(lambda y, zbar: psum_scatter(zbar * y, 'i'),
                 in_specs=(P('i'), P('i')), out_specs=P('i'))

So how do we avoid these inefficient transposes?

Solutions#

Here are two solution ideas. They aren’t mutually exclusive. But (spoilers) the second one is better, and it’s all we need.

Partial solution “P-sum”: build the ability to express a psum into out_specs#

This solution is a bit of a strawperson because it would offer only an awkward way to write programs. And it wouldn’t even fix everything! But it’s worth considering, if only to motivate a more complete solution.

Example 4 above is artificial because we could have just used out_specs instead of an all_gather in the body:

# Example 4 again
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())

# Why didn't we just write it like this?
f4_better = shmap(lambda x: x, P('i'), P('i'))

The f4_better version doesn’t have any transposition problems, since the transpose problems arise from collectives in the body.

Analogously, we could fix Example 1 by extending out_specs so that they can express summing:

# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
           in_specs=P('i'), out_specs=P())

# What if we could write an output sum like this?
f1_better = shmap(g, in_specs=P('i'), out_specs=P(sum='i'))  # sum='i' means sum over that axis

# Then it could transpose like this:
t(f1_better) = shmap(t(g), in_specs=P(), out_specs=P('i'))
t(t(f1_better)) = shmap(t(t(g)), in_specs=P('i'), P(sum='i'))

So offering psums built into out_specs fixes the transpose problem of Example 1. But it doesn’t fully fix the cursed identity transpose in Example 3:

# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())

# How it would transpose with the P-sum partial solution:
t(cursed_identity) = shmap(lambda x: x / 8, P(), P(sum='i'))
t(t(cursed_identity)) = shmap(lambda x: x / 8, P(), P(sum='i'))

It’s an improvement since the program doesn’t continue to get bigger as we keep transposing, but we’re still doing wasteful communication.

Full solution: statically track device-varying vs device-invariant intermediates, plus new primitives#

This solution has two components:

  1. track when values are guaranteed to be device-invariant vs device-varying over particular mesh axes, and

  2. decompose psum into a two-step process, introducing a new pbroadcast primitive, and introduce new primitives for all_gather and its transposes.

Morally, the tracking of device-invariant vs device-varying information is a type-level consideration. But for the expedience of our first implementation, we don’t need to literally add the information to abstract values or jaxpr types. Before we get to implementation, we’ll first introduce the idea using types.

Also to follow is a discussion of making the user API convenient and backward compatible. But to first introduce the idea, we’ll ignore convenience and instead write code that is as explicit as possible.

Tracking device invariance in avals (a.k.a. avals-with-names, revived)#

We can sometimes tell from static information alone that the values of some intermediate variables in the body of a shmap are guaranteed to be invariant along a mesh axis, in the sense that the function instances (and their corresponding devices) along the mesh axis must all be computing with the same value. We’ll call such values device-invariant. For values that are not device-invariant, we’ll say they’re device-varying, though really we mean potentially device-varying from the point of view of the type system.

To encode device variance in types, we’ll extend the syntax of types for arrays. We’ll write things like x:f32[3,4]{i} to indicate that x is (potentially) device-varying along mesh axis i (and device-invariant over any other mesh axes of the shmap). More generally, we’ll say the grammar for array type syntax is something like

shaped_array ::= <dtype>[<int_literal>, ...]<device_variance_type>
device_variance_type ::= {<axis_name>, ...}

We’ll also update the typing rules to handle device variance types:

  • for first-order primitives other than collectives

    • for multi-arity primitives, the operand device variance types must be equal where shapes must be equal, e.g. mul x:f32[s1]{r1} y:f32[s2][r2] requires r1 == r2 in addition to s1 == s2

    • the output device variance type must be the same as the operand(s)

  • for higher-order primitives

    • we just instantiate any type variables including the device variance type (and checking types for equality checks their device variance types are equal)

    • (when performing type inference, e.g. for branches of a cond, we take the union of the sets of axis names in device variance types)

  • for first-order collectives

    • a collective can either accept a device-varying or device-invariant input (along a mesh axis corresponding to its axis name parameter); it’s an error to pass a device-invariant operand to a collective which accepts device-varying operands and vice-versa

    • a collective can either produce a device-varying or device-invariant output

    • see the table below As a side benefit, whatever logic implements this type checking can subsume shmap’s “static analysis” check for whether a shmap body function is compatible with any unmapped out_specs.

Here’s a table summarizing the device variance typing for collective primitives:

Name

Device variance type

Example

Lowers to HLO

Transpose

psum2

Varying -> Invariant

y:f32[3]{j} = psum(x:f32[3]{i,j}, axis='i')

AllReduceSum (communication)

pbroadcast

pbroadcast

Invariant -> Varying

y:f32[3]{i} = pbroadcast(x:f32[3], 'i')

no-op (no communication)

psum

all_to_all

Varying -> Varying

y:f32[16]{i} = all_to_all(x:f32[16]{i}, 'i', 0, 0) AllToAll (communication)

all_to_all

axis_index

() -> Varying

idx:i32[]{i} = axis_index('i')

ReplicaId and some arithmetic (no communication)

n/a

psum_scatter

Varying -> Varying

y:f32[2]{i} = psum_scatter(x:f32[16]{i}, 'i')

ReduceScatterSum (communication)

all_gather

all_gather

Varying -> Varying

y:f32[16]{i} = all_gather(x:f32[2]{i}, 'i')

AllGather (communication)

psum_scatter

pscatter

Invariant -> Varying

y:f32[2]{i} = pscatter(x:f32[16], 'i')

lambda x: x[axis_index('i'), None] (no communication)

all_gather_invariant

all_gather_invariant

Varying -> Invariant

y:f32[16] = all_gather_invariant(x:f32[2]{i}, 'i')

AllGather (communication)

pscatter

There are some surprising things here!

  • We introduced several new primitives, including

    • pbroadcast, which interestingly lowers to a no-op

    • all_gather_invariant, which lowers to the same thing as all_gather but has a different device variance type (essentially all_gather has a pbroadcast fused into it, whereas all_gather_invariant does not)

    • pscatter which is the dual (transpose) of all_gather_invariant

  • all_gather has a device-varying result

Intuitively, the reason to introduce pbroadcast (other than to make the typing rules work) is so that psum can transpose to a physical no-op. The reason we need all_gather to have a device-varying result is so that we can transpose it to psum_scatter; if we instead left it with a device-invariant result, we might need a downstream pbroadcast, and that composition would transpose to an inefficient psum followed by slicing / pscatter. So instead we have a pbroadcast “fused into” the all_gather, thus allowing for an efficient transpose into psum_scatter. We provide all_gather_invariant and its transpose pscatter mainly for completeness; it’s unlikely users will need it (it corresponds to the situation in Example 4, which is easy to write differently using out_specs).

Interestingly, the psum and pbroadcast transpose pair correspond to the psum_idrev and id_psumrev that users introduced while training LLMs with pmap.

How this system solves the inefficient transpose examples#

Consider again the simplified motivating example:

# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
           in_specs=P('i'), out_specs=P())

# Example 1 with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1(x: f32[3,4]{i}):
  w:f32[]{i} = g(x)
  y:f32[]{} = psum(w, 'i')
  return y

With these new rules, the transpose is:

# Example 1 transpose using device variance types (go ahead and transpose this again!)
t(f1) = shmap(lambda ybar: t(g)(pbroadcast(ybar, 'i')),
              in_specs=P(), out_specs=P('i'))

# Example 1 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1_transpose(ybar: f32[]):
  wbar:f32[]{i} = pbroadcast(ybar, 'i')
  xbar:f32[3,4]{i} = transpose(g)(wbar)
  return xbar

where evaluating the pbroadcast application involves no communication or FLOPs at all; it’s a no-op. Notice that if we keep transposing the body does not grow in size; indeed t(t(f1)) == f1. Efficiency achieved!

And we wouldn’t mess up the other examples either, so long as we pbroadcast to make the types check where needed:

# Example 2 rewritten with explicit pbroadcast
f2 = shmap(lambda x, y: pbroadcast(psum(g(x), 'i'), 'i') * y,
           in_specs=(P('i'), P('i')), out_specs=P('i'))

# Example 2 transpose using device variance types
t(f2, 0) = shmap(lambda y, zbar: t(g)(pbroadcast(psum(zbar * y, 'i'), 'i')),
                 in_specs=(P('i'), P('i')), out_specs=P('i'))


# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())
# Notice here the body is `f32[...] -> f32[...]`, i.e. no device varying type.

# Example 3 transpose using device variance types
t(cursed_identity) = shmap(lambda x: x, P(), P())
t(t(cursed_identity)) = shmap(lambda x: x, P(), P())

Intuitively, in Example 1 we now only have “half the original psum”, whereas in Example 2 we get both “halves”. For Example 3 we never need any operations in the body at all.

For the all_gather examples, Example 4 would need to use all_reduce_invariant to have an efficient transpose (though it’d be better to instead use out_specs instead of the collective in the body):

# Example 4 rewritten with explicit all_reduce_invariant
f4 = shmap(lambda x: all_gather_invariant(x, 'i'), P('i'), P())

# Example 4 with intermediate device variance types annotated
@partial(shmap, P('i'), P())
def f4(x:f32[1]{i}):
  y:f32[8]{} = all_gather_invariant(x, 'i')
  return y

# Example 4 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P(), out_specs=P('i'))
def f4_transpose(ybar:f32[8]):
  xbar:f32[1]{i} = pscatter(ybar, 'i')
  return xbar

For Example 5, using the device-varying all_gather works as we’d want:

# Example 5 with intermediate device variance types annotated
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5(x:f32[1]{i}, y:f32[8]{i}):
  z:f32[8]{i} = all_gather(x, 'i')
  w:f32[8]{i} = z * y
  return w

# Transpose with respect to first argument
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5_transpose(y:f32[8]{i}, wbar:f32[8]{i}):
  zbar:f32[8]{i} = wbar * y
  xbar:f32[1]{i} = psum_scatter(zbar, 'i')
  return xbar
How to make the API convenient for users (and backward compatible)#

But what user wants to write pbroadcasts? And what developer wants to break lots of existing user code involving psums which are not fed into unmapped outputs? Not me!

Instead we can automatically insert the pbroadcasts. It’s a bit analogous to how we do automatic rank promotion at the jax.numpy layer, inserting broadcasts to avoid rank mismatch errors in binary operators. But it’s much simpler since we don’t need to contend with shape tuples. The typical rule is: whenever we see a multi-arity operation where the operands disagree in their device variance types, take the union of operands’ device variance types’ axis name sets and insert pbroadcasts to lift each operand to the resulting device variance type.

Automatically inserting pbroadcasts just before they’re needed may mean we apply the same pbroadcast to the same operand multiple times, creating common subexpressions. When we transpose, those could turn into a sum-of-psums rather than a psum-of-sum. We’ll rely on the compiler to clean that up as appropriate. If it’s a problem then we could add some simple memoization to the pbroadcast-insertion pass.

The user API for all_gather will mean all_gather_p by default (not all_gather_invariant_p), covering the common case and meaning no pbroadcasts must be inserted.

We can provide an option on shmap to disable this automatic insertion of pbroadcasts, in which case it’ll be up to the user to ensure type-correctness. This explicit option may be appealing to some who want to be explicit about where the psums occur in the backward pass.

How to implement the solution#

The key to making the implementation lightweight is that we aren’t going to add these types to avals or jaxprs. At least, not at first. That can be expensive because it requires updating the rest of JAX, e.g. all consumers of avals and jaxprs may need to handle the new types. We’re not falling for that again!

Instead we’re going to keep these extended types as metadata internal to shmap, just like the current “replication checking for out_specs” machinery is internal to shmap. Indeed this solution amounts to a relatively small extension to that existing machinery: it was already tracking the same information; now we’re just adding the pbroadcasts.

We have at least two options for where to perform the pbroadcast insertion:

  1. just before transposition, in the transpose rule, where we have a jaxpr of the computation to be transposed;

  2. in every shmap body, whether eagerly executed or staged out, like the current “replication checking for out_specs” machinery. The former may end up being easier since we only have to handle the jaxpr case, and only linear primitives. But we’ll start by trying the latter so the implementation here is a strict revision/extension to the existing replication-checking logic.

Appendix: defining and motivating maps with unmapped inputs and outputs#

For concreteness, we’ll mostly focus on shmap, though these same ideas apply to e.g. pmap and probably xmap.

An argument/input is unmapped along a mesh axis when the corresponding entry of in_specs doesn’t mention that mesh axis’s name. Logically it means that each function instance along that mesh axis gets the same value for the argument. To the caller, each operand is sliced according to the mesh axes over which the operand is mapped, whereas there is no slicing for mesh axes over which the operand is unmapped.

An output is unmapped along a mesh axis when the corresponding entry of out_specs doesn’t mention that mesh axis’s name. Logically it means each function instance along that mesh axis must return the same value. To the caller, each result of the shmap is formed by concatenating the return values of every function instance along which the outputs are mapped, whereas for mesh axes over which the output is unmapped only one copy of the value is used.

See the shmap JEP for examples of unmapped inputs and outputs. For comparison, in vmap unmapped inputs/outputs are indicated by using in_axes / out_axes of None (rather than an int).

Here are reasons we like unmapped inputs and outputs for shmap:

  • Same expressiveness as pjit. Anything pjit can do, the shmap escape hatch should be able to do too. Or else we’d have a lacking escape hatch! If we didn’t have unmapped outputs in shmap then we couldn’t express the same batch-parallel loss function computations as pjit.

  • Closed-over inputs. Closed-over inputs essentially correspond to unmapped inputs, and…

  • Closure under transposition. Once we have unmapped inputs, it’s natural to be able to transpose to unmapped outputs.

So unmapped outputs are both canonical and useful!

JEP 18137: Scope of JAX NumPy & SciPy Wrappers#

Jake VanderPlas

October 2023

Until now, the intended scope of jax.numpy and jax.scipy has been relatively ill-defined. This document proposes a well-defined scope for these packages to better guide and evaluate future contributions, and to motivate the removal of some out-of-scope code.

Background#

From the beginning, JAX has aimed to provide a NumPy-like API for executing code in XLA, and a big part of the project’s development has been building out the jax.numpy and jax.scipy namespaces as JAX-based implementations of NumPy and SciPy APIs. There has always been an implicit understanding that some parts of numpy and scipy are out-of-scope for JAX, but this scope has not been well defined. This can lead to confusion and frustration for contributors, because there’s no clear answer to whether potential jax.numpy and jax.scipy contributions will be accepted into JAX.

Why Limit the Scope?#

To avoid leaving this unsaid, we should be explicit: it is a fact that any code included in a project like JAX incurs a small but nonzero ongoing maintenance burden for the developers. The success of a project over time directly relates to the ability of maintainers to continue this maintenance for the sum of all the project’s parts: documenting functionality, responding to questions, fixing bugs, etc. For long-term success and sustainability of any software tool, it’s vital that maintainers carefully weigh whether any particular contribution will be a net positive for the project given its goals and resources.

Evaluation Rubric#

This document proposes a rubric of six axes along which the scope of any particular numpy or scipy API can be judged for inclusion into JAX. An API which is strong along all axes is an excellent candidate for inclusion in the JAX package; a strong weakness along any of the six axes is a good argument against inclusion in JAX.

Axis 1: XLA alignment#

The first axis we consider is the degree to which the proposed API aligns with native XLA operations. For example, jax.numpy.exp() is a function that more-or-less directly mirrors jax.lax.exp. A large number of functions in numpy, scipy.special, numpy.linalg, scipy.linalg, and others meet this criteria: such functions pass the XLA-alignment check when considering their inclusion into JAX.

On the other end, there are functions like numpy.unique(), which do not directly correspond to any XLA operation, and in some cases are fundamentally incompatible with JAX’s current computational model, which requires statically-shaped arrays (e.g. unique returns a value-dependent dynamic array shape). Such functions do not pass the XLA alignment check when considering their inclusion into JAX.

We also consider as part of this axis the need for pure function semantics. For example, numpy.random is built on an implicitly-updated state-based RNG, which is fundamentally incompatible with JAX’s computational model built on XLA.

Axis 2: Array API Alignment#

The second axis we consider focuses on the Python Array API Standard: this is in some senses a community-driven outline of which array operations are central to array-oriented programming across a wide range of user communities. If an API in numpy or scipy is listed within the Array API standard, it is a strong signal that JAX should include it. Using the example from above, the Array API standard includes several variants of numpy.unique() (unique_all, unique_counts, unique_inverse, unique_values) which suggests that, despite the function not being precisely aligned with XLA, it is important enough to the Python user community that JAX should perhaps implement it.

Axis 3: Existence of Downstream Implementations#

For functionality that does not align with Axis 1 or 2, an important consideration for inclusion into JAX is whether there exist well-supported downstream packages that supply the functionality in question. A good example of this is scipy.optimize: while JAX does include a minimal set of wrappers of scipy.optimize functionality, a much more complete treatment exists in the JAXopt package, which is actively maintained by JAX collaborators. In cases like this, we should lean toward pointing users and contributors to these specialized packages rather than re-implementing such APIs in JAX itself.

Axis 4: Complexity & Robustness of Implementation#

For functionality that does not align with XLA, one consideration is the degree of complexity of the proposed implementation. This aligns to some degree with Axis 1, but nevertheless is important to call out. A number of functions have been contributed to JAX which have relatively complex implementations which are difficult to validate and introduce outsized maintenance burdens; an example is jax.scipy.special.bessel_jn(): as of the writing of this JEP, its current implementation is a non-straightforward iterative approximation that has convergence issues in some domains, and proposed fixes introduce further complexity. Had we more carefully weighed the complexity and robustness of the implementation when accepting the contribution, we may have chosen not to accept this contribution to the package.

Axis 5: Functional vs. Object-Oriented APIs#

JAX works best with functional APIs rather than object-oriented APIs. Object-oriented APIs can often hide impure semantics, making them often difficult to implement well. NumPy and SciPy generally stick to functional APIs, but sometimes provide object-oriented convenience wrappers.

Examples of this are numpy.polynomial.Polynomial, which wraps lower-level operations like numpy.polyadd(), numpy.polydiv(), etc. In general, when there are both functional and object-oriented APIs available, JAX should avoid providing wrappers for the object-oriented APIs and instead provide wrappers for the functional APIs.

In cases where only the object-oriented APIs exist, JAX should avoid providing wrappers unless the case is strong along other axes.

Axis 6: General “Importance” to JAX Users & Stakeholders#

The decision to include a NumPy/SciPy API in JAX should also take into account the importance of the algorithm to the general user community. It is admittedly difficult to quantify who is a “stakeholder” and how this importance should be measured; but we include this to make clear that any decision about what to include in JAX’s NumPy and SciPy wrappers will involve some amount of discretion that cannot be easily quantified.

For existing APIs, searches for usage in github may be useful in establishing importance or lack thereof; as an example, we might return to jax.scipy.special.bessel_jn() discussed above: a search shows that this function has only a handful of uses on github, probably partly to do with the previously mentioned accuracy issues.

Evaluation: what’s in scope?#

In this section, we’ll attempt to evaluate the NumPy and SciPy APIs, including some examples from the current JAX API, in light of the above rubric. This will not be a comprehensive listing of all existing functions and classes, but rather a more general discussion by submodule and topic, with relevant examples.

NumPy APIs#
numpy namespace#

We consider the functions in the main numpy namespace to be essentially all in-scope for JAX, due to its general alignment with XLA (Axis 1) and the Python Array API (Axis 2), as well as its general importance to the JAX user community (Axis 6). Some functions are perhaps borderline (functions like numpy.intersect1d(), np.setdiff1d(), np.union1d() arguably fail parts of the rubric) but for simplicity we declare that all array functions in the main numpy namespace are in-scope for JAX.

numpy.linalg & numpy.fft#

The numpy.linalg and numpy.fft submodules contain many functions that broadly align with functionality provided by XLA. Others have complicated device-specific lowerings, but represent a case where importance to stakeholders (Axis 6) outweighs complexity. For this reason, we deem both of these submodules in-scope for JAX.

numpy.random#

numpy.random is out-of-scope for JAX, because state-based RNGs are fundamentally incompatible with JAX’s computation model. We instead focus on jax.random, which offers similar functionality using a counter-based PRNG.

numpy.ma & numpy.polynomial#

The numpy.ma and numpy.polynomial submodules are mostly concerned with providing object-oriented interfaces to computations that can be expressed via other functional means (Axis 5); for this reason, we deem them out-of-scope for JAX.

numpy.testing#

NumPy’s testing functionality only really makes sense for host-side computation, and so we don’t include any wrappers for it in JAX. That said, JAX arrays are compatible with numpy.testing, and JAX makes frequent use of it throughout the JAX test suite.

SciPy APIs#

SciPy has no functions in the top-level namespace, but includes a number of submodules. We consider each below, leaving out modules which have been deprecated.

scipy.cluster#

The scipy.cluster module includes tools for hierarchical clustering, k-means, and related algorithms. These are weak along several axes, and would be better served by a downstream package. One function already exists within JAX (jax.scipy.cluster.vq.vq()) but has no obvious usage on github: this suggests that clustering is not broadly important to JAX users.

Recommendation: deprecate and remove jax.scipy.cluster.vq().

scipy.constants#

The scipy.constants module includes mathematical and physical constants. These constants can be used directly with JAX, and so there is no reason to re-implement this in JAX.

scipy.datasets#

The scipy.datasets module includes tools to fetch and load datasets. These fetched datasets can be used directly with JAX, and so there is no reason to re-implement this in JAX.

scipy.fft#

The scipy.fft module contains functions that broadly align with functionality provided by XLA, and fare well along other axes as well. For this reason, we deem them in-scope for JAX.

scipy.integrate#

The scipy.integrate module contains functions for numerical integration. The more sophisticated of these (quad, dblquad, ode) are out-of-scope for JAX by axes 1 & 4, since they tend to be loopy algorithms based on dynamic numbers of evaluations. jax.experimental.ode.odeint() is related, but rather limited and not under any active development.

JAX does currently include jax.scipy.integrate.trapezoid(), but this is only because numpy.trapz() was recently deprecated in favor of this. For any particular input, its implementation could be replaced with one line of jax.numpy expressions, so it’s not a particularly useful API to provide.

Based on Axes 1, 2, 4, and 6, scipy.integrate should be considered out-of-scope for JAX.

Recommendation: remove jax.scipy.integrate.trapezoid(), which was added in JAX 0.4.14.

scipy.interpolate#

The scipy.interpolate module provides both low-level and object-oriented routines for interpolating in one or more dimensions. These APIs rate poorly along a number of the axes above: they are class-based rather than low-level, and none but the simplest methods can be expressed efficiently in terms of XLA operations.

JAX does currently have wrappers for scipy.interpolate.RegularGridInterpolator. Were we considering this contribution today, we would probably reject it by the above criteria. But this code has been fairly stable so there is not much downside to continuing to maintain it.

Going forward, we should consider other members of scipy.interpolate to be out-of-scope for JAX.

scipy.io#

The scipy.io submodule has to do with file input/output. There is no reason to re-implement this in JAX.

scipy.linalg#

The scipy.linalg submodule contains functions that broadly align with functionality provided by XLA, and fast linear algebra is broadly important to the JAX user community. For this reason, we deem it in-scope for JAX.

scipy.ndimage#

The scipy.ndimage submodule contains a set of tools for working on image data. Many of these overlap with tools in scipy.signal (e.g. convolutions and filtering). JAX currently provides one scipy.ndimage API, in jax.scipy.ndimage.map_coordinates(). Additionally, JAX provides some image-related tools in the jax.image module. The deepmind ecosystem includes dm-pix, a more full-featured set of tools for image manipulation in JAX. Given all these factors, I’d suggest that scipy.ndimage should be considered out-of-scope for JAX core; we can point interested users and contributors to dm-pix. We can consider moving map_coordinates to dm-pix or to another appropriate package.

scipy.odr#

The scipy.odr module provides an object-oriented wrapper around ODRPACK for performing orthogonal distance regressions. It is not clear that this could be cleanly expressed using existing JAX primitives, and so we deem it out of scope for JAX itself.

scipy.optimize#

The scipy.optimize module provides high-level and low-level interfaces for optimization. Such functionality is important to a lot of JAX users, and very early on JAX created jax.scipy.optimize wrappers. However, developers of these routines soon realized that the scipy.optimize API was too constraining, and different teams began working on the JAXopt package and the Optimistix package, each of which contain a much more comprehensive and better-tested set of optimization routines in JAX.

Because of these well-supported external packages, we now consider scipy.optimize to be out-of-scope for JAX.

Recommendation: deprecate jax.scipy.optimize and/or make it a lightweight wrapper around an optional JAXopt or Optimistix dependency.

🟡 scipy.signal#

The scipy.signal module is mixed: some functions are squarely in-scope for JAX (e.g. correlate and convolve, which are more user-friendly wrappers of lax.conv_general_dilated), while many others are squarely out-of-scope (domain-specific tools with no viable lowering path to XLA). Potential contributions to jax.scipy.signal will have to be weighed on a case-by-case basis.

🟡 scipy.sparse#

The scipy.sparse submodule mainly contains data structures for storing and operating on sparse matrices and arrays in a variety of formats. Additionally, scipy.sparse.linalg contains a number of matrix-free solvers, suitable for use with sparse matrices, dense matrices, and linear operators.

The scipy.sparse array and matrix data structures are out-of-scope for JAX, because they do not align with JAX’s computational model (e.g. many operations depend on dynamically-sized buffers). JAX has developed the jax.experimental.sparse module as an alternative set of data structures that are more in-line with JAX’s computational constraints. For these reasons, we consider the data structures in scipy.sparse to be out-of-scope for JAX.

On the other hand, scipy.sparse.linalg has proven to be an interesting area, and jax.scipy.sparse.linalg includes the bicgstab, cg, and gmres solvers. These are useful to the JAX user community (Axis 6) but aside from this do not fare well along other axes. They would be very suitable for moving into a downstream library; one potential option may be Lineax, which features a number of linear solvers built on JAX.

Recommendation: explore moving sparse solvers into Lineax, and otherwise treat `scipy.sparse`` as out-of-scope for JAX.

scipy.spatial#

The scipy.spatial module contains mainly object-oriented interfaces to spatial/distance computations and nearest neighbor searches. It is mostly out-of-scope for JAX

The scipy.spatial.transform submodule provides tools for manipulating three-dimensional spatial rotations. It is a relatively complicated object-oriented interface, and could perhaps be better served by a downstream project. JAX currently contains partial implementations of Rotation and Slerp within jax.scipy.spatial.transform; these are object-oriented wrappers of otherwise basic functions, which introduce a very large API surface and have very few users. It is our judgment that they are out-of-scope for JAX itself, with users better-served by a hypothetical downstream project.

The scipy.spatial.distance submodule contains a useful collection of distance metrics, and it might be tempting to provide JAX wrappers for these. That said, with jit and vmap it would be straightforward for a user to define efficient versions of most of these from scratch if needed, so adding them to JAX is not particularly beneficial.

Recommendation: consider deprecating and removing the Rotation and Slerp APIs, and consider scipy.spatial as a whole out-of-scope for future contributions.

scipy.special#

The scipy.special module includes implementations of a number of more specialized functions. In many cases, these functions are squarely in scope: for example, functions like gammaln, betainc, digamma, and many others correspond directly to available XLA primitives, and are clearly in-scope by Axis 1 and others.

Other functions require more complicated implementations; one example mentioned above is bessel_jn. Despite not aligning on Axes 1 and 2, these functions tend to be very strong along Axis 6: scipy.special provides fundamental functions necessary for computation in a variety of domains, so even functions with complicated implementations should lean toward in-scope, so long as the implementations are well-designed and robust.

There are a few existing function wrappers that we should take a closer look at; for example:

  • jax.scipy.special.lpmn(): this generates legendre polynomials via a complicated fori_loop, in a way that does not match the scipy API (e.g. for scipy, z must be a scalar, while for JAX, z must be a 1D array). The function has few discoverable uses making it a weak candidate along Axes 1, 2, 4, and 6.

  • jax.scipy.special.lpmn_values(): this has similar weaknesses to lmpn above.

  • jax.scipy.special.sph_harm(): this is built on lpmn, and similarly has an API that diverges from the corresponding scipy function.

  • jax.scipy.special.bessel_jn(): as discussed under Axis 4 above, this has weaknesses in terms of robustness of implementation and little usage. We might consider replacing it with a new, more robust implementation (e.g. #17038).

Recommendation: refactor and improve robustness & test coverage for bessel_jn. Consider deprecating lpmn, lpmn_values, and sph_harm if they cannot be modified to more closely match the scipy APIs.

scipy.stats#

The scipy.stats module contains a wide range of statistical functions, including discrete and continuous distributions, summary statistics, and hypothesis testing. JAX currently wraps a number of these in jax.scipy.stats, primarily including 20 or so statistical distributions, along with a few other functions (mode, rankdata, gaussian_kde). In general these are well-aligned with JAX: distributions usually are expressible in terms of efficient XLA operations, and the APIs are clean and functional.

We don’t currently have any wrappers for hypothesis testing functions, probably because these are less useful to the primary user-base of JAX.

Regarding distributions, in some cases, tensorflow_probability provides similar functionality, and in the future we might consider whether to deprecate the scipy.stats distributions in favor of that implementation.

Recommendation: going forward, we should treat statistical distributions and summary statistics as in-scope, and consider hypothesis tests and related functionality generally out-of-scope.

Several early JEPs were converted in hindsight from other documentation, issues, and pull requests, so they might not exactly reflect the process outlined above.

Investigating a regression#

So you updated JAX and you hit a speed regression? You have a little bit of time and are ready to investigate this? Let’s first make a JAX issue. But if you can pinpoint the commit that triggered the regression, it will really help us.

This document explains how we identified the commit that caused a 15% performance regression.

Steps#

This can be done easily if the reproducer is quick enough. This is a brute force method and not a bisection, but if the reproducer is quick enough, it works well. This makes sure that you always test XLA and JAX commits that are compatible. It also limits XLA recompilation.

Here is a suggested investigation strategy:

  1. You can do a brute force test of nightly containers between the 2 releases.

  2. Hourly recompilation while keeping XLA and JAX in sync.

  3. Final verification: maybe a manual check of a few commits (or a git bisect).

Nightly investigation.#

This can be done by using JAX-Toolbox nightly containers.

  • Some days, bugs prevent the container from being built, or there are temporary regressions. Just discard those days.

  • So you should end up with a specific day or a few days where the regression happens.

  • To automate this, you need 2 python scripts:

    • test_runner.sh: will start the containers and the test.

    • test.sh: will install missing dependencies and run the test

Here are real example scripts used for the issue: https://github.com/google/jax/issues/17686

  • test_runner.sh:

  for m in 7 8 9; do
    for d in `seq -w 1 30`; do
      docker run -v $PWD:/dir --gpus=all ghcr.io/nvidia/jax:nightly-2023-0${m}-${d} /bin/bash /dir/test.sh &> OUT-0${m}-${d}
    done
  Done
  • test.sh:

  pip install jmp pyvista numpy matplotlib Rtree trimesh jmp termcolor orbax
  git clone https://github.com/Autodesk/XLB
  cd XLB
  export PYTHONPATH=.
  export CUDA_VISIBLE_DEVICES=0 # only 1 GPU is needed

  python3 examples/performance/MLUPS3d.py 256 200

Then you can grep each output to see when the regression happens: grep MLUPS OUT*. Here are the results we got:

OUT-07-06:MLUPS: 587.9240990200157
OUT-07-07:MLUPS: 587.8907972116419
OUT-07-08:MLUPS: 587.3186499464459
OUT-07-09:MLUPS: 587.3130127722537
OUT-07-10:MLUPS: 587.8526619429658
OUT-07-17:MLUPS: 570.1631097290182
OUT-07-18:MLUPS: 570.2819775617064
OUT-07-19:MLUPS: 570.1672213357352
OUT-07-20:MLUPS: 587.437153685251
OUT-07-21:MLUPS: 587.6702557143142
OUT-07-25:MLUPS: 577.3063618431178
OUT-07-26:MLUPS: 577.2362978080912
OUT-07-27:MLUPS: 577.2101850145785
OUT-07-28:MLUPS: 577.0716349809895
OUT-07-29:MLUPS: 577.4223280707176
OUT-07-30:MLUPS: 577.2255967221336
OUT-08-01:MLUPS: 577.277685388252
OUT-08-02:MLUPS: 577.0137874289354
OUT-08-03:MLUPS: 577.1333281553946
OUT-08-04:MLUPS: 577.305012020407
OUT-08-05:MLUPS: 577.2143988866626
OUT-08-06:MLUPS: 577.2409145495443
OUT-08-07:MLUPS: 577.2602819927345
OUT-08-08:MLUPS: 577.2823738293221
OUT-08-09:MLUPS: 577.3453199728248
OUT-08-11:MLUPS: 577.3161423260563
OUT-08-12:MLUPS: 577.1697775786824
OUT-08-13:MLUPS: 577.3049883393633
OUT-08-14:MLUPS: 576.9051978525331
OUT-08-15:MLUPS: 577.5331743016213
OUT-08-16:MLUPS: 577.5117505070573
OUT-08-18:MLUPS: 577.5930698237612
OUT-08-19:MLUPS: 577.3539885757353
OUT-08-20:MLUPS: 577.4190113959127
OUT-08-21:MLUPS: 577.300394253605
OUT-08-22:MLUPS: 577.4263792037783
OUT-08-23:MLUPS: 577.4087536357031
OUT-08-24:MLUPS: 577.1094728438082
OUT-08-25:  File "/XLB/examples/performance/MLUPS3d.py", line 5, in <module>
OUT-08-26:MLUPS: 537.0164618489928
OUT-08-27:MLUPS: 536.9545448661609
OUT-08-28:MLUPS: 536.2887650464874
OUT-08-29:MLUPS: 536.7178471720636
OUT-08-30:MLUPS: 536.6978912984252
OUT-09-01:MLUPS: 536.7030899164106
OUT-09-04:MLUPS: 536.5339818238837
OUT-09-05:MLUPS: 536.6507808565617
OUT-09-06:MLUPS: 536.7144494518315
OUT-09-08:MLUPS: 536.7376612408998
OUT-09-09:MLUPS: 536.7798324141778
OUT-09-10:MLUPS: 536.726157440174
OUT-09-11:MLUPS: 536.7446210750584
OUT-09-12:MLUPS: 536.6707332269023
OUT-09-13:MLUPS: 536.6777936517823
OUT-09-14:MLUPS: 536.7581523280307
OUT-09-15:MLUPS: 536.6156273667873
OUT-09-16:MLUPS: 536.7320935035265
OUT-09-17:MLUPS: 536.7104991444398
OUT-09-18:MLUPS: 536.7492269469092
OUT-09-19:MLUPS: 536.6760131792959
OUT-09-20:MLUPS: 536.7361260076634

This found that 8-24 was good, but 8-26 was bad. On 8-25 there was another issue that prevented from getting results. So we need to investigate hourly between 8-24 and 8-26. There was a smaller slowdown earlier, lets ignore it for this example. It would be only another hourly investigation between those dates.

Hourly investigation.#

This does a checkout of JAX and XLA at each hour between the 2 dates, rebuilds everything and runs the test. The scripts are structured differently. We start the working container and keep it. Then inside it, we only trigger incremental XLA builds except for the first build. So it is much faster after the first iteration.

  • test_runner2.sh:

  # Execute this script inside the container:
  # docker run -v $PWD:/dir --gpus=all ghcr.io/nvidia/jax:nightly-2023-08-24 /bin/bash
  cd /opt/xla-source
  git remote update
  cd /opt/jax-source
  git remote update
  pip install jmp pyvista numpy matplotlib Rtree trimesh jmp termcolor orbax
  cd /tmp
  git clone https://github.com/Autodesk/XLB
  cd XLB

  for d in `seq -w 24 26`; do
      for h in `seq -w 0 24`; do
          echo $m $d $h
          /bin/bash /dir/test2.sh Aug $d 2023 $h:00:00 &> OUT-08-${d}-$h
      done
  done
  • test2.sh:

  echo "param: $@"
  cd /opt/xla-source
  git checkout `git rev-list -1 --before="$*" origin/main`
  git show -q
  cd /opt/jax-source
  git checkout `git rev-list -1 --before="$*" origin/main`
  git show -q

  rm /opt/jax-source/dist/jax*.whl
  build-jax.sh # The script is in the nightly container

  export PYTHONPATH=.
  export CUDA_VISIBLE_DEVICES=0 # only 1 GPU is needed

  python3 examples/performance/MLUPS3d.py 256 200

Now, you can execute the grep command on the new output files to see which hours the issue appeared between.

Final verification#

With this, you need to check the JAX and XLA history between those hours. Maybe there are a few commits to test. You can use git bisect if you want to be fancy.

Can this be improved?#

Yes! If it was a crash regression, being able to do a bisect would be useful. But it would be more complicated. If someone want to contribute such instructions, please submit a PR ;)

For speed regressions, a bisect can hide some information. We wouldn’t see as easily that there were two regressions here.

Building on JAX#

A great way to learn advanced JAX usage is to see how other libraries are using JAX, both how they integrate the library into their API, what functionality it adds mathematically, and how it’s used for computational speedup in other libraries.

Below are examples of how JAX’s features can be used to define accelerated computation across numerous domains and software packages.

Gradient Computation#

Easy gradient calculation is a key feature of JAX. In the JaxOpt library value and grad is directly utilized for users in multiple optimization algorithms in its source code.

Similarly the same Dynamax Optax pairing mentioned above is an example of gradients enabling estimation methods that were challenging historically Maximum Likelihood Expectation using Optax.

Computational Speedup on a Single Core across Multiple Devices#

Models defined in JAX can then be compiled to enable single computation speedup through JIT compiling. The same compiled code can then be sent to a CPU device, to a GPU or TPU device for additional speedup, typically with no additional changes needed. This allows for a smooth workflow from development into production. In Dynamax the computationally expensive portion of a Linear State Space Model solver has been jitted. A more complex example comes from PyTensor which compiles a JAX function dynamically and then jits the constructed function.

Single and Multi Computer Speedup Using Parallelization#

Another benefit of JAX is the simplicity of parallelizing computation using pmap and vmap function calls or decorators. In Dynamax state space models are parallelized with a VMAP decorator a practical example of this use case being multi object tracking.

Incorporating JAX code into your, or your users, workflows#

JAX is quite composable and can be used in multiple ways. JAX can be used with a standalone pattern, where the user defines all the calculations themselves. However other patterns, such as using libraries built on jax that provide specific functionality. These can be libraries that define specific types of models, such as Neural Networks or State Space models or others, or provide specific functionality such as optimization. Here are more specific examples of each pattern.

Direct Usage#

Jax can be directly imported and utilized to build models “from scratch” as shown across this website, for example in JAX Tutorials or Neural Network with JAX. This may be the best option if you are unable to find prebuilt code for your particular challenge, or if you’re looking to reduce the number of dependencies in your codebase.

Composable Domain Specific Libraries with JAX exposed#

Another common approach are packages that provide prebuilt functionality, whether it be model definition, or computation of some type. Combinations of these packages can then be mixed and matched for a full end to end workflow where a model is defined and its parameters are estimated.

One example is Flax which simplifies the construction of Neural Networks. Flax is then typically paired with Optax where Flax defines the neural network architecture and Optax supplies the optimization & model-fitting capabilities.

Another is Dynamax which allows easy definition of state space models. With Dynamax parameters can be estimated using Maximum Likelihood using Optax or full Bayesian Posterior can be estimating using MCMC from Blackjax

JAX Totally Hidden from Users#

Other libraries opt to completely wrap JAX in their model specific API. An example is PyMC and Pytensor, in which a user may never “see” JAX directly but instead wrapping JAX functions with a PyMC specific API.

Notes#

This section contains shorter notes on topics relevant to using JAX; see also the longer design discussions in JAX Enhancement Proposals (JEPs).

Dependencies and version compatibility:
Migrations and deprecations:
Memory and computation usage:
Programmer guardrails:

API compatibility#

JAX is constantly evolving, and we want to be able to make improvements to its APIs. That said, we want to minimize churn for the JAX user community, and we try to make breaking changes rarely.

JAX follows a 3 month deprecation policy. When an incompatible change is made to an API, we will make our best effort to obey the following procedure:

  • the change will be announced in CHANGELOG.md and in the doc string for the deprecated API, and the old API will issue a DeprecationWarning.

  • three months after the jax release that deprecated an API, we may remove the deprecated API at any time. Note that three months is a lower bound, and is intentionally chosen to be faster than that of many more mature projects. In practice, deprecations may take considerably longer, particularly if there are many users of a feature. If a three month deprecation period becomes problematic, please raise this with us.

We reserve the right to change this policy at any time.

What is covered?#

Only public JAX APIs are covered, which includes the following modules:

  • jax

  • jax.dlpack

  • jax.image

  • jax.lax

  • jax.nn

  • jax.numpy

  • jax.ops

  • jax.profiler

  • jax.random (see details below)

  • jax.scipy

  • jax.tree_util

  • jax.test_util

Not everything in these modules is public. Over time, we are working to separate public and private APIs. Public APIs are documented in the JAX documentation. Additionally, our goal is that all non-public APIs should have names prefixed with underscores, although we do not entirely comply with this yet.

What is not covered?#

  • anything prefixed with an underscore.

  • jax._src

  • jax.core

  • jax.linear_util

  • jax.lib

  • jax.prng

  • jax.interpreters

  • jax.experimental

  • jax.example_libraries

  • jax.extend (see details)

This list is not exhaustive.

Numerics and randomness#

The exact values of numerical operations are not guaranteed to be stable across JAX releases. In fact, exact numerics are not necessarily stable at a given JAX version, across accelerator platforms, within or without jax.jit, and more.

For a fixed PRNG key input, the outputs of pseudorandom functions in jax.random may vary across JAX versions. The compatibility policy applies only to the output distribution. For example, the expression jax.random.gumbel(jax.random.key(72)) may return a different value across JAX releases, but jax.random.gumbel will remain a pseudorandom generator for the Gumbel distribution.

We try to make such changes to pseudorandom values infrequently. When they happen, the changes are announced in the changelog, but do not follow a deprecation cycle. In some situations, JAX might expose a transient configuration flag that reverts the new behavior, to help users diagnose and update affected code. Such flags will last a deprecation window’s amount of time.

Python and NumPy version support policy#

For NumPy and SciPy version support, JAX follows the Python scientific community’s SPEC 0.

For Python version support, we have heard from users that a 36-month support window can be too short, for example due to the delays in propagation of new CPython releases to Linux vendor releases. For this reason JAX supports Python versions for at least nine months longer than SPEC-0 recommends.

This means we support at least:

  • All minor Python releases in the 45 months prior to each JAX release. For example:

    • Python 3.9 was released October 2020, and will be supported in new JAX releases at least until July 2024.

    • Python 3.10 was released October 2021, and will be supported in new JAX releases at least until July 2025.

    • Python 3.11 was released October 2022, and will be supported in new JAX releases at least until July 2026.

  • All minor NumPy releases in the 24 months prior to each JAX release. For example:

    • NumPy 1.22 was released December 2021, and will be supported in new JAX releases at least until December 2023.

    • NumPy 1.23 was released June 2022, and will be supported in new JAX releases at least until June 2024.

    • NumPy 1.24 was released December 2022, and will be supported in new JAX releases at least until December 2024.

  • All minor SciPy releases in the 24 months prior to each JAX release, starting with SciPy version 1.9. For example:

    • Scipy 1.9 was released July 2022, and will be supported in new JAX releases at least until July 2024.

    • Scipy 1.10 was released January 2023, and will be supported in new JAX releases at least until January 2025.

    • Scipy 1.11 was released June 2023, and will be supported in new JAX releases at least until June 2025.

JAX releases may support older versions of Python, NumPy, and SciPy than strictly required by this policy, but support for older versions may be dropped at any time beyond the listed dates.

jax.Array migration#

yashkatariya@

TL;DR#

JAX switched its default array implementation to the new jax.Array as of version 0.4.1. This guide explains the reasoning behind this, the impact it might have on your code, and how to (temporarily) switch back to the old behavior.

What’s going on?#

jax.Array is a unified array type that subsumes DeviceArray, ShardedDeviceArray, and GlobalDeviceArray types in JAX. The jax.Array type helps make parallelism a core feature of JAX, simplifies and unifies JAX internals, and allows us to unify jit and pjit. If your code doesn’t mention DeviceArray vs ShardedDeviceArray vs GlobalDeviceArray, no changes are needed. But code that depends on details of these separate classes may need to be tweaked to work with the unified jax.Array

After the migration is complete jax.Array will be the only type of array in JAX.

This doc explains how to migrate existing codebases to jax.Array. For more information on using jax.Array and JAX parallelism APIs, see the Distributed arrays and automatic parallelization tutorial.

How to enable jax.Array?#

You can enable jax.Array by:

  • setting the shell environment variable JAX_ARRAY to something true-like (e.g., 1);

  • setting the boolean flag jax_array to something true-like if your code parses flags with absl;

  • using this statement at the top of your main file:

    import jax
    jax.config.update('jax_array', True)
    
How do I know if jax.Array broke my code?#

The easiest way to tell if jax.Array is responsible for any problems is to disable jax.Array and see if the issues go away.

How can I disable jax.Array for now?#

Through March 15, 2023 it will be possible to disable jax.Array by:

  • setting the shell environment variable JAX_ARRAY to something falsey (e.g., 0);

  • setting the boolean flag jax_array to something falsey if your code parses flags with absl;

  • using this statement at the top of your main file:

    import jax
    jax.config.update('jax_array', False)
    

Why create jax.Array?#

Currently JAX has three types; DeviceArray, ShardedDeviceArray and GlobalDeviceArray. jax.Array merges these three types and cleans up JAX’s internals while adding new parallelism features.

We also introduce a new Sharding abstraction that describes how a logical Array is physically sharded out across one or more devices, such as TPUs or GPUs. The change also upgrades, simplifies and merges the parallelism features of pjit into jit. Functions decorated with jit will be able to operate over sharded arrays without copying data onto a single device.

Features you get with jax.Array:

  • C++ pjit dispatch path

  • Op-by-op parallelism (even if the array distributed across multiple devices across multiple hosts)

  • Simpler batch data parallelism with pjit/jit.

  • Ways to create Shardings that are not necessarily consisting of a mesh and partition spec. Can fully utilize the flexibility of OpSharding if you want or any other Sharding that you want.

  • and many more

Example:

import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
import numpy as np
x = jnp.arange(8)

# Let's say there are 8 devices in jax.devices()
mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x'))

sharded_x = jax.device_put(x, sharding)

# `matmul_sharded_x` and `sin_sharded_x` are sharded. `jit` is able to operate over a
# sharded array without copying data to a single device.
matmul_sharded_x = sharded_x @ sharded_x.T
sin_sharded_x = jnp.sin(sharded_x)

# Even jnp.copy preserves the sharding on the output.
copy_sharded_x = jnp.copy(sharded_x)

# double_out is also sharded
double_out = jax.jit(lambda x: x * 2)(sharded_x)

What issues can arise when jax.Array is switched on?#

New public type named jax.Array#

All isinstance(..., jnp.DeviceArray) or isinstance(.., jax.xla.DeviceArray) and other variants of DeviceArray should be switched to using isinstance(..., jax.Array).

Since jax.Array can represent DA, SDA and GDA, you can differentiate those 3 types in jax.Array via:

  • x.is_fully_addressable and len(x.sharding.device_set) == 1 – this means that jax.Array is like a DA

  • x.is_fully_addressable and (len(x.sharding.device_set) > 1 – this means that jax.Array is like a SDA

  • not x.is_fully_addressable – this means that jax.Array is like a GDA and spans across multiple processes

For ShardedDeviceArray, you can move isinstance(..., pxla.ShardedDeviceArray) to isinstance(..., jax.Array) and x.is_fully_addressable and len(x.sharding.device_set) > 1.

In general it is not possible to differentiate a ShardedDeviceArray on 1 device from any other kind of single-device Array.

GDA’s API name changes#

GDA’s local_shards and local_data have been deprecated.

Please use addressable_shards and addressable_data which are compatible with jax.Array and GDA.

Creating jax.Array#

All JAX functions will output jax.Array when the jax_array flag is True. If you were using GlobalDeviceArray.from_callback or make_sharded_device_array or make_device_array functions to explicitly create the respective JAX data types, you will need to switch them to use jax.make_array_from_callback() or jax.make_array_from_single_device_arrays().

For GDA:

GlobalDeviceArray.from_callback(shape, mesh, pspec, callback) can become jax.make_array_from_callback(shape, jax.sharding.NamedSharding(mesh, pspec), callback) in a 1:1 switch.

If you were using the raw GDA constructor to create GDAs, then do this:

GlobalDeviceArray(shape, mesh, pspec, buffers) can become jax.make_array_from_single_device_arrays(shape, jax.sharding.NamedSharding(mesh, pspec), buffers)

For SDA:

make_sharded_device_array(aval, sharding_spec, device_buffers, indices) can become jax.make_array_from_single_device_arrays(shape, sharding, device_buffers).

To decide what the sharding should be, it depends on why you were creating the SDAs:

If it was created to give as an input to pmap, then sharding can be: jax.sharding.PmapSharding(devices, sharding_spec).

If it was created to give as an input to pjit, then sharding can be jax.sharding.NamedSharding(mesh, pspec).

Breaking change for pjit after switching to jax.Array for host local inputs#

If you are exclusively using GDA arguments to pjit, you can skip this section! 🎉

With jax.Array enabled, all inputs to pjit must be globally shaped. This is a breaking change from the previous behavior where pjit would concatenate process-local arguments into a global value; this concatenation no longer occurs.

Why are we making this breaking change? Each array now says explicitly how its local shards fit into a global whole, rather than leaving it implicit. The more explicit representation also unlocks additional flexibility, for example the use of non-contiguous meshes with pjit which can improve efficiency on some TPU models.

Running multi-process pjit computation and passing host-local inputs when jax.Array is enabled can lead to an error similar to this:

Example:

Mesh = {'x': 2, 'y': 2, 'z': 2} and host local input shape == (4,) and pspec = P(('x', 'y', 'z'))

Since pjit doesn’t lift host local shapes to global shapes with jax.Array, you get the following error:

Note: You will only see this error if your host local shape is smaller than the shape of the mesh.

ValueError: One of pjit arguments was given the sharding of
NamedSharding(mesh={'x': 2, 'y': 2, 'chips': 2}, partition_spec=PartitionSpec(('x', 'y', 'chips'),)),
which implies that the global size of its dimension 0 should be divisible by 8,
but it is equal to 4

The error makes sense because you can’t shard dimension 0, 8 ways when the value on dimension 0 is 4.

How can you migrate if you still pass host local inputs to pjit? We are providing transitional APIs to help you migrate:

Note: You don’t need these utilities if you run your pjitted computation on a single process.

from jax.experimental import multihost_utils

global_inps = multihost_utils.host_local_array_to_global_array(
    local_inputs, mesh, in_pspecs)

global_outputs = pjit(f, in_shardings=in_pspecs,
                      out_shardings=out_pspecs)(global_inps)

local_outs = multihost_utils.global_array_to_host_local_array(
    global_outputs, mesh, out_pspecs)

host_local_array_to_global_array is a type cast that looks at a value with only local shards and changes its local shape to the shape that pjit would have previously assumed if that value was passed before the change.

Passing in fully replicated inputs i.e. same shape on each process with P(None) as in_axis_resources is still supported. In this case you do not have to use host_local_array_to_global_array because the shape is already global.

key = jax.random.PRNGKey(1)

# As you can see, using host_local_array_to_global_array is not required since in_axis_resources says
# that the input is fully replicated via P(None)
pjit(f, in_shardings=None, out_shardings=None)(key)

# Mixing inputs
global_inp = multihost_utils.host_local_array_to_global_array(
    local_inp, mesh, P('data'))
global_out = pjit(f, in_shardings=(P(None), P('data')),
                  out_shardings=...)(key, global_inp)
FROM_GDA and jax.Array#

If you were using FROM_GDA in in_axis_resources argument to pjit, then with jax.Array there is no need to pass anything to in_axis_resources as jax.Array will follow computation follows sharding semantics.

For example:

pjit(f, in_shardings=FROM_GDA, out_shardings=...) can be replaced by pjit(f, out_shardings=...)

If you have PartitionSpecs mixed in with FROM_GDA for inputs like numpy arrays, etc, then use host_local_array_to_global_array to convert them to jax.Array.

For example:

If you had this:

pjitted_f = pjit(
    f, in_shardings=(FROM_GDA, P('x'), FROM_GDA, P(None)),
    out_shardings=...)
pjitted_f(gda1, np_array1, gda2, np_array2)

then you can replace it with:


pjitted_f = pjit(f, out_shardings=...)

array2, array3 = multihost_utils.host_local_array_to_global_array(
    (np_array1, np_array2), mesh, (P('x'), P(None)))

pjitted_f(array1, array2, array3, array4)
live_buffers replaced with live_arrays#

live_buffers attribute on jax Device has been deprecated. Please use jax.live_arrays() instead which is compatible with jax.Array.

Handling of host local inputs to pjit like batch, etc#

If you are passing host local inputs to pjit in a multi-process environment, then please use multihost_utils.host_local_array_to_global_array to convert the batch to a global jax.Array and then pass that to pjit.

The most common example of such a host local input is a batch of input data.

This will work for any host local input (not just a batch of input data).

from jax.experimental import multihost_utils

batch = multihost_utils.host_local_array_to_global_array(
    batch, mesh, batch_partition_spec)

See the pjit section above for more details about this change and more examples.

RecursionError: Recursively calling jit#

This happens when some part of your code has jax.Array disabled and then you enable it only for some other part. For example, if you use some third_party code which has jax.Array disabled and you get a DeviceArray from that library and then you enable jax.Array in your library and pass that DeviceArray to JAX functions, it will lead to a RecursionError.

This error should go away when jax.Array is enabled by default so that all libraries return jax.Array unless they explicitly disable it.

Asynchronous dispatch#

JAX uses asynchronous dispatch to hide Python overheads. Consider the following program:

>>> import numpy as np
>>> import jax.numpy as jnp
>>> from jax import random
>>> x = random.uniform(random.key(0), (1000, 1000))
>>> # Printing the result (i.e. evaluating `repr(result)` or `str(result)`)
>>> # will block until the value is ready.
>>> jnp.dot(x, x) + 3.  
Array([[258.01971436, 249.64862061, 257.13372803, ...,
        236.67948914, 250.68939209, 241.36853027],
        [265.65979004, 256.28912354, 262.18252563, ...,
        242.03181458, 256.16757202, 252.44122314],
        [262.38916016, 255.72747803, 261.23059082, ...,
        240.83563232, 255.41094971, 249.62471008],
        ...,
        [259.15814209, 253.09197998, 257.72174072, ...,
        242.23876953, 250.72680664, 247.16642761],
        [271.22662354, 261.91204834, 265.33398438, ...,
        248.26651001, 262.05389404, 261.33700562],
        [257.16134644, 254.7543335, 259.08300781, ..., 241.59848022,
        248.62597656, 243.22348022]], dtype=float32)

When an operation such as jnp.dot(x, x) is executed, JAX does not wait for the operation to complete before returning control to the Python program. Instead, JAX returns a jax.Array value, which is a future, i.e., a value that will be produced in the future on an accelerator device but isn’t necessarily available immediately. We can inspect the shape or type of a jax.Array without waiting for the computation that produced it to complete, and we can even pass it to another JAX computation, as we do with the addition operation here. Only if we actually inspect the value of the array from the host, for example by printing it or by converting it into a plain old numpy.ndarray will JAX force the Python code to wait for the computation to complete.

Asynchronous dispatch is useful since it allows Python code to “run ahead” of an accelerator device, keeping Python code out of the critical path. Provided the Python code enqueues work on the device faster than it can be executed, and provided that the Python code does not actually need to inspect the output of a computation on the host, then a Python program can enqueue arbitrary amounts of work and avoid having the accelerator wait.

Asynchronous dispatch has a slightly surprising consequence for microbenchmarks.

>>> %time jnp.dot(x, x)  
CPU times: user 267 µs, sys: 93 µs, total: 360 µs
Wall time: 269 µs
Array([[255.01972961, 246.64862061, 254.13371277, ...,
        233.67948914, 247.68939209, 238.36853027],
        [262.65979004, 253.28910828, 259.18252563, ...,
        239.03181458, 253.16757202, 249.44122314],
        [259.38916016, 252.72747803, 258.23059082, ...,
        237.83563232, 252.41094971, 246.62471008],
        ...,
        [256.15814209, 250.09197998, 254.72172546, ...,
        239.23876953, 247.72680664, 244.16642761],
        [268.22662354, 258.91204834, 262.33398438, ...,
        245.26651001, 259.05389404, 258.33700562],
        [254.16134644, 251.7543335, 256.08300781, ..., 238.59848022,
        245.62597656, 240.22348022]], dtype=float32)

269µs is a surprisingly small time for a 1000x1000 matrix multiplication on CPU! However it turns out that asynchronous dispatch is misleading us and we are not timing the execution of the matrix multiplication, only the time to dispatch the work. To measure the true cost of the operation we must either read the value on the host (e.g., convert it to a plain old host-side numpy array), or use the block_until_ready() method on a jax.Array value to wait for the computation that produced it to complete.

>>> %time np.asarray(jnp.dot(x, x))  
CPU times: user 61.1 ms, sys: 0 ns, total: 61.1 ms
Wall time: 8.09 ms
Out[16]:
array([[255.01973, 246.64862, 254.13371, ..., 233.67949, 247.68939,
        238.36853],
       [262.6598 , 253.28911, 259.18253, ..., 239.03181, 253.16757,
        249.44122],
       [259.38916, 252.72748, 258.2306 , ..., 237.83563, 252.41095,
        246.62471],
       ...,
       [256.15814, 250.09198, 254.72173, ..., 239.23877, 247.7268 ,
        244.16643],
       [268.22662, 258.91205, 262.33398, ..., 245.26651, 259.0539 ,
        258.337  ],
       [254.16135, 251.75433, 256.083  , ..., 238.59848, 245.62598,
        240.22348]], dtype=float32)
>>> %time jnp.dot(x, x).block_until_ready()  
CPU times: user 50.3 ms, sys: 928 µs, total: 51.2 ms
Wall time: 4.92 ms
Array([[255.01972961, 246.64862061, 254.13371277, ...,
        233.67948914, 247.68939209, 238.36853027],
        [262.65979004, 253.28910828, 259.18252563, ...,
        239.03181458, 253.16757202, 249.44122314],
        [259.38916016, 252.72747803, 258.23059082, ...,
        237.83563232, 252.41094971, 246.62471008],
        ...,
        [256.15814209, 250.09197998, 254.72172546, ...,
        239.23876953, 247.72680664, 244.16642761],
        [268.22662354, 258.91204834, 262.33398438, ...,
        245.26651001, 259.05389404, 258.33700562],
        [254.16134644, 251.7543335, 256.08300781, ..., 238.59848022,
        245.62597656, 240.22348022]], dtype=float32)

Blocking without transferring the result back to Python is usually faster, and is often the best choice when writing microbenchmarks of computation times.

Concurrency#

JAX has limited support for Python concurrency.

Clients may call JAX APIs (e.g., jit() or grad()) concurrently from separate Python threads.

It is not permitted to manipulate JAX trace values concurrently from multiple threads. In other words, while it is permissible to call functions that use JAX tracing (e.g., jit()) from multiple threads, you must not use threading to manipulate JAX values inside the implementation of the function f that is passed to jit(). The most likely outcome if you do this is a mysterious error from JAX.

GPU memory allocation#

JAX will preallocate 75% of the total GPU memory when the first JAX operation is run. Preallocating minimizes allocation overhead and memory fragmentation, but can sometimes cause out-of-memory (OOM) errors. If your JAX process fails with OOM, the following environment variables can be used to override the default behavior:

XLA_PYTHON_CLIENT_PREALLOCATE=false

This disables the preallocation behavior. JAX will instead allocate GPU memory as needed, potentially decreasing the overall memory usage. However, this behavior is more prone to GPU memory fragmentation, meaning a JAX program that uses most of the available GPU memory may OOM with preallocation disabled.

XLA_PYTHON_CLIENT_MEM_FRACTION=.XX

If preallocation is enabled, this makes JAX preallocate XX% of the total GPU memory, instead of the default 75%. Lowering the amount preallocated can fix OOMs that occur when the JAX program starts.

XLA_PYTHON_CLIENT_ALLOCATOR=platform

This makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed (note that this is the only configuration that will deallocate GPU memory, instead of reusing it). This is very slow, so is not recommended for general use, but may be useful for running with the minimal possible GPU memory footprint or debugging OOM failures.

Common causes of OOM failures#

Running multiple JAX processes concurrently.

Either use XLA_PYTHON_CLIENT_MEM_FRACTION to give each process an appropriate amount of memory, or set XLA_PYTHON_CLIENT_PREALLOCATE=false.

Running JAX and GPU TensorFlow concurrently.

TensorFlow also preallocates by default, so this is similar to running multiple JAX processes concurrently.

One solution is to use CPU-only TensorFlow (e.g. if you’re only doing data loading with TF). You can prevent TensorFlow from using the GPU with the command tf.config.experimental.set_visible_devices([], "GPU")

Alternatively, use XLA_PYTHON_CLIENT_MEM_FRACTION or XLA_PYTHON_CLIENT_PREALLOCATE. There are also similar options to configure TensorFlow’s GPU memory allocation (gpu_memory_fraction and allow_growth in TF1, which should be set in a tf.ConfigProto passed to tf.Session. See Using GPUs: Limiting GPU memory growth for TF2).

Running JAX on the display GPU.

Use XLA_PYTHON_CLIENT_MEM_FRACTION or XLA_PYTHON_CLIENT_PREALLOCATE.

Rank promotion warning#

NumPy broadcasting rules allow the automatic promotion of arguments from one rank (number of array axes) to another. This behavior can be convenient when intended but can also lead to surprising bugs where a silent rank promotion masks an underlying shape error.

Here’s an example of rank promotion:

>>> import numpy as np
>>> x = np.arange(12).reshape(4, 3)
>>> y = np.array([0, 1, 0])
>>> x + y
array([[ 0,  2,  2],
       [ 3,  5,  5],
       [ 6,  8,  8],
       [ 9, 11, 11]])

To avoid potential surprises, jax.numpy is configurable so that expressions requiring rank promotion can lead to a warning, error, or can be allowed just like regular NumPy. The configuration option is named jax_numpy_rank_promotion and it can take on string values allow, warn, and raise. The default setting is allow, which allows rank promotion without warning or error. The raise setting raises an error on rank promotion, and warn raises a warning on the first occurrence of rank promotion.

Rank promotion can be enabled or disabled locally with the jax.numpy_rank_promotion() context manager:

with jax.numpy_rank_promotion("warn"):
  z = x + y

This configuration can also be set globally in several ways. One is by using jax.config in your code:

import jax
jax.config.update("jax_numpy_rank_promotion", "warn")

You can also set the option using the environment variable JAX_NUMPY_RANK_PROMOTION, for example as JAX_NUMPY_RANK_PROMOTION='warn'. Finally, when using absl-py the option can be set with a command-line flag.

Public API: jax package#

Subpackages#

jax.numpy module#

Implements the NumPy API, using the primitives in jax.lax.

While JAX tries to follow the NumPy API as closely as possible, sometimes JAX cannot follow NumPy exactly.

  • Notably, since JAX arrays are immutable, NumPy APIs that mutate arrays in-place cannot be implemented in JAX. However, often JAX is able to provide an alternative API that is purely functional. For example, instead of in-place array updates (x[i] = y), JAX provides an alternative pure indexed update function x.at[i].set(y) (see ndarray.at).

  • Relatedly, some NumPy functions often return views of arrays when possible (examples are transpose() and reshape()). JAX versions of such functions will return copies instead, although such are often optimized away by XLA when sequences of operations are compiled using jax.jit().

  • NumPy is very aggressive at promoting values to float64 type. JAX sometimes is less aggressive about type promotion (See Type promotion semantics).

  • Some NumPy routines have data-dependent output shapes (examples include unique() and nonzero()). Because the XLA compiler requires array shapes to be known at compile time, such operations are not compatible with JIT. For this reason, JAX adds an optional size argument to such functions which may be specified statically in order to use them with JIT.

Nearly all applicable NumPy functions are implemented in the jax.numpy namespace; they are listed below.

ndarray.at

Helper property for index update functionality.

abs(x, /)

Calculate the absolute value element-wise.

absolute(x, /)

Calculate the absolute value element-wise.

acos(x, /)

Trigonometric inverse cosine, element-wise.

acosh(x, /)

Inverse hyperbolic cosine, element-wise.

add(x1, x2, /)

Add arguments element-wise.

all(a[, axis, out, keepdims, where])

Test whether all array elements along a given axis evaluate to True.

allclose(a, b[, rtol, atol, equal_nan])

Returns True if two arrays are element-wise equal within a tolerance.

amax(a[, axis, out, keepdims, initial, where])

Return the maximum of an array or maximum along an axis.

amin(a[, axis, out, keepdims, initial, where])

Return the minimum of an array or minimum along an axis.

angle(z[, deg])

Return the angle of the complex argument.

any(a[, axis, out, keepdims, where])

Test whether any array element along a given axis evaluates to True.

append(arr, values[, axis])

Append values to the end of an array.

apply_along_axis(func1d, axis, arr, *args, ...)

Apply a function to 1-D slices along the given axis.

apply_over_axes(func, a, axes)

Apply a function repeatedly over multiple axes.

arange(start[, stop, step, dtype])

Return evenly spaced values within a given interval.

arccos(x, /)

Trigonometric inverse cosine, element-wise.

arccosh(x, /)

Inverse hyperbolic cosine, element-wise.

arcsin(x, /)

Inverse sine, element-wise.

arcsinh(x, /)

Inverse hyperbolic sine element-wise.

arctan(x, /)

Trigonometric inverse tangent, element-wise.

arctan2(x1, x2, /)

Element-wise arc tangent of x1/x2 choosing the quadrant correctly.

arctanh(x, /)

Inverse hyperbolic tangent element-wise.

argmax(a[, axis, out, keepdims])

Returns the indices of the maximum values along an axis.

argmin(a[, axis, out, keepdims])

Returns the indices of the minimum values along an axis.

argpartition(a, kth[, axis])

Perform an indirect partition along the given axis using the

argsort(a[, axis, kind, order, stable, ...])

Returns the indices that would sort an array.

argwhere(a, *[, size, fill_value])

Find the indices of nonzero array elements

around(a[, decimals, out])

Round an array to the given number of decimals.

array(object[, dtype, copy, order, ndmin])

Create an array.

array_equal(a1, a2[, equal_nan])

True if two arrays have the same shape and elements, False otherwise.

array_equiv(a1, a2)

Returns True if input arrays are shape consistent and all elements equal.

array_repr(arr[, max_line_width, precision, ...])

Return the string representation of an array.

array_split(ary, indices_or_sections[, axis])

Split an array into multiple sub-arrays.

array_str(a[, max_line_width, precision, ...])

Return a string representation of the data in an array.

asarray(a[, dtype, order, copy])

Convert the input to an array.

asin(x, /)

Inverse sine, element-wise.

asinh(x, /)

Inverse hyperbolic sine element-wise.

astype(x, dtype, /, *[, copy, device])

This is implemented via jax.lax.convert_element_type(), which may have slightly different behavior than numpy.astype() in some cases.

atan(x, /)

Trigonometric inverse tangent, element-wise.

atanh(x, /)

Inverse hyperbolic tangent element-wise.

atan2(x1, x2, /)

Element-wise arc tangent of x1/x2 choosing the quadrant correctly.

atleast_1d()

Convert inputs to arrays with at least one dimension.

atleast_2d()

View inputs as arrays with at least two dimensions.

atleast_3d()

View inputs as arrays with at least three dimensions.

average()

Compute the weighted average along the specified axis.

bartlett(M)

Return the Bartlett window.

bincount(x[, weights, minlength, length])

Count number of occurrences of each value in array of non-negative ints.

bitwise_and(x1, x2, /)

Compute the bit-wise AND of two arrays element-wise.

bitwise_count(x, /)

bitwise_invert(x, /)

Compute bit-wise inversion, or bit-wise NOT, element-wise.

bitwise_left_shift(x1, x2, /)

Shift the bits of an integer to the left.

bitwise_not(x, /)

Compute bit-wise inversion, or bit-wise NOT, element-wise.

bitwise_or(x1, x2, /)

Compute the bit-wise OR of two arrays element-wise.

bitwise_right_shift(x1, x2, /)

Shift the bits of an integer to the right.

bitwise_xor(x1, x2, /)

Compute the bit-wise XOR of two arrays element-wise.

blackman(M)

Return the Blackman window.

block(arrays)

Assemble an nd-array from nested lists of blocks.

bool_(x)

broadcast_arrays(*args)

Broadcast any number of arrays against each other.

broadcast_shapes()

Broadcast the input shapes into a single shape.

broadcast_to(array, shape)

Broadcast an array to a new shape.

c_

Concatenate slices, scalars and array-like objects along the last axis.

can_cast(from_, to[, casting])

Returns True if cast between data types can occur according to the casting rule.

cbrt(x, /)

Return the cube-root of an array, element-wise.

cdouble

alias of complex128

ceil(x, /)

Return the ceiling of the input, element-wise.

character()

Abstract base class of all character string scalar types.

choose(a, choices[, out, mode])

Construct an array from an index array and a list of arrays to choose from.

clip([x, min, max, a, a_min, a_max])

Clip (limit) the values in an array.

column_stack(tup)

Stack 1-D arrays as columns into a 2-D array.

complex_

alias of complex128

complex128(x)

complex64(x)

complexfloating()

Abstract base class of all complex number scalar types that are made up of floating-point numbers.

ComplexWarning

The warning raised when casting a complex dtype to a real dtype.

compress(condition, a[, axis, size, ...])

Compress an array along a given axis using a boolean condition.

concat(arrays, /, *[, axis])

concatenate(arrays[, axis, dtype])

Join a sequence of arrays along an existing axis.

conj(x, /)

Return the complex conjugate, element-wise.

conjugate(x, /)

Return the complex conjugate, element-wise.

convolve(a, v[, mode, precision, ...])

Returns the discrete, linear convolution of two one-dimensional sequences.

copy(a[, order])

Return an array copy of the given object.

copysign(x1, x2, /)

Change the sign of x1 to that of x2, element-wise.

corrcoef(x[, y, rowvar])

Return Pearson product-moment correlation coefficients.

correlate(a, v[, mode, precision, ...])

Cross-correlation of two 1-dimensional sequences.

cos(x, /)

Cosine element-wise.

cosh(x, /)

Hyperbolic cosine, element-wise.

count_nonzero(a[, axis, keepdims])

Counts the number of non-zero values in the array a.

cov(m[, y, rowvar, bias, ddof, fweights, ...])

Estimate a covariance matrix, given data and weights.

cross(a, b[, axisa, axisb, axisc, axis])

Return the cross product of two (arrays of) vectors.

csingle

alias of complex64

cumprod(a[, axis, dtype, out])

Return the cumulative product of elements along a given axis.

cumsum(a[, axis, dtype, out])

Return the cumulative sum of the elements along a given axis.

cumulative_sum(x, /, *[, axis, dtype, ...])

deg2rad(x, /)

Convert angles from degrees to radians.

degrees(x, /)

Convert angles from radians to degrees.

delete(arr, obj[, axis, assume_unique_indices])

Delete entry or entries from an array.

diag(v[, k])

Extract a diagonal or construct a diagonal array.

diag_indices(n[, ndim])

Return the indices to access the main diagonal of an array.

diag_indices_from(arr)

Return the indices to access the main diagonal of an n-dimensional array.

diagflat(v[, k])

Create a two-dimensional array with the flattened input as a diagonal.

diagonal(a[, offset, axis1, axis2])

Return specified diagonals.

diff(a[, n, axis, prepend, append])

Calculate the n-th discrete difference along the given axis.

digitize(x, bins[, right])

Return the indices of the bins to which each value in input array belongs.

divide(x1, x2, /)

Divide arguments element-wise.

divmod(x1, x2, /)

Return element-wise quotient and remainder simultaneously.

dot(a, b, *[, precision, preferred_element_type])

Compute the dot product of two arrays.

double

alias of float64

dsplit(ary, indices_or_sections)

Split array into multiple sub-arrays along the 3rd axis (depth).

dstack(tup[, dtype])

Stack arrays in sequence depth wise (along third axis).

dtype(dtype[, align, copy])

Create a data type object.

ediff1d(ary[, to_end, to_begin])

The differences between consecutive elements of an array.

einsum()

Evaluates the Einstein summation convention on the operands.

einsum_path(subscripts, *operands[, optimize])

Evaluates the lowest cost contraction order for an einsum expression by

empty(shape[, dtype, device])

Return a new array of given shape and type, without initializing entries.

empty_like(prototype[, dtype, shape, device])

Return a new array with the same shape and type as a given array.

equal(x1, x2, /)

Return (x1 == x2) element-wise.

exp(x, /)

Calculate the exponential of all elements in the input array.

exp2(x, /)

Calculate 2**p for all p in the input array.

expand_dims(a, axis)

Expand the shape of an array.

expm1(x, /)

Calculate exp(x) - 1 for all elements in the array.

extract(condition, arr, *[, size, fill_value])

Return the elements of an array that satisfy a condition.

eye(N[, M, k, dtype])

Return a 2-D array with ones on the diagonal and zeros elsewhere.

fabs(x, /)

Compute the absolute values element-wise.

fill_diagonal(a, val[, wrap, inplace])

Fill the main diagonal of the given array of any dimensionality.

finfo(dtype)

Machine limits for floating point types.

fix(x[, out])

Round to nearest integer towards zero.

flatnonzero(a, *[, size, fill_value])

Return indices of nonzero elements in a flattened array

flexible()

Abstract base class of all scalar types without predefined length.

flip(m[, axis])

Reverse the order of elements in an array along the given axis.

fliplr(m)

Reverse the order of elements along axis 1 (left/right).

flipud(m)

Reverse the order of elements along axis 0 (up/down).

float_

alias of float64

float_power(x1, x2, /)

First array elements raised to powers from second array, element-wise.

float16(x)

float32(x)

float64(x)

floating()

Abstract base class of all floating-point scalar types.

floor(x, /)

Return the floor of the input, element-wise.

floor_divide(x1, x2, /)

Return the largest integer smaller or equal to the division of the inputs.

fmax(x1, x2)

Element-wise maximum of array elements.

fmin(x1, x2)

Element-wise minimum of array elements.

fmod(x1, x2, /)

Returns the element-wise remainder of division.

frexp(x, /)

Decompose the elements of x into mantissa and twos exponent.

frombuffer(buffer[, dtype, count, offset])

Interpret a buffer as a 1-dimensional array.

fromfile(*args, **kwargs)

Unimplemented JAX wrapper for jnp.fromfile.

fromfunction(function, shape, *[, dtype])

Construct an array by executing a function over each coordinate.

fromiter(*args, **kwargs)

Unimplemented JAX wrapper for jnp.fromiter.

frompyfunc(func, /, nin, nout, *[, identity])

Create a JAX ufunc from an arbitrary JAX-compatible scalar function.

fromstring(string[, dtype, count])

A new 1-D array initialized from text data in a string.

from_dlpack(x, /, *[, device, copy])

Create a NumPy array from an object implementing the __dlpack__

full(shape, fill_value[, dtype, device])

Return a new array of given shape and type, filled with fill_value.

full_like(a, fill_value[, dtype, shape, device])

Return a full array with the same shape and type as a given array.

gcd(x1, x2)

Returns the greatest common divisor of |x1| and |x2|

generic()

Base class for numpy scalar types.

geomspace(start, stop[, num, endpoint, ...])

Return numbers spaced evenly on a log scale (a geometric progression).

get_printoptions()

Return the current print options.

gradient(f, *varargs[, axis, edge_order])

Return the gradient of an N-dimensional array.

greater(x1, x2, /)

Return the truth value of (x1 > x2) element-wise.

greater_equal(x1, x2, /)

Return the truth value of (x1 >= x2) element-wise.

hamming(M)

Return the Hamming window.

hanning(M)

Return the Hanning window.

heaviside(x1, x2, /)

Compute the Heaviside step function.

histogram(a[, bins, range, weights, density])

Compute the histogram of a dataset.

histogram_bin_edges(a[, bins, range, weights])

Function to calculate only the edges of the bins used by the histogram

histogram2d(x, y[, bins, range, weights, ...])

Compute the bi-dimensional histogram of two data samples.

histogramdd(sample[, bins, range, weights, ...])

Compute the multidimensional histogram of some data.

hsplit(ary, indices_or_sections)

Split an array into multiple sub-arrays horizontally (column-wise).

hstack(tup[, dtype])

Stack arrays in sequence horizontally (column wise).

hypot(x1, x2, /)

Given the "legs" of a right triangle, return its hypotenuse.

i0

Modified Bessel function of the first kind, order 0.

identity(n[, dtype])

Return the identity array.

iinfo(int_type)

imag(val, /)

Return the imaginary part of the complex argument.

index_exp

A nicer way to build up index tuples for arrays.

indices()

Return an array representing the indices of a grid.

inexact()

Abstract base class of all numeric scalar types with a (potentially) inexact representation of the values in its range, such as floating-point numbers.

inner(a, b, *[, precision, ...])

Compute the inner product of two arrays.

insert(arr, obj, values[, axis])

Insert values along the given axis before the given indices.

int_

alias of int64

int16(x)

int32(x)

int64(x)

int8(x)

integer()

Abstract base class of all integer scalar types.

interp(x, xp, fp[, left, right, period])

One-dimensional linear interpolation for monotonically increasing sample points.

intersect1d(ar1, ar2[, assume_unique, ...])

Find the intersection of two arrays.

invert(x, /)

Compute bit-wise inversion, or bit-wise NOT, element-wise.

isclose(a, b[, rtol, atol, equal_nan])

Returns a boolean array where two arrays are element-wise equal within a

iscomplex(x)

Returns a bool array, where True if input element is complex.

iscomplexobj(x)

Check for a complex type or an array of complex numbers.

isdtype(dtype, kind)

Returns a boolean indicating whether a provided dtype is of a specified kind.

isfinite(x, /)

Test element-wise for finiteness (not infinity and not Not a Number).

isin(element, test_elements[, ...])

Calculates element in test_elements, broadcasting over element only.

isinf(x, /)

Test element-wise for positive or negative infinity.

isnan(x, /)

Test element-wise for NaN and return result as a boolean array.

isneginf(x, /[, out])

Test element-wise for negative infinity, return result as bool array.

isposinf(x, /[, out])

Test element-wise for positive infinity, return result as bool array.

isreal(x)

Returns a bool array, where True if input element is real.

isrealobj(x)

Return True if x is a not complex type or an array of complex numbers.

isscalar(element)

Returns True if the type of element is a scalar type.

issubdtype(arg1, arg2)

Returns True if first argument is a typecode lower/equal in type hierarchy.

iterable(y)

Check whether or not an object can be iterated over.

ix_(*args)

Return a multi-dimensional grid (open mesh) from N one-dimensional sequences.

kaiser(M, beta)

Return the Kaiser window.

kron(a, b)

Kronecker product of two arrays.

lcm(x1, x2)

Returns the lowest common multiple of |x1| and |x2|

ldexp(x1, x2, /)

Returns x1 * 2**x2, element-wise.

left_shift(x1, x2, /)

Shift the bits of an integer to the left.

less(x1, x2, /)

Return the truth value of (x1 < x2) element-wise.

less_equal(x1, x2, /)

Return the truth value of (x1 <= x2) element-wise.

lexsort(keys[, axis])

Perform an indirect stable sort using a sequence of keys.

linspace()

Return evenly spaced numbers over a specified interval.

load(*args, **kwargs)

Load arrays or pickled objects from .npy, .npz or pickled files.

log(x, /)

Natural logarithm, element-wise.

log10(x, /)

Return the base 10 logarithm of the input array, element-wise.

log1p(x, /)

Return the natural logarithm of one plus the input array, element-wise.

log2(x, /)

Base-2 logarithm of x.

logaddexp

Logarithm of the sum of exponentiations of the inputs.

logaddexp2

Logarithm of the sum of exponentiations of the inputs in base-2.

logical_and(*args)

Compute the truth value of x1 AND x2 element-wise.

logical_not(*args)

Compute the truth value of NOT x element-wise.

logical_or(*args)

Compute the truth value of x1 OR x2 element-wise.

logical_xor(*args)

Compute the truth value of x1 XOR x2, element-wise.

logspace(start, stop[, num, endpoint, base, ...])

Return numbers spaced evenly on a log scale.

mask_indices(*args, **kwargs)

Return the indices to access (n, n) arrays, given a masking function.

matmul(a, b, *[, precision, ...])

Perform a matrix multiplication.

matrix_transpose(x, /)

Transpose the last two dimensions of an array.

max(a[, axis, out, keepdims, initial, where])

Return the maximum of an array or maximum along an axis.

maximum(x1, x2, /)

Element-wise maximum of array elements.

mean(a[, axis, dtype, out, keepdims, where])

Compute the arithmetic mean along the specified axis.

median(a[, axis, out, overwrite_input, keepdims])

Compute the median along the specified axis.

meshgrid(*xi[, copy, sparse, indexing])

Return a list of coordinate matrices from coordinate vectors.

mgrid

Return dense multi-dimensional "meshgrid".

min(a[, axis, out, keepdims, initial, where])

Return the minimum of an array or minimum along an axis.

minimum(x1, x2, /)

Element-wise minimum of array elements.

mod(x1, x2, /)

Returns the element-wise remainder of division.

modf(x, /[, out])

Return the fractional and integral parts of an array, element-wise.

moveaxis(a, source, destination)

Move axes of an array to new positions.

multiply(x1, x2, /)

Multiply arguments element-wise.

nan_to_num(x[, copy, nan, posinf, neginf])

Replace NaN with zero and infinity with large finite numbers (default

nanargmax(a[, axis, out, keepdims])

Return the indices of the maximum values in the specified axis ignoring

nanargmin(a[, axis, out, keepdims])

Return the indices of the minimum values in the specified axis ignoring

nancumprod(a[, axis, dtype, out])

Return the cumulative product of array elements over a given axis treating Not a

nancumsum(a[, axis, dtype, out])

Return the cumulative sum of array elements over a given axis treating Not a

nanmax(a[, axis, out, keepdims, initial, where])

Return the maximum of an array or maximum along an axis, ignoring any

nanmean(a[, axis, dtype, out, keepdims, where])

Compute the arithmetic mean along the specified axis, ignoring NaNs.

nanmedian(a[, axis, out, overwrite_input, ...])

Compute the median along the specified axis, while ignoring NaNs.

nanmin(a[, axis, out, keepdims, initial, where])

Return minimum of an array or minimum along an axis, ignoring any NaNs.

nanpercentile(a, q[, axis, out, ...])

Compute the qth percentile of the data along the specified axis,

nanprod(a[, axis, dtype, out, keepdims, ...])

Return the product of array elements over a given axis treating Not a

nanquantile(a, q[, axis, out, ...])

Compute the qth quantile of the data along the specified axis,

nanstd(a[, axis, dtype, out, ddof, ...])

Compute the standard deviation along the specified axis, while

nansum(a[, axis, dtype, out, keepdims, ...])

Return the sum of array elements over a given axis treating Not a

nanvar(a[, axis, dtype, out, ddof, ...])

Compute the variance along the specified axis, while ignoring NaNs.

ndarray

alias of Array

ndim(a)

Return the number of dimensions of an array.

negative(x, /)

Numerical negative, element-wise.

nextafter(x1, x2, /)

Return the next floating-point value after x1 towards x2, element-wise.

nonzero(a, *[, size, fill_value])

Return indices of nonzero elements of an array.

not_equal(x1, x2, /)

Return (x1 != x2) element-wise.

number()

Abstract base class of all numeric scalar types.

object_

Any Python object.

ogrid

Return open multi-dimensional "meshgrid".

ones(shape[, dtype, device])

Return a new array of given shape and type, filled with ones.

ones_like(a[, dtype, shape, device])

Return an array of ones with the same shape and type as a given array.

outer(a, b[, out])

Compute the outer product of two vectors.

packbits(a[, axis, bitorder])

Packs the elements of a binary-valued array into bits in a uint8 array.

pad(array, pad_width[, mode])

Pad an array.

partition(a, kth[, axis])

Return a partitioned copy of an array.

percentile(a, q[, axis, out, ...])

Compute the q-th percentile of the data along the specified axis.

permute_dims(a, /, axes)

piecewise(x, condlist, funclist, *args, **kw)

Evaluate a piecewise-defined function.

place(arr, mask, vals, *[, inplace])

Change elements of an array based on conditional and input values.

poly(seq_of_zeros)

Find the coefficients of a polynomial with the given sequence of roots.

polyadd(a1, a2)

Find the sum of two polynomials.

polyder(p[, m])

Return the derivative of the specified order of a polynomial.

polydiv(u, v, *[, trim_leading_zeros])

Returns the quotient and remainder of polynomial division.

polyfit(x, y, deg[, rcond, full, w, cov])

Least squares polynomial fit.

polyint(p[, m, k])

Return an antiderivative (indefinite integral) of a polynomial.

polymul(a1, a2, *[, trim_leading_zeros])

Find the product of two polynomials.

polysub(a1, a2)

Difference (subtraction) of two polynomials.

polyval(p, x, *[, unroll])

Evaluate a polynomial at specific values.

positive(x, /)

Numerical positive, element-wise.

pow(x1, x2, /)

First array elements raised to powers from second array, element-wise.

power(x1, x2, /)

First array elements raised to powers from second array, element-wise.

printoptions(*args, **kwargs)

Context manager for setting print options.

prod(a[, axis, dtype, out, keepdims, ...])

Return the product of array elements over a given axis.

promote_types(a, b)

Returns the type to which a binary operation should cast its arguments.

ptp(a[, axis, out, keepdims])

Range of values (maximum - minimum) along an axis.

put(a, ind, v[, mode, inplace])

Replaces specified elements of an array with given values.

quantile(a, q[, axis, out, overwrite_input, ...])

Compute the q-th quantile of the data along the specified axis.

r_

Concatenate slices, scalars and array-like objects along the first axis.

rad2deg(x, /)

Convert angles from radians to degrees.

radians(x, /)

Convert angles from degrees to radians.

ravel(a[, order])

Flatten array into a 1-dimensional shape.

ravel_multi_index(multi_index, dims[, mode, ...])

Converts a tuple of index arrays into an array of flat

real(val, /)

Return the real part of the complex argument.

reciprocal(x, /)

Return the reciprocal of the argument, element-wise.

remainder(x1, x2, /)

Returns the element-wise remainder of division.

repeat(a, repeats[, axis, total_repeat_length])

Repeat each element of an array after themselves

reshape(a, newshape[, order])

Return a reshaped copy of an array.

resize(a, new_shape)

Return a new array with the specified shape.

result_type(*args)

Returns the type that results from applying the NumPy

right_shift(x1, x2, /)

Shift the bits of an integer to the right.

rint(x, /)

Round elements of the array to the nearest integer.

roll(a, shift[, axis])

Roll array elements along a given axis.

rollaxis(a, axis[, start])

Roll the specified axis backwards, until it lies in a given position.

roots(p, *[, strip_zeros])

Return the roots of a polynomial with coefficients given in p.

rot90(m[, k, axes])

Rotate an array by 90 degrees in the plane specified by axes.

round(a[, decimals, out])

Round an array to the given number of decimals.

round_(a[, decimals, out])

Round an array to the given number of decimals.

s_

A nicer way to build up index tuples for arrays.

save(file, arr[, allow_pickle, fix_imports])

Save an array to a binary file in NumPy .npy format.

savez(file, *args, **kwds)

Save several arrays into a single file in uncompressed .npz format.

searchsorted(a, v[, side, sorter, method])

Find indices where elements should be inserted to maintain order.

select(condlist, choicelist[, default])

Return an array drawn from elements in choicelist, depending on conditions.

set_printoptions([precision, threshold, ...])

Set printing options.

setdiff1d(ar1, ar2[, assume_unique, size, ...])

Find the set difference of two arrays.

setxor1d(ar1, ar2[, assume_unique])

Find the set exclusive-or of two arrays.

shape(a)

Return the shape of an array.

sign(x, /)

Returns an element-wise indication of the sign of a number.

signbit(x, /)

Returns element-wise True where signbit is set (less than zero).

signedinteger()

Abstract base class of all signed integer scalar types.

sin(x, /)

Trigonometric sine, element-wise.

sinc(x, /)

Return the normalized sinc function.

single

alias of float32

sinh(x, /)

Hyperbolic sine, element-wise.

size(a[, axis])

Return the number of elements along a given axis.

sort(a[, axis, kind, order, stable, descending])

Return a sorted copy of an array.

sort_complex(a)

Sort a complex array using the real part first, then the imaginary part.

split(ary, indices_or_sections[, axis])

Split an array into multiple sub-arrays as views into ary.

sqrt(x, /)

Return the non-negative square-root of an array, element-wise.

square(x, /)

Return the element-wise square of the input.

squeeze(a[, axis])

Remove axes of length one from a.

stack(arrays[, axis, out, dtype])

Join a sequence of arrays along a new axis.

std(a[, axis, dtype, out, ddof, keepdims, where])

Compute the standard deviation along the specified axis.

subtract(x1, x2, /)

Subtract arguments, element-wise.

sum(a[, axis, dtype, out, keepdims, ...])

Sum of array elements over a given axis.

swapaxes(a, axis1, axis2)

Interchange two axes of an array.

take(a, indices[, axis, out, mode, ...])

Take elements from an array along an axis.

take_along_axis(arr, indices, axis[, mode])

Take values from the input array by matching 1d index and data slices.

tan(x, /)

Compute tangent element-wise.

tanh(x, /)

Compute hyperbolic tangent element-wise.

tensordot(a, b[, axes, precision, ...])

Compute the tensor dot product of two N-dimensional arrays.

tile(A, reps)

Construct an array by repeating A the number of times given by reps.

trace(a[, offset, axis1, axis2, dtype, out])

Return the sum along diagonals of the array.

trapezoid(y[, x, dx, axis])

Integrate along the given axis using the composite trapezoidal rule.

transpose(a[, axes])

Return a transposed version of an N-dimensional array.

tri(N[, M, k, dtype])

An array with ones at and below the given diagonal and zeros elsewhere.

tril(m[, k])

Lower triangle of an array.

tril_indices(n[, k, m])

Return the indices for the lower-triangle of an (n, m) array.

tril_indices_from(arr[, k])

Return the indices for the lower-triangle of arr.

trim_zeros(filt[, trim])

Trim the leading and/or trailing zeros from a 1-D array or sequence.

triu(m[, k])

Upper triangle of an array.

triu_indices(n[, k, m])

Return the indices for the upper-triangle of an (n, m) array.

triu_indices_from(arr[, k])

Return the indices for the upper-triangle of arr.

true_divide(x1, x2, /)

Divide arguments element-wise.

trunc(x)

Return the truncated value of the input, element-wise.

ufunc(func, /, nin, nout, *[, name, nargs, ...])

Functions that operate element-by-element on whole arrays.

uint

alias of uint64

uint16(x)

uint32(x)

uint64(x)

uint8(x)

union1d(ar1, ar2, *[, size, fill_value])

Find the union of two arrays.

unique(ar[, return_index, return_inverse, ...])

Find the unique elements of an array.

unique_all(x, /)

unique_counts(x, /)

unique_inverse(x, /)

unique_values(x, /)

unpackbits(a[, axis, count, bitorder])

Unpacks elements of a uint8 array into a binary-valued output array.

unravel_index(indices, shape)

Converts a flat index or array of flat indices into a tuple

unstack(x, /, *[, axis])

unsignedinteger()

Abstract base class of all unsigned integer scalar types.

unwrap(p[, discont, axis, period])

Unwrap by taking the complement of large deltas with respect to the period.

vander(x[, N, increasing])

Generate a Vandermonde matrix.

var(a[, axis, dtype, out, ddof, keepdims, where])

Compute the variance along the specified axis.

vdot(a, b, *[, precision, ...])

Perform a conjugate multiplication of two 1D vectors.

vecdot(x1, x2, /, *[, axis, precision, ...])

Perform a conjugate multiplication of two batched vectors.

vectorize(pyfunc, *[, excluded, signature])

Define a vectorized function with broadcasting.

vsplit(ary, indices_or_sections)

Split an array into multiple sub-arrays vertically (row-wise).

vstack(tup[, dtype])

Stack arrays in sequence vertically (row wise).

where()

Select elements from two arrays based on a condition.

zeros(shape[, dtype, device])

Return a new array of given shape and type, filled with zeros.

zeros_like(a[, dtype, shape, device])

Return an array of zeros with the same shape and type as a given array.

jax.numpy.fft#

fft(a[, n, axis, norm])

Compute the one-dimensional discrete Fourier Transform.

fft2(a[, s, axes, norm])

Compute the 2-dimensional discrete Fourier Transform.

fftfreq(n[, d, dtype])

Return the Discrete Fourier Transform sample frequencies.

fftn(a[, s, axes, norm])

Compute the N-dimensional discrete Fourier Transform.

fftshift(x[, axes])

Shift the zero-frequency component to the center of the spectrum.

hfft(a[, n, axis, norm])

Compute the FFT of a signal that has Hermitian symmetry, i.e., a real

ifft(a[, n, axis, norm])

Compute the one-dimensional inverse discrete Fourier Transform.

ifft2(a[, s, axes, norm])

Compute the 2-dimensional inverse discrete Fourier Transform.

ifftn(a[, s, axes, norm])

Compute the N-dimensional inverse discrete Fourier Transform.

ifftshift(x[, axes])

The inverse of fftshift.

ihfft(a[, n, axis, norm])

Compute the inverse FFT of a signal that has Hermitian symmetry.

irfft(a[, n, axis, norm])

Computes the inverse of rfft.

irfft2(a[, s, axes, norm])

Computes the inverse of rfft2.

irfftn(a[, s, axes, norm])

Computes the inverse of rfftn.

rfft(a[, n, axis, norm])

Compute the one-dimensional discrete Fourier Transform for real input.

rfft2(a[, s, axes, norm])

Compute the 2-dimensional FFT of a real array.

rfftfreq(n[, d, dtype])

Return the Discrete Fourier Transform sample frequencies

rfftn(a[, s, axes, norm])

Compute the N-dimensional discrete Fourier Transform for real input.

jax.numpy.linalg#

cholesky(a, *[, upper])

Compute the Cholesky decomposition of a matrix.

cond(x[, p])

Compute the condition number of a matrix.

cross(x1, x2, /, *[, axis])

Compute the corss-product of two 3D vectors

det

Computes the determinant of an array.

diagonal(x, /, *[, offset])

Extract the diagonal of an matrix or stack of matrices.

eig(a)

Computes the eigenvalues and eigenvectors of a square array.

eigh(a[, UPLO, symmetrize_input])

Computes the eigenvalues and eigenvectors of a Hermitian matrix.

eigvals(a)

Computes the eigenvalues of a general matrix.

eigvalsh(a[, UPLO])

Computes the eigenvalues of a Hermitian matrix.

inv(a)

Return the inverse of a square matrix

lstsq(a, b[, rcond, numpy_resid])

Return the least-squares solution to a linear equation.

matmul(x1, x2, /)

Perform a matrix multiplication.

matrix_norm(x, /, *[, keepdims, ord])

Compute the norm of a matrix or stack of matrices.

matrix_power(a, n)

Raise a square matrix to an integer power.

matrix_rank(M[, tol])

Compute the rank of a matrix.

matrix_transpose(x, /)

Transpose a matrix or stack of matrices.

multi_dot(arrays, *[, precision])

Efficiently compute matrix products between a sequence of arrays.

norm(x[, ord, axis, keepdims])

Compute the norm of a matrix or vector.

outer(x1, x2, /)

Compute the outer product of two 1-dimensional arrays.

pinv

Compute the (Moore-Penrose) pseudo-inverse of a matrix.

qr()

Compute the QR decomposition of an array

slogdet(a, *[, method])

Computes the sign and (natural) logarithm of the determinant of an array.

solve(a, b)

Solve a linear system of equations

svd()

Compute the singular value decomposition.

svdvals(x, /)

Compute the singular values of a matrix.

tensordot(x1, x2, /, *[, axes])

Compute the tensor dot product of two N-dimensional arrays.

tensorinv(a[, ind])

Compute the tensor inverse of an array.

tensorsolve(a, b[, axes])

Solve the tensor equation a x = b for x.

vector_norm(x, /, *[, axis, keepdims, ord])

Computes the vector norm of a vector or batch of vectors.

vecdot(x1, x2, /, *[, axis])

Compute the (batched) vector conjugate dot product of two arrays.

JAX Array#

The JAX Array (along with its alias, jax.numpy.ndarray) is the core array object in JAX: you can think of it as JAX’s equivalent of a numpy.ndarray. Like numpy.ndarray, most users will not need to instantiate Array objects manually, but rather will create them via jax.numpy functions like array(), arange(), linspace(), and others listed above.

Copying and Serialization#

JAX Array objects are designed to work seamlessly with Python standard library tools where appropriate.

With the built-in copy module, when copy.copy() or copy.deepcopy() encounder an Array, it is equivalent to calling the copy() method, which will create a copy of the buffer on the same device as the original array. This will work correctly within traced/JIT-compiled code, though copy operations may be elided by the compiler in this context.

When the built-in pickle module encounters an Array, it will be serialized via a compact bit representation in a similar manner to pickled numpy.ndarray objects. When unpickled, the result will be a new Array object on the default device. This is because in general, pickling and unpickling may take place in different runtime environments, and there is no general way to map the device IDs of one runtime to the device IDs of another. If pickle is used in traced/JIT-compiled code, it will result in a ConcretizationTypeError.

jax.scipy module#

jax.scipy.cluster#

vq(obs, code_book[, check_finite])

Assign codes from a code book to a set of observations.

jax.scipy.fft#

dct(x[, type, n, axis, norm])

Computes the discrete cosine transform of the input

dctn(x[, type, s, axes, norm])

Computes the multidimensional discrete cosine transform of the input

idct(x[, type, n, axis, norm])

Computes the inverse discrete cosine transform of the input

idctn(x[, type, s, axes, norm])

Computes the multidimensional inverse discrete cosine transform of the input

jax.scipy.integrate#

trapezoid(y[, x, dx, axis])

Integrate along the given axis using the composite trapezoidal rule.

jax.scipy.linalg#

block_diag(*arrs)

Create a block diagonal matrix from input arrays.

cho_factor(a[, lower, overwrite_a, check_finite])

Factorization for Cholesky-based linear solves

cho_solve(c_and_lower, b[, overwrite_b, ...])

Solve a linear system using a Cholesky factorization

cholesky(a[, lower, overwrite_a, check_finite])

Compute the Cholesky decomposition of a matrix.

det(a[, overwrite_a, check_finite])

Compute the determinant of a matrix

eigh()

Compute eigenvalues and eigenvectors for a Hermitian matrix

eigh_tridiagonal(d, e, *[, eigvals_only, ...])

Solve the eigenvalue problem for a symmetric real tridiagonal matrix

expm(A, *[, upper_triangular, max_squarings])

Compute the matrix exponential

expm_frechet()

Compute the Frechet derivative of the matrix exponential.

funm(A, func[, disp])

Evaluate a matrix-valued function

hessenberg()

Compute the Hessenberg form of the matrix

hilbert(n)

Create a Hilbert matrix of order n.

inv(a[, overwrite_a, check_finite])

Return the inverse of a square matrix

lu()

Compute the LU decomposition

lu_factor(a[, overwrite_a, check_finite])

Factorization for LU-based linear solves

lu_solve(lu_and_piv, b[, trans, ...])

Solve a linear system using an LU factorization

polar(a[, side, method, eps, max_iterations])

Computes the polar decomposition.

qr()

Compute the QR decomposition of an array

rsf2csf(T, Z[, check_finite])

Convert real Schur form to complex Schur form.

schur(a[, output])

Compute the Schur decomposition

solve(a, b[, lower, overwrite_a, ...])

Solve a linear system of equations

solve_triangular(a, b[, trans, lower, ...])

Solve a triangular linear system of equations

sqrtm(A[, blocksize])

Compute the matrix square root

svd()

Compute the singular value decomposition.

toeplitz(c[, r])

Construct a Toeplitz matrix

jax.scipy.ndimage#

map_coordinates(input, coordinates, order[, ...])

Map the input array to new coordinates using interpolation.

jax.scipy.optimize#

minimize(fun, x0[, args, tol, options])

Minimization of scalar function of one or more variables.

OptimizeResults(x, success, status, fun, ...)

Object holding optimization results.

jax.scipy.signal#

fftconvolve(in1, in2[, mode, axes])

Convolve two N-dimensional arrays using Fast Fourier Transform (FFT).

convolve(in1, in2[, mode, method, precision])

Convolution of two N-dimensional arrays.

convolve2d(in1, in2[, mode, boundary, ...])

Convolution of two 2-dimensional arrays.

correlate(in1, in2[, mode, method, precision])

Cross-correlation of two N-dimensional arrays.

correlate2d(in1, in2[, mode, boundary, ...])

Cross-correlation of two 2-dimensional arrays.

csd(x, y[, fs, window, nperseg, noverlap, ...])

Estimate cross power spectral density (CSD) using Welch's method.

detrend(data[, axis, type, bp, overwrite_data])

Remove linear or piecewise linear trends from data.

istft(Zxx[, fs, window, nperseg, noverlap, ...])

Perform the inverse short-time Fourier transform (ISTFT).

stft(x[, fs, window, nperseg, noverlap, ...])

Compute the short-time Fourier transform (STFT).

welch(x[, fs, window, nperseg, noverlap, ...])

Estimate power spectral density (PSD) using Welch's method.

jax.scipy.spatial.transform#

Rotation(quat)

Rotation in 3 dimensions.

Slerp(times, timedelta, rotations, rotvecs)

Spherical Linear Interpolation of Rotations.

jax.scipy.sparse.linalg#

bicgstab(A, b[, x0, tol, atol, maxiter, M])

Use Bi-Conjugate Gradient Stable iteration to solve Ax = b.

cg(A, b[, x0, tol, atol, maxiter, M])

Use Conjugate Gradient iteration to solve Ax = b.

gmres(A, b[, x0, tol, atol, restart, ...])

GMRES solves the linear system A x = b for x, given A and b.

jax.scipy.special#

bernoulli(n)

Generate the first N Bernoulli numbers.

beta(x, y)

The beta function

betainc(a, b, x)

The regularized incomplete beta function.

betaln(a, b)

Natural log of the absolute value of the beta function

digamma(x)

The digamma function

entr(x)

The entropy function

erf(x)

The error function

erfc(x)

The complement of the error function

erfinv(x)

The inverse of the error function

exp1(x)

Exponential integral function.

expi

Exponential integral function.

expit(x)

The logistic sigmoid (expit) function

expn

Generalized exponential integral function.

factorial(n[, exact])

Factorial function

gamma(x)

The gamma function.

gammainc(a, x)

The regularized lower incomplete gamma function.

gammaincc(a, x)

The regularized upper incomplete gamma function.

gammaln(x)

Natural log of the absolute value of the gamma function.

gammasgn(x)

Sign of the gamma function.

hyp1f1

The 1F1 hypergeometric function.

i0(x)

Modified bessel function of zeroth order.

i0e(x)

Exponentially scaled modified bessel function of zeroth order.

i1(x)

Modified bessel function of first order.

i1e(x)

Exponentially scaled modified bessel function of first order.

log_ndtr

Log Normal distribution function.

logit

The logit function

logsumexp()

Log-sum-exp reduction.

lpmn(m, n, z)

The associated Legendre functions (ALFs) of the first kind.

lpmn_values(m, n, z, is_normalized)

The associated Legendre functions (ALFs) of the first kind.

multigammaln(a, d)

The natural log of the multivariate gamma function.

ndtr(x)

Normal distribution function.

ndtri(p)

The inverse of the CDF of the Normal distribution function.

poch

The Pochammer symbol.

polygamma(n, x)

The polygamma function.

spence(x)

Spence's function, also known as the dilogarithm for real values.

sph_harm(m, n, theta, phi[, n_max])

Computes the spherical harmonics.

xlog1py

Compute x*log(1 + y), returning 0 for x=0.

xlogy

Compute x*log(y), returning 0 for x=0.

zeta

The Hurwitz zeta function.

kl_div(p, q)

The Kullback-Leibler divergence.

rel_entr(p, q)

The relative entropy function.

jax.scipy.stats#

mode(a[, axis, nan_policy, keepdims])

Compute the mode (most common value) along an axis of an array.

rankdata(a[, method, axis, nan_policy])

Compute the rank of data along an array axis.

sem(a[, axis, ddof, nan_policy, keepdims])

Compute the standard error of the mean.

jax.scipy.stats.bernoulli#

logpmf(k, p[, loc])

Bernoulli log probability mass function.

pmf(k, p[, loc])

Bernoulli probability mass function.

cdf(k, p)

Bernoulli cumulative distribution function.

ppf(q, p)

Bernoulli percent point function.

jax.scipy.stats.beta#

logpdf(x, a, b[, loc, scale])

Beta log probability distribution function.

pdf(x, a, b[, loc, scale])

Beta probability distribution function.

cdf(x, a, b[, loc, scale])

Beta cumulative distribution function

logcdf(x, a, b[, loc, scale])

Beta log cumulative distribution function.

sf(x, a, b[, loc, scale])

Beta distribution survival function.

logsf(x, a, b[, loc, scale])

Beta distribution log survival function.

jax.scipy.stats.betabinom#

logpmf(k, n, a, b[, loc])

Beta-binomial log probability mass function.

pmf(k, n, a, b[, loc])

Beta-binomial probability mass function.

jax.scipy.stats.binom#

logpmf(k, n, p[, loc])

Binomial log probability mass function.

pmf(k, n, p[, loc])

Binomial probability mass function.

jax.scipy.stats.cauchy#

logpdf(x[, loc, scale])

Cauchy log probability distribution function.

pdf(x[, loc, scale])

Cauchy probability distribution function.

cdf(x[, loc, scale])

Cauchy cumulative distribution function.

logcdf(x[, loc, scale])

Cauchy log cumulative distribution function.

sf(x[, loc, scale])

Cauchy distribution log survival function.

logsf(x[, loc, scale])

Cauchy distribution log survival function.

isf(q[, loc, scale])

Cauchy distribution inverse survival function.

ppf(q[, loc, scale])

Cauchy distribution percent point function.

jax.scipy.stats.chi2#

logpdf(x, df[, loc, scale])

Chi-square log probability distribution function.

pdf(x, df[, loc, scale])

Chi-square probability distribution function.

cdf(x, df[, loc, scale])

Chi-square cumulative distribution function.

logcdf(x, df[, loc, scale])

Chi-square log cumulative distribution function.

sf(x, df[, loc, scale])

Chi-square survival function.

logsf(x, df[, loc, scale])

Chi-square log survival function.

jax.scipy.stats.dirichlet#

logpdf(x, alpha)

Dirichlet log probability distribution function.

pdf(x, alpha)

Dirichlet probability distribution function.

jax.scipy.stats.expon#

logpdf(x[, loc, scale])

Exponential log probability distribution function.

pdf(x[, loc, scale])

Exponential probability distribution function.

jax.scipy.stats.gamma#

logpdf(x, a[, loc, scale])

Gamma log probability distribution function.

pdf(x, a[, loc, scale])

Gamma probability distribution function.

cdf(x, a[, loc, scale])

Gamma cumulative distribution function.

logcdf(x, a[, loc, scale])

Gamma log cumulative distribution function.

sf(x, a[, loc, scale])

Gamma survival function.

logsf(x, a[, loc, scale])

Gamma log survival function.

jax.scipy.stats.gennorm#

cdf(x, beta)

Generalized normal cumulative distribution function.

logpdf(x, beta)

Generalized normal log probability distribution function.

pdf(x, beta)

Generalized normal probability distribution function.

jax.scipy.stats.geom#

logpmf(k, p[, loc])

Geometric log probability mass function.

pmf(k, p[, loc])

Geometric probability mass function.

jax.scipy.stats.laplace#

cdf(x[, loc, scale])

Laplace cumulative distribution function.

logpdf(x[, loc, scale])

Laplace log probability distribution function.

pdf(x[, loc, scale])

Laplace probability distribution function.

jax.scipy.stats.logistic#

cdf(x[, loc, scale])

Logistic cumulative distribution function.

isf(x[, loc, scale])

Logistic distribution inverse survival function.

logpdf(x[, loc, scale])

Logistic log probability distribution function.

pdf(x[, loc, scale])

Logistic probability distribution function.

ppf(x[, loc, scale])

Logistic distribution percent point function.

sf(x[, loc, scale])

Logistic distribution survival function.

jax.scipy.stats.multinomial#

logpmf(x, n, p)

Multinomial log probability mass function.

pmf(x, n, p)

Multinomial probability mass function.

jax.scipy.stats.multivariate_normal#

logpdf(x, mean, cov[, allow_singular])

Multivariate normal log probability distribution function.

pdf(x, mean, cov)

Multivariate normal probability distribution function.

jax.scipy.stats.nbinom#

logpmf(k, n, p[, loc])

Negative-binomial log probability mass function.

pmf(k, n, p[, loc])

Negative-binomial probability mass function.

jax.scipy.stats.norm#

logpdf(x[, loc, scale])

Normal log probability distribution function.

pdf(x[, loc, scale])

Normal probability distribution function.

cdf(x[, loc, scale])

Normal cumulative distribution function.

logcdf(x[, loc, scale])

Normal log cumulative distribution function.

ppf(q[, loc, scale])

Normal distribution percent point function.

sf(x[, loc, scale])

Normal distribution survival function.

logsf(x[, loc, scale])

Normal distribution log survival function.

isf(q[, loc, scale])

Normal distribution inverse survival function.

jax.scipy.stats.pareto#

logpdf(x, b[, loc, scale])

Pareto log probability distribution function.

pdf(x, b[, loc, scale])

Pareto probability distribution function.

jax.scipy.stats.poisson#

logpmf(k, mu[, loc])

Poisson log probability mass function.

pmf(k, mu[, loc])

Poisson probability mass function.

cdf(k, mu[, loc])

Poisson cumulative distribution function.

jax.scipy.stats.t#

logpdf(x, df[, loc, scale])

Student's T log probability distribution function.

pdf(x, df[, loc, scale])

Student's T probability distribution function.

jax.scipy.stats.truncnorm#

cdf(x, a, b[, loc, scale])

Truncated normal cumulative distribution function.

logcdf(x, a, b[, loc, scale])

Truncated normal log cumulative distribution function.

logpdf(x, a, b[, loc, scale])

Truncated normal log probability distribution function.

logsf(x, a, b[, loc, scale])

Truncated normal distribution log survival function.

pdf(x, a, b[, loc, scale])

Truncated normal probability distribution function.

sf(x, a, b[, loc, scale])

Truncated normal distribution log survival function.

jax.scipy.stats.uniform#

logpdf(x[, loc, scale])

Uniform log probability distribution function.

pdf(x[, loc, scale])

Uniform probability distribution function.

cdf(x[, loc, scale])

Uniform cumulative distribution function.

ppf(q[, loc, scale])

Uniform distribution percent point function.

jax.scipy.stats.gaussian_kde#

gaussian_kde(dataset[, bw_method, weights])

Gaussian Kernel Density Estimator

gaussian_kde.evaluate(points)

Evaluate the Gaussian KDE on the given points.

gaussian_kde.integrate_gaussian(mean, cov)

Integrate the distribution weighted by a Gaussian.

gaussian_kde.integrate_box_1d(low, high)

Integrate the distribution over the given limits.

gaussian_kde.integrate_kde(other)

Integrate the product of two Gaussian KDE distributions.

gaussian_kde.resample(key[, shape])

Randomly sample a dataset from the estimated pdf

gaussian_kde.pdf(x)

Probability density function

gaussian_kde.logpdf(x)

Log probability density function

jax.scipy.stats.vonmises#

logpdf(x, kappa)

von Mises log probability distribution function.

pdf(x, kappa)

von Mises probability distribution function.

jax.scipy.stats.wrapcauchy#

logpdf(x, c)

Wrapped Cauchy log probability distribution function.

pdf(x, c)

Wrapped Cauchy probability distribution function.

jax.lax module#

jax.lax is a library of primitives operations that underpins libraries such as jax.numpy. Transformation rules, such as JVP and batching rules, are typically defined as transformations on jax.lax primitives.

Many of the primitives are thin wrappers around equivalent XLA operations, described by the XLA operation semantics documentation. In a few cases JAX diverges from XLA, usually to ensure that the set of operations is closed under the operation of JVP and transpose rules.

Where possible, prefer to use libraries such as jax.numpy instead of using jax.lax directly. The jax.numpy API follows NumPy, and is therefore more stable and less likely to change than the jax.lax API.

Operators#

abs(x)

Elementwise absolute value: \(|x|\).

acos(x)

Elementwise arc cosine: \(\mathrm{acos}(x)\).

acosh(x)

Elementwise inverse hyperbolic cosine: \(\mathrm{acosh}(x)\).

add(x, y)

Elementwise addition: \(x + y\).

after_all(*operands)

Merges one or more XLA token values.

approx_max_k(operand, k[, ...])

Returns max k values and their indices of the operand in an approximate manner.

approx_min_k(operand, k[, ...])

Returns min k values and their indices of the operand in an approximate manner.

argmax(operand, axis, index_dtype)

Computes the index of the maximum element along axis.

argmin(operand, axis, index_dtype)

Computes the index of the minimum element along axis.

asin(x)

Elementwise arc sine: \(\mathrm{asin}(x)\).

asinh(x)

Elementwise inverse hyperbolic sine: \(\mathrm{asinh}(x)\).

atan(x)

Elementwise arc tangent: \(\mathrm{atan}(x)\).

atan2(x, y)

Elementwise arc tangent of two variables: \(\mathrm{atan}({x \over y})\).

atanh(x)

Elementwise inverse hyperbolic tangent: \(\mathrm{atanh}(x)\).

batch_matmul(lhs, rhs[, precision])

Batch matrix multiplication.

bessel_i0e(x)

Exponentially scaled modified Bessel function of order 0: \(\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)\)

bessel_i1e(x)

Exponentially scaled modified Bessel function of order 1: \(\mathrm{i1e}(x) = e^{-|x|} \mathrm{i1}(x)\)

betainc(a, b, x)

Elementwise regularized incomplete beta integral.

bitcast_convert_type(operand, new_dtype)

Elementwise bitcast.

bitwise_and(x, y)

Elementwise AND: \(x \wedge y\).

bitwise_not(x)

Elementwise NOT: \(\neg x\).

bitwise_or(x, y)

Elementwise OR: \(x \vee y\).

bitwise_xor(x, y)

Elementwise exclusive OR: \(x \oplus y\).

population_count(x)

Elementwise popcount, count the number of set bits in each element.

broadcast(operand, sizes)

Broadcasts an array, adding new leading dimensions

broadcast_in_dim(operand, shape, ...)

Wraps XLA's BroadcastInDim operator.

broadcast_shapes()

Returns the shape that results from NumPy broadcasting of shapes.

broadcast_to_rank(x, rank)

Adds leading dimensions of 1 to give x rank rank.

broadcasted_iota(dtype, shape, dimension)

Convenience wrapper around iota.

cbrt(x)

Elementwise cube root: \(\sqrt[3]{x}\).

ceil(x)

Elementwise ceiling: \(\left\lceil x \right\rceil\).

clamp(min, x, max)

Elementwise clamp.

clz(x)

Elementwise count-leading-zeros.

collapse(operand, start_dimension[, ...])

Collapses dimensions of an array into a single dimension.

complex(x, y)

Elementwise make complex number: \(x + jy\).

concatenate(operands, dimension)

Concatenates a sequence of arrays along dimension.

conj(x)

Elementwise complex conjugate function: \(\overline{x}\).

conv(lhs, rhs, window_strides, padding[, ...])

Convenience wrapper around conv_general_dilated.

convert_element_type(operand, new_dtype)

Elementwise cast.

conv_dimension_numbers(lhs_shape, rhs_shape, ...)

Converts convolution dimension_numbers to a ConvDimensionNumbers.

conv_general_dilated(lhs, rhs, ...[, ...])

General n-dimensional convolution operator, with optional dilation.

conv_general_dilated_local(lhs, rhs, ...[, ...])

General n-dimensional unshared convolution operator with optional dilation.

conv_general_dilated_patches(lhs, ...[, ...])

Extract patches subject to the receptive field of conv_general_dilated.

conv_transpose(lhs, rhs, strides, padding[, ...])

Convenience wrapper for calculating the N-d convolution "transpose".

conv_with_general_padding(lhs, rhs, ...[, ...])

Convenience wrapper around conv_general_dilated.

cos(x)

Elementwise cosine: \(\mathrm{cos}(x)\).

cosh(x)

Elementwise hyperbolic cosine: \(\mathrm{cosh}(x)\).

cumlogsumexp(operand[, axis, reverse])

Computes a cumulative logsumexp along axis.

cummax(operand[, axis, reverse])

Computes a cumulative maximum along axis.

cummin(operand[, axis, reverse])

Computes a cumulative minimum along axis.

cumprod(operand[, axis, reverse])

Computes a cumulative product along axis.

cumsum(operand[, axis, reverse])

Computes a cumulative sum along axis.

digamma(x)

Elementwise digamma: \(\psi(x)\).

div(x, y)

Elementwise division: \(x \over y\).

dot(lhs, rhs[, precision, ...])

Vector/vector, matrix/vector, and matrix/matrix multiplication.

dot_general(lhs, rhs, dimension_numbers[, ...])

General dot product/contraction operator.

dynamic_index_in_dim(operand, index[, axis, ...])

Convenience wrapper around dynamic_slice to perform int indexing.

dynamic_slice(operand, start_indices, ...)

Wraps XLA's DynamicSlice operator.

dynamic_slice_in_dim(operand, start_index, ...)

Convenience wrapper around lax.dynamic_slice() applied to one dimension.

dynamic_update_index_in_dim(operand, update, ...)

Convenience wrapper around dynamic_update_slice() to update a slice of size 1 in a single axis.

dynamic_update_slice(operand, update, ...)

Wraps XLA's DynamicUpdateSlice operator.

dynamic_update_slice_in_dim(operand, update, ...)

Convenience wrapper around dynamic_update_slice() to update a slice in a single axis.

eq(x, y)

Elementwise equals: \(x = y\).

erf(x)

Elementwise error function: \(\mathrm{erf}(x)\).

erfc(x)

Elementwise complementary error function: \(\mathrm{erfc}(x) = 1 - \mathrm{erf}(x)\).

erf_inv(x)

Elementwise inverse error function: \(\mathrm{erf}^{-1}(x)\).

exp(x)

Elementwise exponential: \(e^x\).

expand_dims(array, dimensions)

Insert any number of size 1 dimensions into an array.

expm1(x)

Elementwise \(e^{x} - 1\).

fft(x, fft_type, fft_lengths)

floor(x)

Elementwise floor: \(\left\lfloor x \right\rfloor\).

full(shape, fill_value[, dtype, sharding])

Returns an array of shape filled with fill_value.

full_like(x, fill_value[, dtype, shape, ...])

Create a full array like np.full based on the example array x.

gather(operand, start_indices, ...[, ...])

Gather operator.

ge(x, y)

Elementwise greater-than-or-equals: \(x \geq y\).

gt(x, y)

Elementwise greater-than: \(x > y\).

igamma(a, x)

Elementwise regularized incomplete gamma function.

igammac(a, x)

Elementwise complementary regularized incomplete gamma function.

imag(x)

Elementwise extract imaginary part: \(\mathrm{Im}(x)\).

index_in_dim(operand, index[, axis, keepdims])

Convenience wrapper around lax.slice() to perform int indexing.

index_take(src, idxs, axes)

integer_pow(x, y)

Elementwise power: \(x^y\), where \(y\) is a fixed integer.

iota(dtype, size)

Wraps XLA's Iota operator.

is_finite(x)

Elementwise \(\mathrm{isfinite}\).

le(x, y)

Elementwise less-than-or-equals: \(x \leq y\).

lgamma(x)

Elementwise log gamma: \(\mathrm{log}(\Gamma(x))\).

log(x)

Elementwise natural logarithm: \(\mathrm{log}(x)\).

log1p(x)

Elementwise \(\mathrm{log}(1 + x)\).

logistic(x)

Elementwise logistic (sigmoid) function: \(\frac{1}{1 + e^{-x}}\).

lt(x, y)

Elementwise less-than: \(x < y\).

max(x, y)

Elementwise maximum: \(\mathrm{max}(x, y)\)

min(x, y)

Elementwise minimum: \(\mathrm{min}(x, y)\)

mul(x, y)

Elementwise multiplication: \(x \times y\).

ne(x, y)

Elementwise not-equals: \(x \neq y\).

neg(x)

Elementwise negation: \(-x\).

nextafter(x1, x2)

Returns the next representable value after x1 in the direction of x2.

pad(operand, padding_value, padding_config)

Applies low, high, and/or interior padding to an array.

polygamma(m, x)

Elementwise polygamma: \(\psi^{(m)}(x)\).

population_count(x)

Elementwise popcount, count the number of set bits in each element.

pow(x, y)

Elementwise power: \(x^y\).

random_gamma_grad(a, x)

Elementwise derivative of samples from Gamma(a, 1).

real(x)

Elementwise extract real part: \(\mathrm{Re}(x)\).

reciprocal(x)

Elementwise reciprocal: \(1 \over x\).

reduce(operands, init_values, computation, ...)

Wraps XLA's Reduce operator.

reduce_precision(operand, exponent_bits, ...)

Wraps XLA's ReducePrecision operator.

reduce_window(operand, init_value, ...[, ...])

Wraps XLA's ReduceWindowWithGeneralPadding operator.

rem(x, y)

Elementwise remainder: \(x \bmod y\).

reshape(operand, new_sizes[, dimensions])

Wraps XLA's Reshape operator.

rev(operand, dimensions)

Wraps XLA's Rev operator.

rng_bit_generator(key, shape[, dtype, algorithm])

Stateless PRNG bit generator.

rng_uniform(a, b, shape)

Stateful PRNG generator.

round(x[, rounding_method])

Elementwise round.

rsqrt(x)

Elementwise reciprocal square root: \(1 \over \sqrt{x}\).

scatter(operand, scatter_indices, updates, ...)

Scatter-update operator.

scatter_add(operand, scatter_indices, ...[, ...])

Scatter-add operator.

scatter_apply(operand, scatter_indices, ...)

Scatter-apply operator.

scatter_max(operand, scatter_indices, ...[, ...])

Scatter-max operator.

scatter_min(operand, scatter_indices, ...[, ...])

Scatter-min operator.

scatter_mul(operand, scatter_indices, ...[, ...])

Scatter-multiply operator.

shift_left(x, y)

Elementwise left shift: \(x \ll y\).

shift_right_arithmetic(x, y)

Elementwise arithmetic right shift: \(x \gg y\).

shift_right_logical(x, y)

Elementwise logical right shift: \(x \gg y\).

sign(x)

Elementwise sign.

sin(x)

Elementwise sine: \(\mathrm{sin}(x)\).

sinh(x)

Elementwise hyperbolic sine: \(\mathrm{sinh}(x)\).

slice(operand, start_indices, limit_indices)

Wraps XLA's Slice operator.

slice_in_dim(operand, start_index, limit_index)

Convenience wrapper around lax.slice() applying to only one dimension.

sort()

Wraps XLA's Sort operator.

sort_key_val(keys, values[, dimension, ...])

Sorts keys along dimension and applies the same permutation to values.

sqrt(x)

Elementwise square root: \(\sqrt{x}\).

square(x)

Elementwise square: \(x^2\).

squeeze(array, dimensions)

Squeeze any number of size 1 dimensions from an array.

sub(x, y)

Elementwise subtraction: \(x - y\).

tan(x)

Elementwise tangent: \(\mathrm{tan}(x)\).

tanh(x)

Elementwise hyperbolic tangent: \(\mathrm{tanh}(x)\).

tie_in(x, y)

Deprecated.

top_k(operand, k)

Returns top k values and their indices along the last axis of operand.

transpose(operand, permutation)

Wraps XLA's Transpose operator.

zeros_like_array(x)

zeta(x, q)

Elementwise Hurwitz zeta function: \(\zeta(x, q)\)

Control flow operators#

associative_scan(fn, elems[, reverse, axis])

Performs a scan with an associative binary operation, in parallel.

cond(pred, true_fun, false_fun, *operands[, ...])

Conditionally apply true_fun or false_fun.

fori_loop(lower, upper, body_fun, init_val, *)

Loop from lower to upper by reduction to jax.lax.while_loop().

map(f, xs)

Map a function over leading array axes.

scan(f, init[, xs, length, reverse, unroll, ...])

Scan a function over leading array axes while carrying along state.

select(pred, on_true, on_false)

Selects between two branches based on a boolean predicate.

select_n(which, *cases)

Selects array values from multiple cases.

switch(index, branches, *operands[, operand])

Apply exactly one of the branches given by index.

while_loop(cond_fun, body_fun, init_val)

Call body_fun repeatedly in a loop while cond_fun is True.

Custom gradient operators#

stop_gradient(x)

Stops gradient computation.

custom_linear_solve(matvec, b, solve[, ...])

Perform a matrix-free linear solve with implicitly defined gradients.

custom_root(f, initial_guess, solve, ...[, ...])

Differentiably solve for the roots of a function.

Parallel operators#

all_gather(x, axis_name, *[, ...])

Gather values of x across all replicas.

all_to_all(x, axis_name, split_axis, ...[, ...])

Materialize the mapped axis and map a different axis.

pdot(x, y, axis_name[, pos_contract, ...])

psum(x, axis_name, *[, axis_index_groups])

Compute an all-reduce sum on x over the pmapped axis axis_name.

psum_scatter(x, axis_name, *[, ...])

Like psum(x, axis_name) but each device retains only part of the result.

pmax(x, axis_name, *[, axis_index_groups])

Compute an all-reduce max on x over the pmapped axis axis_name.

pmin(x, axis_name, *[, axis_index_groups])

Compute an all-reduce min on x over the pmapped axis axis_name.

pmean(x, axis_name, *[, axis_index_groups])

Compute an all-reduce mean on x over the pmapped axis axis_name.

ppermute(x, axis_name, perm)

Perform a collective permutation according to the permutation perm.

pshuffle(x, axis_name, perm)

Convenience wrapper of jax.lax.ppermute with alternate permutation encoding

pswapaxes(x, axis_name, axis, *[, ...])

Swap the pmapped axis axis_name with the unmapped axis axis.

axis_index(axis_name)

Return the index along the mapped axis axis_name.

Linear algebra operators (jax.lax.linalg)#

cholesky(x, *[, symmetrize_input])

Cholesky decomposition.

eig(x, *[, compute_left_eigenvectors, ...])

Eigendecomposition of a general matrix.

eigh(x, *[, lower, symmetrize_input, ...])

Eigendecomposition of a Hermitian matrix.

hessenberg(a)

Reduces a square matrix to upper Hessenberg form.

lu(x)

LU decomposition with partial pivoting.

householder_product(a, taus)

Product of elementary Householder reflectors.

qdwh(x, *[, is_hermitian, max_iterations, ...])

QR-based dynamically weighted Halley iteration for polar decomposition.

qr(x, *[, full_matrices])

QR decomposition.

schur(x, *[, compute_schur_vectors, ...])

svd()

Singular value decomposition.

triangular_solve(a, b, *[, left_side, ...])

Triangular solve.

tridiagonal(a, *[, lower])

Reduces a symmetric/Hermitian matrix to tridiagonal form.

tridiagonal_solve(dl, d, du, b)

Computes the solution of a tridiagonal linear system.

Argument classes#
class jax.lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)[source]#

Describes batch, spatial, and feature dimensions of a convolution.

Parameters:
  • lhs_spec (Sequence[int]) – a tuple of nonnegative integer dimension numbers containing (batch dimension, feature dimension, spatial dimensions…).

  • rhs_spec (Sequence[int]) – a tuple of nonnegative integer dimension numbers containing (out feature dimension, in feature dimension, spatial dimensions…).

  • out_spec (Sequence[int]) – a tuple of nonnegative integer dimension numbers containing (batch dimension, feature dimension, spatial dimensions…).

jax.lax.ConvGeneralDilatedDimensionNumbers#

alias of tuple[str, str, str] | ConvDimensionNumbers | None

class jax.lax.GatherDimensionNumbers(offset_dims, collapsed_slice_dims, start_index_map)[source]#

Describes the dimension number arguments to an XLA’s Gather operator. See the XLA documentation for more details of what the dimension numbers mean.

Parameters:
  • offset_dims (tuple[int, ...]) – the set of dimensions in the gather output that offset into an array sliced from operand. Must be a tuple of integers in ascending order, each representing a dimension number of the output.

  • collapsed_slice_dims (tuple[int, ...]) – the set of dimensions i in operand that have slice_sizes[i] == 1 and that should not have a corresponding dimension in the output of the gather. Must be a tuple of integers in ascending order.

  • start_index_map (tuple[int, ...]) – for each dimension in start_indices, gives the corresponding dimension in the operand that is to be sliced. Must be a tuple of integers with size equal to start_indices.shape[-1].

Unlike XLA’s GatherDimensionNumbers structure, index_vector_dim is implicit; there is always an index vector dimension and it must always be the last dimension. To gather scalar indices, add a trailing dimension of size 1.

class jax.lax.GatherScatterMode(value)[source]#

Describes how to handle out-of-bounds indices in a gather or scatter.

Possible values are:

CLIP:

Indices will be clamped to the nearest in-range value, i.e., such that the entire window to be gathered is in-range.

FILL_OR_DROP:

If any part of a gathered window is out of bounds, the entire window that is returned, even those elements that were otherwise in-bounds, will be filled with a constant. If any part of a scattered window is out of bounds, the entire window will be discarded.

PROMISE_IN_BOUNDS:

The user promises that indices are in bounds. No additional checking will be performed. In practice, with the current XLA implementation this means that out-of-bounds gathers will be clamped but out-of-bounds scatters will be discarded. Gradients will not be correct if indices are out-of-bounds.

class jax.lax.Precision(value)[source]#

Precision enum for lax functions

The precision argument to JAX functions generally controls the tradeoff between speed and accuracy for array computations on accelerator backends, (i.e. TPU and GPU). Members are:

DEFAULT:

Fastest mode, but least accurate. Performs computations in bfloat16. Aliases: 'default', 'fastest', 'bfloat16'.

HIGH:

Slower but more accurate. Performs float32 computations in 3 bfloat16 passes, or using tensorfloat32 where available. Aliases: 'high', 'bfloat16_3x', 'tensorfloat32'.

HIGHEST:

Slowest but most accurate. Performs computations in float32 or float64 as applicable. Aliases: 'highest', 'float32'.

jax.lax.PrecisionLike#

alias of str | Precision | tuple[str, str] | tuple[Precision, Precision] | None

class jax.lax.RoundingMethod(value)[source]#

An enumeration.

class jax.lax.ScatterDimensionNumbers(update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims)[source]#

Describes the dimension number arguments to an XLA’s Scatter operator. See the XLA documentation for more details of what the dimension numbers mean.

Parameters:
  • update_window_dims (Sequence[int]) – the set of dimensions in the updates that are window dimensions. Must be a tuple of integers in ascending order, each representing a dimension number.

  • inserted_window_dims (Sequence[int]) – the set of size 1 window dimensions that must be inserted into the shape of updates. Must be a tuple of integers in ascending order, each representing a dimension number of the output. These are the mirror image of collapsed_slice_dims in the case of gather.

  • scatter_dims_to_operand_dims (Sequence[int]) – for each dimension in scatter_indices, gives the corresponding dimension in operand. Must be a sequence of integers with size equal to scatter_indices.shape[-1].

Unlike XLA’s ScatterDimensionNumbers structure, index_vector_dim is implicit; there is always an index vector dimension and it must always be the last dimension. To scatter scalar indices, add a trailing dimension of size 1.

jax.random module#

Utilities for pseudo-random number generation.

The jax.random package provides a number of routines for deterministic generation of sequences of pseudorandom numbers.

Basic usage#
>>> seed = 1701
>>> num_steps = 100
>>> key = jax.random.key(seed)
>>> for i in range(num_steps):
...   key, subkey = jax.random.split(key)
...   params = compiled_update(subkey, params, next(batches))  
PRNG keys#

Unlike the stateful pseudorandom number generators (PRNGs) that users of NumPy and SciPy may be accustomed to, JAX random functions all require an explicit PRNG state to be passed as a first argument. The random state is described by a special array element type that we call a key, usually generated by the jax.random.key() function:

>>> from jax import random
>>> key = random.key(0)
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]

This key can then be used in any of JAX’s random number generation routines:

>>> random.uniform(key)
Array(0.41845703, dtype=float32)

Note that using a key does not modify it, so reusing the same key will lead to the same result:

>>> random.uniform(key)
Array(0.41845703, dtype=float32)

If you need a new random number, you can use jax.random.split() to generate new subkeys:

>>> key, subkey = random.split(key)
>>> random.uniform(subkey)
Array(0.10536897, dtype=float32)

Note

Typed key arrays, with element types such as key<fry> above, were introduced in JAX v0.4.16. Before then, keys were conventionally represented in uint32 arrays, whose final dimension represented the key’s bit-level representation.

Both forms of key array can still be created and used with the jax.random module. New-style typed key arrays are made with jax.random.key(). Legacy uint32 key arrays are made with jax.random.PRNGKey().

To convert between the two, use jax.random.key_data() and jax.random.wrap_key_data(). The legacy key format may be needed when interfacing with systems outside of JAX (e.g. exporting arrays to a serializable format), or when passing keys to JAX-based libraries that assume the legacy format.

Otherwise, typed keys are recommended. Caveats of legacy keys relative to typed ones include:

  • They have an extra trailing dimension.

  • They have a numeric dtype (uint32), allowing for operations that are typically not meant to be carried out over keys, such as integer arithmetic.

  • They do not carry information about the RNG implementation. When legacy keys are passed to jax.random functions, a global configuration setting determines the RNG implementation (see “Advanced RNG configuration” below).

To learn more about this upgrade, and the design of key types, see JEP 9263.

Advanced#
Design and background#

TLDR: JAX PRNG = Threefry counter PRNG + a functional array-oriented splitting model

See docs/jep/263-prng.md for more details.

To summarize, among other requirements, the JAX PRNG aims to:

  1. ensure reproducibility,

  2. parallelize well, both in terms of vectorization (generating array values) and multi-replica, multi-core computation. In particular it should not use sequencing constraints between random function calls.

Advanced RNG configuration#

JAX provides several PRNG implementations. A specific one can be selected with the optional impl keyword argument to jax.random.key. When no impl option is passed to the key constructor, the implementation is determined by the global jax_default_prng_impl configuration flag.

  • default, “threefry2x32”: A counter-based PRNG built around the Threefry hash function.

  • experimental A PRNG that thinly wraps the XLA Random Bit Generator (RBG) algorithm. See TF doc.

    • “rbg” uses ThreeFry for splitting, and XLA RBG for data generation.

    • “unsafe_rbg” exists only for demonstration purposes, using RBG both for splitting (using an untested made up algorithm) and generating.

    The random streams generated by these experimental implementations haven’t been subject to any empirical randomness testing (e.g. Big Crush). The random bits generated may change between JAX versions.

The possible reasons not use the default RNG are:

  1. it may be slow to compile (specifically for Google Cloud TPUs)

  2. it’s slower to execute on TPUs

  3. it doesn’t support efficient automatic sharding / partitioning

Here is a short summary:

Property

Threefry

Threefry*

rbg

unsafe_rbg

rbg**

unsafe_rbg**

Fastest on TPU

efficiently shardable (w/ pjit)

identical across shardings

identical across CPU/GPU/TPU

identical across JAX/XLA versions

(*): with jax_threefry_partitionable=1 set

(**): with XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1 set

The difference between “rbg” and “unsafe_rbg” is that while “rbg” uses a less robust/studied hash function for random value generation (but not for jax.random.split or jax.random.fold_in), “unsafe_rbg” additionally uses less robust hash functions for jax.random.split and jax.random.fold_in. Therefore less safe in the sense that the quality of random streams it generates from different keys is less well understood.

For more about jax_threefry_partitionable, see https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers

API Reference#
Key Creation & Manipulation#

PRNGKey(seed, *[, impl])

Create a pseudo-random number generator (PRNG) key given an integer seed.

key(seed, *[, impl])

Create a pseudo-random number generator (PRNG) key given an integer seed.

key_data(keys)

Recover the bits of key data underlying a PRNG key array.

wrap_key_data(key_bits_array, *[, impl])

Wrap an array of key data bits into a PRNG key array.

fold_in(key, data)

Folds in data to a PRNG key to form a new PRNG key.

split(key[, num])

Splits a PRNG key into num new keys by adding a leading axis.

clone(key)

Clone a key for reuse

Random Samplers#

ball(key, d[, p, shape, dtype])

Sample uniformly from the unit Lp ball.

bernoulli(key[, p, shape])

Sample Bernoulli random values with given shape and mean.

beta(key, a, b[, shape, dtype])

Sample Beta random values with given shape and float dtype.

binomial(key, n, p[, shape, dtype])

Sample Binomial random values with given shape and float dtype.

bits(key[, shape, dtype])

Sample uniform bits in the form of unsigned integers.

categorical(key, logits[, axis, shape])

Sample random values from categorical distributions.

cauchy(key[, shape, dtype])

Sample Cauchy random values with given shape and float dtype.

chisquare(key, df[, shape, dtype])

Sample Chisquare random values with given shape and float dtype.

choice(key, a[, shape, replace, p, axis])

Generates a random sample from a given array.

dirichlet(key, alpha[, shape, dtype])

Sample Dirichlet random values with given shape and float dtype.

double_sided_maxwell(key, loc, scale[, ...])

Sample from a double sided Maxwell distribution.

exponential(key[, shape, dtype])

Sample Exponential random values with given shape and float dtype.

f(key, dfnum, dfden[, shape, dtype])

Sample F-distribution random values with given shape and float dtype.

gamma(key, a[, shape, dtype])

Sample Gamma random values with given shape and float dtype.

generalized_normal(key, p[, shape, dtype])

Sample from the generalized normal distribution.

geometric(key, p[, shape, dtype])

Sample Geometric random values with given shape and float dtype.

gumbel(key[, shape, dtype])

Sample Gumbel random values with given shape and float dtype.

laplace(key[, shape, dtype])

Sample Laplace random values with given shape and float dtype.

loggamma(key, a[, shape, dtype])

Sample log-gamma random values with given shape and float dtype.

logistic(key[, shape, dtype])

Sample logistic random values with given shape and float dtype.

lognormal(key[, sigma, shape, dtype])

Sample lognormal random values with given shape and float dtype.

maxwell(key[, shape, dtype])

Sample from a one sided Maxwell distribution.

multivariate_normal(key, mean, cov[, shape, ...])

Sample multivariate normal random values with given mean and covariance.

normal(key[, shape, dtype])

Sample standard normal random values with given shape and float dtype.

orthogonal(key, n[, shape, dtype])

Sample uniformly from the orthogonal group O(n).

pareto(key, b[, shape, dtype])

Sample Pareto random values with given shape and float dtype.

permutation(key, x[, axis, independent])

Returns a randomly permuted array or range.

poisson(key, lam[, shape, dtype])

Sample Poisson random values with given shape and integer dtype.

rademacher(key, shape[, dtype])

Sample from a Rademacher distribution.

randint(key, shape, minval, maxval[, dtype])

Sample uniform random values in [minval, maxval) with given shape/dtype.

rayleigh(key, scale[, shape, dtype])

Sample Rayleigh random values with given shape and float dtype.

t(key, df[, shape, dtype])

Sample Student's t random values with given shape and float dtype.

triangular(key, left, mode, right[, shape, ...])

Sample Triangular random values with given shape and float dtype.

truncated_normal(key, lower, upper[, shape, ...])

Sample truncated standard normal random values with given shape and dtype.

uniform(key[, shape, dtype, minval, maxval])

Sample uniform random values in [minval, maxval) with given shape/dtype.

wald(key, mean[, shape, dtype])

Sample Wald random values with given shape and float dtype.

weibull_min(key, scale, concentration[, ...])

Sample from a Weibull distribution.

jax.sharding module#

Classes#
class jax.sharding.Sharding#

Describes how a jax.Array is laid out across devices.

property addressable_devices: set[Device]#

The set of devices in the Sharding that are addressable by the current process.

addressable_devices_indices_map(global_shape)[source]#

A mapping from addressable devices to the slice of array data each contains.

addressable_devices_indices_map contains that part of device_indices_map that applies to the addressable devices.

Parameters:

global_shape (tuple[int, ...])

Return type:

Mapping[Device, tuple[slice, …] | None]

property device_set: set[Device][source]#

The set of devices that this Sharding spans.

In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.

devices_indices_map(global_shape)[source]#

Returns a mapping from devices to the array slices each contains.

The mapping includes all global devices, i.e., including non-addressable devices from other processes.

Parameters:

global_shape (tuple[int, ...])

Return type:

Mapping[Device, tuple[slice, …] | None]

is_equivalent_to(other, ndim)[source]#

Returns True if two shardings are equivalent.

Two shardings are equivalent if they place the same logical array shards on the same devices.

For example, a NamedSharding may be equivalent to a PositionalSharding if both place the same shards of the array on the same devices.

Parameters:
Return type:

bool

property is_fully_addressable: bool[source]#

Is this sharding fully addressable?

A sharding is fully addressable if the current process can address all of the devices named in the Sharding. is_fully_addressable is equivalent to “is_local” in multi-process JAX.

property is_fully_replicated: bool[source]#

Is this sharding fully replicated?

A sharding is fully replicated if each device has a complete copy of the entire data.

property memory_kind: str | None[source]#

Returns the memory kind of the sharding.

shard_shape(global_shape)[source]#

Returns the shape of the data on each device.

The shard shape returned by this function is calculated from global_shape and the properties of the sharding.

Parameters:

global_shape (tuple[int, ...])

Return type:

tuple[int, …]

with_memory_kind(kind)[source]#

Returns a new Sharding instance with the specified memory kind.

Parameters:

kind (str)

Return type:

Sharding

class jax.sharding.XLACompatibleSharding#

Bases: Sharding

A Sharding that describes shardings expressible to XLA.

Subclasses of XLACompatibleSharding work with all JAX APIs and transformations that use XLA.

devices_indices_map(global_shape)[source]#

Returns a mapping from devices to the array slices each contains.

The mapping includes all global devices, i.e., including non-addressable devices from other processes.

Parameters:

global_shape (tuple[int, ...])

Return type:

Mapping[Device, tuple[slice, …]]

is_equivalent_to(other, ndim)[source]#

Returns True if two shardings are equivalent.

Two shardings are equivalent if they place the same logical array shards on the same devices.

For example, a NamedSharding may be equivalent to a PositionalSharding if both place the same shards of the array on the same devices.

Parameters:
Return type:

bool

shard_shape(global_shape)[source]#

Returns the shape of the data on each device.

The shard shape returned by this function is calculated from global_shape and the properties of the sharding.

Parameters:

global_shape (tuple[int, ...])

Return type:

tuple[int, …]

class jax.sharding.SingleDeviceSharding#

Bases: XLACompatibleSharding

A Sharding that places its data on a single device.

Parameters:

device – A single Device.

Example

>>> single_device_sharding = jax.sharding.SingleDeviceSharding(
...     jax.devices()[0])
property device_set: set[Device][source]#

The set of devices that this Sharding spans.

In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.

devices_indices_map(global_shape)[source]#

Returns a mapping from devices to the array slices each contains.

The mapping includes all global devices, i.e., including non-addressable devices from other processes.

Parameters:

global_shape (tuple[int, ...])

Return type:

Mapping[Device, tuple[slice, …]]

property is_fully_addressable: bool[source]#

Is this sharding fully addressable?

A sharding is fully addressable if the current process can address all of the devices named in the Sharding. is_fully_addressable is equivalent to “is_local” in multi-process JAX.

property is_fully_replicated: bool[source]#

Is this sharding fully replicated?

A sharding is fully replicated if each device has a complete copy of the entire data.

property memory_kind: str | None[source]#

Returns the memory kind of the sharding.

with_memory_kind(kind)[source]#

Returns a new Sharding instance with the specified memory kind.

Parameters:

kind (str)

Return type:

SingleDeviceSharding

class jax.sharding.NamedSharding#

Bases: XLACompatibleSharding

A NamedSharding expresses sharding using named axes.

A NamedSharding is a pair of a Mesh of devices and PartitionSpec which describes how to shard an array across that mesh.

A Mesh is a multidimensional NumPy array of JAX devices, where each axis of the mesh has a name, e.g. 'x' or 'y'.

A PartitionSpec is a tuple, whose elements can be a None, a mesh axis, or a tuple of mesh axes. Each element describes how an input dimension is partitioned across zero or more mesh dimensions. For example, PartitionSpec('x', 'y') says that the first dimension of data is sharded across x axis of the mesh, and the second dimension is sharded across y axis of the mesh.

The Distributed arrays and automatic parallelization (https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names) tutorial has more details and diagrams that explain how Mesh and PartitionSpec are used.

Parameters:

Example

>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> spec = P('x', 'y')
>>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
property addressable_devices: set[Device][source]#

The set of devices in the Sharding that are addressable by the current process.

property device_set: set[Device][source]#

The set of devices that this Sharding spans.

In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.

property is_fully_addressable: bool[source]#

Is this sharding fully addressable?

A sharding is fully addressable if the current process can address all of the devices named in the Sharding. is_fully_addressable is equivalent to “is_local” in multi-process JAX.

property is_fully_replicated: bool#

Is this sharding fully replicated?

A sharding is fully replicated if each device has a complete copy of the entire data.

property memory_kind: str | None[source]#

Returns the memory kind of the sharding.

property mesh#

(self) -> object

property spec#

(self) -> object

with_memory_kind(kind)[source]#

Returns a new Sharding instance with the specified memory kind.

Parameters:

kind (str)

Return type:

NamedSharding

class jax.sharding.PositionalSharding(devices, *, memory_kind=None)[source]#

Bases: XLACompatibleSharding

Parameters:
  • devices (Sequence[xc.Device] | np.ndarray)

  • memory_kind (str | None)

property device_set: set[Device]#

The set of devices that this Sharding spans.

In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.

property is_fully_addressable: bool#

Is this sharding fully addressable?

A sharding is fully addressable if the current process can address all of the devices named in the Sharding. is_fully_addressable is equivalent to “is_local” in multi-process JAX.

property is_fully_replicated: bool#

Is this sharding fully replicated?

A sharding is fully replicated if each device has a complete copy of the entire data.

property memory_kind: str | None[source]#

Returns the memory kind of the sharding.

with_memory_kind(kind)[source]#

Returns a new Sharding instance with the specified memory kind.

Parameters:

kind (str)

Return type:

PositionalSharding

class jax.sharding.PmapSharding#

Bases: XLACompatibleSharding

Describes a sharding used by jax.pmap().

classmethod default(shape, sharded_dim=0, devices=None)[source]#

Creates a PmapSharding which matches the default placement used by jax.pmap().

Parameters:
  • shape (tuple[int, ...]) – The shape of the input array.

  • sharded_dim (int) – Dimension the input array is sharded on. Defaults to 0.

  • devices (Sequence[Device] | None) – Optional sequence of devices to use. If omitted, the implicit

  • used (device order used by pmap is) – jax.local_devices().

  • of (which is the order) – jax.local_devices().

Return type:

PmapSharding

property device_set: set[Device]#

The set of devices that this Sharding spans.

In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.

property devices#

(self) -> ndarray

devices_indices_map(global_shape)[source]#

Returns a mapping from devices to the array slices each contains.

The mapping includes all global devices, i.e., including non-addressable devices from other processes.

Parameters:

global_shape (tuple[int, ...])

Return type:

Mapping[Device, tuple[slice, …]]

is_equivalent_to(other, ndim)[source]#

Returns True if two shardings are equivalent.

Two shardings are equivalent if they place the same logical array shards on the same devices.

For example, a NamedSharding may be equivalent to a PositionalSharding if both place the same shards of the array on the same devices.

Parameters:
Return type:

bool

property is_fully_addressable: bool#

Is this sharding fully addressable?

A sharding is fully addressable if the current process can address all of the devices named in the Sharding. is_fully_addressable is equivalent to “is_local” in multi-process JAX.

property is_fully_replicated: bool#

Is this sharding fully replicated?

A sharding is fully replicated if each device has a complete copy of the entire data.

property memory_kind: str | None[source]#

Returns the memory kind of the sharding.

shard_shape(global_shape)[source]#

Returns the shape of the data on each device.

The shard shape returned by this function is calculated from global_shape and the properties of the sharding.

Parameters:

global_shape (tuple[int, ...])

Return type:

tuple[int, …]

property sharding_spec#

(self) -> jax::ShardingSpec

with_memory_kind(kind)[source]#

Returns a new Sharding instance with the specified memory kind.

Parameters:

kind (str)

class jax.sharding.GSPMDSharding#

Bases: XLACompatibleSharding

property device_set: set[Device]#

The set of devices that this Sharding spans.

In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.

devices_indices_map(global_shape)[source]#

Returns a mapping from devices to the array slices each contains.

The mapping includes all global devices, i.e., including non-addressable devices from other processes.

Parameters:

global_shape (tuple[int, ...])

Return type:

Mapping[Device, tuple[slice, …]]

property is_fully_addressable: bool#

Is this sharding fully addressable?

A sharding is fully addressable if the current process can address all of the devices named in the Sharding. is_fully_addressable is equivalent to “is_local” in multi-process JAX.

property is_fully_replicated: bool#

Is this sharding fully replicated?

A sharding is fully replicated if each device has a complete copy of the entire data.

property memory_kind: str | None[source]#

Returns the memory kind of the sharding.

with_memory_kind(kind)[source]#

Returns a new Sharding instance with the specified memory kind.

Parameters:

kind (str)

Return type:

GSPMDSharding

class jax.sharding.PartitionSpec(*partitions)[source]#

Tuple describing how to partition an array across a mesh of devices.

Each element is either None, a string, or a tuple of strings. See the documentation of jax.sharding.NamedSharding for more details.

This class exists so JAX’s pytree utilities can distinguish a partition specifications from tuples that should be treated as pytrees.

class jax.sharding.Mesh(devices, axis_names)[source]#

Declare the hardware resources available in the scope of this manager.

In particular, all axis_names become valid resource names inside the managed block and can be used e.g. in the in_axis_resources argument of jax.experimental.pjit.pjit(). Also see JAX’s multi-process programming model (https://jax.readthedocs.io/en/latest/multi_process.html) and the Distributed arrays and automatic parallelization tutorial (https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)

If you are compiling in multiple threads, make sure that the with Mesh context manager is inside the function that the threads will execute.

Parameters:
  • devices (ndarray) – A NumPy ndarray object containing JAX device objects (as obtained e.g. from jax.devices()).

  • axis_names (tuple[Any, ...]) – A sequence of resource axis names to be assigned to the dimensions of the devices argument. Its length should match the rank of devices.

Example

>>> from jax.experimental.pjit import pjit
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> inp = np.arange(16).reshape((8, 2))
>>> devices = np.array(jax.devices()).reshape(4, 2)
...
>>> # Declare a 2D mesh with axes `x` and `y`.
>>> global_mesh = Mesh(devices, ('x', 'y'))
>>> # Use the mesh object directly as a context manager.
>>> with global_mesh:
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Initialize the Mesh and use the mesh as the context manager.
>>> with Mesh(devices, ('x', 'y')) as global_mesh:
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Also you can use it as `with ... as ...`.
>>> global_mesh = Mesh(devices, ('x', 'y'))
>>> with global_mesh as m:
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # You can also use it as `with Mesh(...)`.
>>> with Mesh(devices, ('x', 'y')):
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)

jax.debug module#

Runtime value debugging utilities#

jax.debug.print and jax.debug.breakpoint describes how to make use of JAX’s runtime value debugging features.

callback(callback, *args[, ordered])

Calls a stageable Python callback.

print(fmt, *args[, ordered])

Prints values and works in staged out JAX functions.

breakpoint(*[, backend, filter_frames, ...])

Enters a breakpoint at a point in a program.

Sharding debugging utilities#

Functions that enable inspecting and visualizing array shardings inside (and outside) staged functions.

inspect_array_sharding(value, *, callback)

Enables inspecting array sharding inside JIT-ted functions.

visualize_array_sharding(arr, **kwargs)

Visualizes an array's sharding.

visualize_sharding(shape, sharding, *[, ...])

Visualizes a Sharding using rich.

jax.dlpack module#

from_dlpack(external_array[, device, copy])

Returns a Array representation of a DLPack tensor.

to_dlpack(x[, stream, src_device, ...])

Returns a DLPack tensor that encapsulates a Array x.

jax.distributed module#

initialize([coordinator_address, ...])

Initializes the JAX distributed system.

shutdown()

Shuts down the distributed system.

jax.dtypes module#

bfloat16

bfloat16 floating-point values

canonicalize_dtype(dtype[, allow_extended_dtype])

Convert from a dtype to a canonical dtype based on config.x64_enabled.

float0

DType class corresponding to the scalar type and dtype of the same name.

issubdtype(a, b)

Returns True if first argument is a typecode lower/equal in type hierarchy.

prng_key()

Scalar class for PRNG Key dtypes.

result_type(*args[, return_weak_type_flag])

Convenience function to apply JAX argument dtype promotion.

scalar_type_of(x)

Return the scalar type associated with a JAX value.

jax.flatten_util module#

List of Functions#

ravel_pytree(pytree)

Ravel (flatten) a pytree of arrays down to a 1D array.

jax.image module#

Image manipulation functions.

More image manipulation functions can be found in libraries built on top of JAX, such as PIX.

Image manipulation functions#

resize(image, shape, method[, antialias, ...])

Image resize.

scale_and_translate(image, shape, ...[, ...])

Apply a scale and translation to an image.

Argument classes#
class jax.image.ResizeMethod(value)[source]#

Image resize method.

Possible values are:

NEAREST:

Nearest-neighbor interpolation.

LINEAR:

Linear interpolation.

LANCZOS3:

Lanczos resampling, using a kernel of radius 3.

LANCZOS5:

Lanczos resampling, using a kernel of radius 5.

CUBIC:

Cubic interpolation, using the Keys cubic kernel.

jax.nn module#

jax.nn.initializers module#

Common neural network layer initializers, consistent with definitions used in Keras and Sonnet.

Initializers#

This module provides common neural network layer initializers, consistent with definitions used in Keras and Sonnet.

An initializer is a function that takes three arguments: (key, shape, dtype) and returns an array with dimensions shape and data type dtype. Argument key is a PRNG key (e.g. from jax.random.key()), used to generate random numbers to initialize the array.

constant(value[, dtype])

Builds an initializer that returns arrays full of a constant value.

delta_orthogonal([scale, column_axis, dtype])

Builds an initializer for delta orthogonal kernels.

glorot_normal([in_axis, out_axis, ...])

Builds a Glorot normal initializer (aka Xavier normal initializer).

glorot_uniform([in_axis, out_axis, ...])

Builds a Glorot uniform initializer (aka Xavier uniform initializer).

he_normal([in_axis, out_axis, batch_axis, dtype])

Builds a He normal initializer (aka Kaiming normal initializer).

he_uniform([in_axis, out_axis, batch_axis, ...])

Builds a He uniform initializer (aka Kaiming uniform initializer).

lecun_normal([in_axis, out_axis, ...])

Builds a Lecun normal initializer.

lecun_uniform([in_axis, out_axis, ...])

Builds a Lecun uniform initializer.

normal([stddev, dtype])

Builds an initializer that returns real normally-distributed random arrays.

ones(key, shape[, dtype])

An initializer that returns a constant array full of ones.

orthogonal([scale, column_axis, dtype])

Builds an initializer that returns uniformly distributed orthogonal matrices.

truncated_normal([stddev, dtype, lower, upper])

Builds an initializer that returns truncated-normal random arrays.

uniform([scale, dtype])

Builds an initializer that returns real uniformly-distributed random arrays.

variance_scaling(scale, mode, distribution)

Initializer that adapts its scale to the shape of the weights tensor.

zeros(key, shape[, dtype])

An initializer that returns a constant array full of zeros.

Common functions for neural network libraries.

Activation functions#

relu

Rectified linear unit activation function.

relu6

Rectified Linear Unit 6 activation function.

sigmoid(x)

Sigmoid activation function.

softplus(x)

Softplus activation function.

sparse_plus(x)

Sparse plus function.

soft_sign(x)

Soft-sign activation function.

silu(x)

SiLU (aka swish) activation function.

swish(x)

SiLU (aka swish) activation function.

log_sigmoid(x)

Log-sigmoid activation function.

leaky_relu(x[, negative_slope])

Leaky rectified linear unit activation function.

hard_sigmoid(x)

Hard Sigmoid activation function.

hard_silu(x)

Hard SiLU (swish) activation function

hard_swish(x)

Hard SiLU (swish) activation function

hard_tanh(x)

Hard \(\mathrm{tanh}\) activation function.

elu(x[, alpha])

Exponential linear unit activation function.

celu(x[, alpha])

Continuously-differentiable exponential linear unit activation.

selu(x)

Scaled exponential linear unit activation.

gelu(x[, approximate])

Gaussian error linear unit activation function.

glu(x[, axis])

Gated linear unit activation function.

squareplus(x[, b])

Squareplus activation function.

mish(x)

Mish activation function.

Other functions#

softmax(x[, axis, where, initial])

Softmax function.

log_softmax(x[, axis, where, initial])

Log-Softmax function.

logsumexp()

Log-sum-exp reduction.

standardize(x[, axis, mean, variance, ...])

Normalizes an array by subtracting mean and dividing by \(\sqrt{\mathrm{variance}}\).

one_hot(x, num_classes, *[, dtype, axis])

One-hot encodes the given indices.

jax.ops module#

The functions jax.ops.index_update, jax.ops.index_add, etc., which were deprecated in JAX 0.2.22, have been removed. Please use the jax.numpy.ndarray.at property on JAX arrays instead.

Segment reduction operators#

segment_max(data, segment_ids[, ...])

Computes the maximum within segments of an array.

segment_min(data, segment_ids[, ...])

Computes the minimum within segments of an array.

segment_prod(data, segment_ids[, ...])

Computes the product within segments of an array.

segment_sum(data, segment_ids[, ...])

Computes the sum within segments of an array.

jax.profiler module#

Tracing and time profiling#

Profiling JAX programs describes how to make use of JAX’s tracing and time profiling features.

start_server(port)

Starts the profiler server on port port.

start_trace(log_dir[, create_perfetto_link, ...])

Starts a profiler trace.

stop_trace()

Stops the currently-running profiler trace.

trace(log_dir[, create_perfetto_link, ...])

Context manager to take a profiler trace.

annotate_function(func[, name])

Decorator that generates a trace event for the execution of a function.

TraceAnnotation

Context manager that generates a trace event in the profiler.

StepTraceAnnotation(name, **kwargs)

Context manager that generates a step trace event in the profiler.

Device memory profiling#

See Device Memory Profiling for an introduction to JAX’s device memory profiling features.

device_memory_profile([backend])

Captures a JAX device memory profile as pprof-format protocol buffer.

save_device_memory_profile(filename[, backend])

Collects a device memory profile and writes it to a file.

jax.stages module#

Interfaces to stages of the compiled execution process.

JAX transformations that compile just in time for execution, such as jax.jit and jax.pmap, also support a common means of explicit lowering and compilation ahead of time. This module defines types that represent the stages of this process.

For more, see the AOT walkthrough.

Classes#
class jax.stages.Wrapped(*args, **kwargs)[source]#

A function ready to be specialized, lowered, and compiled.

This protocol reflects the output of functions such as jax.jit. Calling it results in JIT (just-in-time) lowering, compilation, and execution. It can also be explicitly lowered prior to compilation, and the result compiled prior to execution.

__call__(*args, **kwargs)[source]#

Executes the wrapped function, lowering and compiling as needed.

lower(*args, **kwargs)[source]#

Lower this function explicitly for the given arguments.

A lowered function is staged out of Python and translated to a compiler’s input language, possibly in a backend-dependent manner. It is ready for compilation but not yet compiled.

Returns:

A Lowered instance representing the lowering.

Return type:

Lowered

class jax.stages.Lowered(lowering, args_info, out_tree, no_kwargs=False)[source]#

Lowering of a function specialized to argument types and values.

A lowering is a computation ready for compilation. This class carries a lowering together with the remaining information needed to later compile and execute it. It also provides a common API for querying properties of lowered computations across JAX’s various lowering paths (jit(), pmap(), etc.).

Parameters:
  • lowering (XlaLowering)

  • args_info (Any)

  • out_tree (PyTreeDef)

  • no_kwargs (bool)

as_text(dialect=None)[source]#

A human-readable text representation of this lowering.

Intended for visualization and debugging purposes. This need not be a valid nor reliable serialization. It is relayed directly to external callers.

Parameters:

dialect (str | None) – Optional string specifying a lowering dialect (e.g. “stablehlo”)

Return type:

str

compile(compiler_options=None)[source]#

Compile, returning a corresponding Compiled instance.

Parameters:

compiler_options (dict[str, str | bool] | None)

Return type:

Compiled

compiler_ir(dialect=None)[source]#

An arbitrary object representation of this lowering.

Intended for debugging purposes. This is not a valid nor reliable serialization. The output has no guarantee of consistency across invocations.

Returns None if unavailable, e.g. based on backend, compiler, or runtime.

Parameters:

dialect (str | None) – Optional string specifying a lowering dialect (e.g. “stablehlo”)

Return type:

Any | None

cost_analysis()[source]#

A summary of execution cost estimates.

Intended for visualization and debugging purposes. The object output by this is some simple data structure that can easily be printed or serialized (e.g. nested dicts, lists, and tuples with numeric leaves). However, its structure can be arbitrary: it may be inconsistent across versions of JAX and jaxlib, or even across invocations.

Returns None if unavailable, e.g. based on backend, compiler, or runtime.

Return type:

Any | None

property in_tree: PyTreeDef[source]#

Tree structure of the pair (positional arguments, keyword arguments).

class jax.stages.Compiled(executable, args_info, out_tree, no_kwargs=False)[source]#

Compiled representation of a function specialized to types/values.

A compiled computation is associated with an executable and the remaining information needed to execute it. It also provides a common API for querying properties of compiled computations across JAX’s various compilation paths and backends.

Parameters:
  • args_info (Any)

  • out_tree (PyTreeDef)

__call__(*args, **kwargs)[source]#

Call self as a function.

as_text()[source]#

A human-readable text representation of this executable.

Intended for visualization and debugging purposes. This is not a valid nor reliable serialization.

Returns None if unavailable, e.g. based on backend, compiler, or runtime.

Return type:

str | None

cost_analysis()[source]#

A summary of execution cost estimates.

Intended for visualization and debugging purposes. The object output by this is some simple data structure that can easily be printed or serialized (e.g. nested dicts, lists, and tuples with numeric leaves). However, its structure can be arbitrary: it may be inconsistent across versions of JAX and jaxlib, or even across invocations.

Returns None if unavailable, e.g. based on backend, compiler, or runtime.

Return type:

Any | None

property in_tree: PyTreeDef[source]#

Tree structure of the pair (positional arguments, keyword arguments).

memory_analysis()[source]#

A summary of estimated memory requirements.

Intended for visualization and debugging purposes. The object output by this is some simple data structure that can easily be printed or serialized (e.g. nested dicts, lists, and tuples with numeric leaves). However, its structure can be arbitrary: it may be inconsistent across versions of JAX and jaxlib, or even across invocations.

Returns None if unavailable, e.g. based on backend, compiler, or runtime.

Return type:

Any | None

runtime_executable()[source]#

An arbitrary object representation of this executable.

Intended for debugging purposes. This is not valid nor reliable serialization. The output has no guarantee of consistency across invocations.

Returns None if unavailable, e.g. based on backend, compiler, or runtime.

Return type:

Any | None

jax.tree module#

Utilities for working with tree-like container data structures.

The jax.tree namespace contains aliases of utilities from jax.tree_util.

List of Functions#

all(tree)

Call all() over the leaves of a tree.

flatten(tree[, is_leaf])

Flattens a pytree.

leaves(tree[, is_leaf])

Gets the leaves of a pytree.

map(f, tree, *rest[, is_leaf])

Maps a multi-input function over pytree args to produce a new pytree.

reduce(function, tree[, initializer, is_leaf])

Call reduce() over the leaves of a tree.

structure(tree[, is_leaf])

Gets the treedef for a pytree.

transpose(outer_treedef, inner_treedef, ...)

Transform a tree having tree structure (outer, inner) into one having structure (inner, outer).

unflatten(treedef, leaves)

Reconstructs a pytree from the treedef and the leaves.

jax.tree_util module#

Utilities for working with tree-like container data structures.

This module provides a small set of utility functions for working with tree-like data structures, such as nested tuples, lists, and dicts. We call these structures pytrees. They are trees in that they are defined recursively (any non-pytree is a pytree, i.e. a leaf, and any pytree of pytrees is a pytree) and can be operated on recursively (object identity equivalence is not preserved by mapping operations, and the structures cannot contain reference cycles).

The set of Python types that are considered pytree nodes (e.g. that can be mapped over, rather than treated as leaves) is extensible. There is a single module-level registry of types, and class hierarchy is ignored. By registering a new pytree node type, that type in effect becomes transparent to the utility functions in this file.

The primary purpose of this module is to enable the interoperability between user defined data structures and JAX transformations (e.g. jit). This is not meant to be a general purpose tree-like data structure handling library.

See the JAX pytrees note for examples.

List of Functions#

Partial(func, *args, **kw)

A version of functools.partial that works in pytrees.

all_leaves(iterable[, is_leaf])

Tests whether all elements in the given iterable are all leaves.

build_tree(treedef, xs)

register_pytree_node(nodetype, flatten_func, ...)

Extends the set of types that are considered internal nodes in pytrees.

register_pytree_node_class(cls)

Extends the set of types that are considered internal nodes in pytrees.

register_pytree_with_keys(nodetype, ...[, ...])

Extends the set of types that are considered internal nodes in pytrees.

register_pytree_with_keys_class(cls)

Extends the set of types that are considered internal nodes in pytrees.

tree_all(tree)

Call all() over the leaves of a tree.

tree_flatten(tree[, is_leaf])

Flattens a pytree.

tree_flatten_with_path(tree[, is_leaf])

Flattens a pytree like tree_flatten, but also returns each leaf's key path.

tree_leaves(tree[, is_leaf])

Gets the leaves of a pytree.

tree_leaves_with_path(tree[, is_leaf])

Gets the leaves of a pytree like tree_leaves and returns each leaf's key path.

tree_map(f, tree, *rest[, is_leaf])

Maps a multi-input function over pytree args to produce a new pytree.

tree_map_with_path(f, tree, *rest[, is_leaf])

Maps a multi-input function over pytree key path and args to produce a new pytree.

tree_reduce()

Call reduce() over the leaves of a tree.

tree_structure(tree[, is_leaf])

Gets the treedef for a pytree.

tree_transpose(outer_treedef, inner_treedef, ...)

Transform a tree having tree structure (outer, inner) into one having structure (inner, outer).

tree_unflatten(treedef, leaves)

Reconstructs a pytree from the treedef and the leaves.

treedef_children(treedef)

treedef_is_leaf(treedef)

treedef_tuple(treedefs)

Makes a tuple treedef from an iterable of child treedefs.

keystr(keys)

Helper to pretty-print a tuple of keys.

jax.typing module#

The JAX typing module is where JAX-specific static type annotations live. This submodule is a work in progress; to see the proposal behind the types exported here, see https://jax.readthedocs.io/en/latest/jep/12049-type-annotations.html.

The currently-available types are:

  • jax.Array: annotation for any JAX array or tracer (i.e. representations of arrays within JAX transforms).

  • jax.typing.ArrayLike: annotation for any value that is safe to implicitly cast to a JAX array; this includes jax.Array, numpy.ndarray, as well as Python builtin numeric values (e.g. int, float, etc.) and numpy scalar values (e.g. numpy.int32, numpy.flota64, etc.)

  • jax.typing.DTypeLike: annotation for any value that can be cast to a JAX-compatible dtype; this includes strings (e.g. ‘float32’, ‘int32’), scalar types (e.g. float, np.float32), dtypes (e.g. np.dtype(‘float32’)), or objects with a dtype attribute (e.g. jnp.float32, jnp.int32).

We may add additional types here in future releases.

JAX Typing Best Practices#

When annotating JAX arrays in public API functions, we recommend using ArrayLike for array inputs, and Array for array outputs.

For example, your function might look like this:

import numpy as np
import jax.numpy as jnp
from jax import Array
from jax.typing import ArrayLike

def my_function(x: ArrayLike) -> Array:
  # Runtime type validation, Python 3.10 or newer:
  if not isinstance(x, ArrayLike):
    raise TypeError(f"Expected arraylike input; got {x}")
  # Runtime type validation, any Python version:
  if not (isinstance(x, (np.ndarray, Array)) or np.isscalar(x)):
    raise TypeError(f"Expected arraylike input; got {x}")

  # Convert input to jax.Array:
  x_arr = jnp.asarray(x)

  # ... do some computation; JAX functions will return Array types:
  result = x_arr.sum(0) / x_arr.shape[0]

  # return an Array
  return result

Most of JAX’s public APIs follow this pattern. Note in particular that we recommend JAX functions to not accept sequences such as list or tuple in place of arrays, as this can cause extra overhead in JAX transforms like jit() and can behave in unexpected ways with batch-wise transforms like vmap() or jax.pmap(). For more information on this, see Non-array inputs NumPy vs JAX

List of Members#

ArrayLike

Type annotation for JAX array-like objects.

DTypeLike

alias of str | type[Any] | dtype | SupportsDType

jax.extend module#

Modules for JAX extensions.

The jax.extend package provides modules for access to JAX internal machinery. See JEP #15856.

API policy#

Unlike the public API, this package offers no compatibility guarantee across releases. Breaking changes will be announced via the JAX project changelog.

Modules#
jax.extend.linear_util module#

StoreException

WrappedFun(f, transforms, stores, params, ...)

Represents a function f to which transforms are to be applied.

cache(call, *[, explain])

Memoization decorator for functions taking a WrappedFun as first argument.

merge_linear_aux(aux1, aux2)

transformation

Adds one more transformation to a WrappedFun.

transformation_with_aux

Adds one more transformation with auxiliary output to a WrappedFun.

wrap_init(f[, params])

Wraps function f as a WrappedFun, suitable for transformation.

jax.extend.mlir module#

dialects

ir

passmanager

jax.extend.random module#

define_prng_impl(*, key_shape, seed, split, ...)

seed_with_impl(impl, seed)

threefry2x32_p

threefry_2x32(keypair, count)

Apply the Threefry 2x32 hash.

threefry_prng_impl

Specifies PRNG key shape and operations.

rbg_prng_impl

Specifies PRNG key shape and operations.

unsafe_rbg_prng_impl

Specifies PRNG key shape and operations.

jax.example_libraries module#

JAX provides some small, experimental libraries for machine learning. These libraries are in part about providing tools and in part about serving as examples for how to build such libraries using JAX. Each one is only <300 source lines of code, so take a look inside and adapt them as you need!

Note

Each mini-library is meant to be an inspiration, but not a prescription.

To serve that purpose, it is best to keep their code samples minimal; so we generally will not merge PRs adding new features. Instead, please send your lovely pull requests and design ideas to more fully-featured libraries like Haiku or Flax.

jax.example_libraries.optimizers module#

Examples of how to write optimizers with JAX.

You likely do not mean to import this module! The optimizers in this library are intended as examples only. If you are looking for a fully featured optimizer library, two good options are JAXopt and Optax.

This module contains some convenient optimizer definitions, specifically initialization and update functions, which can be used with ndarrays or arbitrarily-nested tuple/list/dicts of ndarrays.

An optimizer is modeled as an (init_fun, update_fun, get_params) triple of functions, where the component functions have these signatures:

init_fun(params)

Args:
  params: pytree representing the initial parameters.

Returns:
  A pytree representing the initial optimizer state, which includes the
  initial parameters and may also include auxiliary values like initial
  momentum. The optimizer state pytree structure generally differs from that
  of `params`.
update_fun(step, grads, opt_state)

Args:
  step: integer representing the step index.
  grads: a pytree with the same structure as `get_params(opt_state)`
    representing the gradients to be used in updating the optimizer state.
  opt_state: a pytree representing the optimizer state to be updated.

Returns:
  A pytree with the same structure as the `opt_state` argument representing
  the updated optimizer state.
get_params(opt_state)

Args:
  opt_state: pytree representing an optimizer state.

Returns:
  A pytree representing the parameters extracted from `opt_state`, such that
  the invariant `params == get_params(init_fun(params))` holds true.

Notice that an optimizer implementation has a lot of flexibility in the form of opt_state: it just has to be a pytree of JaxTypes (so that it can be passed to the JAX transforms defined in api.py) and it has to be consumable by update_fun and get_params.

Example Usage:

opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
opt_state = opt_init(params)

def step(step, opt_state):
  value, grads = jax.value_and_grad(loss_fn)(get_params(opt_state))
  opt_state = opt_update(step, grads, opt_state)
  return value, opt_state

for i in range(num_steps):
  value, opt_state = step(i, opt_state)
class jax.example_libraries.optimizers.JoinPoint(subtree)[source]#

Bases: object

Marks the boundary between two joined (nested) pytrees.

class jax.example_libraries.optimizers.Optimizer(init_fn, update_fn, params_fn)[source]#

Bases: NamedTuple

Parameters:
init_fn: Callable[[Any], OptimizerState]#

Alias for field number 0

params_fn: Callable[[OptimizerState], Any]#

Alias for field number 2

update_fn: Callable[[int, Any, OptimizerState], OptimizerState]#

Alias for field number 1

class jax.example_libraries.optimizers.OptimizerState(packed_state, tree_def, subtree_defs)#

Bases: tuple

packed_state#

Alias for field number 0

subtree_defs#

Alias for field number 2

tree_def#

Alias for field number 1

jax.example_libraries.optimizers.adagrad(step_size, momentum=0.9)[source]#

Construct optimizer triple for Adagrad.

Adaptive Subgradient Methods for Online Learning and Stochastic Optimization: http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf

Parameters:
  • step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to a positive scalar.

  • momentum – optional, a positive scalar value for momentum

Returns:

An (init_fun, update_fun, get_params) triple.

jax.example_libraries.optimizers.adam(step_size, b1=0.9, b2=0.999, eps=1e-08)[source]#

Construct optimizer triple for Adam.

Parameters:
  • step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to a positive scalar.

  • b1 – optional, a positive scalar value for beta_1, the exponential decay rate for the first moment estimates (default 0.9).

  • b2 – optional, a positive scalar value for beta_2, the exponential decay rate for the second moment estimates (default 0.999).

  • eps – optional, a positive scalar value for epsilon, a small constant for numerical stability (default 1e-8).

Returns:

An (init_fun, update_fun, get_params) triple.

jax.example_libraries.optimizers.adamax(step_size, b1=0.9, b2=0.999, eps=1e-08)[source]#

Construct optimizer triple for AdaMax (a variant of Adam based on infinity norm).

Parameters:
  • step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to a positive scalar.

  • b1 – optional, a positive scalar value for beta_1, the exponential decay rate for the first moment estimates (default 0.9).

  • b2 – optional, a positive scalar value for beta_2, the exponential decay rate for the second moment estimates (default 0.999).

  • eps – optional, a positive scalar value for epsilon, a small constant for numerical stability (default 1e-8).

Returns:

An (init_fun, update_fun, get_params) triple.

jax.example_libraries.optimizers.clip_grads(grad_tree, max_norm)[source]#

Clip gradients stored as a pytree of arrays to maximum norm max_norm.

jax.example_libraries.optimizers.constant(step_size)[source]#
Return type:

Callable[[int], float]

jax.example_libraries.optimizers.exponential_decay(step_size, decay_steps, decay_rate)[source]#
jax.example_libraries.optimizers.inverse_time_decay(step_size, decay_steps, decay_rate, staircase=False)[source]#
jax.example_libraries.optimizers.l2_norm(tree)[source]#

Compute the l2 norm of a pytree of arrays. Useful for weight decay.

jax.example_libraries.optimizers.make_schedule(scalar_or_schedule)[source]#
Parameters:

scalar_or_schedule (float | Callable[[int], float])

Return type:

Callable[[int], float]

jax.example_libraries.optimizers.momentum(step_size, mass)[source]#

Construct optimizer triple for SGD with momentum.

Parameters:
  • step_size (Callable[[int], float]) – positive scalar, or a callable representing a step size schedule that maps the iteration index to a positive scalar.

  • mass (float) – positive scalar representing the momentum coefficient.

Returns:

An (init_fun, update_fun, get_params) triple.

jax.example_libraries.optimizers.nesterov(step_size, mass)[source]#

Construct optimizer triple for SGD with Nesterov momentum.

Parameters:
  • step_size (Callable[[int], float]) – positive scalar, or a callable representing a step size schedule that maps the iteration index to a positive scalar.

  • mass (float) – positive scalar representing the momentum coefficient.

Returns:

An (init_fun, update_fun, get_params) triple.

jax.example_libraries.optimizers.optimizer(opt_maker)[source]#

Decorator to make an optimizer defined for arrays generalize to containers.

With this decorator, you can write init, update, and get_params functions that each operate only on single arrays, and convert them to corresponding functions that operate on pytrees of parameters. See the optimizers defined in optimizers.py for examples.

Parameters:

opt_maker (Callable[[...], tuple[Callable[[Any], Any], Callable[[int, Any, Any], Any], Callable[[Any], Any]]]) –

a function that returns an (init_fun, update_fun, get_params) triple of functions that might only work with ndarrays, as per

init_fun :: ndarray -> OptStatePytree ndarray
update_fun :: OptStatePytree ndarray -> OptStatePytree ndarray
get_params :: OptStatePytree ndarray -> ndarray

Returns:

An (init_fun, update_fun, get_params) triple of functions that work on arbitrary pytrees, as per

init_fun :: ParameterPytree ndarray -> OptimizerState
update_fun :: OptimizerState -> OptimizerState
get_params :: OptimizerState -> ParameterPytree ndarray

The OptimizerState pytree type used by the returned functions is isomorphic to ParameterPytree (OptStatePytree ndarray), but may store the state instead as e.g. a partially-flattened data structure for performance.

Return type:

Callable[[…], Optimizer]

jax.example_libraries.optimizers.pack_optimizer_state(marked_pytree)[source]#

Converts a marked pytree to an OptimizerState.

The inverse of unpack_optimizer_state. Converts a marked pytree with the leaves of the outer pytree represented as JoinPoints back into an OptimizerState. This function is intended to be useful when deserializing optimizer states.

Parameters:

marked_pytree – A pytree containing JoinPoint leaves that hold more pytrees.

Returns:

An equivalent OptimizerState to the input argument.

jax.example_libraries.optimizers.piecewise_constant(boundaries, values)[source]#
Parameters:
  • boundaries (Any)

  • values (Any)

jax.example_libraries.optimizers.polynomial_decay(step_size, decay_steps, final_step_size, power=1.0)[source]#
jax.example_libraries.optimizers.rmsprop(step_size, gamma=0.9, eps=1e-08)[source]#

Construct optimizer triple for RMSProp.

Parameters:

step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to a positive scalar. gamma: Decay parameter. eps: Epsilon parameter.

Returns:

An (init_fun, update_fun, get_params) triple.

jax.example_libraries.optimizers.rmsprop_momentum(step_size, gamma=0.9, eps=1e-08, momentum=0.9)[source]#

Construct optimizer triple for RMSProp with momentum.

This optimizer is separate from the rmsprop optimizer because it needs to keep track of additional parameters.

Parameters:
  • step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to a positive scalar.

  • gamma – Decay parameter.

  • eps – Epsilon parameter.

  • momentum – Momentum parameter.

Returns:

An (init_fun, update_fun, get_params) triple.

jax.example_libraries.optimizers.sgd(step_size)[source]#

Construct optimizer triple for stochastic gradient descent.

Parameters:

step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to a positive scalar.

Returns:

An (init_fun, update_fun, get_params) triple.

jax.example_libraries.optimizers.sm3(step_size, momentum=0.9)[source]#

Construct optimizer triple for SM3.

Memory-Efficient Adaptive Optimization for Large-Scale Learning. https://arxiv.org/abs/1901.11150

Parameters:
  • step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to a positive scalar.

  • momentum – optional, a positive scalar value for momentum

Returns:

An (init_fun, update_fun, get_params) triple.

jax.example_libraries.optimizers.unpack_optimizer_state(opt_state)[source]#

Converts an OptimizerState to a marked pytree.

Converts an OptimizerState to a marked pytree with the leaves of the outer pytree represented as JoinPoints to avoid losing information. This function is intended to be useful when serializing optimizer states.

Parameters:

opt_state – An OptimizerState

Returns:

A pytree with JoinPoint leaves that contain a second level of pytrees.

jax.example_libraries.stax module#

Stax is a small but flexible neural net specification library from scratch.

You likely do not mean to import this module! Stax is intended as an example library only. There are a number of other much more fully-featured neural network libraries for JAX, including Flax from Google, and Haiku from DeepMind.

jax.example_libraries.stax.AvgPool(window_shape, strides=None, padding='VALID', spec=None)[source]#

Layer construction function for a pooling layer.

jax.example_libraries.stax.BatchNorm(axis=(0, 1, 2), epsilon=1e-05, center=True, scale=True, beta_init=<function zeros>, gamma_init=<function ones>)[source]#

Layer construction function for a batch normalization layer.

jax.example_libraries.stax.Conv(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)#

Layer construction function for a general convolution layer.

jax.example_libraries.stax.Conv1DTranspose(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)#

Layer construction function for a general transposed-convolution layer.

jax.example_libraries.stax.ConvTranspose(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)#

Layer construction function for a general transposed-convolution layer.

jax.example_libraries.stax.Dense(out_dim, W_init=<function variance_scaling.<locals>.init>, b_init=<function normal.<locals>.init>)[source]#

Layer constructor function for a dense (fully-connected) layer.

jax.example_libraries.stax.Dropout(rate, mode='train')[source]#

Layer construction function for a dropout layer with given rate.

jax.example_libraries.stax.FanInConcat(axis=-1)[source]#

Layer construction function for a fan-in concatenation layer.

jax.example_libraries.stax.FanOut(num)[source]#

Layer construction function for a fan-out layer.

jax.example_libraries.stax.GeneralConv(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)[source]#

Layer construction function for a general convolution layer.

jax.example_libraries.stax.GeneralConvTranspose(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)[source]#

Layer construction function for a general transposed-convolution layer.

jax.example_libraries.stax.MaxPool(window_shape, strides=None, padding='VALID', spec=None)[source]#

Layer construction function for a pooling layer.

jax.example_libraries.stax.SumPool(window_shape, strides=None, padding='VALID', spec=None)[source]#

Layer construction function for a pooling layer.

jax.example_libraries.stax.elementwise(fun, **fun_kwargs)[source]#

Layer that applies a scalar function elementwise on its inputs.

jax.example_libraries.stax.parallel(*layers)[source]#

Combinator for composing layers in parallel.

The layer resulting from this combinator is often used with the FanOut and FanInSum layers.

Parameters:

*layers – a sequence of layers, each an (init_fun, apply_fun) pair.

Returns:

A new layer, meaning an (init_fun, apply_fun) pair, representing the parallel composition of the given sequence of layers. In particular, the returned layer takes a sequence of inputs and returns a sequence of outputs with the same length as the argument layers.

jax.example_libraries.stax.serial(*layers)[source]#

Combinator for composing layers in serial.

Parameters:

*layers – a sequence of layers, each an (init_fun, apply_fun) pair.

Returns:

A new layer, meaning an (init_fun, apply_fun) pair, representing the serial composition of the given sequence of layers.

jax.example_libraries.stax.shape_dependent(make_layer)[source]#

Combinator to delay layer constructor pair until input shapes are known.

Parameters:

make_layer – a one-argument function that takes an input shape as an argument (a tuple of positive integers) and returns an (init_fun, apply_fun) pair.

Returns:

A new layer, meaning an (init_fun, apply_fun) pair, representing the same layer as returned by make_layer but with its construction delayed until input shapes are known.

jax.experimental module#

jax.experimental.optix has been moved into its own Python package (deepmind/optax).

jax.experimental.ann has been moved into jax.lax.

Experimental Modules#
jax.experimental.array_api module#

This module includes experimental JAX support for the Python array API standard. Support for this is currently experimental and not fully complete.

Example Usage:

>>> from jax.experimental import array_api as xp

>>> xp.__array_api_version__
'2023.12'

>>> arr = xp.arange(1000)

>>> arr.sum()
Array(499500, dtype=int32)

The xp namespace is the array API compliant analog of jax.numpy, and implements most of the API listed in the standard.

jax.experimental.checkify module#
API#

checkify(f[, errors])

Functionalize check calls in fun, and optionally add run-time error checks.

check(pred, msg, *fmt_args, **fmt_kwargs)

Check a predicate, add an error with msg if predicate is False.

check_error(error)

Raise an Exception if error represents a failure.

Error(_pred, _code, _metadata, _payload)

JaxRuntimeError

user_checks

frozenset() -> empty frozenset object frozenset(iterable) -> frozenset object

nan_checks

frozenset() -> empty frozenset object frozenset(iterable) -> frozenset object

index_checks

frozenset() -> empty frozenset object frozenset(iterable) -> frozenset object

div_checks

frozenset() -> empty frozenset object frozenset(iterable) -> frozenset object

float_checks

frozenset() -> empty frozenset object frozenset(iterable) -> frozenset object

automatic_checks

frozenset() -> empty frozenset object frozenset(iterable) -> frozenset object

all_checks

frozenset() -> empty frozenset object frozenset(iterable) -> frozenset object

jax.experimental.host_callback module#

Primitives for calling Python functions on the host from JAX accelerator code.

Warning

The host_callback APIs are deprecated as of March 20, 2024. The functionality is subsumed by the new JAX external callbacks See google/jax#20385.

This module introduces the host callback functions call(), id_tap(), and id_print(), that send their arguments from the device to the host and invoke user-defined Python functions on the host, optionally returning results back to the device computation.

We show below how these functions can be used. We start with call(), and we discuss examples of calling from JAX to arbitrary Python functions on the CPU, e.g., to use NumPy CPU custom kernels. Then we show uses of id_tap() and id_print(), which have the restriction that they cannot return values from the host to the device. These primitives are generally faster because they are executed asynchronously with the device code. In particular, they can be used to tap into and to debug JAX code.

Using call() to call a host function and return results to device#

Use call() to invoke a computation on the host and return NumPy arrays to the device computation. Host computation is useful, e.g., when a device computation needs some data that requires I/O on the host, or it needs a library that is available on the host and you do not want to code it in JAX. For example, eigen decomposition for general matrices in JAX does not work on TPU. We can call the Numpy implementation from any JAX accelerator computation, using a host computation:

# This function runs on the host
def host_eig(m: np.ndarray) -> np.ndarray:
  return np.linalg.eigvals(m)

# This function is used in JAX
def device_fun(m):
  # We send "m" to the host, asking it to call "host_eig" and return the result.
  # We have to specify the result shape and dtype, either in the form of an
  # example return value or any object that has `shape` and `dtype` attributes,
  # e.g., a NumPy array or a `jax.ShapeDtypeStruct`.
  return hcb.call(host_eig, m,
                  # Given an input of shape (..., d, d), eig output has shape (..., d)
                  result_shape=jax.ShapeDtypeStruct(m.shape[:-1], m.dtype))

The call() function and the Python host function both take a single argument and return a single result, but those can be pytrees. Note that we must tell the call() what shape and dtype to expect from the host invocation, using the result_shape keyword argument. This is important because the device code is compiled with that expectation. There will be an error raised at runtime if the actual invocation produces a different result shape. In general, such errors and also exceptions raised by the host computation may be difficult to debug. See the Debugging section below. This is a problem for call() but not for id_tap() because for the latter the device code does not expect a returned value.

The call() API can be used inside a jit or pmap computation or inside cond/scan/while control flow. When used inside jax.pmap(), there will be separate calls to the host from each of the participating devices:

def host_sin(x, *, device):
  # The ``device`` argument is passed due to ``call_with_device=True`` below.
  print(f"Invoking host_sin with {x.shape} on {device}")
  return np.sin(x)

# Use pmap to run the computation on two devices
jax.pmap(lambda x: hcb.call(host_sin, x,
                            result_shape=x,
                            # Ask that the `host_sin` function be passed `device=dev`
                            call_with_device=True))(
         np.ones((2, 4), dtype=np.float32))

# prints (in arbitrary order)
# Invoking host_sin with (4,) on cpu:0
# Invoking host_sin with (4,) on cpu:1

Note that call() does not support any JAX transformations, but as we show below one can make use of the existing support for Custom differentiation in JAX.

Using id_tap() to call a Python function on the host, with no returned values#

The id_tap() and id_print() are special cases of call(), when you just want the side effects of your Python callback. These functions have the advantage that once the arguments have been sent to the host, the device computation can proceed without waiting for the Python callback to return. For id_tap() you can specify your Python callback to be called, while id_print() uses a built-in callback that prints the arguments to stdout on the host. The Python function passed to id_tap() takes two positional arguments (the value tapped from the device computation along with a transforms tuple, described below). Optionally, the function may be passed a keyword argument device with the Device from which the value was tapped.

A few examples:

def host_func(arg, transforms):
   ...do something with arg...

# calls host_func(2x, []) on host
id_tap(host_func, 2 * x)

# calls host_func((2x, 3x), [])
id_tap(host_func, (2 * x, 3 * x))  # The argument can be a pytree

# calls host_func(2x, [], device=jax.devices()[0])
id_tap(host_func, 2 * x, tap_with_device=True)  # Pass the device to the tap

# calls host_func(2x, [], what='activation')
id_tap(functools.partial(host_func, what='activation'), 2 * x)

# calls host_func(dict(x=x, y=y), what='data')
id_tap(lambda tap, transforms: host_func(tap, what='data'), dict(x=x, y=y))

The above examples can all be adapted to use id_print() instead, with the difference that id_print() prints on the host the positional argument, along with any additional kwargs and the automatic kwarg transforms.

Using barrier_wait() to wait until all callbacks have executed#

If your Python callbacks have side-effects you may need to wait until the computation has finished to ensure that the side-effects have been observed. You can use the barrier_wait() function for that purpose:

accumulator = []
def host_log(arg, transforms):
  # We just record the arguments in a list
  accumulator.append(arg)


def device_fun(x):
  id_tap(host_log, x)
  id_tap(host_log, 2. * x)

jax.jit(device_fun)(1.)
jax.jit(device_fun)(1.)

# At this point, we have started two computations, each with two
# taps, but they may not have yet executed.
barrier_wait()
# Now we know that all the computations started before `barrier_wait`
# on all devices, have finished, and all the callbacks have finished
# executing.

Note that barrier_wait() will start one tiny computation with one tap on each of the jax.local_devices() and will wait for all these taps to be received.

An alternative to using barrier_wait() is to just wait for the end of the computation, if all the callbacks are call():

accumulator = p[]
def host_log(arg):
  # We just record the arguments in a list
  accumulator.append(arg)
  return 0.  #  return something


def device_fun(c):
  y = call(host_log, x, result_shape=jax.ShapeDtypeStruct((), np.float32))
  z = call(host_log, 2. * x, result_shape=jax.ShapeDtypeStruct((), np.float32))
  return y + z  # return something that uses both results

res1 = jax.jit(device_fun)(1.)
res2 = jax.jit(device_fun)(1.)
res1.block_until_ready()
res2.block_until_ready()
Behavior under parallelization transformations#

In presence of jax.pmap() the code will run on multiple devices and each device will tap its values independently. It may be helpful to use the tap_with_device option for id_print() or id_tap(), so that you see which device is sending which data:

jax.pmap(power3, devices=jax.local_devices()[:2])(np.array([3., 4.])
# device=cpu:0 what=x,x^2: (3., 9.)  # from the first device
# device=cpu:1 what=x,x^2: (4., 16.)  # from the second device

When using jax.pmap() with multiple devices on multiple hosts, every host will receive callbacks from all of its local devices, with an operand that corresponds to each device slice. For a call(), the callback must return to each device only the slice of the result that pertains to the corresponding device.

When using the experimental pjit.pjit() the code will run on multiple devices on different shards of the input. The current implementation of host callbacks will ensure that a single device will collect and outfeed the entire operand, in a single callback. The callback function is supposed to return the entire array, which will then be sent in a single infeed to the same device that issued the outfeed. This device is then responsible for sending the required shards to the other devices:

with jax.sharding.Mesh(jax.local_devices()[:2], ["d"]):
  pjit.pjit(power3, in_shardings=(P("d"),),
            out_shardings=(P("d"),))(np.array([3., 4.]))

# device=TPU:0 what=x,x^2: ( [3., 4.],
#                            [9., 16.] )

Note that the collection of the operand on one device may result in OOM if the operand was sharded across devices.

When using pjit.pjit() with multiple devices on multiple hosts, only the host for the device 0 (w.r.t. the mesh) will receive the callback, with the operand collected from all participating devices on all hosts. For a call(), the callback must return the entire array for all devices on all hosts.

Behavior under JAX autodiff transformations#

When used under a JAX autodiff transformation, the host callback functions operate on the primal values only. Consider the following example:

def power3(x):
  y = x * x
  # Print both 'x' and 'x^2'. Must pack as a tuple.
  hcb.id_print((x, y), what="x,x^2")
  return y * x

power3(3.)
# what: x,x^2 : (3., 9.)

(You can see these examples tested in host_callback_test.HostCallbackTapTest.test_tap_transforms.)

When used under jax.jvp() there will be one callback with the primal values only:

jax.jvp(power3, (3.,), (0.1,))
# what: x,x^2 : (3., 9.)

Similarly for jax.grad(), we get a callback from the forward computation only:

jax.grad(power3)(3.)
# what: x,x^2 : (3., 9.)

If you want to invoke the callback on the tangents during a jax.jvp(), you can use a custom_jvp. For example, you can define a function that does nothing interesting except that its custom_jvp will print the tangents:

@jax.custom_jvp
def print_tangents(arg):
  return None

@print_tangents.defjvp
def print_tangents_jvp(primals, tangents):
  arg_dot, = tangents
  hcb.id_print(arg_dot, what="tangents")
  return primals, tangents

Then you use this function in the places where you want to tap the tangents:

def power3_with_tangents(x):
  y = x * x
  # Print both 'x' and 'x^2'. Must pack as a tuple.
  hcb.id_print((x, y), what="x,x^2")
  print_tangents((x, y))
  return y * x

jax.jvp(power3_with_tangents, (3.,), (0.1,))
# what: x,x^2 : (3., 9.)
# what: tangents : (0.1, 0.6)

You can do a similar thing for the cotangents during jax.grad(). This time you must be careful to use in the rest of the computation the values whose cotangents you want to tap. Hence we make the print_cotangents return its argument:

@jax.custom_vjp
def print_cotangents(arg):
  # Must return the argument for which we want the cotangent.
  return arg

# f_fwd: a -> (b, residual)
def print_cotangents_fwd(arg):
  return print_cotangents(arg), None
# f_bwd: (residual, CT b) -> [CT a]
def print_cotangents_bwd(residual, ct_b):
  hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream)
  return ct_b,

print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd)

def power3_with_cotangents(x):
  y = x * x
  # Print both 'x' and 'x^2'. Must pack as a tuple.
  hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
  (x1, y1) = print_cotangents((x, y))
  # Must use the output of print_cotangents
  return y1 * x1

jax.grad(power3_with_cotangents)(3.)
# what: x,x^2 : (3., 9.)
# what: cotangents : (9., 3.)

If you use ad_checkpoint.checkpoint() to rematerialize the residuals for the backward pass, then the callbacks from the primal computation will be called twice:

jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.)
# what: x,x^2 : (3., 9.)
# what: x,x^2 : (27., 729.)
# what: x,x^2 : (3., 9.)

The callbacks are, in order from: the primal computation of the inner power3, the primal computation of the outer power3, and the rematerialization of the residuals for the inner power3.

Behavior under jax.vmap#

The host callback functions id_print() and id_tap() support the vectorization transformation jax.vmap().

For jax.vmap() the arguments to the callback are batched, and the callback function is passed an additional special transforms containing a list of transformation descriptors in the form ("batch", {"batch_dims": ...}), where ...` denotes the batched dimensions for the tapped values (one entry per argument, ` None` denotes an argument that was broadcast).

jax.vmap(power3)(np.array([2., 3.])) # transforms: [(‘batch’, {‘batch_dims’: (0, 0)})] what: x,x^2 : ([2., 3.], [4., 9.])

See documentation for id_tap(), id_print(), and call().

For more usage example, see tests/host_callback_test.py.

Using call() to call a TensorFlow function, with reverse-mode autodiff support#

Another possible use for host computation is to invoke a library written for another framework, such as TensorFlow. In this case it becomes interesting to support JAX autodiff for host callbacks by deferring to the autodiff mechanism in TensorFlow, using the jax.custom_vjp() mechanism.

This is relatively easy to do, once one understands both the JAX custom VJP and the TensorFlow autodiff mechanisms. The code for how this can be done is shown in the call_tf_full_ad function in host_callback_to_tf_test.py. This example supports arbitrary higher-order differentiation as well.

Note that if you just want to call TensorFlow functions from JAX, you can also use the jax2tf.call_tf function.

Using call() to call a JAX function on another device, with reverse-mode autodiff support#

It should not be surprising that we can use host computation to invoke a JAX computation on another device. The arguments are sent from the accelerator to the host, and then to the outside device on which the JAX host computation will run, and then the results are sent back to the original accelerator.

The code for how this can be done is shown in the call_jax_other_device function in host_callback_test.py.

Low-level details and debugging#

The host callback functions will be executed for each device in the order in which the send operations were performed on the device.

The host callback functions for multiple devices may be interleaved. The data from the devices is received by separate threads managed by the JAX runtime (one thread per device). The runtime maintains a buffer of configurable size (see the flag --jax_host_callback_max_queue_byte_size). When the buffer is full, all the receiving threads are paused which eventually pauses the computation on devices. The runtime has one additional thread for each device to invoke the Python user functions with the received data. If the processing of the callbacks is slow, it may actually lead to the runtime buffer filling up, and eventually pausing the computation on the devices when they need to send something. For more details on the outfeed receiver runtime mechanism see runtime code.

In order to pause the execution until all data from computations already started on devices has arrived and has been processed, use barrier_wait().

Exceptions from the user-defined callback functions are logged along with their stack traces, but the receiving threads are not stopped. Instead the last exception is recorded and the subsequent barrier_wait() will raise CallbackException if any exception had occurred in one of the tap functions. This exception will include the text and the stack trace of the last exception encountered.

One further complication arises for callback functions that must return results to the call origin device, such as call(). This is handled differently on CPU/GPU devices compared to TPU devices.

On CPU/GPU devices, in order to avoid the device computation being stuck waiting for a result that will never arrive, in case of any error during the processing of the callback (whether raised by the user-code itself or due to a mismatch of the returned value and the expected return_shape) we send the device a “fake” result of shape int8[12345]. This will make the device computation abort because the received data is different than the one that it expects. On CPU the runtime will crash with a distinctive error message:

` Check failed: buffer->length() == buffer_length (12345 vs. ...) `

On GPU, the failure is more user-friendly and will be surfaced to the Python program as:

` RET_CHECK failure ... Mismatch between infeed source buffer shape s8[12345] ... `

To debug the underlying cause for these messages, see the Debugging section.

On TPU devices, there is currently no shape check for infeed, so we take the safer route of not sending this fake result in case of errors. This means that the computation will hang, and no exception will be raised (but any exceptions in the callback functions will still appear in the logs).

The current implementation uses the outfeed mechanism provided by XLA. The mechanism itself is quite primitive in the sense that a receiver must know exactly the shape of each incoming packet, and how many packets are expected. This makes it hard to use for multiple kinds of data in the same computation, and it is practically impossible to use it under conditionals or in loops of non-constant iteration count. Furthermore, code that uses the outfeed mechanism directly cannot be transformed by JAX. All these limitations are addressed by the host callback functions. The tapping API introduced here makes it easy to share the outfeed mechanism for multiple purposes, while supporting all transformations.

Note that after you have used the host callback functions, you cannot use lax.outfeed directly. You may want to stop_outfeed_receiver() if you later need to use lax.outfeed.

Since the actual calls to your callback functions are made from the C++ receiver, it may be hard to debug the calls. In particular, the stack trace will not include the calling code. You can use the flag jax_host_callback_inline (or the environment variable JAX_HOST_CALLBACK_INLINE) to ensure that the calls to the callbacks are inlined. This works only if the calls are outside a staging context (jit() or a control-flow primitive).

The C++ receiver is started automatically on the first call to id_tap(). In order to stop it properly, upon start an atexit handler is registered to call barrier_wait() with the logging name “at_exit”.

There are a few environment variables that you can use to turn on logging for the C++ outfeed receiver backend.

  • TF_CPP_MIN_LOG_LEVEL=0: will turn on INFO logging, needed for all below.

  • TF_CPP_MIN_VLOG_LEVEL=3: will make all VLOG logging up to level 3 behave like INFO logs. This may be too much, but you will see which modules are logging relevant info, and then you can select which modules to log from.

  • TF_CPP_VMODULE=<module_name>=3 (the module name can be either C++ or Python, without the extension).

You should also use the --verbosity=2 flag so that you see the logs from Python.

For example, you can try to enable logging in the host_callback module: TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=host_callback=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple

If you want to enable logging in lower-level implementation modules try: TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,infeed_thunk=3,cpu_transfer_manager=3,cpu_runtime=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple

(For bazel tests use –test_arg=–vmodule=…

Still to do:
  • More performance tests.

  • Explore implementation with outside compilation for TPU.

  • Explore implementation with XLA CustomCall for CPU and GPU.

API#

id_tap(tap_func, arg, *[, result, ...])

Host-callback tap primitive, like identity function with a call to tap_func.

id_print(arg, *[, result, tap_with_device, ...])

Like id_tap() with a printing tap function.

call(callback_func, arg, *[, result_shape, ...])

Make a call to the host, and expect a result.

barrier_wait([logging_name])

Blocks the calling thread until all current outfeed is processed.

CallbackException

Signals that some callback function had exceptions.

jax.experimental.maps module#
API#

xmap(fun, in_axes, out_axes, *[, ...])

Assign a positional signature to a program that uses named array axes.

jax.experimental.pjit module#
API#
jax.experimental.pjit.pjit(fun, in_shardings=UnspecifiedValue, out_shardings=UnspecifiedValue, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None)[source]#

Makes fun compiled and automatically partitioned across multiple devices.

NOTE: This function is now equivalent to jax.jit please use that instead. The returned function has semantics equivalent to those of fun, but is compiled to an XLA computation that runs across multiple devices (e.g. multiple GPUs or multiple TPU cores). This can be useful if the jitted version of fun would not fit in a single device’s memory, or to speed up fun by running each operation in parallel across multiple devices.

The partitioning over devices happens automatically based on the propagation of the input partitioning specified in in_shardings and the output partitioning specified in out_shardings. The resources specified in those two arguments must refer to mesh axes, as defined by the jax.sharding.Mesh() context manager. Note that the mesh definition at pjit() application time is ignored, and the returned function will use the mesh definition available at each call site.

Inputs to a pjit()’d function will be automatically partitioned across devices if they’re not already correctly partitioned based on in_shardings. In some scenarios, ensuring that the inputs are already correctly pre-partitioned can increase performance. For example, if passing the output of one pjit()’d function to another pjit()’d function (or the same pjit()’d function in a loop), make sure the relevant out_shardings match the corresponding in_shardings.

Note

Multi-process platforms: On multi-process platforms such as TPU pods, pjit() can be used to run computations across all available devices across processes. To achieve this, pjit() is designed to be used in SPMD Python programs, where every process is running the same Python code such that all processes run the same pjit()’d function in the same order.

When running in this configuration, the mesh should contain devices across all processes. However, any input argument dimensions partitioned over multi-process mesh axes should be of size equal to the corresponding local mesh axis size, and outputs will be similarly sized according to the local mesh. fun will still be executed across all devices in the mesh, including those from other processes, and will be given a global view of the data spread across multiple processes as a single array. However, outside of pjit() every process only “sees” its local piece of the input and output, corresponding to its local sub-mesh.

This means that each process’s participating local devices must form a _contiguous_ local sub-mesh within the full global mesh. A contiguous sub-mesh is one where all of its devices are adjacent within the global mesh, and form a rectangular prism.

The SPMD model also requires that the same multi-process pjit()’d functions must be run in the same order on all processes, but they can be interspersed with arbitrary operations running in a single process.

Parameters:
  • fun (Callable) – Function to be compiled. Should be a pure function, as side-effects may only be executed once. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. Positional arguments indicated by static_argnums can be anything at all, provided they are hashable and have an equality operation defined. Static arguments are included as part of a compilation cache key, which is why hash and equality operators must be defined.

  • in_shardings

    Pytree of structure matching that of arguments to fun, with all actual arguments replaced by resource assignment specifications. It is also valid to specify a pytree prefix (e.g. one value in place of a whole subtree), in which case the leaves get broadcast to all values in that subtree.

    The in_shardings argument is optional. JAX will infer the shardings from the input jax.Array’s, and defaults to replicating the input if the sharding cannot be inferred.

    The valid resource assignment specifications are:

    • XLACompatibleSharding, which will decide how the value will be partitioned. With this, using a mesh context manager is not required.

    • None is a special case whose semantics are:
      • if the mesh context manager is not provided, JAX has the freedom to choose whatever sharding it wants. For in_shardings, JAX will mark is as replicated but this behavior can change in the future. For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings.

      • If the mesh context manager is provided, None will imply that the value will be replicated on all devices of the mesh.

    • For backwards compatibility, in_shardings still supports ingesting PartitionSpec. This option can only be used with the mesh context manager.

      • PartitionSpec, a tuple of length at most equal to the rank of the partitioned value. Each element can be a None, a mesh axis or a tuple of mesh axes, and specifies the set of resources assigned to partition the value’s dimension matching its position in the spec.

    The size of every dimension has to be a multiple of the total number of resources assigned to it.

  • out_shardings – Like in_shardings, but specifies resource assignment for function outputs. The out_shardings argument is optional. If not specified, jax.jit() will use GSPMD’s sharding propagation to determine how to shard the outputs.

  • static_argnums (int | Sequence[int] | None) –

    An optional int or collection of ints that specify which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded in Python (during tracing), and so the corresponding argument values can be any Python object.

    Static arguments should be hashable, meaning both __hash__ and __eq__ are implemented, and immutable. Calling the jitted function with different values for these constants will trigger recompilation. Arguments that are not arrays or containers thereof must be marked as static.

    If static_argnums is not provided, no arguments are treated as static.

  • static_argnames (str | Iterable[str] | None) – An optional string or collection of strings specifying which named arguments to treat as static (compile-time constant). See the comment on static_argnums for details. If not provided but static_argnums is set, the default is based on calling inspect.signature(fun) to find corresponding named arguments.

  • donate_argnums (int | Sequence[int] | None) –

    Specify which positional argument buffers are “donated” to the computation. It is safe to donate argument buffers if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to. By default, no argument buffers are donated.

    If neither donate_argnums nor donate_argnames is provided, no arguments are donated. If donate_argnums is not provided but donate_argnames is, or vice versa, JAX uses inspect.signature(fun) to find any positional arguments that correspond to donate_argnames (or vice versa). If both donate_argnums and donate_argnames are provided, inspect.signature is not used, and only actual parameters listed in either donate_argnums or donate_argnames will be donated.

    For more details on buffer donation see the FAQ.

  • donate_argnames (str | Iterable[str] | None) – An optional string or collection of strings specifying which named arguments are donated to the computation. See the comment on donate_argnums for details. If not provided but donate_argnums is set, the default is based on calling inspect.signature(fun) to find corresponding named arguments.

  • keep_unused (bool) – If False (the default), arguments that JAX determines to be unused by fun may be dropped from resulting compiled XLA executables. Such arguments will not be transferred to the device nor provided to the underlying executable. If True, unused arguments will not be pruned.

  • device (Device | None) – This argument is deprecated. Please put your arguments on the device you want before passing them to jit. Optional, the Device the jitted function will run on. (Available devices can be retrieved via jax.devices().) The default is inherited from XLA’s DeviceAssignment logic and is usually to use jax.devices()[0].

  • backend (str | None) – This argument is deprecated. Please put your arguments on the backend you want before passing them to jit. Optional, a string representing the XLA backend: 'cpu', 'gpu', or 'tpu'.

  • inline (bool)

  • abstracted_axes (Any | None)

Returns:

A wrapped version of fun, set up for just-in-time compilation and automatically partitioned by the mesh available at each call site.

Return type:

JitWrapped

For example, a convolution operator can be automatically partitioned over an arbitrary set of devices by a single pjit() application:

>>> import jax
>>> import jax.numpy as jnp
>>> import numpy as np
>>> from jax.sharding import Mesh, PartitionSpec
>>> from jax.experimental.pjit import pjit
>>>
>>> x = jnp.arange(8, dtype=jnp.float32)
>>> f = pjit(lambda x: jax.numpy.convolve(x, jnp.asarray([0.5, 1.0, 0.5]), 'same'),
...         in_shardings=None, out_shardings=PartitionSpec('devices'))
>>> with Mesh(np.array(jax.devices()), ('devices',)):
...   print(f(x))  
[ 0.5  2.   4.   6.   8.  10.  12.  10. ]
jax.experimental.sparse module#

The jax.experimental.sparse module includes experimental support for sparse matrix operations in JAX. It is under active development, and the API is subject to change. The primary interfaces made available are the BCOO sparse array type, and the sparsify() transform.

Batched-coordinate (BCOO) sparse matrices#

The main high-level sparse object currently available in JAX is the BCOO, or batched coordinate sparse array, which offers a compressed storage format compatible with JAX transformations, in particular JIT (e.g. jax.jit()), batching (e.g. jax.vmap()) and autodiff (e.g. jax.grad()).

Here is an example of creating a sparse array from a dense array:

>>> from jax.experimental import sparse
>>> import jax.numpy as jnp
>>> import numpy as np
>>> M = jnp.array([[0., 1., 0., 2.],
...                [3., 0., 0., 0.],
...                [0., 0., 4., 0.]])
>>> M_sp = sparse.BCOO.fromdense(M)
>>> M_sp
BCOO(float32[3, 4], nse=4)

Convert back to a dense array with the todense() method:

>>> M_sp.todense()
Array([[0., 1., 0., 2.],
       [3., 0., 0., 0.],
       [0., 0., 4., 0.]], dtype=float32)

The BCOO format is a somewhat modified version of the standard COO format, and the dense representation can be seen in the data and indices attributes:

>>> M_sp.data  # Explicitly stored data
Array([1., 2., 3., 4.], dtype=float32)
>>> M_sp.indices # Indices of the stored data
Array([[0, 1],
       [0, 3],
       [1, 0],
       [2, 2]], dtype=int32)

BCOO objects have familiar array-like attributes, as well as sparse-specific attributes:

>>> M_sp.ndim
2
>>> M_sp.shape
(3, 4)
>>> M_sp.dtype
dtype('float32')
>>> M_sp.nse  # "number of specified elements"
4

BCOO objects also implement a number of array-like methods, to allow you to use them directly within jax programs. For example, here we compute the transposed matrix-vector product:

>>> y = jnp.array([3., 6., 5.])
>>> M_sp.T @ y
Array([18.,  3., 20.,  6.], dtype=float32)
>>> M.T @ y  # Compare to dense version
Array([18.,  3., 20.,  6.], dtype=float32)

BCOO objects are designed to be compatible with JAX transforms, including jax.jit(), jax.vmap(), jax.grad(), and others. For example:

>>> from jax import grad, jit
>>> def f(y):
...   return (M_sp.T @ y).sum()
...
>>> jit(grad(f))(y)
Array([3., 3., 4.], dtype=float32)

Note, however, that under normal circumstances jax.numpy and jax.lax functions do not know how to handle sparse matrices, so attempting to compute things like jnp.dot(M_sp.T, y) will result in an error (however, see the next section).

Sparsify transform#

An overarching goal of the JAX sparse implementation is to provide a means to switch from dense to sparse computation seamlessly, without having to modify the dense implementation. This sparse experiment accomplishes this through the sparsify() transform.

Consider this function, which computes a more complicated result from a matrix and a vector input:

>>> def f(M, v):
...   return 2 * jnp.dot(jnp.log1p(M.T), v) + 1
...
>>> f(M, y)
Array([17.635532,  5.158883, 17.09438 ,  7.591674], dtype=float32)

Were we to pass a sparse matrix to this directly, it would result in an error, because jnp functions do not recognize sparse inputs. However, with sparsify(), we get a version of this function that does accept sparse matrices:

>>> f_sp = sparse.sparsify(f)
>>> f_sp(M_sp, y)
Array([17.635532,  5.158883, 17.09438 ,  7.591674], dtype=float32)

Support for sparsify() includes a large number of the most common primitives, including:

  • generalized (batched) matrix products & einstein summations (dot_general_p)

  • zero-preserving elementwise binary operations (e.g. add_p, mul_p, etc.)

  • zero-preserving elementwise unary operations (e.g. abs_p, jax.lax.neg_p, etc.)

  • summation reductions (reduce_sum_p)

  • general indexing operations (slice_p, lax.dynamic_slice_p, lax.gather_p)

  • concatenation and stacking (concatenate_p)

  • transposition & reshaping ((transpose_p, reshape_p, squeeze_p, broadcast_in_dim_p)

  • some higher-order functions (cond_p, while_p, scan_p)

  • some simple 1D convolutions (conv_general_dilated_p)

Nearly any jax.numpy function that lowers to these supported primitives can be used within a sparsify transform to operate on sparse arrays. This set of primitives is enough to enable relatively sophisticated sparse workflows, as the next section will show.

Example: sparse logistic regression#

As an example of a more complicated sparse workflow, let’s consider a simple logistic regression implemented in JAX. Notice that the following implementation has no reference to sparsity:

>>> import functools
>>> from sklearn.datasets import make_classification
>>> from jax.scipy import optimize
>>> def sigmoid(x):
...   return 0.5 * (jnp.tanh(x / 2) + 1)
...
>>> def y_model(params, X):
...   return sigmoid(jnp.dot(X, params[1:]) + params[0])
...
>>> def loss(params, X, y):
...   y_hat = y_model(params, X)
...   return -jnp.mean(y * jnp.log(y_hat) + (1 - y) * jnp.log(1 - y_hat))
...
>>> def fit_logreg(X, y):
...   params = jnp.zeros(X.shape[1] + 1)
...   result = optimize.minimize(functools.partial(loss, X=X, y=y),
...                              x0=params, method='BFGS')
...   return result.x
>>> X, y = make_classification(n_classes=2, random_state=1701)
>>> params_dense = fit_logreg(X, y)
>>> print(params_dense)  
[-0.7298445   0.29893667  1.0248291  -0.44436368  0.8785025  -0.7724008
 -0.62893456  0.2934014   0.82974285  0.16838408 -0.39774987 -0.5071844
  0.2028872   0.5227761  -0.3739224  -0.7104083   2.4212713   0.6310087
 -0.67060554  0.03139788 -0.05359547]

This returns the best-fit parameters of a dense logistic regression problem. To fit the same model on sparse data, we can apply the sparsify() transform:

>>> Xsp = sparse.BCOO.fromdense(X)  # Sparse version of the input
>>> fit_logreg_sp = sparse.sparsify(fit_logreg)  # Sparse-transformed fit function
>>> params_sparse = fit_logreg_sp(Xsp, y)
>>> print(params_sparse)  
[-0.72971725  0.29878938  1.0246326  -0.44430563  0.8784217  -0.77225566
 -0.6288222   0.29335397  0.8293481   0.16820715 -0.39764675 -0.5069753
  0.202579    0.522672   -0.3740134  -0.7102678   2.4209507   0.6310593
 -0.670236    0.03132951 -0.05356663]
Sparse API Reference#

sparsify(f[, use_tracer])

Experimental sparsification transform.

grad(fun[, argnums, has_aux])

Sparse-aware version of jax.grad()

value_and_grad(fun[, argnums, has_aux])

Sparse-aware version of jax.value_and_grad()

empty(shape[, dtype, index_dtype, sparse_format])

Create an empty sparse array.

eye(N[, M, k, dtype, index_dtype, sparse_format])

Create 2D sparse identity matrix.

todense(arr)

Convert input to a dense matrix.

random_bcoo(key, shape, *[, dtype, ...])

Generate a random BCOO matrix.

JAXSparse(args, *, shape)

Base class for high-level JAX sparse objects.

BCOO Data Structure#

BCOO is the Batched COO format, and is the main sparse data structure implemented in jax.experimental.sparse. Its operations are compatible with JAX’s core transformations, including batching (e.g. jax.vmap()) and autodiff (e.g. jax.grad()).

BCOO(args, *, shape[, indices_sorted, ...])

Experimental batched COO matrix implemented in JAX

bcoo_broadcast_in_dim(mat, *, shape, ...)

Expand the size and rank of a BCOO array by duplicating the data.

bcoo_concatenate(operands, *, dimension)

Sparse implementation of jax.lax.concatenate()

bcoo_dot_general(lhs, rhs, *, dimension_numbers)

A general contraction operation.

bcoo_dot_general_sampled(A, B, indices, *, ...)

A contraction operation with output computed at given sparse indices.

bcoo_dynamic_slice(mat, start_indices, ...)

Sparse implementation of {func}`jax.lax.dynamic_slice`.

bcoo_extract(sparr, arr, *[, assume_unique])

Extract values from a dense array according to the sparse array's indices.

bcoo_fromdense(mat, *[, nse, n_batch, ...])

Create BCOO-format sparse matrix from a dense matrix.

bcoo_gather(operand, start_indices, ...[, ...])

BCOO version of lax.gather.

bcoo_multiply_dense(sp_mat, v)

An element-wise multiplication between a sparse and a dense array.

bcoo_multiply_sparse(lhs, rhs)

An element-wise multiplication of two sparse arrays.

bcoo_update_layout(mat, *[, n_batch, ...])

Update the storage layout (i.e. n_batch & n_dense) of a BCOO matrix.

bcoo_reduce_sum(mat, *, axes)

Sum array element over given axes.

bcoo_reshape(mat, *, new_sizes[, dimensions])

Sparse implementation of {func}`jax.lax.reshape`.

bcoo_slice(mat, *, start_indices, limit_indices)

Sparse implementation of {func}`jax.lax.slice`.

bcoo_sort_indices(mat)

Sort indices of a BCOO array.

bcoo_squeeze(arr, *, dimensions)

Sparse implementation of {func}`jax.lax.squeeze`.

bcoo_sum_duplicates(mat[, nse])

Sums duplicate indices within a BCOO array, returning an array with sorted indices.

bcoo_todense(mat)

Convert batched sparse matrix to a dense matrix.

bcoo_transpose(mat, *, permutation)

Transpose a BCOO-format array.

BCSR Data Structure#

BCSR is the Batched Compressed Sparse Row format, and is under development. Its operations are compatible with JAX’s core transformations, including batching (e.g. jax.vmap()) and autodiff (e.g. jax.grad()).

BCSR(args, *, shape[, indices_sorted, ...])

Experimental batched CSR matrix implemented in JAX.

bcsr_dot_general(lhs, rhs, *, dimension_numbers)

A general contraction operation.

bcsr_extract(indices, indptr, mat)

Extract values from a dense matrix at given BCSR (indices, indptr).

bcsr_fromdense(mat, *[, nse, n_batch, ...])

Create BCSR-format sparse matrix from a dense matrix.

bcsr_todense(mat)

Convert batched sparse matrix to a dense matrix.

Other Sparse Data Structures#

Other sparse data structures include COO, CSR, and CSC. These are reference implementations of simple sparse structures with a few core operations implemented. Their operations are generally compatible with autodiff transformations such as jax.grad(), but not with batching transforms like jax.vmap().

COO(args, *, shape[, rows_sorted, cols_sorted])

Experimental COO matrix implemented in JAX.

CSC(args, *, shape)

Experimental CSC matrix implemented in JAX; API subject to change.

CSR(args, *, shape)

Experimental CSR matrix implemented in JAX.

coo_fromdense(mat, *[, nse, index_dtype])

Create a COO-format sparse matrix from a dense matrix.

coo_matmat(mat, B, *[, transpose])

Product of COO sparse matrix and a dense matrix.

coo_matvec(mat, v[, transpose])

Product of COO sparse matrix and a dense vector.

coo_todense(mat)

Convert a COO-format sparse matrix to a dense matrix.

csr_fromdense(mat, *[, nse, index_dtype])

Create a CSR-format sparse matrix from a dense matrix.

csr_matmat(mat, B, *[, transpose])

Product of CSR sparse matrix and a dense matrix.

csr_matvec(mat, v[, transpose])

Product of CSR sparse matrix and a dense vector.

csr_todense(mat)

Convert a CSR-format sparse matrix to a dense matrix.

jax.experimental.sparse.linalg#

Sparse linear algebra routines.

spsolve(data, indices, indptr, b[, tol, reorder])

A sparse direct solver using QR factorization.

lobpcg_standard(A, X[, m, tol])

Compute the top-k standard eigenvalues using the LOBPCG routine.

jax.experimental.jet module#

Jet is an experimental module for higher-order automatic differentiation that does not rely on repeated first-order automatic differentiation.

How? Through the propagation of truncated Taylor polynomials. Consider a function \(f = g \circ h\), some point \(x\) and some offset \(v\). First-order automatic differentiation (such as jax.jvp()) computes the pair \((f(x), \partial f(x)[v])\) from the pair \((h(x), \partial h(x)[v])\).

jet() implements the higher-order analogue: Given the tuple

\[(h_0, ... h_K) := (h(x), \partial h(x)[v], \partial^2 h(x)[v, v], ..., \partial^K h(x)[v,...,v]),\]

which represents a \(K\)-th order Taylor approximation of \(h\) at \(x\), jet() returns a \(K\)-th order Taylor approximation of \(f\) at \(x\),

\[(f_0, ..., f_K) := (f(x), \partial f(x)[v], \partial^2 f(x)[v, v], ..., \partial^K f(x)[v,...,v]).\]

More specifically, jet() computes

\[f_0, (f_1, . . . , f_K) = \texttt{jet} (f, h_0, (h_1, . . . , h_K))\]

and can thus be used for high-order automatic differentiation of \(f\). Details are explained in these notes.

Note

Help improve jet() by contributing outstanding primitive rules.

API#
jax.experimental.jet.jet(fun, primals, series)[source]#

Taylor-mode higher-order automatic differentiation.

Parameters:
  • fun – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars.

  • primals – The primal values at which the Taylor approximation of fun should be evaluated. Should be either a tuple or a list of arguments, and its length should be equal to the number of positional parameters of fun.

  • series – Higher order Taylor-series-coefficients. Together, primals and series make up a truncated Taylor polynomial. Should be either a tuple or a list of tuples or lists, and its length dictates the degree of the truncated Taylor polynomial.

Returns:

A (primals_out, series_out) pair, where primals_out is fun(*primals), and together, primals_out and series_out are a truncated Taylor polynomial of \(f(h(\cdot))\). The primals_out value has the same Python tree structure as primals, and the series_out value the same Python tree structure as series.

For example:

>>> import jax
>>> import jax.numpy as np

Consider the function \(h(z) = z^3\), \(x = 0.5\), and the first few Taylor coefficients \(h_0=x^3\), \(h_1=3x^2\), and \(h_2=6x\). Let \(f(y) = \sin(y)\).

>>> h0, h1, h2 = 0.5**3., 3.*0.5**2., 6.*0.5
>>> f, df, ddf = np.sin, np.cos, lambda *args: -np.sin(*args)

jet() returns the Taylor coefficients of \(f(h(z)) = \sin(z^3)\) according to Faà di Bruno’s formula:

>>> f0, (f1, f2) =  jet(f, (h0,), ((h1, h2),))
>>> print(f0,  f(h0))
0.12467473 0.12467473
>>> print(f1, df(h0) * h1)
0.7441479 0.74414825
>>> print(f2, ddf(h0) * h1 ** 2 + df(h0) * h2)
2.9064622 2.9064634
jax.experimental.custom_partitioning module#
API#
jax.experimental.custom_partitioning.custom_partitioning(fun, static_argnums=())[source]#

Inserts a CustomCallOp into the XLA graph with custom SPMD lowering rules.

@custom_partitioning
def f(*args):
  return ...

def propagate_user_sharding(mesh, user_shape):
  '''Update the sharding of the op from a user's shape.sharding.'''
  user_sharding = jax.tree.map(lambda x: x.sharding, user_shape)

def partition(mesh, arg_shapes, result_shape):
  def lower_fn(*args):
    ... builds computation on per-device shapes ...
  result_shardings = jax.tree.map(lambda x: x.sharding, result_shape)
  arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
  # result_sharding and arg_shardings may optionally be modified and the
  # partitioner will insert collectives to reshape.
  return mesh, lower_fn, result_sharding, arg_shardings

def infer_sharding_from_operands(mesh, arg_shapes, shape):
  '''Compute the result sharding from the sharding of the operands.'''
  arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)


f.def_partition(partition, propagate_user_sharding, infer_sharding_from_operands)

The args to def_partition are as follows:

  • propagate_user_sharding: Callable which takes the sharding of a user (in the dag) and returns a suggestion for a new NamedSharding. The default implementation is just to return the suggested sharding.

  • partition: Callable which takes the SPMD suggested partition shapes and partition specs and returns the mesh, a per-shard lowering function, and the final input and output sharding specs (the SPMD partitioner will repartition the inputs to match). The mesh is returned to allow configuring axis_names for collectives when no mesh is provided.

  • infer_sharding_from_operands: Callable which computes an output NamedSharding from the NamedSharding chosen for each argument.

  • decode_shardings: When set to True, convert input GSPMDSharding``s to ``NamedSharding if possible. This may not be possible if the user does not provide a contextual mesh.

Positional arguments can be specified as static using static_argnums. JAX uses inspect.signature(fun) to resolve these positional arguments.

Example

As an example, assume we want to enhance the existing jax.numpy.fft.fft. This function computes the discrete Fourier transform of an N-dimensional input along the last dimension, and is batched along the first N-1 dimensions. By default, however, it will ignore the sharding of the input and gather the input on all devices. However, since jax.numpy.fft.fft is batched along the first N-1 dimensions, this is unnecessary. We will create a new my_fft op that, instead, does not alter the sharding along the first N-1 dimensions, and only gathers the input along the last dimension if needed.

import jax
from jax.sharding import NamedSharding
from jax.experimental.custom_partitioning import custom_partitioning
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
from jax.sharding import Mesh
from jax.numpy.fft import fft
import regex as re
import numpy as np

# Pattern to detect all-gather or dynamic-slice in the generated HLO
_PATTERN = '(dynamic-slice|all-gather)'

# For an N-D input, keeps sharding along the first N-1 dimensions
# but replicate along the last dimension
def supported_sharding(sharding, shape):
    rank = len(shape.shape)
    max_shared_dims = min(len(sharding.spec), rank-1)
    names = tuple(sharding.spec[:max_shared_dims]) + tuple(None for _ in range(rank - max_shared_dims))
    return NamedSharding(sharding.mesh, P(*names))

def partition(mesh, arg_shapes, result_shape):
    result_shardings = jax.tree.map(lambda x: x.sharding, result_shape)
    arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
    return mesh, fft,               supported_sharding(arg_shardings[0], arg_shapes[0]),               (supported_sharding(arg_shardings[0], arg_shapes[0]),)

def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
    arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
    return supported_sharding(arg_shardings[0], arg_shapes[0])

@custom_partitioning
def my_fft(x):
    return fft(x)

my_fft.def_partition(
    infer_sharding_from_operands=infer_sharding_from_operands,
    partition=partition)

Now create a 2D array sharded along the first axis, pass it through my_fft and notice how it is still sharded as expected, and identical to the output of fft. However, inspecting the HLO (using lower(x).compile().runtime_executable().hlo_modules()) reveals that my_fft does not create any all-gather or dynamic-slice, while fft does.

with Mesh(np.array(jax.devices()), ('x',)):
  x = np.asarray(np.random.randn(32*1024, 1024), dtype=np.complex64)
  y = pjit(lambda x: x, in_shardings=None, out_shardings=P('x'))(x)
  pjit_my_fft = pjit(my_fft, in_shardings=P('x'), out_shardings=P('x'))
  pjit_fft    = pjit(fft,    in_shardings=P('x'), out_shardings=P('x'))
  print(pjit_my_fft(y))
  print(pjit_fft(y))
  # dynamic-slice or all-gather are not present in the HLO for my_fft, because x is a 2D array
  assert(re.search(_PATTERN, pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is None)
  # dynamic-slice or all-gather are present in the HLO for fft
  assert(re.search(_PATTERN, pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string())    is not None)
# my_fft
[[-38.840824   +0.j        -40.649452  +11.845365j
...
  -1.6937828  +0.8402481j  15.999859   -4.0156755j]]

# jax.numpy.fft.fft
[[-38.840824   +0.j        -40.649452  +11.845365j
  ...
  -1.6937828  +0.8402481j  15.999859   -4.0156755j]]

Because of the logic in supported_sharding, my_fft also works on 1-dimensional arrays. However, in this case, the HLO of my_fft does show a dynamic-slice, since the last dimension is the dimension along which FFTs are calculated and needs to be replicated on all devices before the computation can be done.

with Mesh(np.array(jax.devices()), ('x',)):
  x = np.asarray(np.random.randn(32*1024*1024), dtype=np.complex64)
  y = pjit(lambda x: x, in_shardings=None, out_shardings=P('x'))(x)
  pjit_my_fft = pjit(my_fft, in_shardings=P('x'), out_shardings=P('x'))
  pjit_fft    = pjit(fft,    in_shardings=P('x'), out_shardings=P('x'))
  print(pjit_my_fft(y))
  print(pjit_fft(y))
  # dynamic-slice or all-gather are present in the HLO for my_fft, because x is a 1D array
  assert(re.search(_PATTERN, pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is None)
  # dynamic-slice or all-gather are present in the HLO for fft
  assert(re.search(_PATTERN, pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string())    is not None)
# my_fft
[    7.217285   +0.j     -3012.4937  +4287.635j   -405.83594 +3042.984j
...  1422.4502  +7271.4297j  -405.84033 -3042.983j
-3012.4963  -4287.6343j]

# jax.numpy.fft.fft
[    7.217285   +0.j     -3012.4937  +4287.635j   -405.83594 +3042.984j
...  1422.4502  +7271.4297j  -405.84033 -3042.983j
-3012.4963  -4287.6343j]
jax.experimental.multihost_utils module#

Utilities for synchronizing and communication across multiple hosts.

Multihost Utils API Reference#

broadcast_one_to_all(in_tree[, is_source])

Broadcast data from a source host (host 0 by default) to all other hosts.

sync_global_devices(name)

Creates a barrier across all hosts/devices.

process_allgather(in_tree[, tiled])

Gather data from across processes.

assert_equal(in_tree[, fail_message])

Verifies that all the hosts have the same tree of values.

host_local_array_to_global_array(...)

Converts a host local value to a globally sharded jax.Array.

global_array_to_host_local_array(...)

Converts a global jax.Array to a host local jax.Array.

jax.experimental.compilation_cache module#

JAX disk compilation cache.

API#
jax.experimental.compilation_cache.compilation_cache.is_initialized()[source]#

Deprecated.

Return whether the cache is enabled. Initialization can be deferred, so initialized status is not checked. The name is retained for backwards compatibility.

Return type:

bool

jax.experimental.compilation_cache.compilation_cache.initialize_cache(path)[source]#

This API is deprecated; use set_cache_dir instead.

Set the path. To take effect, should be called prior to any calls to get_executable_and_time() and put_executable_and_time().

Return type:

None

jax.experimental.compilation_cache.compilation_cache.set_cache_dir(path)[source]#

Sets the persistent compilation cache directory.

After calling this, jit-compiled functions are saved to path, so they do not need be recompiled if the process is restarted or otherwise run again. This also tells Jax where to look for compiled functions before compiling.

Return type:

None

jax.experimental.compilation_cache.compilation_cache.reset_cache()[source]#

Get back to pristine, uninitialized state.

Return type:

None

jax.experimental.key_reuse module#
Experimental Key Reuse Checking#

This module contains experimental functionality for detecting reuse of random keys within JAX programs. It is under active development and the APIs here are likely to change. The usage below requires JAX version 0.4.26 or newer.

Key reuse checking can be enabled using the jax_debug_key_reuse configuration. This can be set globally using:

>>> jax.config.update('jax_debug_key_reuse', True)  

Or it can be enabled locally with the jax.debug_key_reuse() context manager. When enabled, using the same key twice will result in a KeyReuseError:

>>> import jax
>>> with jax.debug_key_reuse(True):
...   key = jax.random.key(0)
...   val1 = jax.random.normal(key)
...   val2 = jax.random.normal(key)  
Traceback (most recent call last):
 ...
KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0

The key reuse checker is currently experimental, but in the future we will likely enable it by default.

jax.experimental.mesh_utils module#

Utils for building a device mesh.

API#

create_device_mesh(mesh_shape[, devices, ...])

Creates a performant device mesh for jax.sharding.Mesh.

create_hybrid_device_mesh(mesh_shape, ...[, ...])

Creates a device mesh for hybrid (e.g., ICI and DCN) parallelism.

Experimental APIs#

enable_x64([new_val])

Experimental context manager to temporarily enable X64 mode.

disable_x64()

Experimental context manager to temporarily disable X64 mode.

jax.experimental.checkify.checkify(f[, errors])

Functionalize check calls in fun, and optionally add run-time error checks.

jax.experimental.checkify.check(pred, msg, ...)

Check a predicate, add an error with msg if predicate is False.

jax.experimental.checkify.check_error(error)

Raise an Exception if error represents a failure.

jax.lib module#

The jax.lib package is a set of internal tools and types for bridging between JAX’s Python frontend and its XLA backend.

jax.lib.xla_bridge#

default_backend()

Returns the platform name of the default XLA backend.

get_backend([platform])

get_compile_options(num_replicas, num_partitions)

Returns the compile options to use, as derived from flag values.

jax.lib.xla_client#

Configuration#

config

check_tracer_leaks

Context manager for jax_check_tracer_leaks config option.

checking_leaks

Context manager for jax_check_tracer_leaks config option.

debug_nans

Context manager for jax_debug_nans config option.

debug_infs

Context manager for jax_debug_infs config option.

default_device

Context manager for jax_default_device config option.

default_matmul_precision

Context manager for jax_default_matmul_precision config option.

default_prng_impl

Context manager for jax_default_prng_impl config option.

enable_checks

Context manager for jax_enable_checks config option.

enable_custom_prng

Context manager for jax_enable_custom_prng config option (transient).

enable_custom_vjp_by_custom_transpose

Context manager for jax_enable_custom_vjp_by_custom_transpose config option (transient).

log_compiles

Context manager for jax_log_compiles config option.

numpy_rank_promotion

Context manager for jax_numpy_rank_promotion config option.

transfer_guard(new_val)

A contextmanager to control the transfer guard level for all transfers.

Just-in-time compilation (jit)#

jit(fun[, in_shardings, out_shardings, ...])

Sets up fun for just-in-time compilation with XLA.

disable_jit([disable])

Context manager that disables jit() behavior under its dynamic context.

ensure_compile_time_eval()

Context manager to ensure evaluation at trace/compile time (or error).

xla_computation(fun[, static_argnums, ...])

Creates a function that produces its XLA computation given example args.

make_jaxpr([axis_env, return_shape, ...])

Creates a function that produces its jaxpr given example args.

eval_shape(fun, *args, **kwargs)

Compute the shape/dtype of fun without any FLOPs.

ShapeDtypeStruct(shape, dtype[, ...])

A container for the shape, dtype, and other static attributes of an array.

device_put(x[, device, src])

Transfers x to device.

device_put_replicated(x, devices)

Transfer array(s) to each specified device and form Array(s).

device_put_sharded(shards, devices)

Transfer array shards to specified devices and form Array(s).

device_get(x)

Transfer x to host.

default_backend()

Returns the platform name of the default XLA backend.

named_call(fun, *[, name])

Adds a user specified name to a function when staging out JAX computations.

named_scope(name)

A context manager that adds a user specified name to the JAX name stack.

block_until_ready(x)

Tries to call a block_until_ready method on pytree leaves.

Automatic differentiation#

grad(fun[, argnums, has_aux, holomorphic, ...])

Creates a function that evaluates the gradient of fun.

value_and_grad(fun[, argnums, has_aux, ...])

Create a function that evaluates both fun and the gradient of fun.

jacfwd(fun[, argnums, has_aux, holomorphic])

Jacobian of fun evaluated column-by-column using forward-mode AD.

jacrev(fun[, argnums, has_aux, holomorphic, ...])

Jacobian of fun evaluated row-by-row using reverse-mode AD.

hessian(fun[, argnums, has_aux, holomorphic])

Hessian of fun as a dense array.

jvp(fun, primals, tangents[, has_aux])

Computes a (forward-mode) Jacobian-vector product of fun.

linearize()

Produces a linear approximation to fun using jvp() and partial eval.

linear_transpose(fun, *primals[, reduce_axes])

Transpose a function that is promised to be linear.

vjp() ))

Compute a (reverse-mode) vector-Jacobian product of fun.

custom_jvp(fun[, nondiff_argnums])

Set up a JAX-transformable function for a custom JVP rule definition.

custom_vjp(fun[, nondiff_argnums])

Set up a JAX-transformable function for a custom VJP rule definition.

custom_gradient(fun)

Convenience function for defining custom VJP rules (aka custom gradients).

closure_convert(fun, *example_args)

Closure conversion utility, for use with higher-order custom derivatives.

checkpoint(fun, *[, prevent_cse, policy, ...])

Make fun recompute internal linearization points when differentiated.

jax.Array (jax.Array)#

Array()

Array base class for JAX

make_array_from_callback(shape, sharding, ...)

Returns a jax.Array via data fetched from data_callback.

make_array_from_single_device_arrays(shape, ...)

Returns a jax.Array from a sequence of jax.Arrays each on a single device.

Vectorization (vmap)#

vmap(fun[, in_axes, out_axes, axis_name, ...])

Vectorizing map.

numpy.vectorize(pyfunc, *[, excluded, signature])

Define a vectorized function with broadcasting.

Parallelization (pmap)#

pmap(fun[, axis_name, in_axes, out_axes, ...])

Parallel map with support for collective operations.

devices([backend])

Returns a list of all devices for a given backend.

local_devices([process_index, backend, host_id])

Like jax.devices(), but only returns devices local to a given process.

process_index([backend])

Returns the integer process index of this process.

device_count([backend])

Returns the total number of devices.

local_device_count([backend])

Returns the number of devices addressable by this process.

process_count([backend])

Returns the number of JAX processes associated with the backend.

Callbacks#

pure_callback(callback, result_shape_dtypes, ...)

Calls a pure Python callback.

experimental.io_callback(callback, ...[, ...])

Calls an impure Python callback.

debug.callback(callback, *args[, ordered])

Calls a stageable Python callback.

debug.print(fmt, *args[, ordered])

Prints values and works in staged out JAX functions.

Miscellaneous#

Device

A descriptor of an available device.

print_environment_info([return_string])

Returns a string containing local environment & JAX installation information.

live_arrays([platform])

Return all live arrays in the backend for platform.

clear_caches()

Clear all compilation and staging caches.

Change log#

Best viewed here.

jax 0.4.28#

jaxlib 0.4.28#

jax 0.4.27 (May 7, 2024)#

  • New Functionality

    • Added jax.numpy.unstack() and jax.numpy.cumulative_sum(), following their addition in the array API 2023 standard, soon to be adopted by NumPy.

    • Added a new config option jax_cpu_collectives_implementation to select the implementation of cross-process collective operations used by the CPU backend. Choices available are 'none'(default), 'gloo' and 'mpi' (requires jaxlib 0.4.26). If set to 'none', cross-process collective operations are disabled.

  • Changes

    • jax.pure_callback(), jax.experimental.io_callback() and jax.debug.callback() now use jax.Array instead of np.ndarray. You can recover the old behavior by transforming the arguments via jax.tree.map(np.asarray, args) before passing them to the callback.

    • complex_arr.astype(bool) now follows the same semantics as NumPy, returning False where complex_arr is equal to 0 + 0j, and True otherwise.

    • core.Token now is a non-trivial class which wraps a jax.Array. It could be created and threaded in and out of computations to build up dependency. The singleton object core.token has been removed, users now should create and use fresh core.Token objects instead.

    • On GPU, the Threefry PRNG implementation no longer lowers to a kernel call by default. This choice can improve runtime memory usage at a compile-time cost. Prior behavior, which produces a kernel call, can be recovered with jax.config.update('jax_threefry_gpu_kernel_lowering', True). If the new default causes issues, please file a bug. Otherwise, we intend to remove this flag in a future release.

  • Deprecations & Removals

    • Pallas now exclusively uses XLA for compiling kernels on GPU. The old lowering pass via Triton Python APIs has been removed and the JAX_TRITON_COMPILE_VIA_XLA environment variable no longer has any effect.

    • jax.numpy.clip() has a new argument signature: a, a_min, and a_max are deprecated in favor of x (positional only), min, and max (#20550).

    • The device() method of JAX arrays has been removed, after being deprecated since JAX v0.4.21. Use arr.devices() instead.

    • The initial argument to jax.nn.softmax() and jax.nn.log_softmax() is deprecated; empty inputs to softmax are now supported without setting this.

    • In jax.jit(), passing invalid static_argnums or static_argnames now leads to an error rather than a warning.

    • The minimum jaxlib version is now 0.4.23.

    • The jax.numpy.hypot() function now issues a deprecation warning when passing complex-valued inputs to it. This will raise an error when the deprecation is completed.

    • Scalar arguments to jax.numpy.nonzero(), jax.numpy.where(), and related functions now raise an error, following a similar change in NumPy.

    • The config option jax_cpu_enable_gloo_collectives is deprecated. Use jax.config.update('jax_cpu_collectives_implementation', 'gloo') instead.

    • The jax.Array.device_buffer and jax.Array.device_buffers methods have been removed after being deprecated in JAX v0.4.22. Instead use jax.Array.addressable_shards and jax.Array.addressable_data().

    • The condition, x, and y parameters of jax.numpy.where are now positional-only, following deprecation of the keywords in JAX v0.4.21.

    • Non-array arguments to functions in jax.lax.linalg now must be specified by keyword. Previously, this raised a DeprecationWarning.

    • Array-like arguments are now required in several :func:jax.numpy APIs, including apply_along_axis(), apply_over_axes(), inner(), outer(), cross(), kron(), and lexsort().

  • Bug fixes

    • jax.numpy.astype() will now always return a copy when copy=True. Previously, no copy would be made when the output array would have the same dtype as the input array. This may result in some increased memory usage. The default value is set to copy=False to preserve backwards compatability.

jaxlib 0.4.27 (May 7, 2024)#

jax 0.4.26 (April 3, 2024)#

  • New Functionality

  • Changes

    • Complex-valued jax.numpy.geomspace() now chooses the logarithmic spiral branch consistent with that of NumPy 2.0.

    • The behavior of lax.rng_bit_generator, and in turn the 'rbg' and 'unsafe_rbg' PRNG implementations, under jax.vmap has changed so that mapping over keys results in random generation only from the first key in the batch.

    • Docs now use jax.random.key for construction of PRNG key arrays rather than jax.random.PRNGKey.

  • Deprecations & Removals

    • jax.tree_map() is deprecated; use jax.tree.map instead, or for backward compatibility with older JAX versions, use jax.tree_util.tree_map().

    • jax.clear_backends() is deprecated as it does not necessarily do what its name suggests and can lead to unexpected consequences, e.g., it will not destroy existing backends and release corresponding owned resources. Use jax.clear_caches() if you only want to clean up compilation caches. For backward compatibility or you really need to switch/reinitialize the default backend, use jax.extend.backend.clear_backends().

    • The jax.experimental.maps module and jax.experimental.maps.xmap are deprecated. Use jax.experimental.shard_map or jax.vmap with the spmd_axis_name argument for expressing SPMD device-parallel computations.

    • The jax.experimental.host_callback module is deprecated. Use instead the new JAX external callbacks. Added JAX_HOST_CALLBACK_LEGACY flag to assist in the transition to the new callbacks. See #20385 for a discussion.

    • Passing arguments to jax.numpy.array_equal() and jax.numpy.array_equiv() that cannot be converted to a JAX array now results in an exception.

    • The deprecated flag jax_parallel_functions_output_gda has been removed. This flag was long deprecated and did nothing; its use was a no-op.

    • The previously-deprecated imports jax.interpreters.ad.config and jax.interpreters.ad.source_info_util have now been removed. Use jax.config and jax.extend.source_info_util instead.

    • JAX export does not support older serialization versions anymore. Version 9 has been supported since October 27th, 2023 and has become the default since February 1, 2024. See a description of the versions. This change could break clients that set a specific JAX serialization version lower than 9.

jaxlib 0.4.26 (April 3, 2024)#

  • Changes

    • JAX now supports CUDA 12.1 or newer only. Support for CUDA 11.8 has been dropped.

    • JAX now supports NumPy 2.0.

jax 0.4.25 (Feb 26, 2024)#

  • New Features

  • Changes

    • Pallas now uses XLA instead of the Triton Python APIs to compile Triton kernels. You can revert to the old behavior by setting the JAX_TRITON_COMPILE_VIA_XLA environment variable to "0".

    • Several deprecated APIs in jax.interpreters.xla that were removed in v0.4.24 have been re-added in v0.4.25, including backend_specific_translations, translations, register_translation, xla_destructure, TranslationRule, TranslationContext, and XLAOp. These are still considered deprecated, and will be removed again in the future when better replacements are available. Refer to #19816 for discussion.

  • Deprecations & Removals

    • jax.numpy.linalg.solve() now shows a deprecation warning for batched 1D solves with b.ndim > 1. In the future these will be treated as batched 2D solves.

    • Conversion of a non-scalar array to a Python scalar now raises an error, regardless of the size of the array. Previously a deprecation warning was raised in the case of non-scalar arrays of size 1. This follows a similar deprecation in NumPy.

    • The previously deprecated configuration APIs have been removed following a standard 3 months deprecation cycle (see API compatibility). These include

      • the jax.config.config object and

      • the define_*_state and DEFINE_* methods of jax.config.

    • Importing the jax.config submodule via import jax.config is deprecated. To configure JAX use import jax and then reference the config object via jax.config.

    • The minimum jaxlib version is now 0.4.20.

jaxlib 0.4.25 (Feb 26, 2024)#

jax 0.4.24 (Feb 6, 2024)#

  • Changes

    • JAX lowering to StableHLO does not depend on physical devices anymore. If your primitive wraps custom_paritioning or JAX callbacks in the lowering rule i.e. function passed to rule parameter of mlir.register_lowering then add your primitive to jax._src.dispatch.prim_requires_devices_during_lowering set. This is needed because custom_partitioning and JAX callbacks need physical devices to create Shardings during lowering. This is a temporary state until we can create Shardings without physical devices.

    • jax.numpy.argsort() and jax.numpy.sort() now support the stable and descending arguments.

    • Several changes to the handling of shape polymorphism (used in jax.experimental.jax2tf and jax.experimental.export):

      • cleaner pretty-printing of symbolic expressions (#19227)

      • added the ability to specify symbolic constraints on the dimension variables. This makes shape polymorphism more expressive, and gives a way to workaround limitations in the reasoning about inequalities. See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.

      • with the addition of symbolic constraints (#19235) we now consider dimension variables from different scopes to be different, even if they have the same name. Symbolic expressions from different scopes cannot interact, e.g., in arithmetic operations. Scopes are introduced by jax.experimental.jax2tf.convert(), jax.experimental.export.symbolic_shape(), jax.experimental.export.symbolic_args_specs(). The scope of a symbolic expression e can be read with e.scope and passed into the above functions to direct them to construct symbolic expressions in a given scope. See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.

      • simplified and faster equality comparisons, where we consider two symbolic dimensions to be equal if the normalized form of their difference reduces to 0 (#19231; note that this may result in user-visible behavior changes)

      • improved the error messages for inconclusive inequality comparisons (#19235).

      • the core.non_negative_dim API (introduced recently) was deprecated and core.max_dim and core.min_dim were introduced (#18953) to express max and min for symbolic dimensions. You can use core.max_dim(d, 0) instead of core.non_negative_dim(d).

      • the shape_poly.is_poly_dim is deprecated in favor of export.is_symbolic_dim (#19282).

      • the export.args_specs is deprecated in favor of export.symbolic_args_specs ({jax-issue}#19283`).

      • the shape_poly.PolyShape and jax2tf.PolyShape are deprecated, use strings for polymorphic shapes specifications (#19284).

      • JAX default native serialization version is now 9. This is relevant for jax.experimental.jax2tf and jax.experimental.export. See description of version numbers.

    • Refactored the API for jax.experimental.export. Instead of from jax.experimental.export import export you should use now from jax.experimental import export. The old way of importing will continue to work for a deprecation period of 3 months.

    • Added jax.scipy.stats.sem().

    • jax.numpy.unique() with return_inverse = True returns inverse indices reshaped to the dimension of the input, following a similar change to numpy.unique() in NumPy 2.0.

    • jax.numpy.sign() now returns x / abs(x) for nonzero complex inputs. This is consistent with the behavior of numpy.sign() in NumPy version 2.0.

    • jax.scipy.special.logsumexp() with return_sign=True now uses the NumPy 2.0 convention for the complex sign, x / abs(x). This is consistent with the behavior of scipy.special.logsumexp() in SciPy v1.13.

    • JAX now supports the bool DLPack type for both import and export. Previously bool values could not be imported and were exported as integers.

  • Deprecations & Removals

    • A number of previously deprecated functions have been removed, following a standard 3+ month deprecation cycle (see API compatibility). This includes:

      • From jax.core: TracerArrayConversionError, TracerIntegerConversionError, UnexpectedTracerError, as_hashable_function, collections, dtypes, lu, map, namedtuple, partial, pp, ref, safe_zip, safe_map, source_info_util, total_ordering, traceback_util, tuple_delete, tuple_insert, and zip.

      • From jax.lax: dtypes, itertools, naryop, naryop_dtype_rule, standard_abstract_eval, standard_naryop, standard_primitive, standard_unop, unop, and unop_dtype_rule.

      • The jax.linear_util submodule and all its contents.

      • The jax.prng submodule and all its contents.

      • From jax.random: PRNGKeyArray, KeyArray, default_prng_impl, threefry_2x32, threefry2x32_key, threefry2x32_p, rbg_key, and unsafe_rbg_key.

      • From jax.tree_util: register_keypaths, AttributeKeyPathEntry, and GetItemKeyPathEntry.

      • from jax.interpreters.xla: backend_specific_translations, translations, register_translation, xla_destructure, TranslationRule, TranslationContext, axis_groups, ShapedArray, ConcreteArray, AxisEnv, backend_compile, and XLAOp.

      • from jax.numpy: NINF, NZERO, PZERO, row_stack, issubsctype, trapz, and in1d.

      • from jax.scipy.linalg: tril and triu.

    • The previously-deprecated method PRNGKeyArray.unsafe_raw_array has been removed. Use jax.random.key_data() instead.

    • bool(empty_array) now raises an error rather than returning False. This previously raised a deprecation warning, and follows a similar change in NumPy.

    • Support for the mhlo MLIR dialect has been deprecated. JAX no longer uses the mhlo dialect, in favor of stablehlo. APIs that refer to “mhlo” will be removed in the future. Use the “stablehlo” dialect instead.

    • jax.random: passing batched keys directly to random number generation functions, such as bits(), gamma(), and others, is deprecated and will emit a FutureWarning. Use jax.vmap for explicit batching.

    • jax.lax.tie_in() is deprecated: it has been a no-op since JAX v0.2.0.

jaxlib 0.4.24 (Feb 6, 2024)#

  • Changes

    • JAX now supports CUDA 12.3 and CUDA 11.8. Support for CUDA 12.2 has been dropped.

    • cost_analysis now works with cross-compiled Compiled objects (i.e. when using .lower().compile() with a topology object, e.g., to compile for Cloud TPU from a non-TPU computer).

    • Added CUDA Array Interface import support (requires jax 0.4.25).

jax 0.4.23 (Dec 13, 2023)#

jaxlib 0.4.23 (Dec 13, 2023)#

  • Fixed a bug that caused verbose logging from the GPU compiler during compilation.

jax 0.4.22 (Dec 13, 2023)#

  • Deprecations

    • The device_buffer and device_buffers properties of JAX arrays are deprecated. Explicit buffers have been replaced by the more flexible array sharding interface, but the previous outputs can be recovered this way:

      • arr.device_buffer becomes arr.addressable_data(0)

      • arr.device_buffers becomes [x.data for x in arr.addressable_shards]

jaxlib 0.4.22 (Dec 13, 2023)#

jax 0.4.21 (Dec 4 2023)#

  • New Features

  • Changes

    • The minimum jaxlib version is now 0.4.19.

    • Released wheels are built now with clang instead of gcc.

    • Enforce that the device backend has not been initialized prior to calling jax.distributed.initialize().

    • Automate arguments to jax.distributed.initialize() in cloud TPU environments.

  • Deprecations

    • The previously-deprecated sym_pos argument has been removed from jax.scipy.linalg.solve(). Use assume_a='pos' instead.

    • Passing None to jax.array() or jax.asarray(), either directly or within a list or tuple, is deprecated and now raises a FutureWarning. It currently is converted to NaN, and in the future will raise a TypeError.

    • Passing the condition, x, and y parameters to jax.numpy.where by keyword arguments has been deprecated, to match numpy.where.

    • Passing arguments to jax.numpy.array_equal() and jax.numpy.array_equiv() that cannot be converted to a JAX array is deprecated and now raises a DeprecationWaning. Currently the functions return False, in the future this will raise an exception.

    • The device() method of JAX arrays is deprecated. Depending on the context, it may be replaced with one of the following:

      • jax.Array.devices() returns the set of all devices used by the array.

      • jax.Array.sharding gives the sharding configuration used by the array.

jaxlib 0.4.21 (Dec 4 2023)#

  • Changes

    • In preparation for adding distributed CPU support, JAX now treats CPU devices identically to GPU and TPU devices, that is:

      • jax.devices() includes all devices present in a distributed job, even those not local to the current process. jax.local_devices() still only includes devices local to the current process, so if the change to jax.devices() breaks you, you most likely want to use jax.local_devices() instead.

      • CPU devices now receive a globally unique ID number within a distributed job; previously CPU devices would receive a process-local ID number.

      • The process_index of each CPU device will now match any GPU or TPU devices within the same process; previously the process_index of a CPU device was always 0.

    • On NVIDIA GPU, JAX now prefers a Jacobi SVD solver for matrices up to 1024x1024. The Jacobi solver appears faster than the non-Jacobi version.

  • Bug fixes

    • Fixed error/hang when an array with non-finite values is passed to a non-symmetric eigendecomposition (#18226). Arrays with non-finite values now produce arrays full of NaNs as outputs.

jax 0.4.20 (Nov 2, 2023)#

jaxlib 0.4.20 (Nov 2, 2023)#

  • Bug fixes

    • Fixed some type confusion between E4M3 and E5M2 float8 types.

jax 0.4.19 (Oct 19, 2023)#

  • New Features

    • Added jax.typing.DTypeLike, which can be used to annotate objects that are convertible to JAX dtypes.

    • Added jax.numpy.fill_diagonal.

  • Changes

    • JAX now requires SciPy 1.9 or newer.

  • Bug fixes

    • Only process 0 in a multicontroller distributed JAX program will write persistent compilation cache entries. This fixes write contention if the cache is placed on a network file system such as GCS.

    • The version check for cusolver and cufft no longer considers the patch versions when determining if the installed version of these libraries is at least as new as the versions against which JAX was built.

jaxlib 0.4.19 (Oct 19, 2023)#

  • Changes

    • jaxlib will now always prefer pip-installed NVIDIA CUDA libraries (nvidia-… packages) over any other CUDA installation if they are installed, including installations named in LD_LIBRARY_PATH. If this causes problems and the intent is to use a system-installed CUDA, the fix is to remove the pip installed CUDA library packages.

jax 0.4.18 (Oct 6, 2023)#

jaxlib 0.4.18 (Oct 6, 2023)#

  • Changes

    • CUDA jaxlibs now depend on the user to install a compatible NCCL version. If using the recommended cuda12_pip installation, NCCL should be installed automatically. Currently, NCCL 2.16 or newer is required.

    • We now provide Linux aarch64 wheels, both with and without NVIDIA GPU support.

    • jax.Array.item() now supports optional index arguments.

  • Deprecations

    • A number of internal utilities and inadvertent exports in jax.lax have been deprecated, and will be removed in a future release.

      • jax.lax.dtypes: use jax.dtypes instead.

      • jax.lax.itertools: use itertools instead.

      • naryop, naryop_dtype_rule, standard_abstract_eval, standard_naryop, standard_primitive, standard_unop, unop, and unop_dtype_rule are internal utilities, now deprecated without replacement.

  • Bug fixes

    • Fixed Cloud TPU regression where compilation would OOM due to smem.

jax 0.4.17 (Oct 3, 2023)#

  • New features

  • Deprecations

    • Removed the deprecated module jax.abstract_arrays and all its contents.

    • Named key constructors in jax.random are deprecated. Pass the impl argument to jax.random.PRNGKey() or jax.random.key() instead:

      • random.threefry2x32_key(seed) becomes random.PRNGKey(seed, impl='threefry2x32')

      • random.rbg_key(seed) becomes random.PRNGKey(seed, impl='rbg')

      • random.unsafe_rbg_key(seed) becomes random.PRNGKey(seed, impl='unsafe_rbg')

  • Changes:

    • CUDA: JAX now verifies that the CUDA libraries it finds are at least as new as the CUDA libraries that JAX was built against. If older libraries are found, JAX raises an exception since that is preferable to mysterious failures and crashes.

    • Removed the “No GPU/TPU” found warning. Instead warn if, on Linux, an NVIDIA GPU or a Google TPU are found but not used and --jax_platforms was not specified.

    • jax.scipy.stats.mode() now returns a 0 count if the mode is taken across a size-0 axis, matching the behavior of scipy.stats.mode in SciPy 1.11.

    • Most jax.numpy functions and attributes now have fully-defined type stubs. Previously many of these were treated as Any by static type checkers like mypy and pytype.

jaxlib 0.4.17 (Oct 3, 2023)#

  • Changes:

    • Python 3.12 wheels were added in this release.

    • The CUDA 12 wheels now require CUDA 12.2 or newer and cuDNN 8.9.4 or newer.

  • Bug fixes:

    • Fixed log spam from ABSL when the JAX CPU backend was initialized.

jax 0.4.16 (Sept 18, 2023)#

  • Changes

    • Added jax.numpy.ufunc, as well as jax.numpy.frompyfunc(), which can convert any scalar-valued function into a numpy.ufunc()-like object, with methods such as outer(), reduce(), accumulate(), at(), and reduceat() (#17054).

    • Added jax.scipy.integrate.trapezoid().

    • When not running under IPython: when an exception is raised, JAX now filters out the entirety of its internal frames from tracebacks. (Without the “unfiltered stack trace” that previously appeared.) This should produce much friendlier-looking tracebacks. See here for an example. This behavior can be changed by setting JAX_TRACEBACK_FILTERING=remove_frames (for two separate unfiltered/filtered tracebacks, which was the old behavior) or JAX_TRACEBACK_FILTERING=off (for one unfiltered traceback).

    • jax2tf default serialization version is now 7, which introduces new shape safety assertions.

    • Devices passed to jax.sharding.Mesh should be hashable. This specifically applies to mock devices or user created devices. jax.devices() are already hashable.

  • Breaking changes:

    • jax2tf now uses native serialization by default. See the jax2tf documentation for details and for mechanisms to override the default.

    • The option --jax_coordination_service has been removed. It is now always True.

    • jax.jaxpr_util has been removed from the public JAX namespace.

    • JAX_USE_PJRT_C_API_ON_TPU no longer has an effect (i.e. it always defaults to true).

    • The backwards compatibility flag --jax_host_callback_ad_transforms introduced in December 2021, has been removed.

  • Deprecations:

    • Several jax.numpy APIs have been deprecated following NumPy NEP-52:

      • jax.numpy.NINF has been deprecated. Use -jax.numpy.inf instead.

      • jax.numpy.PZERO has been deprecated. Use 0.0 instead.

      • jax.numpy.NZERO has been deprecated. Use -0.0 instead.

      • jax.numpy.issubsctype(x, t) has been deprecated. Use jax.numpy.issubdtype(x.dtype, t).

      • jax.numpy.row_stack has been deprecated. Use jax.numpy.vstack instead.

      • jax.numpy.in1d has been deprecated. Use jax.numpy.isin instead.

      • jax.numpy.trapz has been deprecated. Use jax.scipy.integrate.trapezoid instead.

    • jax.scipy.linalg.tril and jax.scipy.linalg.triu have been deprecated, following SciPy. Use jax.numpy.tril and jax.numpy.triu instead.

    • jax.lax.prod has been removed after being deprecated in JAX v0.4.11. Use the built-in math.prod instead.

    • A number of exports from jax.interpreters.xla related to defining HLO lowering rules for custom JAX primitives have been deprecated. Custom primitives should be defined using the StableHLO lowering utilities in jax.interpreters.mlir instead.

    • The following previously-deprecated functions have been removed after a three-month deprecation period:

      • jax.abstract_arrays.ShapedArray: use jax.core.ShapedArray.

      • jax.abstract_arrays.raise_to_shaped: use jax.core.raise_to_shaped.

      • jax.numpy.alltrue: use jax.numpy.all.

      • jax.numpy.sometrue: use jax.numpy.any.

      • jax.numpy.product: use jax.numpy.prod.

      • jax.numpy.cumproduct: use jax.numpy.cumprod.

  • Deprecations/removals:

    • The internal submodule jax.prng is now deprecated. Its contents are available at jax.extend.random.

    • The internal submodule path jax.linear_util has been deprecated. Use jax.extend.linear_util instead (Part of jax.extend: a module for extensions)

    • jax.random.PRNGKeyArray and jax.random.KeyArray are deprecated. Use jax.Array for type annotations, and jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key) for runtime detection of typed prng keys.

    • The method PRNGKeyArray.unsafe_raw_array is deprecated. Use jax.random.key_data() instead.

    • jax.experimental.pjit.with_sharding_constraint is deprecated. Use jax.lax.with_sharding_constraint instead.

    • The internal utilities jax.core.is_opaque_dtype and jax.core.has_opaque_dtype have been removed. Opaque dtypes have been renamed to Extended dtypes; use jnp.issubdtype(dtype, jax.dtypes.extended) instead (available since jax v0.4.14).

    • The utility jax.interpreters.xla.register_collective_primitive has been removed. This utility did nothing useful in recent JAX releases and calls to it can be safely removed.

    • The internal submodule path jax.linear_util has been deprecated. Use jax.extend.linear_util instead (Part of jax.extend: a module for extensions)

jaxlib 0.4.16 (Sept 18, 2023)#

  • Changes:

    • Sparse CSR matrix multiplications via the experimental jax sparse APIs no longer uses a deterministic algorithm on NVIDIA GPUs. This change was made to improve compatibility with CUDA 12.2.1.

  • Bug fixes:

    • Fixed a crash on Windows due to a fatal LLVM error related to out-of-order sections and IMAGE_REL_AMD64_ADDR32NB relocations (https://github.com/openxla/xla/commit/cb732a921f0c4184995cbed82394931011d12bd4).

jax 0.4.14 (July 27, 2023)#

  • Changes

    • jax.jit takes donate_argnames as an argument. It’s semantics are similar to static_argnames. If neither donate_argnums nor donate_argnames is provided, no arguments are donated. If donate_argnums is not provided but donate_argnames is, or vice versa, JAX uses inspect.signature(fun) to find any positional arguments that correspond to donate_argnames (or vice versa). If both donate_argnums and donate_argnames are provided, inspect.signature is not used, and only actual parameters listed in either donate_argnums or donate_argnames will be donated.

    • jax.random.gamma() has been re-factored to a more efficient algorithm with more robust endpoint behavior (#16779). This means that the sequence of values returned for a given key will change between JAX v0.4.13 and v0.4.14 for gamma and related samplers (including jax.random.ball(), jax.random.beta(), jax.random.chisquare(), jax.random.dirichlet(), jax.random.generalized_normal(), jax.random.loggamma(), jax.random.t()).

  • Deletions

    • in_axis_resources and out_axis_resources have been deleted from pjit since it has been more than 3 months since their deprecation. Please use in_shardings and out_shardings as the replacement. This is a safe and trivial name replacement. It does not change any of the current pjit semantics and doesn’t break any code. You can still pass in PartitionSpecs to in_shardings and out_shardings.

  • Deprecations

    • Python 3.8 support has been dropped as per https://jax.readthedocs.io/en/latest/deprecation.html

    • JAX now requires NumPy 1.22 or newer as per https://jax.readthedocs.io/en/latest/deprecation.html

    • Passing optional arguments to jax.numpy.ndarray.at() by position is no longer supported, after being deprecated in JAX version 0.4.7. For example, instead of x.at[i].get(True), use x.at[i].get(indices_are_sorted=True)

    • The following jax.Array methods have been removed, after being deprecated in JAX v0.4.5:

    • The following APIs have been removed after previous deprecation:

      • jax.ad: use jax.interpreters.ad.

      • jax.curry: use curry = lambda f: partial(partial, f).

      • jax.partial_eval: use jax.interpreters.partial_eval.

      • jax.pxla: use jax.interpreters.pxla.

      • jax.xla: use jax.interpreters.xla.

      • jax.ShapedArray: use jax.core.ShapedArray.

      • jax.interpreters.pxla.device_put: use jax.device_put().

      • jax.interpreters.pxla.make_sharded_device_array: use jax.make_array_from_single_device_arrays().

      • jax.interpreters.pxla.ShardedDeviceArray: use jax.Array.

      • jax.numpy.DeviceArray: use jax.Array.

      • jax.stages.Compiled.compiler_ir: use jax.stages.Compiled.as_text().

  • Breaking changes

    • JAX now requires ml_dtypes version 0.2.0 or newer.

    • To fix a corner case, calls to jax.lax.cond() with five arguments will always resolve to the “common operands” cond behavior (as documented) if the second and third arguments are callable, even if other operands are callable as well. See #16413.

    • The deprecated config options jax_array and jax_jit_pjit_api_merge, which did nothing, have been removed. These options have been true by default for many releases.

  • New features

    • JAX now supports a configuration flag –jax_serialization_version and a JAX_SERIALIZATION_VERSION environment variable to control the serialization version (#16746).

    • jax2tf in presence of shape polymorphism now generates code that checks certain shape constraints, if the serialization version is at least 7. See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism.

jaxlib 0.4.14 (July 27, 2023)#

  • Deprecations

    • Python 3.8 support has been dropped as per https://jax.readthedocs.io/en/latest/deprecation.html

jax 0.4.13 (June 22, 2023)#

  • Changes

    • jax.jit now allows None to be passed to in_shardings and out_shardings. The semantics are as follows:

      • For in_shardings, JAX will mark is as replicated but this behavior can change in the future.

      • For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings.

    • jax.experimental.pjit.pjit also allows None to be passed to in_shardings and out_shardings. The semantics are as follows:

      • If the mesh context manager is not provided, JAX has the freedom to choose whatever sharding it wants.

        • For in_shardings, JAX will mark is as replicated but this behavior can change in the future.

        • For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings.

      • If the mesh context manager is provided, None will imply that the value will be replicated on all devices of the mesh.

    • Executable.cost_analysis() works on Cloud TPU

    • Added a warning if a non-allowlisted jaxlib plugin is in use.

    • Added jax.tree_util.tree_leaves_with_path.

    • None is not a valid input to jax.experimental.multihost_utils.host_local_array_to_global_array or jax.experimental.multihost_utils.global_array_to_host_local_array. Please use jax.sharding.PartitionSpec() if you wanted to replicate your input.

  • Bug fixes

    • Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel is named cudnn89 instead of cudnn88.

  • Deprecations

    • The native_serialization_strict_checks parameter to jax.experimental.jax2tf.convert() is deprecated in favor of the new native_serializaation_disabled_checks (#16347).

jaxlib 0.4.13 (June 22, 2023)#

  • Changes

    • Added Windows CPU-only wheels to the jaxlib Pypi release.

  • Bug fixes

    • __cuda_array_interface__ was broken in previous jaxlib versions and is now fixed (#16440).

    • Concurrent CUDA kernel tracing is now enabled by default on NVIDIA GPUs.

jax 0.4.12 (June 8, 2023)#

  • Changes

  • Deprecations

    • jax.abstract_arrays and its contents are now deprecated. See related functionality in :mod:jax.core.

    • jax.numpy.alltrue: use jax.numpy.all. This follows the deprecation of numpy.alltrue in NumPy version 1.25.0.

    • jax.numpy.sometrue: use jax.numpy.any. This follows the deprecation of numpy.sometrue in NumPy version 1.25.0.

    • jax.numpy.product: use jax.numpy.prod. This follows the deprecation of numpy.product in NumPy version 1.25.0.

    • jax.numpy.cumproduct: use jax.numpy.cumprod. This follows the deprecation of numpy.cumproduct in NumPy version 1.25.0.

    • jax.sharding.OpShardingSharding has been removed since it has been 3 months since it was deprecated.

jaxlib 0.4.12 (June 8, 2023)#

  • Changes

    • Includes PTX/SASS for Hopper (SM version 9.0+) GPUs. Previous versions of jaxlib should work on Hopper but would have a long JIT-compilation delay the first time a JAX operation was executed.

  • Bug fixes

    • Fixes incorrect source line information in JAX-generated Python tracebacks under Python 3.11.

    • Fixes crash when printing local variables of frames in JAX-generated Python tracebacks (#16027).

jax 0.4.11 (May 31, 2023)#

  • Deprecations

    • The following APIs have been removed after a 3 month deprecation period, in accordance with the API compatibility policy:

      • jax.experimental.PartitionSpec: use jax.sharding.PartitionSpec.

      • jax.experimental.maps.Mesh: use jax.sharding.Mesh

      • jax.experimental.pjit.NamedSharding: use jax.sharding.NamedSharding.

      • jax.experimental.pjit.PartitionSpec: use jax.sharding.PartitionSpec.

      • jax.experimental.pjit.FROM_GDA. Instead pass sharded jax.Array objects as input and remove the optional in_shardings argument to pjit.

      • jax.interpreters.pxla.PartitionSpec: use jax.sharding.PartitionSpec.

      • jax.interpreters.pxla.Mesh: use jax.sharding.Mesh

      • jax.interpreters.xla.Buffer: use jax.Array.

      • jax.interpreters.xla.Device: use jax.Device.

      • jax.interpreters.xla.DeviceArray: use jax.Array.

      • jax.interpreters.xla.device_put: use jax.device_put.

      • jax.interpreters.xla.xla_call_p: use jax.experimental.pjit.pjit_p.

      • axis_resources argument of with_sharding_constraint is removed. Please use shardings instead.

jaxlib 0.4.11 (May 31, 2023)#

  • Changes

    • Added memory_stats() method to Devices. If supported, this returns a dict of string stat names with int values, e.g. "bytes_in_use", or None if the platform doesn’t support memory statistics. The exact stats returned may vary across platforms. Currently only implemented on Cloud TPU.

    • Readded support for the Python buffer protocol (memoryview) on CPU devices.

jax 0.4.10 (May 11, 2023)#

jaxlib 0.4.10 (May 11, 2023)#

  • Changes

    • Fixed 'apple-m1' is not a recognized processor for this target (ignoring processor) issue that prevented previous release from running on Mac M1.

jax 0.4.9 (May 9, 2023)#

  • Changes

    • The flags experimental_cpp_jit, experimental_cpp_pjit and experimental_cpp_pmap have been removed. They are now always on.

    • Accuracy of singular value decomposition (SVD) on TPU has been improved (requires jaxlib 0.4.9).

  • Deprecations

    • jax.experimental.gda_serialization is deprecated and has been renamed to jax.experimental.array_serialization. Please change your imports to use jax.experimental.array_serialization.

    • The in_axis_resources and out_axis_resources arguments of pjit have been deprecated. Please use in_shardings and out_shardings respectively.

    • The function jax.numpy.msort has been removed. It has been deprecated since JAX v0.4.1. Use jnp.sort(a, axis=0) instead.

    • in_parts and out_parts arguments have been removed from jax.xla_computation since they were only used with sharded_jit and sharded_jit is long gone.

    • instantiate_const_outputs argument has been removed from jax.xla_computation since it has been unused for a very long time.

jaxlib 0.4.9 (May 9, 2023)#

jax 0.4.8 (March 29, 2023)#

  • Breaking changes

    • A major component of the Cloud TPU runtime has been upgraded. This enables the following new features on Cloud TPU:

      jax.experimental.host_callback() is no longer supported on Cloud TPU with the new runtime component. Please file an issue on the JAX issue tracker if the new jax.debug APIs are insufficient for your use case.

      The old runtime component will be available for at least the next three months by setting the environment variable JAX_USE_PJRT_C_API_ON_TPU=false. If you find you need to disable the new runtime for any reason, please let us know on the JAX issue tracker.

  • Changes

    • The minimum jaxlib version has been bumped from 0.4.6 to 0.4.7.

  • Deprecations

    • CUDA 11.4 support has been dropped. JAX GPU wheels only support CUDA 11.8 and CUDA 12. Older CUDA versions may work if jaxlib is built from source.

    • global_arg_shapes argument of pmap only worked with sharded_jit and has been removed from pmap. Please migrate to pjit and remove global_arg_shapes from pmap.

jax 0.4.7 (March 27, 2023)#

  • Changes

    • As per https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration jax.config.jax_array cannot be disabled anymore.

    • jax.config.jax_jit_pjit_api_merge cannot be disabled anymore.

    • jax.experimental.jax2tf.convert() now supports the native_serialization parameter to use JAX’s native lowering to StableHLO to obtain a StableHLO module for the entire JAX function instead of lowering each JAX primitive to a TensorFlow op. This simplifies the internals and increases the confidence that what you serialize matches the JAX native semantics. See documentation. As part of this change the config flag --jax2tf_default_experimental_native_lowering has been renamed to --jax2tf_native_serialization.

    • JAX now depends on ml_dtypes, which contains definitions of NumPy types like bfloat16. These definitions were previously internal to JAX, but have been split into a separate package to facilitate sharing them with other projects.

    • JAX now requires NumPy 1.21 or newer and SciPy 1.7 or newer.

  • Deprecations

    • The type jax.numpy.DeviceArray is deprecated. Use jax.Array instead, for which it is an alias.

    • The type jax.interpreters.pxla.ShardedDeviceArray is deprecated. Use jax.Array instead.

    • Passing additional arguments to jax.numpy.ndarray.at() by position is deprecated. For example, instead of x.at[i].get(True), use x.at[i].get(indices_are_sorted=True)

    • jax.interpreters.xla.device_put is deprecated. Please use jax.device_put.

    • jax.interpreters.pxla.device_put is deprecated. Please use jax.device_put.

    • jax.experimental.pjit.FROM_GDA is deprecated. Please pass in sharded jax.Arrays as input and remove the in_shardings argument to pjit since it is optional.

jaxlib 0.4.7 (March 27, 2023)#

Changes:

  • jaxlib now depends on ml_dtypes, which contains definitions of NumPy types like bfloat16. These definitions were previously internal to JAX, but have been split into a separate package to facilitate sharing them with other projects.

jax 0.4.6 (Mar 9, 2023)#

  • Changes

    • jax.tree_util now contain a set of APIs that allow user to define keys for their custom pytree node. This includes:

      • tree_flatten_with_path that flattens a tree and return not only each leaf but also their key paths.

      • tree_map_with_path that can map a function that takes the key path as an argument.

      • register_pytree_with_keys to register how the key path and leaves should looks like in a custom pytree node.

      • keystr that pretty-prints a key path.

    • jax2tf.call_tf() has a new parameter output_shape_dtype (default None) that can be used to declare the output shape and type of the result. This enables jax2tf.call_tf() to work in the presence of shape polymorphism. (#14734).

  • Deprecations

    • The old key-path APIs in jax.tree_util are deprecated and will be removed 3 months from Mar 10 2023:

jaxlib 0.4.6 (Mar 9, 2023)#

jax 0.4.5 (Mar 2, 2023)#

  • Deprecations

    • jax.sharding.OpShardingSharding has been renamed to jax.sharding.GSPMDSharding. jax.sharding.OpShardingSharding will be removed in 3 months from Feb 17, 2023.

    • The following jax.Array methods are deprecated and will be removed 3 months from Feb 23 2023:

jax 0.4.4 (Feb 16, 2023)#

  • Changes

    • The implementation of jit and pjit has been merged. Merging jit and pjit changes the internals of JAX without affecting the public API of JAX. Before, jit was a final style primitive. Final style means that the creation of jaxpr was delayed as much as possible and transformations were stacked on top of each other. With the jit-pjit implementation merge, jit becomes an initial style primitive which means that we trace to jaxpr as early as possible. For more information see this section in autodidax. Moving to initial style should simplify JAX’s internals and make development of features like dynamic shapes, etc easier. You can disable it only via the environment variable i.e. os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'. The merge must be disabled via an environment variable since it affects JAX at import time so it needs to be disabled before jax is imported.

    • axis_resources argument of with_sharding_constraint is deprecated. Please use shardings instead. There is no change needed if you were using axis_resources as an arg. If you were using it as a kwarg, then please use shardings instead. axis_resources will be removed after 3 months from Feb 13, 2023.

    • added the jax.typing module, with tools for type annotations of JAX functions.

    • The following names have been deprecated:

      • jax.xla.Device and jax.interpreters.xla.Device: use jax.Device.

      • jax.experimental.maps.Mesh. Use jax.sharding.Mesh instead.

      • jax.experimental.pjit.NamedSharding: use jax.sharding.NamedSharding.

      • jax.experimental.pjit.PartitionSpec: use jax.sharding.PartitionSpec.

      • jax.interpreters.pxla.Mesh: use jax.sharding.Mesh.

      • jax.interpreters.pxla.PartitionSpec: use jax.sharding.PartitionSpec.

  • Breaking Changes

    • the initial argument to reduction functions like :func:jax.numpy.sum is now required to be a scalar, consistent with the corresponding NumPy API. The previous behavior of broadcasting the output against non-scalar initial values was an unintentional implementation detail (#14446).

jaxlib 0.4.4 (Feb 16, 2023)#

  • Breaking changes

    • Support for NVIDIA Kepler series GPUs has been removed from the default jaxlib builds. If Kepler support is needed, it is still possible to build jaxlib from source with Kepler support (via the --cuda_compute_capabilities=sm_35 option to build.py), however note that CUDA 12 has completely dropped support for Kepler GPUs.

jax 0.4.3 (Feb 8, 2023)#

jaxlib 0.4.3 (Feb 8, 2023)#

  • jax.Array now has the non-blocking is_ready() method, which returns True if the array is ready (see also jax.block_until_ready()).

jax 0.4.2 (Jan 24, 2023)#

  • Breaking changes

    • Deleted jax.experimental.callback

    • Operations with dimensions in presence of jax2tf shape polymorphism have been generalized to work in more scenarios, by converting the symbolic dimension to JAX arrays. Operations involving symbolic dimensions and np.ndarray now can raise errors when the result is used as a shape value (#14106).

    • jaxpr objects now raise an error on attribute setting in order to avoid problematic mutations (#14102)

  • Changes

    • jax2tf.call_tf() has a new parameter has_side_effects (default True) that can be used to declare whether an instance can be removed or replicated by JAX optimizations such as dead-code elimination (#13980).

    • Added more support for floordiv and mod for jax2tf shape polymorphism. Previously, certain division operations resulted in errors in presence of symbolic dimensions (#14108).

jaxlib 0.4.2 (Jan 24, 2023)#

  • Changes

    • Set JAX_USE_PJRT_C_API_ON_TPU=1 to enable new Cloud TPU runtime, featuring automatic device memory defragmentation.

jax 0.4.1 (Dec 13, 2022)#

  • Changes

    • Support for Python 3.7 has been dropped, in accordance with JAX’s Python and NumPy version support policy.

    • We introduce jax.Array which is a unified array type that subsumes DeviceArray, ShardedDeviceArray, and GlobalDeviceArray types in JAX. The jax.Array type helps make parallelism a core feature of JAX, simplifies and unifies JAX internals, and allows us to unify jit and pjit. jax.Array has been enabled by default in JAX 0.4 and makes some breaking change to the pjit API. The jax.Array migration guide can help you migrate your codebase to jax.Array. You can also look at the Distributed arrays and automatic parallelization tutorial to understand the new concepts.

    • PartitionSpec and Mesh are now out of experimental. The new API endpoints are jax.sharding.PartitionSpec and jax.sharding.Mesh. jax.experimental.maps.Mesh and jax.experimental.PartitionSpec are deprecated and will be removed in 3 months.

    • with_sharding_constraints new public endpoint is jax.lax.with_sharding_constraint.

    • If using ABSL flags together with jax.config, the ABSL flag values are no longer read or written after the JAX configuration options are initially populated from the ABSL flags. This change improves performance of reading jax.config options, which are used pervasively in JAX.

    • The jax2tf.call_tf function now uses for TF lowering the first TF device of the same platform as used by the embedding JAX computation. Before, it was using the 0th device for the JAX-default backend.

    • A number of jax.numpy functions now have their arguments marked as positional-only, matching NumPy.

    • jnp.msort is now deprecated, following the deprecation of np.msort in numpy 1.24. It will be removed in a future release, in accordance with the API compatibility policy. It can be replaced with jnp.sort(a, axis=0).

jaxlib 0.4.1 (Dec 13, 2022)#

  • Changes

    • Support for Python 3.7 has been dropped, in accordance with JAX’s Python and NumPy version support policy.

    • The behavior of XLA_PYTHON_CLIENT_MEM_FRACTION=.XX has been changed to allocate XX% of the total GPU memory instead of the previous behavior of using currently available GPU memory to calculate preallocation. Please refer to GPU memory allocation for more details.

    • The deprecated method .block_host_until_ready() has been removed. Use .block_until_ready() instead.

jax 0.4.0 (Dec 12, 2022)#

  • The release was yanked.

jaxlib 0.4.0 (Dec 12, 2022)#

  • The release was yanked.

jax 0.3.25 (Nov 15, 2022)#

  • Changes

  • Breaking Changes

    • Deleted the jax_experimental_name_stack config option.

    • Convert a string axis_names arguments to the jax.experimental.maps.Mesh constructor into a singleton tuple instead of unpacking the string into a sequence of character axis names.

jaxlib 0.3.25 (Nov 15, 2022)#

  • Changes

    • Added support for tridiagonal reductions on CPU and GPU.

    • Added support for upper Hessenberg reductions on CPU.

  • Bugs

    • Fixed a bug that meant that frames in tracebacks captured by JAX were incorrectly mapped to source lines under Python 3.10+

jax 0.3.24 (Nov 4, 2022)#

  • Changes

    • JAX should be faster to import. We now import scipy lazily, which accounted for a significant fraction of JAX’s import time.

    • Setting the env var JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=$N can be used to limit the number of cache entries written to the persistent cache. By default, computations that take 1 second or more to compile will be cached.

    • The default device order used by pmap on TPU if no order is specified now matches jax.devices() for single-process jobs. Previously the two orderings differed, which could lead to unnecessary copies or out-of-memory errors. Requiring the orderings to agree simplifies matters.

  • Breaking Changes

  • Deprecations

    • jax.sharding.MeshPspecSharding has been renamed to jax.sharding.NamedSharding. jax.sharding.MeshPspecSharding name will be removed in 3 months.

jaxlib 0.3.24 (Nov 4, 2022)#

  • Changes

    • Buffer donation now works on CPU. This may break code that marked buffers for donation on CPU but relied on donation not being implemented.

jax 0.3.23 (Oct 12, 2022)#

  • Changes

    • Update Colab TPU driver version for new jaxlib release.

jax 0.3.22 (Oct 11, 2022)#

  • Changes

    • Add JAX_PLATFORMS=tpu,cpu as default setting in TPU initialization, so JAX will raise an error if TPU cannot be initialized instead of falling back to CPU. Set JAX_PLATFORMS='' to override this behavior and automatically choose an available backend (the original default), or set JAX_PLATFORMS=cpu to always use CPU regardless of if the TPU is available.

  • Deprecations

    • Several test utilities deprecated in JAX v0.3.8 are now removed from jax.test_util.

jaxlib 0.3.22 (Oct 11, 2022)#

jax 0.3.21 (Sep 30, 2022)#

  • GitHub commits.

  • Changes

    • The persistent compilation cache will now warn instead of raising an exception on error (#12582), so program execution can continue if something goes wrong with the cache. Set JAX_RAISE_PERSISTENT_CACHE_ERRORS=true to revert this behavior.

jax 0.3.20 (Sep 28, 2022)#

  • Bug fixes:

    • Adds missing .pyi files that were missing from the previous release (#12536).

    • Fixes an incompatibility between jax 0.3.19 and the libtpu version it pinned (#12550). Requires jaxlib 0.3.20.

    • Fix incorrect pip url in setup.py comment (#12528).

jaxlib 0.3.20 (Sep 28, 2022)#

  • GitHub commits.

  • Bug fixes

    • Fixes support for limiting the visible CUDA devices via jax_cuda_visible_devices in distributed jobs. This functionality is needed for the JAX/SLURM integration on GPU (#12533).

jax 0.3.19 (Sep 27, 2022)#

jax 0.3.18 (Sep 26, 2022)#

  • GitHub commits.

  • Changes

    • Ahead-of-time lowering and compilation functionality (tracked in #7733) is stable and public. See the overview and the API docs for jax.stages.

    • Introduced jax.Array, intended to be used for both isinstance checks and type annotations for array types in JAX. Notice that this included some subtle changes to how isinstance works for jax.numpy.ndarray for jax-internal objects, as jax.numpy.ndarray is now a simple alias of jax.Array.

  • Breaking changes

    • jax._src is no longer imported into the public jax namespace. This may break users that were using JAX internals.

    • jax.soft_pmap has been deleted. Please use pjit or xmap instead. jax.soft_pmap is undocumented. If it were documented, a deprecation period would have been provided.

jax 0.3.17 (Aug 31, 2022)#

  • GitHub commits.

  • Bugs

    • Fix corner case issue in gradient of lax.pow with an exponent of zero (#12041)

  • Breaking changes

    • jax.checkpoint(), also known as jax.remat(), no longer supports the concrete option, following the previous version’s deprecation; see JEP 11830.

  • Changes

    • Added jax.pure_callback() that enables calling back to pure Python functions from compiled functions (e.g. functions decorated with jax.jit or jax.pmap).

  • Deprecations:

    • The deprecated DeviceArray.tile() method has been removed. Use jax.numpy.tile() (#11944).

    • DeviceArray.to_py() has been deprecated. Use np.asarray(x) instead.

jax 0.3.16#

jax 0.3.15 (July 22, 2022)#

jaxlib 0.3.15 (July 22, 2022)#

jax 0.3.14 (June 27, 2022)#

  • GitHub commits.

  • Breaking changes

    • jax.experimental.compilation_cache.initialize_cache() does not support max_cache_size_  bytes anymore and will not get that as an input.

    • JAX_PLATFORMS now raises an exception when platform initialization fails.

  • Changes

    • Fixed compatibility problems with NumPy 1.23.

    • jax.numpy.linalg.slogdet() now accepts an optional method argument that allows selection between an LU-decomposition based implementation and an implementation based on QR decomposition.

    • jax.numpy.linalg.qr() now supports mode="raw".

    • pickle, copy.copy, and copy.deepcopy now have more complete support when used on jax arrays (#10659). In particular:

      • pickle and deepcopy previously returned np.ndarray objects when used on a DeviceArray; now DeviceArray objects are returned. For deepcopy, the copied array is on the same device as the original. For pickle the deserialized array will be on the default device.

      • Within function transformations (i.e. traced code), deepcopy and copy previously were no-ops. Now they use the same mechanism as DeviceArray.copy().

      • Calling pickle on a traced array now results in an explicit ConcretizationTypeError.

    • The implementation of singular value decomposition (SVD) and symmetric/Hermitian eigendecomposition should be significantly faster on TPU, especially for matrices above 1000x1000 or so. Both now use a spectral divide-and-conquer algorithm for eigendecomposition (QDWH-eig).

    • jax.numpy.ldexp() no longer silently promotes all inputs to float64, instead it promotes to float32 for integer inputs of size int32 or smaller (#10921).

    • Add a create_perfetto_link option to jax.profiler.start_trace() and jax.profiler.start_trace(). When used, the profiler will generate a link to the Perfetto UI to view the trace.

    • Changed the semantics of jax.profiler.start_server(...)() to store the keepalive globally, rather than requiring the user to keep a reference to it.

    • Added jax.random.generalized_normal().

    • Added jax.random.ball().

    • Added jax.default_device().

    • Added a python -m jax.collect_profile script to manually capture program traces as an alternative to the Tensorboard UI.

    • Added a jax.named_scope context manager that adds profiler metadata to Python programs (similar to jax.named_call).

    • In scatter-update operations (i.e. :attr:jax.numpy.ndarray.at), unsafe implicit dtype casts are deprecated, and now result in a FutureWarning. In a future release, this will become an error. An example of an unsafe implicit cast is jnp.zeros(4, dtype=int).at[0].set(1.5), in which 1.5 previously was silently truncated to 1.

    • jax.experimental.compilation_cache.initialize_cache() now supports gcs bucket path as input.

    • Added jax.scipy.stats.gennorm().

    • jax.numpy.roots() is now better behaved when strip_zeros=False when coefficients have leading zeros (#11215).

jaxlib 0.3.14 (June 27, 2022)#

  • GitHub commits.

    • x86-64 Mac wheels now require Mac OS 10.14 (Mojave) or newer. Mac OS 10.14 was released in 2018, so this should not be a very onerous requirement.

    • The bundled version of NCCL was updated to 2.12.12, fixing some deadlocks.

    • The Python flatbuffers package is no longer a dependency of jaxlib.

jax 0.3.13 (May 16, 2022)#

jax 0.3.12 (May 15, 2022)#

jax 0.3.11 (May 15, 2022)#

  • GitHub commits.

  • Changes

    • jax.lax.eigh() now accepts an optional sort_eigenvalues argument that allows users to opt out of eigenvalue sorting on TPU.

  • Deprecations

    • Non-array arguments to functions in jax.lax.linalg are now marked keyword-only. As a backward-compatibility step passing keyword-only arguments positionally yields a warning, but in a future JAX release passing keyword-only arguments positionally will fail. However, most users should prefer to use jax.numpy.linalg instead.

    • jax.scipy.linalg.polar_unitary(), which was a JAX extension to the scipy API, is deprecated. Use jax.scipy.linalg.polar() instead.

jax 0.3.10 (May 3, 2022)#

jaxlib 0.3.10 (May 3, 2022)#

  • GitHub commits.

  • Changes

    • TF commit fixes an issue in the MHLO canonicalizer that caused constant folding to take a long time or crash for certain programs.

jax 0.3.9 (May 2, 2022)#

  • GitHub commits.

  • Changes

    • Added support for fully asynchronous checkpointing for GlobalDeviceArray.

jax 0.3.8 (April 29 2022)#

  • GitHub commits.

  • Changes

    • jax.numpy.linalg.svd() on TPUs uses a qdwh-svd solver.

    • jax.numpy.linalg.cond() on TPUs now accepts complex input.

    • jax.numpy.linalg.pinv() on TPUs now accepts complex input.

    • jax.numpy.linalg.matrix_rank() on TPUs now accepts complex input.

    • jax.scipy.cluster.vq.vq() has been added.

    • jax.experimental.maps.mesh has been deleted. Please use jax.experimental.maps.Mesh. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh for more information.

    • jax.scipy.linalg.qr() now returns a length-1 tuple rather than the raw array when mode='r', in order to match the behavior of scipy.linalg.qr (#10452)

    • jax.numpy.take_along_axis() now takes an optional mode parameter that specifies the behavior of out-of-bounds indexing. By default, invalid values (e.g., NaN) will be returned for out-of-bounds indices. In previous versions of JAX, invalid indices were clamped into range. The previous behavior can be restored by passing mode="clip".

    • jax.numpy.take() now defaults to mode="fill", which returns invalid values (e.g., NaN) for out-of-bounds indices.

    • Scatter operations, such as x.at[...].set(...), now have "drop" semantics. This has no effect on the scatter operation itself, but it means that when differentiated the gradient of a scatter will yield zero cotangents for out-of-bounds indices. Previously out-of-bounds indices were clamped into range for the gradient, which was not mathematically correct.

    • jax.numpy.take_along_axis() now raises a TypeError if its indices are not of an integer type, matching the behavior of numpy.take_along_axis(). Previously non-integer indices were silently cast to integers.

    • jax.numpy.ravel_multi_index() now raises a TypeError if its dims argument is not of an integer type, matching the behavior of numpy.ravel_multi_index(). Previously non-integer dims was silently cast to integers.

    • jax.numpy.split() now raises a TypeError if its axis argument is not of an integer type, matching the behavior of numpy.split(). Previously non-integer axis was silently cast to integers.

    • jax.numpy.indices() now raises a TypeError if its dimensions are not of an integer type, matching the behavior of numpy.indices(). Previously non-integer dimensions were silently cast to integers.

    • jax.numpy.diag() now raises a TypeError if its k argument is not of an integer type, matching the behavior of numpy.diag(). Previously non-integer k was silently cast to integers.

    • Added jax.random.orthogonal().

  • Deprecations

    • Many functions and objects available in jax.test_util are now deprecated and will raise a warning on import. This includes cases_from_list, check_close, check_eq, device_under_test, format_shape_dtype_string, rand_uniform, skip_on_devices, with_config, xla_bridge, and _default_tolerance (#10389). These, along with previously-deprecated JaxTestCase, JaxTestLoader, and BufferDonationTestCase, will be removed in a future JAX release. Most of these utilities can be replaced by calls to standard python & numpy testing utilities found in e.g. unittest, absl.testing, numpy.testing, etc. JAX-specific functionality such as device checking can be replaced through the use of public APIs such as jax.devices(). Many of the deprecated utilities will still exist in jax._src.test_util, but these are not public APIs and as such may be changed or removed without notice in future releases.

jax 0.3.7 (April 15, 2022)#

jaxlib 0.3.7 (April 15, 2022)#

  • Changes:

    • Linux wheels are now built conforming to the manylinux2014 standard, instead of manylinux2010.

jax 0.3.6 (April 12, 2022)#

  • GitHub commits.

  • Changes:

    • Upgraded libtpu wheel to a version that fixes a hang when initializing a TPU pod. Fixes #10218.

  • Deprecations:

    • jax.experimental.loops is being deprecated. See #10278 for an alternative API.

jax 0.3.5 (April 7, 2022)#

jaxlib 0.3.5 (April 7, 2022)#

  • Bug fixes

    • Fixed a bug where double-precision complex-to-real IRFFTs would mutate their input buffers on GPU (#9946).

    • Fixed incorrect constant-folding of complex scatters (#10159)

jax 0.3.4 (March 18, 2022)#

jax 0.3.3 (March 17, 2022)#

jax 0.3.2 (March 16, 2022)#

  • GitHub commits.

  • Changes:

    • The functions jax.ops.index_update, jax.ops.index_add, which were deprecated in 0.2.22, have been removed. Please use the .at property on JAX arrays instead, e.g., x.at[idx].set(y).

    • Moved jax.experimental.ann.approx_*_k into jax.lax. These functions are optimized alternatives to jax.lax.top_k.

    • jax.numpy.broadcast_arrays() and jax.numpy.broadcast_to() now require scalar or array-like inputs, and will fail if they are passed lists (part of #7737).

    • The standard jax[tpu] install can now be used with Cloud TPU v4 VMs.

    • pjit now works on CPU (in addition to previous TPU and GPU support).

jaxlib 0.3.2 (March 16, 2022)#

  • Changes

    • XlaComputation.as_hlo_text() now supports printing large constants by passing boolean flag print_large_constants=True.

  • Deprecations:

    • The .block_host_until_ready() method on JAX arrays has been deprecated. Use .block_until_ready() instead.

jax 0.3.1 (Feb 18, 2022)#

jax 0.3.0 (Feb 10, 2022)#

jaxlib 0.3.0 (Feb 10, 2022)#

  • Changes

    • Bazel 5.0.0 is now required to build jaxlib.

    • jaxlib version has been bumped to 0.3.0. Please see the design doc for the explanation.

jax 0.2.28 (Feb 1, 2022)#

  • GitHub commits.

    • jax.jit(f).lower(...).compiler_ir() now defaults to the MHLO dialect if no dialect= is passed.

    • The jax.jit(f).lower(...).compiler_ir(dialect='mhlo') now returns an MLIR ir.Module object instead of its string representation.

jaxlib 0.1.76 (Jan 27, 2022)#

  • New features

    • Includes precompiled SASS for NVidia compute capability 8.0 GPUS (e.g. A100). Removes precompiled SASS for compute capability 6.1 so as not to increase the number of compute capabilities: GPUs with compute capability 6.1 can use the 6.0 SASS.

    • With jaxlib 0.1.76, JAX uses the MHLO MLIR dialect as its primary target compiler IR by default.

  • Breaking changes

    • Support for NumPy 1.18 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.

  • Bug fixes

    • Fixed a bug where apparently identical pytreedef objects constructed by different routes do not compare as equal (#9066).

    • The JAX jit cache requires two static arguments to have identical types for a cache hit (#9311).

jax 0.2.27 (Jan 18 2022)#

  • GitHub commits.

  • Breaking changes:

    • Support for NumPy 1.18 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.

    • The host_callback primitives have been simplified to drop the special autodiff handling for hcb.id_tap and id_print. From now on, only the primals are tapped. The old behavior can be obtained (for a limited time) by setting the JAX_HOST_CALLBACK_AD_TRANSFORMS environment variable, or the --jax_host_callback_ad_transforms flag. Additionally, added documentation for how to implement the old behavior using JAX custom AD APIs (#8678).

    • Sorting now matches the behavior of NumPy for 0.0 and NaN regardless of the bit representation. In particular, 0.0 and -0.0 are now treated as equivalent, where previously -0.0 was treated as less than 0.0. Additionally all NaN representations are now treated as equivalent and sorted to the end of the array. Previously negative NaN values were sorted to the front of the array, and NaN values with different internal bit representations were not treated as equivalent, and were sorted according to those bit patterns (#9178).

    • jax.numpy.unique() now treats NaN values in the same way as np.unique in NumPy versions 1.21 and newer: at most one NaN value will appear in the uniquified output (#9184).

  • Bug fixes:

    • host_callback now supports ad_checkpoint.checkpoint (#8907).

  • New features:

    • add jax.block_until_ready ({jax-issue}`#8941)

    • Added a new debugging flag/environment variable JAX_DUMP_IR_TO=/path. If set, JAX dumps the MHLO/HLO IR it generates for each computation to a file under the given path.

    • Added jax.ensure_compile_time_eval to the public api (#7987).

    • jax2tf now supports a flag jax2tf_associative_scan_reductions to change the lowering for associative reductions, e.g., jnp.cumsum, to behave like JAX on CPU and GPU (to use an associative scan). See the jax2tf README for more details (#9189).

jaxlib 0.1.75 (Dec 8, 2021)#

  • New features:

    • Support for python 3.10.

jax 0.2.26 (Dec 8, 2021)#

  • GitHub commits.

  • Bug fixes:

    • Out-of-bounds indices to jax.ops.segment_sum will now be handled with FILL_OR_DROP semantics, as documented. This primarily affects the reverse-mode derivative, where gradients corresponding to out-of-bounds indices will now be returned as 0. (#8634).

    • jax2tf will force the converted code to use XLA for the code fragments under jax.jit, e.g., most jax.numpy functions (#7839).

jaxlib 0.1.74 (Nov 17, 2021)#

  • Enabled peer-to-peer copies between GPUs. Previously, GPU copies were bounced via the host, which is usually slower.

  • Added experimental MLIR Python bindings for use by JAX.

jax 0.2.25 (Nov 10, 2021)#

  • GitHub commits.

  • New features:

    • (Experimental) jax.distributed.initialize exposes multi-host GPU backend.

    • jax.random.permutation supports new independent keyword argument (#8430)

  • Breaking changes

    • Moved jax.experimental.stax to jax.example_libraries.stax

    • Moved jax.experimental.optimizers to jax.example_libraries.optimizers

  • New features:

    • Added jax.lax.linalg.qdwh.

jax 0.2.24 (Oct 19, 2021)#

  • GitHub commits.

  • New features:

    • jax.random.choice and jax.random.permutation now support multidimensional arrays and an optional axis argument (#8158)

  • Breaking changes:

    • jax.numpy.take and jax.numpy.take_along_axis now require array-like inputs (see #7737)

jaxlib 0.1.73 (Oct 18, 2021)#

  • Multiple cuDNN versions are now supported for jaxlib GPU cuda11 wheels.

    • cuDNN 8.2 or newer. We recommend using the cuDNN 8.2 wheel if your cuDNN installation is new enough, since it supports additional functionality.

    • cuDNN 8.0.5 or newer.

  • Breaking changes:

    • The install commands for GPU jaxlib are as follows:

      pip install --upgrade pip
      
      # Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
      pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
      
      # Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer.
      pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
      
      # Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer.
      pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
      

jax 0.2.22 (Oct 12, 2021)#

  • GitHub commits.

  • Breaking Changes

    • Static arguments to jax.pmap must now be hashable.

      Unhashable static arguments have long been disallowed on jax.jit, but they were still permitted on jax.pmap; jax.pmap compared unhashable static arguments using object identity.

      This behavior is a footgun, since comparing arguments using object identity leads to recompilation each time the object identity changes. Instead, we now ban unhashable arguments: if a user of jax.pmap wants to compare static arguments by object identity, they can define __hash__ and __eq__ methods on their objects that do that, or wrap their objects in an object that has those operations with object identity semantics. Another option is to use functools.partial to encapsulate the unhashable static arguments into the function object.

    • jax.util.partial was an accidental export that has now been removed. Use functools.partial from the Python standard library instead.

  • Deprecations

    • The functions jax.ops.index_update, jax.ops.index_add etc. are deprecated and will be removed in a future JAX release. Please use the .at property on JAX arrays instead, e.g., x.at[idx].set(y). For now, these functions produce a DeprecationWarning.

  • New features:

    • An optimized C++ code-path improving the dispatch time for pmap is now the default when using jaxlib 0.1.72 or newer. The feature can be disabled using the --experimental_cpp_pmap flag (or JAX_CPP_PMAP environment variable).

    • jax.numpy.unique now supports an optional fill_value argument (#8121)

jaxlib 0.1.72 (Oct 12, 2021)#

  • Breaking changes:

    • Support for CUDA 10.2 and CUDA 10.1 has been dropped. Jaxlib now supports CUDA 11.1+.

  • Bug fixes:

    • Fixes https://github.com/google/jax/issues/7461, which caused wrong outputs on all platforms due to incorrect buffer aliasing inside the XLA compiler.

jax 0.2.21 (Sept 23, 2021)#

  • GitHub commits.

  • Breaking Changes

    • jax.api has been removed. Functions that were available as jax.api.* were aliases for functions in jax.*; please use the functions in jax.* instead.

    • jax.partial, and jax.lax.partial were accidental exports that have now been removed. Use functools.partial from the Python standard library instead.

    • Boolean scalar indices now raise a TypeError; previously this silently returned wrong results (#7925).

    • Many more jax.numpy functions now require array-like inputs, and will error if passed a list (#7747 #7802 #7907). See #7737 for a discussion of the rationale behind this change.

    • When inside a transformation such as jax.jit, jax.numpy.array always stages the array it produces into the traced computation. Previously jax.numpy.array would sometimes produce a on-device array, even under a jax.jit decorator. This change may break code that used JAX arrays to perform shape or index computations that must be known statically; the workaround is to perform such computations using classic NumPy arrays instead.

    • jnp.ndarray is now a true base-class for JAX arrays. In particular, this means that for a standard numpy array x, isinstance(x, jnp.ndarray) will now return False (#7927).

  • New features:

jax 0.2.20 (Sept 2, 2021)#

  • GitHub commits.

  • Breaking Changes

    • jnp.poly* functions now require array-like inputs (#7732)

    • jnp.unique and other set-like operations now require array-like inputs (#7662)

jaxlib 0.1.71 (Sep 1, 2021)#

  • Breaking changes:

    • Support for CUDA 11.0 and CUDA 10.1 has been dropped. Jaxlib now supports CUDA 10.2 and CUDA 11.1+.

jax 0.2.19 (Aug 12, 2021)#

  • GitHub commits.

  • Breaking changes:

    • Support for NumPy 1.17 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.

    • The jit decorator has been added around the implementation of a number of operators on JAX arrays. This speeds up dispatch times for common operators such as +.

      This change should largely be transparent to most users. However, there is one known behavioral change, which is that large integer constants may now produce an error when passed directly to a JAX operator (e.g., x + 2**40). The workaround is to cast the constant to an explicit type (e.g., np.float64(2**40)).

  • New features:

    • Improved the support for shape polymorphism in jax2tf for operations that need to use a dimension size in array computation, e.g., jnp.mean. (#7317)

  • Bug fixes:

    • Some leaked trace errors from the previous release (#7613)

jaxlib 0.1.70 (Aug 9, 2021)#

  • Breaking changes:

    • Support for Python 3.6 has been dropped, per the deprecation policy. Please upgrade to a supported Python version.

    • Support for NumPy 1.17 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.

    • The host_callback mechanism now uses one thread per local device for making the calls to the Python callbacks. Previously there was a single thread for all devices. This means that the callbacks may now be called interleaved. The callbacks corresponding to one device will still be called in sequence.

jax 0.2.18 (July 21 2021)#

  • GitHub commits.

  • Breaking changes:

    • Support for Python 3.6 has been dropped, per the deprecation policy. Please upgrade to a supported Python version.

    • The minimum jaxlib version is now 0.1.69.

    • The backend argument to jax.dlpack.from_dlpack() has been removed.

  • New features:

  • Bug fixes:

    • Tightened the checks for lax.argmin and lax.argmax to ensure they are not used with an invalid axis value, or with an empty reduction dimension. (#7196)

jaxlib 0.1.69 (July 9 2021)#

  • Fix bugs in TFRT CPU backend that results in incorrect results.

jax 0.2.17 (July 9 2021)#

jax 0.2.16 (June 23 2021)#

jax 0.2.15 (June 23 2021)#

  • GitHub commits.

  • New features:

    • #7042 Turned on TFRT CPU backend with significant dispatch performance improvements on CPU.

    • The jax2tf.convert() supports inequalities and min/max for booleans (#6956).

    • New SciPy function jax.scipy.special.lpmn_values().

  • Breaking changes:

  • Bug fixes:

    • Fixed bug that prevented round-tripping from JAX to TF and back: jax2tf.call_tf(jax2tf.convert) (#6947).

jaxlib 0.1.68 (June 23 2021)#

  • Bug fixes:

    • Fixed bug in TFRT CPU backend that gets nans when transfer TPU buffer to CPU.

jax 0.2.14 (June 10 2021)#

  • GitHub commits.

  • New features:

    • The jax2tf.convert() now has support for pjit and sharded_jit.

    • A new configuration option JAX_TRACEBACK_FILTERING controls how JAX filters tracebacks.

    • A new traceback filtering mode using __tracebackhide__ is now enabled by default in sufficiently recent versions of IPython.

    • The jax2tf.convert() supports shape polymorphism even when the unknown dimensions are used in arithmetic operations, e.g., jnp.reshape(-1) (#6827).

    • The jax2tf.convert() generates custom attributes with location information in TF ops. The code that XLA generates after jax2tf has the same location information as JAX/XLA.

    • New SciPy function jax.scipy.special.lpmn().

  • Bug fixes:

    • The jax2tf.convert() now ensures that it uses the same typing rules for Python scalars and for choosing 32-bit vs. 64-bit computations as JAX (#6883).

    • The jax2tf.convert() now scopes the enable_xla conversion parameter properly to apply only during the just-in-time conversion (#6720).

    • The jax2tf.convert() now converts lax.dot_general using the XlaDot TensorFlow op, for better fidelity w.r.t. JAX numerical precision (#6717).

    • The jax2tf.convert() now has support for inequality comparisons and min/max for complex numbers (#6892).

jaxlib 0.1.67 (May 17 2021)#

jaxlib 0.1.66 (May 11 2021)#

  • New features:

    • CUDA 11.1 wheels are now supported on all CUDA 11 versions 11.1 or higher.

      NVidia now promises compatibility between CUDA minor releases starting with CUDA 11.1. This means that JAX can release a single CUDA 11.1 wheel that is compatible with CUDA 11.2 and 11.3.

      There is no longer a separate jaxlib release for CUDA 11.2 (or higher); use the CUDA 11.1 wheel for those versions (cuda111).

    • Jaxlib now bundles libdevice.10.bc in CUDA wheels. There should be no need to point JAX to a CUDA installation to find this file.

    • Added automatic support for static keyword arguments to the jit() implementation.

    • Added support for pretransformation exception traces.

    • Initial support for pruning unused arguments from jit() -transformed computations. Pruning is still a work in progress.

    • Improved the string representation of PyTreeDef objects.

    • Added support for XLA’s variadic ReduceWindow.

  • Bug fixes:

    • Fixed a bug in the remote cloud TPU support when large numbers of arguments are passed to a computation.

    • Fix a bug that meant that JAX garbage collection was not triggered by jit() transformed functions.

jax 0.2.13 (May 3 2021)#

  • GitHub commits.

  • New features:

    • When combined with jaxlib 0.1.66, jax.jit() now supports static keyword arguments. A new static_argnames option has been added to specify keyword arguments as static.

    • jax.nonzero() has a new optional size argument that allows it to be used within jit (#6501)

    • jax.numpy.unique() now supports the axis argument (#6532).

    • jax.experimental.host_callback.call() now supports pjit.pjit (#6569).

    • Added jax.scipy.linalg.eigh_tridiagonal() that computes the eigenvalues of a tridiagonal matrix. Only eigenvalues are supported at present.

    • The order of the filtered and unfiltered stack traces in exceptions has been changed. The traceback attached to an exception thrown from JAX-transformed code is now filtered, with an UnfilteredStackTrace exception containing the original trace as the __cause__ of the filtered exception. Filtered stack traces now also work with Python 3.6.

    • If an exception is thrown by code that has been transformed by reverse-mode automatic differentiation, JAX now attempts to attach as a __cause__ of the exception a JaxStackTraceBeforeTransformation object that contains the stack trace that created the original operation in the forward pass. Requires jaxlib 0.1.66.

  • Breaking changes:

    • The following function names have changed. There are still aliases, so this should not break existing code, but the aliases will eventually be removed so please change your code.

    • Similarly, the argument to local_devices() has been renamed from host_id to process_index.

    • Arguments to jax.jit() other than the function are now marked as keyword-only. This change is to prevent accidental breakage when arguments are added to jit.

  • Bug fixes:

    • The jax2tf.convert() now works in presence of gradients for functions with integer inputs (#6360).

    • Fixed assertion failure in jax2tf.call_tf() when used with captured tf.Variable (#6572).

jaxlib 0.1.65 (April 7 2021)#

jax 0.2.12 (April 1 2021)#

  • GitHub commits.

  • New features

  • Breaking changes:

    • The minimum jaxlib version is now 0.1.64.

    • Some profiler APIs names have been changed. There are still aliases, so this should not break existing code, but the aliases will eventually be removed so please change your code.

    • Omnistaging can no longer be disabled. See omnistaging for more information.

    • Python integers larger than the maximum int64 value will now lead to an overflow in all cases, rather than being silently converted to uint64 in some cases (#6047).

    • Outside X64 mode, Python integers outside the range representable by int32 will now lead to an OverflowError rather than having their value silently truncated.

  • Bug fixes:

    • host_callback now supports empty arrays in arguments and results (#6262).

    • jax.random.randint() clips rather than wraps of out-of-bounds limits, and can now generate integers in the full range of the specified dtype (#5868)

jax 0.2.11 (March 23 2021)#

  • GitHub commits.

  • New features:

    • #6112 added context managers: jax.enable_checks, jax.check_tracer_leaks, jax.debug_nans, jax.debug_infs, jax.log_compiles.

    • #6085 added jnp.delete

  • Bug fixes:

    • #6136 generalized jax.flatten_util.ravel_pytree to handle integer dtypes.

    • #6129 fixed a bug with handling some constants like enum.IntEnums

    • #6145 fixed batching issues with incomplete beta functions

    • #6014 fixed H2D transfers during tracing

    • #6165 avoids OverflowErrors when converting some large Python integers to floats

  • Breaking changes:

    • The minimum jaxlib version is now 0.1.62.

jaxlib 0.1.64 (March 18 2021)#

jaxlib 0.1.63 (March 17 2021)#

jax 0.2.10 (March 5 2021)#

  • GitHub commits.

  • New features:

    • jax.scipy.stats.chi2() is now available as a distribution with logpdf and pdf methods.

    • jax.scipy.stats.betabinom() is now available as a distribution with logpmf and pmf methods.

    • Added jax.experimental.jax2tf.call_tf() to call TensorFlow functions from JAX (#5627) and README).

    • Extended the batching rule for lax.pad to support batching of the padding values.

  • Bug fixes:

  • Breaking changes:

    • JAX’s promotion rules were adjusted to make promotion more consistent and invariant to JIT. In particular, binary operations can now result in weakly-typed values when appropriate. The main user-visible effect of the change is that some operations result in outputs of different precision than before; for example the expression jnp.bfloat16(1) + 0.1 * jnp.arange(10) previously returned a float64 array, and now returns a bfloat16 array. JAX’s type promotion behavior is described at Type promotion semantics.

    • jax.numpy.linspace() now computes the floor of integer values, i.e., rounding towards -inf rather than 0. This change was made to match NumPy 1.20.0.

    • jax.numpy.i0() no longer accepts complex numbers. Previously the function computed the absolute value of complex arguments. This change was made to match the semantics of NumPy 1.20.0.

    • Several jax.numpy functions no longer accept tuples or lists in place of array arguments: jax.numpy.pad(), :funcjax.numpy.ravel, jax.numpy.repeat(), jax.numpy.reshape(). In general, jax.numpy functions should be used with scalars or array arguments.

jaxlib 0.1.62 (March 9 2021)#

  • New features:

    • jaxlib wheels are now built to require AVX instructions on x86-64 machines by default. If you want to use JAX on a machine that doesn’t support AVX, you can build a jaxlib from source using the --target_cpu_features flag to build.py. --target_cpu_features also replaces --enable_march_native.

jaxlib 0.1.61 (February 12 2021)#

jaxlib 0.1.60 (February 3 2021)#

  • Bug fixes:

    • Fixed a memory leak when converting CPU DeviceArrays to NumPy arrays. The memory leak was present in jaxlib releases 0.1.58 and 0.1.59.

    • bool, int8, and uint8 are now considered safe to cast to bfloat16 NumPy extension type.

jax 0.2.9 (January 26 2021)#

  • GitHub commits.

  • New features:

  • Breaking changes:

    • jax.ops.segment_sum() now drops segment IDs that are out of range rather than wrapping them into the segment ID space. This was done for performance reasons.

jaxlib 0.1.59 (January 15 2021)#

jax 0.2.8 (January 12 2021)#

  • GitHub commits.

  • New features:

  • Bug fixes:

    • jax.numpy.arccosh now returns the same branch as numpy.arccosh for complex inputs (#5156)

    • host_callback.id_tap now works for jax.pmap also. There is an optional parameter for id_tap and id_print to request that the device from which the value is tapped be passed as a keyword argument to the tap function (#5182).

  • Breaking changes:

  • New features:

    • New flag for debugging inf, analogous to that for NaN (#5224).

jax 0.2.7 (Dec 4 2020)#

  • GitHub commits.

  • New features:

    • Add jax.device_put_replicated

    • Add multi-host support to jax.experimental.sharded_jit

    • Add support for differentiating eigenvalues computed by jax.numpy.linalg.eig

    • Add support for building on Windows platforms

    • Add support for general in_axes and out_axes in jax.pmap

    • Add complex support for jax.numpy.linalg.slogdet

  • Bug fixes:

    • Fix higher-than-second order derivatives of jax.numpy.sinc at zero

    • Fix some hard-to-hit bugs around symbolic zeros in transpose rules

  • Breaking changes:

    • jax.experimental.optix has been deleted, in favor of the standalone optax Python package.

    • indexing of JAX arrays with non-tuple sequences now raises a TypeError. This type of indexing has been deprecated in Numpy since v1.16, and in JAX since v0.2.4. See #4564.

jax 0.2.6 (Nov 18 2020)#

  • GitHub commits.

  • New Features:

    • Add support for shape-polymorphic tracing for the jax.experimental.jax2tf converter. See README.md.

  • Breaking change cleanup

    • Raise an error on non-hashable static arguments for jax.jit and xla_computation. See cb48f42.

    • Improve consistency of type promotion behavior (#4744):

      • Adding a complex Python scalar to a JAX floating point number respects the precision of the JAX float. For example, jnp.float32(1) + 1j now returns complex64, where previously it returned complex128.

      • Results of type promotion with 3 or more terms involving uint64, a signed int, and a third type are now independent of the order of arguments. For example: jnp.result_type(jnp.uint64, jnp.int64, jnp.float16) and jnp.result_type(jnp.float16, jnp.uint64, jnp.int64) both return float16, where previously the first returned float64 and the second returned float16.

    • The contents of the (undocumented) jax.lax_linalg linear algebra module are now exposed publicly as jax.lax.linalg.

    • jax.random.PRNGKey now produces the same results in and out of JIT compilation (#4877). This required changing the result for a given seed in a few particular cases:

      • With jax_enable_x64=False, negative seeds passed as Python integers now return a different result outside JIT mode. For example, jax.random.PRNGKey(-1) previously returned [4294967295, 4294967295], and now returns [0, 4294967295]. This matches the behavior in JIT.

      • Seeds outside the range representable by int64 outside JIT now result in an OverflowError rather than a TypeError. This matches the behavior in JIT.

      To recover the keys returned previously for negative integers with jax_enable_x64=False outside JIT, you can use:

      key = random.PRNGKey(-1).at[0].set(0xFFFFFFFF)
      
    • DeviceArray now raises RuntimeError instead of ValueError when trying to access its value while it has been deleted.

jaxlib 0.1.58 (January 12ish 2021)#

  • Fixed a bug that meant JAX sometimes return platform-specific types (e.g., np.cint) instead of standard types (e.g., np.int32). (#4903)

  • Fixed a crash when constant-folding certain int16 operations. (#4971)

  • Added an is_leaf predicate to pytree.flatten().

jaxlib 0.1.57 (November 12 2020)#

  • Fixed manylinux2010 compliance issues in GPU wheels.

  • Switched the CPU FFT implementation from Eigen to PocketFFT.

  • Fixed a bug where the hash of bfloat16 values was not correctly initialized and could change (#4651).

  • Add support for retaining ownership when passing arrays to DLPack (#4636).

  • Fixed a bug for batched triangular solves with sizes greater than 128 but not a multiple of 128.

  • Fixed a bug when performing concurrent FFTs on multiple GPUs (#3518).

  • Fixed a bug in profiler where tools are missing (#4427).

  • Dropped support for CUDA 10.0.

jax 0.2.5 (October 27 2020)#

jax 0.2.4 (October 19 2020)#

  • GitHub commits.

  • Improvements:

    • Add support for remat to jax.experimental.host_callback. See #4608.

  • Deprecations

    • Indexing with non-tuple sequences is now deprecated, following a similar deprecation in Numpy. In a future release, this will result in a TypeError. See #4564.

jaxlib 0.1.56 (October 14, 2020)#

jax 0.2.3 (October 14 2020)#

  • GitHub commits.

  • The reason for another release so soon is we need to temporarily roll back a new jit fastpath while we look into a performance degradation

jax 0.2.2 (October 13 2020)#

jax 0.2.1 (October 6 2020)#

jax (0.2.0) (September 23 2020)#

jax (0.1.77) (September 15 2020)#

jaxlib 0.1.55 (September 8, 2020)#

  • Update XLA:

    • Fix bug in DLPackManagedTensorToBuffer (#4196)

jax 0.1.76 (September 8, 2020)#

jax 0.1.75 (July 30, 2020)#

  • GitHub commits.

  • Bug Fixes:

    • make jnp.abs() work for unsigned inputs (#3914)

  • Improvements:

    • “Omnistaging” behavior added behind a flag, disabled by default (#3370)

jax 0.1.74 (July 29, 2020)#

  • GitHub commits.

  • New Features:

    • BFGS (#3101)

    • TPU support for half-precision arithmetic (#3878)

  • Bug Fixes:

    • Prevent some accidental dtype warnings (#3874)

    • Fix a multi-threading bug in custom derivatives (#3845, #3869)

  • Improvements:

    • Faster searchsorted implementation (#3873)

    • Better test coverage for jax.numpy sorting algorithms (#3836)

jaxlib 0.1.52 (July 22, 2020)#

  • Update XLA.

jax 0.1.73 (July 22, 2020)#

  • GitHub commits.

  • The minimum jaxlib version is now 0.1.51.

  • New Features:

    • jax.image.resize. (#3703)

    • hfft and ihfft (#3664)

    • jax.numpy.intersect1d (#3726)

    • jax.numpy.lexsort (#3812)

    • lax.scan and the scan primitive support an unroll parameter for loop unrolling when lowering to XLA (#3738).

  • Bug Fixes:

    • Fix reduction repeated axis error (#3618)

    • Fix shape rule for lax.pad for input dimensions of size 0. (#3608)

    • make psum transpose handle zero cotangents (#3653)

    • Fix shape error when taking JVP of reduce-prod over size 0 axis. (#3729)

    • Support differentiation through jax.lax.all_to_all (#3733)

    • address nan issue in jax.scipy.special.zeta (#3777)

  • Improvements:

    • Many improvements to jax2tf

    • Reimplement argmin/argmax using a single pass variadic reduction. (#3611)

    • Enable XLA SPMD partitioning by default. (#3151)

    • Add support for 0d transpose convolution (#3643)

    • Make LU gradient work for low-rank matrices (#3610)

    • support multiple_results and custom JVPs in jet (#3657)

    • Generalize reduce-window padding to support (lo, hi) pairs. (#3728)

    • Implement complex convolutions on CPU and GPU. (#3735)

    • Make jnp.take work for empty slices of empty arrays. (#3751)

    • Relax dimension ordering rules for dot_general. (#3778)

    • Enable buffer donation for GPU. (#3800)

    • Add support for base dilation and window dilation to reduce window op… (#3803)

jaxlib 0.1.51 (July 2, 2020)#

  • Update XLA.

  • Add new runtime support for host_callback.

jax 0.1.72 (June 28, 2020)#

jax 0.1.71 (June 25, 2020)#

  • GitHub commits.

  • The minimum jaxlib version is now 0.1.48.

  • Bug fixes:

    • Allow jax.experimental.ode.odeint dynamics functions to close over values with respect to which we’re differentiating #3562.

jaxlib 0.1.50 (June 25, 2020)#

  • Add support for CUDA 11.0.

  • Drop support for CUDA 9.2 (we only maintain support for the last four CUDA versions.)

  • Update XLA.

jaxlib 0.1.49 (June 19, 2020)#

jaxlib 0.1.48 (June 12, 2020)#

  • New features:

    • Adds support for fast traceback collection.

    • Adds preliminary support for on-device heap profiling.

    • Implements np.nextafter for bfloat16 types.

    • Complex128 support for FFTs on CPU and GPU.

  • Bug fixes:

    • Improved float64 tanh accuracy on GPU.

    • float64 scatters on GPU are much faster.

    • Complex matrix multiplication on CPU should be much faster.

    • Stable sorts on CPU should actually be stable now.

    • Concurrency bug fix in CPU backend.

jax 0.1.70 (June 8, 2020)#

  • GitHub commits.

  • New features:

    • lax.switch introduces indexed conditionals with multiple branches, together with a generalization of the cond primitive #3318.

jax 0.1.69 (June 3, 2020)#

jax 0.1.68 (May 21, 2020)#

jax 0.1.67 (May 12, 2020)#

  • GitHub commits.

  • New features:

    • Support for reduction over subsets of a pmapped axis using axis_index_groups #2382.

    • Experimental support for printing and calling host-side Python function from compiled code. See id_print and id_tap (#3006).

  • Notable changes:

    • The visibility of names exported from jax.numpy has been tightened. This may break code that was making use of names that were previously exported accidentally.

jaxlib 0.1.47 (May 8, 2020)#

  • Fixes crash for outfeed.

jax 0.1.66 (May 5, 2020)#

jaxlib 0.1.46 (May 5, 2020)#

  • Fixes crash for linear algebra functions on Mac OS X (#432).

  • Fixes an illegal instruction crash caused by using AVX512 instructions when an operating system or hypervisor disabled them (#2906).

jax 0.1.65 (April 30, 2020)#

  • GitHub commits.

  • New features:

    • Differentiation of determinants of singular matrices #2809.

  • Bug fixes:

    • Fix odeint() differentiation with respect to time of ODEs with time-dependent dynamics #2817, also add ODE CI testing.

    • Fix lax_linalg.qr() differentiation #2867.

jaxlib 0.1.45 (April 21, 2020)#

  • Fixes segfault: #2755

  • Plumb is_stable option on Sort HLO through to Python.

jax 0.1.64 (April 21, 2020)#

jaxlib 0.1.44 (April 16, 2020)#

  • Fixes a bug where if multiple GPUs of different models were present, JAX would only compile programs suitable for the first GPU.

  • Bugfix for batch_group_count convolutions.

  • Added precompiled SASS for more GPU versions to avoid startup PTX compilation hang.

jax 0.1.63 (April 12, 2020)#

  • GitHub commits.

  • Added jax.custom_jvp and jax.custom_vjp from #2026, see the tutorial notebook. Deprecated jax.custom_transforms and removed it from the docs (though it still works).

  • Add scipy.sparse.linalg.cg #2566.

  • Changed how Tracers are printed to show more useful information for debugging #2591.

  • Made jax.numpy.isclose handle nan and inf correctly #2501.

  • Added several new rules for jax.experimental.jet #2537.

  • Fixed jax.experimental.stax.BatchNorm when scale/center isn’t provided.

  • Fix some missing cases of broadcasting in jax.numpy.einsum #2512.

  • Implement jax.numpy.cumsum and jax.numpy.cumprod in terms of a parallel prefix scan #2596 and make reduce_prod differentiable to arbitray order #2597.

  • Add batch_group_count to conv_general_dilated #2635.

  • Add docstring for test_util.check_grads #2656.

  • Add callback_transform #2665.

  • Implement rollaxis, convolve/correlate 1d & 2d, copysign, trunc, roots, and quantile/percentile interpolation options.

jaxlib 0.1.43 (March 31, 2020)#

  • Fixed a performance regression for Resnet-50 on GPU.

jax 0.1.62 (March 21, 2020)#

  • GitHub commits.

  • JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.

  • Removed the internal function lax._safe_mul, which implemented the convention 0. * nan == 0.. This change means some programs when differentiated will produce nans when they previously produced correct values, though it ensures nans rather than silently incorrect results are produced for other programs. See #2447 and #1052 for details.

  • Added an all_gather parallel convenience function.

  • More type annotations in core code.

jaxlib 0.1.42 (March 19, 2020)#

  • jaxlib 0.1.41 broke cloud TPU support due to an API incompatibility. This release fixes it again.

  • JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.

jax 0.1.61 (March 17, 2020)#

  • GitHub commits.

  • Fixes Python 3.5 support. This will be the last JAX or jaxlib release that supports Python 3.5.

jax 0.1.60 (March 17, 2020)#

  • GitHub commits.

  • New features:

    • jax.pmap() has static_broadcast_argnums argument which allows the user to specify arguments that should be treated as compile-time constants and should be broadcasted to all devices. It works analogously to static_argnums in jax.jit().

    • Improved error messages for when tracers are mistakenly saved in global state.

    • Added jax.nn.one_hot() utility function.

    • Added jax.experimental.jet for exponentially faster higher-order automatic differentiation.

    • Added more correctness checking to arguments of jax.lax.broadcast_in_dim().

  • The minimum jaxlib version is now 0.1.41.

jaxlib 0.1.40 (March 4, 2020)#

  • Adds experimental support in Jaxlib for TensorFlow profiler, which allows tracing of CPU and GPU computations from TensorBoard.

  • Includes prototype support for multihost GPU computations that communicate via NCCL.

  • Improves performance of NCCL collectives on GPU.

  • Adds TopK, CustomCallWithoutLayout, CustomCallWithLayout, IGammaGradA and RandomGamma implementations.

  • Supports device assignments known at XLA compilation time.

jax 0.1.59 (February 11, 2020)#

  • GitHub commits.

  • Breaking changes

    • The minimum jaxlib version is now 0.1.38.

    • Simplified Jaxpr by removing the Jaxpr.freevars and Jaxpr.bound_subjaxprs. The call primitives (xla_call, xla_pmap, sharded_call, and remat_call) get a new parameter call_jaxpr with a fully-closed (no constvars) jaxpr. Also, added a new field call_primitive to primitives.

  • New features:

    • Reverse-mode automatic differentiation (e.g. grad) of lax.cond, making it now differentiable in both modes (#2091)

    • JAX now supports DLPack, which allows sharing CPU and GPU arrays in a zero-copy way with other libraries, such as PyTorch.

    • JAX GPU DeviceArrays now support __cuda_array_interface__, which is another zero-copy protocol for sharing GPU arrays with other libraries such as CuPy and Numba.

    • JAX CPU device buffers now implement the Python buffer protocol, which allows zero-copy buffer sharing between JAX and NumPy.

    • Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known as slow.

jaxlib 0.1.39 (February 11, 2020)#

  • Updates XLA.

jaxlib 0.1.38 (January 29, 2020)#

  • CUDA 9.0 is no longer supported.

  • CUDA 10.2 wheels are now built by default.

jax 0.1.58 (January 28, 2020)#

Notable bug fixes#

  • With the Python 3 upgrade, JAX no longer depends on fastcache, which should help with installation.

JAX Glossary of Terms#

Array#

JAX’s analog of numpy.ndarray. See jax.Array.

CPU#

Short for Central Processing Unit, CPUs are the standard computational architecture available in most computers. JAX can run computations on CPUs, but often can achieve much better performance on GPU and TPU.

Device#

The generic name used to refer to the CPU, GPU, or TPU used by JAX to perform computations.

forward-mode autodiff#

See JVP

functional programming#

A programming paradigm in which programs are defined by applying and composing pure functions. JAX is designed for use with functional programs.

GPU#

Short for Graphical Processing Unit, GPUs were originally specialized for operations related to rendering of images on screen, but now are much more general-purpose. JAX is able to target GPUs for fast operations on arrays (see also CPU and TPU).

jaxpr#

Short for JAX Expression, a jaxpr is an intermediate representation of a computation that is generated by JAX, and is forwarded to XLA for compilation and execution. See Understanding Jaxprs for more discussion and examples.

JIT#

Short for Just In Time compilation, JIT in JAX generally refers to the compilation of array operations to XLA, most often accomplished using jax.jit().

JVP#

Short for Jacobian Vector Product, also sometimes known as forward-mode automatic differentiation. For more details, see Jacobian-Vector products (JVPs, aka forward-mode autodiff). In JAX, JVP is a transformation that is implemented via jax.jvp(). See also VJP.

primitive#

A primitive is a fundamental unit of computation used in JAX programs. Most functions in jax.lax represent individual primitives. When representing a computation in a jaxpr, each operation in the jaxpr is a primitive.

pure function#

A pure function is a function whose outputs are based only on its inputs, and which has no side-effects. JAX’s transformation model is designed to work with pure functions. See also functional programming.

pytree#

A pytree is an abstraction that lets JAX handle tuples, lists, dicts, and other more general containers of array values in a uniform way. Refer to Working with pytrees for a more detailed discussion.

reverse-mode autodiff#

See VJP.

SPMD#

Short for Single Program Multi Data, it refers to a parallel computation technique in which the same computation (e.g., the forward pass of a neural net) is run on different input data (e.g., different inputs in a batch) in parallel on different devices (e.g., several TPUs). jax.pmap() is a JAX transformation that implements SPMD parallelism.

static#

In a JIT compilation, a value that is not traced (see Tracer). Also sometimes refers to compile-time computations on static values.

TPU#

Short for Tensor Processing Unit, TPUs are chips specifically engineered for fast operations on N-dimensional tensors used in deep learning applications. JAX is able to target TPUs for fast operations on arrays (see also CPU and GPU).

Tracer#

An object used as a standin for a JAX Array in order to determine the sequence of operations performed by a Python function. Internally, JAX implements this via the jax.core.Tracer class.

transformation#

A higher-order function: that is, a function that takes a function as input and outputs a transformed function. Examples in JAX include jax.jit(), jax.vmap(), and jax.grad().

VJP#

Short for Vector Jacobian Product, also sometimes known as reverse-mode automatic differentiation. For more details, see Vector-Jacobian products (VJPs, aka reverse-mode autodiff). In JAX, VJP is a transformation that is implemented via jax.vjp(). See also JVP.

XLA#

Short for Accelerated Linear Algebra, XLA is a domain-specific compiler for linear algebra operations that is the primary backend for JIT-compiled JAX code. See https://www.tensorflow.org/xla/.

weak type#

A JAX data type that has the same type promotion semantics as Python scalars; see Weakly-typed values in JAX.