jax.experimental.host_callback module

Primitives for calling from JAX accelerator code to Python functions on the host.

Experimental: please give feedback, and expect changes.

This module introduces the host callback functions call(), id_tap(), and id_print(), that send their arguments from the device to the host and invoke user-defined Python functions on the host, optionally returning results back to the device computation.

We show below how these functions can be used. We start with call(), and we discuss examples of calling from JAX to NumPy CPU custom kernels, or to TensorFlow functions, or to JAX running on another device. In the latter two cases we show how we can support JAX autodiff for the host callbacks, by deferring to the reverse-mode AD on the target platform. Then we show uses of id_tap() and id_print(), which have the restriction that they cannot return values from the host to the device. These primitives are generally faster because they are executed asynchronously with the device code and they also support the whole spectrum of JAX transformations. In particular, they can be used to tap into and to debug JAX-transformed code.

Using call() to call a host function and return results to device

Use call() to invoke a computation on the host and return NumPy arrays to the device computation. Host computation is useful, e.g., when a device computation needs some data that requires I/O on the host, or it needs a library that is available on the host and you do not want to code it in JAX. For example, eigen decomposition for general matrices in JAX does not work on TPU. We can call the Numpy implementation from any JAX accelerator computation, using a host computation:

# This function runs on the host
def host_eig(m: np.ndarray) -> np.ndarray:
  return np.linalg.eigvals(m)

# This function is used in JAX
def device_fun(m):
  # We send "m" to the host, asking it to call "host_eig" and return the result.
  # We have to specify the result shape and dtype, either in the form of an
  # example return value or any object that has `shape` and `dtype` attributes,
  # e.g., a NumPy array or a `jax.ShapeDtypeStruct`.
  return hcb.call(host_eig, m,
                  # Given an input of shape (..., d, d), eig output has shape (..., d)
                  result_shape=jax.ShapeDtypeStruct(m.shape[:-1], m.dtype))

The call() function and the Python host function both take a single argument and return a single result, but those can be pytrees. Note that we must tell the call() what shape and dtype to expect from the host invocation, using the result_shape kwarg. This is important because the device code is compiled with that expectation. There will be an error raised at runtime if the actual invocation produces a different result shape. In general, such errors and also exceptions raised by the host computation may be difficult to debug. See the Debugging section below. This is a problem for call() but not for id_tap().

The call() API can be used inside a jit or pmap computation or inside cond/scan/while control flow. When used inside jax.pmap(), there will be separate calls to the host from each of the participating devices:

def host_sin(x, *, device):
  print(f"Invoking host_sin with {x.shape} on {device}")
  return np.sin(x)

# Use pmap to run the computation on two devices
jax.pmap(lambda x: hcb.call(host_sin, x,
                            result_shape=x,
                            # Ask that the `host_sin` function be passed `device=dev`
                            call_with_device=True))(
         np.ones((2, 4), dtype=np.float32))

# prints (in arbitrary order)
# Invoking host_sin with (4,) on cpu:0
# Invoking host_sin with (4,) on cpu:1

Note that call() does not (yet) support any JAX transformations, but as we show in the next section one can make use of the existing support for Custom differentiation in JAX.

Using call() to call a TensorFlow function, with reverse-mode autodiff support

Another possible use for host computation is to invoke a library written for another framework, such as TensorFlow. In this case it becomes interesting to support JAX autodiff for host callbacks by defering to the autodiff mechanism in TensorFlow, using the jax.custom_vjp() mechanism.

This is relatively easy to do, once one understands both the JAX custom VJP and the TensorFlow autodiff mechanisms. The code for how this can be done is shown in the call_tf_full_ad function in host_callback_to_tf_test.py. This example supports arbitrary higher-order differentiation as well.

Using call() to call a JAX function on another device, with reverse-mode autodiff support

It should not be surprising that we can use host computation to invoke a JAX computation on another device. The arguments are sent from the accelerator to the host, and then to the outside device on which the JAX host computation will run, and then the results are sent back to the original accelerator.

The code for how this can be done is shown in the call_jax_other_device function in host_callback_test.py.

