jax.experimental.host_callback module

Primitives for calling from accelerators to Python functions on the host.

Experimental: please give feedback, and expect changes.

This module introduces the host callback functions id_tap() and id_print(), which behave like the identity function but have the side-effect of sending the arguments from the device to the host and invoking a user-specified Python function (for id_tap()) or printing the arguments on the host (for id_print()). The Python function passed to id_tap() takes two positional arguments (the value tapped from the device computation along with transforms sequence, described below). A few examples:

# calls func(2x, []) on host and returns 2x
y = id_tap(func, 2 * x)
# calls func((2x, 3x), []) and returns (2x, 3x)
y, z = id_tap(func, (2 * x, 3 * x))  # The argument can be a pytree
# calls func(2x, []) and returns y
y = id_tap(func, 2 * x, result=y)  # override the result of id_tap
# calls func(2x, [], what='activation') and returns 2x
y = id_tap(functools.partial(func, what='activation'), 2 * x)
# calls func(dict(x=x, y=y), what='data') and returns dict(x=x, y=y)
x, y = id_tap(lambda tap, transforms: func(tap, what='data'), dict(x=x, y=y))

The above examples can all be adapted to use id_print() instead, with the difference that id_print() takes one positional argument (to print on the host), the optional kwarg result, and possibly additional kwargs that are also printed along with the automatic kwarg transforms.

The order of execution of the tap functions is constrained by data dependency: the arguments are tapped after all the arguments are computed and before the result of the call is used. As of September 2020, it is not necessary anymore for the results of the tap to be used in the rest of the computation. The tap function will execute based on program order. The host tap functions will be executed for each device in the order in which the send operations were performed on the device.

The host tap functions for multiple devices may be interleaved. The data from the devices is received by separate threads managed by the JAX runtime (one thread per device). The runtime maintains a buffer of configurable size. When the buffer is full, all the receiving threads are paused which eventually pauses the computation on devices. The runtime has one additional thread that invokes the Python user functions with the received data. If the processing of the callbacks is slow, it may actually lead to the runtime buffer filling up, and eventually pausing the computation on the devices when they need to send something. For more details on the outfeed receiver runtime mechanism see runtime code.

Exceptions from the user-defined tap functions are logged along with their stack traces, but the receiving threads are not stopped.

In order to pause the execution until all data from computations already started on devices has arrived and has been processed, use barrier_wait(). This will also raise TapFunctionException if any exception had occurred in one of the tap functions.

The current implementation uses the outfeed mechanism provided by XLA. The mechanism itself is quite primitive in the sense that a receiver must know exactly the shape of each incoming packet, and how many packets are expected. This makes it hard to use for multiple kinds of data in the same computation, and it is practically impossible to use it under conditionals or in loops of non-constant iteration count. Furthermore, code that uses the outfeed mechanism directly cannot be transformed by JAX. All these limitations are addressed by the host callback functions. The tapping API introduced here makes it easy to share the outfeed mechanism for multiple purposes, while supporting all transformations.

Note that after you have used the host callback functions, you cannot use lax.outfeed directly. You may want to stop_outfeed_receiver() if you later need to use lax.outfeed.

We describe the behaviour under transformations in the context of the following function definition:

def power3(x):
   y = x * x
   _, y = id_print((x, y), what="x,x^2")  # Must pack multiple arguments
   return y * x

power3(3.)
# what: x,x^2 : [3., 9.]

During JAX transformations the special parameter transforms is added to contain a list of transformation descriptors in the form (transform_name, transform_params).

For jax.vmap() the arguments are batched, and transforms is extended with transformation name batch and batch_dims set to the the tuple of batched dimensions (one entry per argument, None denotes an argument that was broadcast):

jax.vmap(power3)(np.arange(3.))
# transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 : [[0, 1, 2], [0, 1,
4]]

For jax.jvp() there will be two callbacks, one with the values of the primals and one with the tangents:

jax.jvp(power3, (3.,), (0.1,))
# what: x,x^2: [3., 9.]
# transforms: ['jvp'] what: x,x^2 : [0.1, 0.6]

For jax.vjp() or jax.grad() there will be one callback with the values of the adjoints for the arguments. You may also see a callback with the values of the primals from the forward pass, if those values are needed for the backward pass:

jax.grad(power3)(3.)
# what=x,x^2: [3., 9.]  # from forward pass, since y is used in backward pass
# transforms: ['jvp', 'transpose'] what: x,x^2 : [0., 3.]  # from backward pass, adjoints of _, y

See documentation for id_tap() and id_print(). For more usage example, see tests/host_callback_test.py.

Still to do:
  • Performance tests.

  • Add flags for logging.

  • Add unit tests with mocks.

  • Explore a simpler API that uses Python program-order, instead of data dependency-order.

  • Explore implementation with outside compilation.

  • Explore an extended API that allows the host function to return values to the accelerator computation.

Low-level details and debugging

The C++ receiver is started automatically on the first call to id_tap(). In order to stop it properly, upon start an atexit handler is registered to call barrier_wait() with the logging name “at_exit”.

There are a few environment variables that you can use to turn on logging for the C++ outfeed receiver backend.

  • TF_CPP_MIN_LOG_LEVEL=0: will turn on INFO logging, needed for all below.

  • TF_CPP_MIN_VLOG_LEVEL=3: will turn make all VLOG logging up to level 3 behave like INFO logs. This may be too much, but you will see which modules are logging relevant info, and then you can select which modules to log from:

  • TF_CPP_VMODULE=<module_name>=3`

You should also use the --verbosity=2 flag so that you see the logs from Python.

For example: ` TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,cpu_transfer_manager=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackTest.test_jit_simple `

API

jax.experimental.host_callback.id_tap(tap_func, arg, *, result=None, **kwargs)[source]

Host-callback tap primitive, like identity function with a call to tap_func.

Experimental: please give feedback, and expect changes!

id_tap behaves semantically like the identity function but has the side-effect that a user-defined Python function is called with the runtime value of the argument.

Parameters
  • tap_func – tap function to call like tap_func(arg, transforms), with arg as described below and where transforms is the sequence of applied JAX transformations in the form (name, params).

  • arg – the argument passed to the tap function, can be a pytree of JAX types.

  • result – if given, specifies the return value of id_tap. This value is not passed to the tap function, and in fact is not sent from the device to the host. If the result parameter is not specified then the return value of id_tap is arg.

Returns

arg, or result if given.

The order of execution is by data dependency: after all the arguments and the value of result if present, are computed and before the returned value is used. At least one of the returned values of id_tap must be used in the rest of the computation, or else this operation has no effect.

If you want to tap a constant value, you should use the result parameter to control when it is tapped, otherwise it will be tapped during tracing of the function:

x = id_tap(42, result=x)

Tapping works even for code executed on accelerators and even for code under JAX transformations. Code that uses taps must be run embedded in outfeed_receiver().

For more details see the module documentation.

jax.experimental.host_callback.id_print(arg, *, result=None, output_stream=None, threshold=None, **kwargs)[source]

Like id_tap() with a printing tap function.

Experimental: please give feedback, and expect changes!

On each invocation of the printing tap, the kwargs if present will be printed first (sorted by keys). Then arg will be printed, with the arrays stringified with numpy.array2string.

See the id_tap() documentation.

Additional keyword arguments:

  • output_stream if given then it will be used instead of the built-in print. The string will be passed as output_stream.write(s).

  • threshold is passed to numpy.array2string.

jax.experimental.host_callback.outfeed_receiver()[source]

Implements a barrier after a block of code.

DEPRECATED: This function is not necessary anymore, it is here for backwards compatiblity. At the moment it implements a barrier_wait after the body of the context manager finishes.

exception jax.experimental.host_callback.TapFunctionException[source]

Signals that some tap function had exceptions.

Raised by outfeed_receiver().