External callbacks#

This tutorial outlines how you can use various callback functions, which allow JAX runtimes to execute Python code on the host. Examples of JAX callbacks are jax.pure_callback(), jax.experimental.io_callback() and jax.debug.callback(). You can use them even while running under JAX transformations, including jit(), vmap(), grad().

Why callbacks?#

A callback routine is a way to perform host-side execution of code at runtime. As a simple example, suppose you’d like to print the value of some variable during the course of a computation. Using a simple Python print() statement, it looks like this:

import jax

@jax.jit
def f(x):
  y = x + 1
  print("intermediate value: {}".format(y))
  return y * 2

result = f(2)
intermediate value: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>

What is printed is not the runtime value, but the trace-time abstract value (if you’re not familiar with tracing in JAX, a good primer can be found in Tracing.

To print the value at runtime, you need a callback, for example jax.debug.print() (you can learn more about debugging in Introduction to debugging):

@jax.jit
def f(x):
  y = x + 1
  jax.debug.print("intermediate value: {}", y)
  return y * 2

result = f(2)
intermediate value: 3

This works by passing the runtime value of y as a CPU jax.Array back to the host process, where the host can print it.

Flavors of callback#

In earlier versions of JAX, there was only one kind of callback available, implemented in jax.experimental.host_callback(). The host_callback routines had some deficiencies, and are now deprecated in favor of several callbacks designed for different situations:

(The jax.debug.print() function you used previously is a wrapper around jax.debug.callback()).

From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow.

callback function

supports return value

jit

vmap

grad

scan/while_loop

guaranteed execution

jax.pure_callback()

❌¹

jax.experimental.io_callback()

✅/❌²

✅³

jax.debug.callback()

¹ jax.pure_callback can be used with custom_jvp to make it compatible with autodiff

² jax.experimental.io_callback is compatible with vmap only if ordered=False.

³ Note that vmap of scan/while_loop of io_callback has complicated semantics, and its behavior may change in future releases.

Exploring pure_callback#

jax.pure_callback() is generally the callback function you should reach for when you want host-side execution of a pure function: i.e. a function that has no side-effects (such as printing values, reading data from disk, updating a global state, etc.).

The function you pass to jax.pure_callback() need not actually be pure, but it will be assumed pure by JAX’s transformations and higher-order functions, which means that it may be silently elided or called multiple times.

import jax
import jax.numpy as jnp
import numpy as np

def f_host(x):
  # call a numpy (not jax.numpy) operation:
  return np.sin(x).astype(x.dtype)

def f(x):
  result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
  return jax.pure_callback(f_host, result_shape, x)

x = jnp.arange(5.0)
f(x)
Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)

Because pure_callback can be elided or duplicated, it is compatible out-of-the-box with transformations like jit and vmap, as well as higher-order primitives like scan and while_loop:”

jax.jit(f)(x)
Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)
jax.vmap(f)(x)
/tmp/ipykernel_995/3691550925.py:11: DeprecationWarning: The default behavior of pure_callback under vmap will soon change. Currently, the default behavior is to generate a sequential vmap (i.e. a loop), but in the future the default will be to raise an error. To keep the current default, set vmap_method='sequential'.
  return jax.pure_callback(f_host, result_shape, x)
Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)
def body_fun(_, x):
  return _, f(x)
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]
Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)

However, because there is no way for JAX to introspect the content of the callback, pure_callback has undefined autodiff semantics:

jax.grad(f)(x)
ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.

For an example of using pure_callback with jax.custom_jvp(), see Example: pure_callback with custom_jvp below.

By design functions passed to pure_callback are treated as if they have no side-effects: one consequence of this is that if the output of the function is not used, the compiler may eliminate the callback entirely:

def print_something():
  print('printing something')
  return np.int32(0)

@jax.jit
def f1():
  return jax.pure_callback(print_something, np.int32(0))
f1();
printing something
@jax.jit
def f2():
  jax.pure_callback(print_something, np.int32(0))
  return 1.0
f2();

