External Callbacks in JAX#

This guide is a work-in-progress outlining the uses of various callback functions, which allow JAX code to execute certain commands on the host, even while running under jit, vmap, grad, or another transformation.

This is a work-in-progress, and will be updated soon.

TODO(jakevdp, sharadmv): fill-in some simple examples of jax.pure_callback(), jax.debug.callback(), jax.debug.print(), and others.

Example: pure_callback with custom_jvp#

One powerful way to take advantage of jax.pure_callback() is to combine it with jax.custom_jvp (see Custom derivative rules for more details on custom_jvp). Suppose we want to create a JAX-compatible wrapper for a scipy or numpy function that is not yet available in the jax.scipy or jax.numpy wrappers.

Here, we’ll consider creating a wrapper for the Bessel function of the first kind, implemented in scipy.special.jv. We can start by defining a straightforward pure_callback:

import jax
import jax.numpy as jnp
import scipy.special

def jv(v, z):
  v, z = jnp.asarray(v), jnp.asarray(z)

  # Require the order v to be integer type: this simplifies
  # the JVP rule below.
  assert jnp.issubdtype(v.dtype, jnp.integer)

  # Promote the input to inexact (float/complex).
  # Note that jnp.result_type() accounts for the enable_x64 flag.
  z = z.astype(jnp.result_type(float, z.dtype))

  # Wrap scipy function to return the expected dtype.
  _scipy_jv = lambda v, z: scipy.special.jv(v, z).astype(z.dtype)

  # Define the expected shape & dtype of output.
  result_shape_dtype = jax.ShapeDtypeStruct(
      shape=jnp.broadcast_shapes(v.shape, z.shape),
      dtype=z.dtype)

  # We use vectorize=True because scipy.special.jv handles broadcasted inputs.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)

This lets us call into scipy.special.jv from transformed JAX code, including when transformed by jit and vmap:

from functools import partial
j1 = partial(jv, 1)
z = jnp.arange(5.0)
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
print(j1(z))
[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]

Here is the same result with jit:

print(jax.jit(j1)(z))
[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]

And here is the same result again with vmap:

print(jax.vmap(j1)(z))
[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]

However, if we call jax.grad, we see an error because there is no autodiff rule defined for this function:

jax.grad(j1)(z)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[6], line 1
----> 1 jax.grad(j1)(z)

    [... skipping hidden 10 frame]

Cell In[1], line 25, in jv(v, z)
     20 result_shape_dtype = jax.ShapeDtypeStruct(
     21     shape=jnp.broadcast_shapes(v.shape, z.shape),
     22     dtype=z.dtype)
     24 # We use vectorize=True because scipy.special.jv handles broadcasted inputs.
---> 25 return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)

    [... skipping hidden 1 frame]

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.9/site-packages/jax/_src/callback.py:149, in pure_callback(callback, result_shape_dtypes, vectorized, *args, **kwargs)
    146 result_avals = tree_util.tree_map(
    147     lambda x: core.ShapedArray(x.shape, x.dtype), result_shape_dtypes)
    148 flat_result_avals, out_tree = tree_util.tree_flatten(result_avals)
--> 149 out_flat = pure_callback_p.bind(
    150     *flat_args, callback=_flat_callback,
    151     result_avals=tuple(flat_result_avals), vectorized=vectorized)
    152 return tree_util.tree_unflatten(out_tree, out_flat)

    [... skipping hidden 3 frame]

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.9/site-packages/jax/_src/callback.py:56, in pure_callback_jvp_rule(***failed resolving arguments***)
     54 def pure_callback_jvp_rule(*args, **kwargs):
     55   del args, kwargs
---> 56   raise ValueError(
     57       "Pure callbacks do not support JVP. "
     58       "Please use `jax.custom_jvp` to use callbacks while taking gradients.")

ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.

Let’s define a custom gradient rule for this. Looking at the definition of the Bessel Function of the First Kind, we find that there is a relatively straightforward recurrence relationship for the derivative with respect to the argument z:

\[\begin{split} d J_\nu(z) = \left\{ \begin{eqnarray} -J_1(z),\ &\nu=0\\ [J_{\nu - 1}(z) - J_{\nu + 1}(z)]/2,\ &\nu\ne 0 \end{eqnarray}\right. \end{split}\]

The gradient with respect to \(\nu\) is more complicated, but since we’ve restricted the v argument to integer types we don’t need to worry about its gradient for the sake of this example.

We can use jax.custom_jvp to define this automatic differentiation rule for our callback function:

jv = jax.custom_jvp(jv)

@jv.defjvp
def _jv_jvp(primals, tangents):
  v, z = primals
  _, z_dot = tangents  # Note: v_dot is always 0 because v is integer.
  jv_minus_1, jv_plus_1 = jv(v - 1, z), jv(v + 1, z)
  djv_dz = jnp.where(v == 0, -jv_plus_1, 0.5 * (jv_minus_1 - jv_plus_1))
  return jv(v, z), z_dot * djv_dz

Now computing the gradient of our function will work correctly:

j1 = partial(jv, 1)
print(jax.grad(j1)(2.0))
-0.06447162

Further, since we’ve defined our gradient in terms of jv itself, JAX’s architecture means that we get second-order and higher derivatives for free:

jax.hessian(j1)(2.0)
Array(-0.4003078, dtype=float32, weak_type=True)

Keep in mind that although this all works correctly with JAX, each call to our callback-based jv function will result in passing the input data from the device to the host, and passing the output of scipy.special.jv from the host back to the device. When running on accelerators like GPU or TPU, this data movement and host synchronization can lead to significant overhead each time jv is called. However, if you are running JAX on a single CPU (where the “host” and “device” are on the same hardware), JAX will generally do this data transfer in a fast, zero-copy fashion, making this pattern is a relatively straightforward way extend JAX’s capabilities.