Using id_tap() to call a JAX function on another device, with no returned values, but full JAX transformation support

The id_tap() and id_print() behave like the identity function but have the side-effect of sending the arguments from the device to the host and invoking a user-specified Python function (for id_tap()) or printing the arguments on the host (for id_print()). The Python function passed to id_tap() takes two positional arguments (the value tapped from the device computation along with transforms sequence, described below). Optionally, the function may be passed a keyword argument device with the Device from which the value was tapped.

A few examples:

# calls func(2x, []) on host and returns 2x
y = id_tap(func, 2 * x)
# calls func((2x, 3x), []) and returns (2x, 3x)
y, z = id_tap(func, (2 * x, 3 * x))  # The argument can be a pytree
# calls func(2x, []) and returns y
y = id_tap(func, 2 * x, result=y)  # override the result of id_tap
# calls func(2x, [], device=jax.devices()[0])
y = id_tap(func, 2 * x, tap_with_device=True)  # Pass the device to the tap
# calls func(2x, [], what='activation') and returns 2x
y = id_tap(functools.partial(func, what='activation'), 2 * x)
# calls func(dict(x=x, y=y), what='data') and returns dict(x=x, y=y)
x, y = id_tap(lambda tap, transforms: func(tap, what='data'), dict(x=x, y=y))

The above examples can all be adapted to use id_print() instead, with the difference that id_print() takes one positional argument (to print on the host), the optional kwarg result, and possibly additional kwargs that are also printed along with the automatic kwarg transforms.

The order of execution of the callback functions is constrained by data dependency: the arguments are tapped after all the arguments are computed and before the result of the call is used. As of September 2020, it is not strictly necessary anymore for the results of the tap to be used in the rest of the computation. You can just do:

id_tap(func, x)

The tap function will execute based on program order. However, if this code is subject to transformations, it is possible for the tap to appear to the transformation as dead code and to be removed from the computation. In that case it is best to use the result of the callback.

Behavior under JAX transformations

We describe the behaviour under transformations for id_tap() and id_print() in the context of the following function definition:

def power3(x):
   y = x * x
   _, y = id_print((x, y), what="x,x^2")  # Must pack multiple arguments
   return y * x

power3(3.)
# what: x,x^2 : [3., 9.]

During JAX transformations the special parameter transforms is added to contain a list of transformation descriptors in the form (transform_name, transform_params).

For jax.vmap() the arguments are batched, and transforms is extended with transformation name batch and batch_dims set to the the tuple of batched dimensions (one entry per argument, None denotes an argument that was broadcast):

jax.vmap(power3)(np.arange(3.))
# transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 : [[0, 1, 2], [0, 1,
4]]

For jax.jvp() there will be two callbacks, one with the values of the primals and one with the tangents:

jax.jvp(power3, (3.,), (0.1,))
# what: x,x^2: [3., 9.]
# transforms: ['jvp'] what: x,x^2 : [0.1, 0.6]

For jax.vjp() or jax.grad() there will be one callback with the values of the adjoints for the arguments. You may also see a callback with the values of the primals from the forward pass, if those values are needed for the backward pass:

jax.grad(power3)(3.)
# what=x,x^2: [3., 9.]  # from forward pass, since y is used in backward pass
# transforms: ['jvp', 'transpose'] what: x,x^2 : [0., 3.]  # from backward pass, adjoints of _, y

In presence of jax.pmap() the code will run on multiple devices and each device will tap its values independently. It may be helpful to use the tap_with_device option for id_print() or id_tap(), so that you see which device is sending which data:

jax.pmap(power3, devices=jax.devices()[0:2])(np.array([3., 4.])
# device=cpu:0 what=x,x^2: [3., 9.]  # from the first device
# device=cpu:1 what=x,x^2: [4., 16.]  # from the second device

See documentation for id_tap() and id_print(). For more usage example, see tests/host_callback_test.py.

Low-level details and debugging

The host callback functions will be executed for each device in the order in which the send operations were performed on the device.

The host callback functions for multiple devices may be interleaved. The data from the devices is received by separate threads managed by the JAX runtime (one thread per device). The runtime maintains a buffer of configurable size. When the buffer is full, all the receiving threads are paused which eventually pauses the computation on devices. The runtime has one additional thread that invokes the Python user functions with the received data. If the processing of the callbacks is slow, it may actually lead to the runtime buffer filling up, and eventually pausing the computation on the devices when they need to send something. For more details on the outfeed receiver runtime mechanism see runtime code.

In order to pause the execution until all data from computations already started on devices has arrived and has been processed, use barrier_wait(). Note that this is needed only for id_tap() and id_print(), which are processed asyncronously with the device computation.

Exceptions from the user-defined callback functions are logged along with their stack traces, but the receiving threads are not stopped. Instead the last exception is recorded and the subsequent barrier_wait() will raise CallbackException if any exception had occurred in one of the tap functions. This exception will include the text and the stack trace of the last exception encountered.

One further complication arises for callback functions that must return results to the call origin device. In order to avoid the device computation being stuck waiting for a result that will never arrive, in case of any error during the processing of the callback (whether raised by the user-code itself or due to a mismatch of the returned value and the expected return_shape) we send the device a “fake” result of shape int8[12345]. This will make the device computation abort because the received data is different than then one that it expects. On CPU the runtime will crash with a distinctive error message:

` Check failed: buffer->length() == buffer_length (12345 vs. ...) `

On GPU, the failure is more user-friendly and will be surfaced to the Python program as:

` RET_CHECK failure ... Mismatch between infeed source buffer shape s8[12345] ... `

On TPU, there is currently no shape check for infeed, so we take the safer route to not send anything in case of errors, and let the computation hang.

The current implementation uses the outfeed mechanism provided by XLA. The mechanism itself is quite primitive in the sense that a receiver must know exactly the shape of each incoming packet, and how many packets are expected. This makes it hard to use for multiple kinds of data in the same computation, and it is practically impossible to use it under conditionals or in loops of non-constant iteration count. Furthermore, code that uses the outfeed mechanism directly cannot be transformed by JAX. All these limitations are addressed by the host callback functions. The tapping API introduced here makes it easy to share the outfeed mechanism for multiple purposes, while supporting all transformations.

Note that after you have used the host callback functions, you cannot use lax.outfeed directly. You may want to stop_outfeed_receiver() if you later need to use lax.outfeed.

Since the actual calls to your callback functions are made from the C++ receiver, it may be hard to debug the calls. In particular, the stack trace will not include the calling code. You can use the flag jax_host_callback_inline (or the environment variable JAX_HOST_CALLBACK_INLINE) to ensure that the calls to the callbacks are inlined. This works only if the calls are outside a staging context (jit or a control-flow primitive).

The C++ receiver is started automatically on the first call to id_tap(). In order to stop it properly, upon start an atexit handler is registered to call barrier_wait() with the logging name “at_exit”.

There are a few environment variables that you can use to turn on logging for the C++ outfeed receiver backend.

  • TF_CPP_MIN_LOG_LEVEL=0: will turn on INFO logging, needed for all below.

  • TF_CPP_MIN_VLOG_LEVEL=3: will turn make all VLOG logging up to level 3 behave like INFO logs. This may be too much, but you will see which modules are logging relevant info, and then you can select which modules to log from:

  • TF_CPP_VMODULE=<module_name>=3` (the module name can be either C++ or Python, without the extension).

You should also use the --verbosity=2 flag so that you see the logs from Python.

For example: ` TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,infeed_thunk=3,cpu_transfer_manager=3,cpu_runtime=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_jit_simple `

(For bazel tests use –test_arg=–vmodule=…

Still to do:
  • More performance tests.

  • Explore implementation with outside compilation for TPU.

  • Explore implementation with XLA CustomCall for CPU and GPU.

API

jax.experimental.host_callback.id_tap(tap_func: Callable[[T, Sequence[Tuple[str, Dict[str, Any]]]], Any], arg: T)T[source]
jax.experimental.host_callback.id_tap(tap_func: Callable[[T, Sequence[Tuple[str, Dict[str, Any]]]], Any], arg: T, *, result: U)U
jax.experimental.host_callback.id_tap(tap_func: Callable[[T, Sequence[Tuple[str, Dict[str, Any]]]], Any], arg: T, *, result: U, tap_with_device: bool)U

Host-callback tap primitive, like identity function with a call to tap_func.

Experimental: please give feedback, and expect changes!

id_tap behaves semantically like the identity function but has the side-effect that a user-defined Python function is called with the runtime value of the argument.

Parameters
  • tap_func – tap function to call like tap_func(arg, transforms), with arg as described below and where transforms is the sequence of applied JAX transformations in the form (name, params). If the tap_with_device optional argument is True, then the invocation also includes the device from which the value is tapped as a keyword argument: tap_func(arg, transforms, device=dev).

  • arg – the argument passed to the tap function, can be a pytree of JAX types.

  • result – if given, specifies the return value of id_tap. This value is not passed to the tap function, and in fact is not sent from the device to the host. If the result parameter is not specified then the return value of id_tap is arg.

  • tap_with_device – if True then the tap function is invoked with the device from which the tap originates as a keyword argument.

Returns

arg, or result if given.

The order of execution is by data dependency: after all the arguments and the value of result if present, are computed and before the returned value is used. At least one of the returned values of id_tap must be used in the rest of the computation, or else this operation has no effect.

Tapping works even for code executed on accelerators and even for code under JAX transformations.

For more details see the module documentation.

jax.experimental.host_callback.id_print(arg, *, result=None, tap_with_device=False, output_stream=None, threshold=None, **kwargs)[source]

Like id_tap() with a printing tap function.

Experimental: please give feedback, and expect changes!

On each invocation of the printing tap, the kwargs if present will be printed first (sorted by keys). Then arg will be printed, with the arrays stringified with numpy.array2string.

See the id_tap() documentation.

Additional keyword arguments:

  • tap_with_device if True, will print also the device from which the value originates.

  • output_stream if given then it will be used instead of the built-in print. The string will be passed as output_stream.write(s).

  • threshold is passed to numpy.array2string.

jax.experimental.host_callback.call(callback_func, arg, *, result_shape=None, call_with_device=False)[source]

Make a call to the host, and expect a result.

Experimental: please give feedback, and expect changes!

Parameters
  • callback_func (Callable) – The Python function to invoke on the host as callback_func(arg). If the call_with_device optional argument is True, then the invocation also includes the device kwarg with the device from which the call originates: callback_func(arg, device=dev). This function must return a pytree of numpy ndarrays.

  • arg – the argument passed to the callback function, can be a pytree of JAX types.

  • result_shape – a value that describes the expected shape and dtype of the result. This can be a numeric scalar, from which a shape and dtype are obtained, or an object that has .shape and .dtype attributes. If the result of the callback is a pytree, then result_shape should also be a pytree with the same structure. In particular, result_shape can be () or None if the function does not have any results. The device code containing call is compiled with the expected result shape and dtype, and an error will be raised at runtime if the actual callback_func invocation returns a different kind of result.

  • call_with_device – if True then the callback function is invoked with the device from which the call originates as a keyword argument.

Returns

the result of the callback_func invocation.

For more details see the module documentation.

jax.experimental.host_callback.barrier_wait(logging_name=None)[source]

Blocks the calling thread until all current outfeed is processed.

Waits until all outfeed from computations already running on all devices has been received and processed by the Python callbacks. Raises CallbackException if there were exceptions while processing the callbacks.

This works by enqueueing a special tap computation to all devices to which we are listening for outfeed. Once all those tap computations are done, we return from barrier_wait.

Note: If any of the devices are busy and cannot accept new computations, this will deadlock.

Parameters

logging_name (Optional[str]) – an optional string that will be used in the logging statements for this invocation. See Debugging in the module documentation.

exception jax.experimental.host_callback.CallbackException[source]

Signals that some callback function had exceptions.

Raised by barrier_wait(). See module documentation for details.