In f1, the output of the callback is used in the return value of the function, so the callback is executed and we see the printed output. In f2 on the other hand, the output of the callback is unused, and so the compiler notices this and eliminates the function call. These are the correct semantics for a callback to a function with no side-effects.

Exploring io_callback#

In contrast to jax.pure_callback(), jax.experimental.io_callback() is explicitly meant to be used with impure functions, i.e. functions that do have side-effects.

As an example, here is a callback to a global host-side numpy random generator. This is an impure operation because a side-effect of generating a random number in numpy is that the random state is updated (Please note that this is meant as a toy example of io_callback and not necessarily a recommended way of generating random numbers in JAX!).

from jax.experimental import io_callback
from functools import partial

global_rng = np.random.default_rng(0)

def host_side_random_like(x):
  """Generate a random array like x using the global_rng state"""
  # We have two side-effects here:
  # - printing the shape and dtype
  # - calling global_rng, thus updating its state
  print(f'generating {x.dtype}{list(x.shape)}')
  return global_rng.uniform(size=x.shape).astype(x.dtype)

@jax.jit
def numpy_random_like(x):
  return io_callback(host_side_random_like, x, x)

x = jnp.zeros(5)
numpy_random_like(x)
generating float32[5]
Array([0.6369617 , 0.26978672, 0.04097353, 0.01652764, 0.8132702 ],      dtype=float32)

The io_callback is compatible with vmap by default:

jax.vmap(numpy_random_like)(x)
generating float32[]
generating float32[]
generating float32[]
generating float32[]
generating float32[]
Array([0.91275555, 0.60663575, 0.72949654, 0.543625  , 0.9350724 ],      dtype=float32)

Note, however, that this may execute the mapped callbacks in any order. So, for example, if you ran this on a GPU, the order of the mapped outputs might differ from run to run.

If it is important that the order of callbacks be preserved, you can set ordered=True, in which case attempting to vmap will raise an error:

@jax.jit
def numpy_random_like_ordered(x):
  return io_callback(host_side_random_like, x, x, ordered=True)

jax.vmap(numpy_random_like_ordered)(x)
ValueError: Cannot `vmap` ordered IO callback.

On the other hand, scan and while_loop work with io_callback regardless of whether ordering is enforced:

def body_fun(_, x):
  return _, numpy_random_like_ordered(x)
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]
generating float32[]
generating float32[]
generating float32[]
generating float32[]
generating float32[]
Array([0.81585354, 0.0027385 , 0.8574043 , 0.03358557, 0.72965544],      dtype=float32)

Like pure_callback, io_callback fails under automatic differentiation if it is passed a differentiated variable:

jax.grad(numpy_random_like)(x)
ValueError: IO callbacks do not support JVP.

However, if the callback is not dependent on a differentiated variable, it will execute:

@jax.jit
def f(x):
  io_callback(lambda: print('hello'), None)
  return x

jax.grad(f)(1.0);
hello

Unlike pure_callback, the compiler will not remove the callback execution in this case, even though the output of the callback is unused in the subsequent computation.

Exploring debug.callback#

Both pure_callback and io_callback enforce some assumptions about the purity of the function they’re calling, and limit in various ways what JAX transforms and compilation machinery may do. debug.callback essentially assumes nothing about the callback function, such that the action of the callback reflects exactly what JAX is doing during the course of a program. Further, debug.callback cannot return any value to the program.

from jax import debug

def log_value(x):
  # This could be an actual logging call; we'll use
  # print() for demonstration
  print("log:", x)

@jax.jit
def f(x):
  debug.callback(log_value, x)
  return x

f(1.0);
log: 1.0

The debug callback is compatible with vmap:

x = jnp.arange(5.0)
jax.vmap(f)(x);
log: 0.0
log: 1.0
log: 2.0
log: 3.0
log: 4.0

And is also compatible with grad and other autodiff transformations

jax.grad(f)(1.0);
log: 1.0

This can make debug.callback more useful for general-purpose debugging than either pure_callback or io_callback.

Example: pure_callback with custom_jvp#

One powerful way to take advantage of jax.pure_callback() is to combine it with jax.custom_jvp. (Refer to Custom derivative rules for JAX-transformable Python functions for more details on jax.custom_jvp()).

Suppose you want to create a JAX-compatible wrapper for a scipy or numpy function that is not yet available in the jax.scipy or jax.numpy wrappers.

Here, we’ll consider creating a wrapper for the Bessel function of the first kind, available in scipy.special.jv. You can start by defining a straightforward pure_callback():

import jax
import jax.numpy as jnp
import scipy.special

def jv(v, z):
  v, z = jnp.asarray(v), jnp.asarray(z)

  # Require the order v to be integer type: this simplifies
  # the JVP rule below.
  assert jnp.issubdtype(v.dtype, jnp.integer)

  # Promote the input to inexact (float/complex).
  # Note that jnp.result_type() accounts for the enable_x64 flag.
  z = z.astype(jnp.result_type(float, z.dtype))

  # Wrap scipy function to return the expected dtype.
  _scipy_jv = lambda v, z: scipy.special.jv(v, z).astype(z.dtype)

  # Define the expected shape & dtype of output.
  result_shape_dtype = jax.ShapeDtypeStruct(
      shape=jnp.broadcast_shapes(v.shape, z.shape),
      dtype=z.dtype)

  # You use vectorize=True because scipy.special.jv handles broadcasted inputs.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)

This lets us call into scipy.special.jv() from transformed JAX code, including when transformed by jit() and vmap():

from functools import partial
j1 = partial(jv, 1)
z = jnp.arange(5.0)
print(j1(z))
[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]
/tmp/ipykernel_995/2939642295.py:25: DeprecationWarning: The vectorized argument of jax.pure_callback is deprecated and setting it will soon raise an error. To avoid an error in the future, and to suppress this warning, please use the vmap_method argument instead.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)

Here is the same result with jit():

print(jax.jit(j1)(z))
[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]
/tmp/ipykernel_995/2939642295.py:25: DeprecationWarning: The vectorized argument of jax.pure_callback is deprecated and setting it will soon raise an error. To avoid an error in the future, and to suppress this warning, please use the vmap_method argument instead.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)

And here is the same result again with vmap():

print(jax.vmap(j1)(z))
[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]
/tmp/ipykernel_995/2939642295.py:25: DeprecationWarning: The vectorized argument of jax.pure_callback is deprecated and setting it will soon raise an error. To avoid an error in the future, and to suppress this warning, please use the vmap_method argument instead.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)

However, if you call grad(), you will get an error because there is no autodiff rule defined for this function:

jax.grad(j1)(z)
/tmp/ipykernel_995/2939642295.py:25: DeprecationWarning: The vectorized argument of jax.pure_callback is deprecated and setting it will soon raise an error. To avoid an error in the future, and to suppress this warning, please use the vmap_method argument instead.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.

Let’s define a custom gradient rule for this. Looking at the definition of the Bessel Function of the First Kind, you find that there is a relatively straightforward recurrence relationship for the derivative with respect to the argument z:

\[\begin{split} d J_\nu(z) = \left\{ \begin{eqnarray} -J_1(z),\ &\nu=0\\ [J_{\nu - 1}(z) - J_{\nu + 1}(z)]/2,\ &\nu\ne 0 \end{eqnarray}\right. \end{split}\]

The gradient with respect to \(\nu\) is more complicated, but since we’ve restricted the v argument to integer types you don’t need to worry about its gradient for the sake of this example.

You can use jax.custom_jvp() to define this automatic differentiation rule for your callback function:

jv = jax.custom_jvp(jv)

@jv.defjvp
def _jv_jvp(primals, tangents):
  v, z = primals
  _, z_dot = tangents  # Note: v_dot is always 0 because v is integer.
  jv_minus_1, jv_plus_1 = jv(v - 1, z), jv(v + 1, z)
  djv_dz = jnp.where(v == 0, -jv_plus_1, 0.5 * (jv_minus_1 - jv_plus_1))
  return jv(v, z), z_dot * djv_dz

Now computing the gradient of your function will work correctly:

j1 = partial(jv, 1)
print(jax.grad(j1)(2.0))
-0.06447162
/tmp/ipykernel_995/2939642295.py:25: DeprecationWarning: The vectorized argument of jax.pure_callback is deprecated and setting it will soon raise an error. To avoid an error in the future, and to suppress this warning, please use the vmap_method argument instead.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
/tmp/ipykernel_995/2939642295.py:25: DeprecationWarning: The vectorized argument of jax.pure_callback is deprecated and setting it will soon raise an error. To avoid an error in the future, and to suppress this warning, please use the vmap_method argument instead.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
/tmp/ipykernel_995/2939642295.py:25: DeprecationWarning: The vectorized argument of jax.pure_callback is deprecated and setting it will soon raise an error. To avoid an error in the future, and to suppress this warning, please use the vmap_method argument instead.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)

Further, since we’ve defined your gradient in terms of jv itself, JAX’s architecture means that you get second-order and higher derivatives for free:

jax.hessian(j1)(2.0)
/tmp/ipykernel_995/2939642295.py:25: DeprecationWarning: The vectorized argument of jax.pure_callback is deprecated and setting it will soon raise an error. To avoid an error in the future, and to suppress this warning, please use the vmap_method argument instead.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
/tmp/ipykernel_995/2939642295.py:25: DeprecationWarning: The vectorized argument of jax.pure_callback is deprecated and setting it will soon raise an error. To avoid an error in the future, and to suppress this warning, please use the vmap_method argument instead.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
/tmp/ipykernel_995/2939642295.py:25: DeprecationWarning: The vectorized argument of jax.pure_callback is deprecated and setting it will soon raise an error. To avoid an error in the future, and to suppress this warning, please use the vmap_method argument instead.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
/tmp/ipykernel_995/2939642295.py:25: DeprecationWarning: The vectorized argument of jax.pure_callback is deprecated and setting it will soon raise an error. To avoid an error in the future, and to suppress this warning, please use the vmap_method argument instead.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
/tmp/ipykernel_995/2939642295.py:25: DeprecationWarning: The vectorized argument of jax.pure_callback is deprecated and setting it will soon raise an error. To avoid an error in the future, and to suppress this warning, please use the vmap_method argument instead.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
/tmp/ipykernel_995/2939642295.py:25: DeprecationWarning: The vectorized argument of jax.pure_callback is deprecated and setting it will soon raise an error. To avoid an error in the future, and to suppress this warning, please use the vmap_method argument instead.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
/tmp/ipykernel_995/2939642295.py:25: DeprecationWarning: The vectorized argument of jax.pure_callback is deprecated and setting it will soon raise an error. To avoid an error in the future, and to suppress this warning, please use the vmap_method argument instead.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
/tmp/ipykernel_995/2939642295.py:25: DeprecationWarning: The vectorized argument of jax.pure_callback is deprecated and setting it will soon raise an error. To avoid an error in the future, and to suppress this warning, please use the vmap_method argument instead.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
/tmp/ipykernel_995/2939642295.py:25: DeprecationWarning: The vectorized argument of jax.pure_callback is deprecated and setting it will soon raise an error. To avoid an error in the future, and to suppress this warning, please use the vmap_method argument instead.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
Array(-0.4003078, dtype=float32, weak_type=True)

Keep in mind that although this all works correctly with JAX, each call to your callback-based jv function will result in passing the input data from the device to the host, and passing the output of scipy.special.jv() from the host back to the device.

When running on accelerators like GPU or TPU, this data movement and host synchronization can lead to significant overhead each time jv is called.

However, if you are running JAX on a single CPU (where the “host” and “device” are on the same hardware), JAX will generally do this data transfer in a fast, zero-copy fashion, making this pattern a relatively straightforward way to extend JAX’s capabilities.