jax.experimental.host_callback
module#
Primitives for calling Python functions on the host from JAX accelerator code.
Experimental: please give feedback, and expect changes.
This module introduces the host callback functions call()
,
id_tap()
, and id_print()
, that send their arguments from the device
to the host and invoke user-defined Python functions on the host, optionally
returning results back to the device computation.
We show below how these functions can be used. We start with call()
,
and we discuss examples of calling from JAX to arbitrary Python functions
on the CPU, e.g., to use NumPy CPU custom kernels. Then we
show uses of id_tap()
and id_print()
, which have the restriction
that they cannot return values from the host to the device.
These primitives are generally faster
because they are executed asynchronously with the device code.
In particular, they can be used to tap into and to debug JAX code.
Using call()
to call a host function and return results to device#
Use call()
to invoke a computation on the host and return
NumPy arrays to the device computation.
Host computation is useful, e.g., when a device computation needs some data
that requires I/O on the host, or it needs a library that is available on the
host and you do not want to code it in JAX.
For example, eigen decomposition for general matrices in JAX does not work on TPU.
We can call the Numpy implementation from any JAX accelerator computation,
using a host computation:
# This function runs on the host
def host_eig(m: np.ndarray) -> np.ndarray:
return np.linalg.eigvals(m)
# This function is used in JAX
def device_fun(m):
# We send "m" to the host, asking it to call "host_eig" and return the result.
# We have to specify the result shape and dtype, either in the form of an
# example return value or any object that has `shape` and `dtype` attributes,
# e.g., a NumPy array or a `jax.ShapeDtypeStruct`.
return hcb.call(host_eig, m,
# Given an input of shape (..., d, d), eig output has shape (..., d)
result_shape=jax.ShapeDtypeStruct(m.shape[:-1], m.dtype))
The call()
function and the Python host function both take a single argument
and return a single result, but those can be pytrees. Note that we must tell
the call()
what shape and dtype to expect from the host invocation, using
the result_shape
keyword argument.
This is important because the device code is compiled with that expectation.
There will be an error raised at runtime if the actual invocation produces a
different result shape. In general, such errors and also exceptions raised
by the host computation may be difficult to debug. See the Debugging section
below.
This is a problem for call()
but not for id_tap()
because for the
latter the device code does not expect a returned value.
The call()
API can be used inside a jit or pmap computation or inside
cond/scan/while control flow. When used inside jax.pmap()
, there will be
separate calls to the host from each of the participating devices:
def host_sin(x, *, device):
# The ``device`` argument is passed due to ``call_with_device=True`` below.
print(f"Invoking host_sin with {x.shape} on {device}")
return np.sin(x)
# Use pmap to run the computation on two devices
jax.pmap(lambda x: hcb.call(host_sin, x,
result_shape=x,
# Ask that the `host_sin` function be passed `device=dev`
call_with_device=True))(
np.ones((2, 4), dtype=np.float32))
# prints (in arbitrary order)
# Invoking host_sin with (4,) on cpu:0
# Invoking host_sin with (4,) on cpu:1
Note that call()
does not support any JAX transformations, but as we
show below one can make use of the
existing support for Custom differentiation in JAX.
Using id_tap()
to call a Python function on the host, with no returned values#
The id_tap()
and id_print()
are special cases of call()
, when
you just want the side effects of your Python callback. These functions have
the advantage that once the arguments have been sent to the host, the device
computation can proceed without waiting for the Python callback to return.
For id_tap()
you can specify your Python callback to be called, while
id_print()
uses a built-in callback that prints the arguments to
stdout on the host.
The Python function passed
to id_tap()
takes two positional arguments (the value tapped
from the device computation along with a transforms
tuple,
described below). Optionally, the function may be passed a keyword argument
device
with the Device from which the value was tapped.
A few examples:
def host_func(arg, transforms):
...do something with arg...
# calls host_func(2x, []) on host
id_tap(host_func, 2 * x)
# calls host_func((2x, 3x), [])
id_tap(host_func, (2 * x, 3 * x)) # The argument can be a pytree
# calls host_func(2x, [], device=jax.devices()[0])
id_tap(host_func, 2 * x, tap_with_device=True) # Pass the device to the tap
# calls host_func(2x, [], what='activation')
id_tap(functools.partial(host_func, what='activation'), 2 * x)
# calls host_func(dict(x=x, y=y), what='data')
id_tap(lambda tap, transforms: host_func(tap, what='data'), dict(x=x, y=y))
The above examples can all be adapted to use id_print()
instead, with
the difference that id_print()
prints on the host the positional argument,
along with any additional kwargs and the automatic kwarg transforms
.
Using barrier_wait()
to wait until all callbacks have executed#
If your Python callbacks have side-effects you may need to wait until the
computation has finished to ensure that the side-effects have been observed.
You can use the barrier_wait()
function for that purpose:
accumulator = []
def host_log(arg, transforms):
# We just record the arguments in a list
accumulator.append(arg)
def device_fun(c):
id_tap(host_log, x)
id_tap(host_log, 2. * x)
jax.jit(device_fun)(1.)
jax.jit(device_fun)(1.)
# At this point, we have started two computations, each with two
# taps, but they may not have yet executed.
barrier_wait()
# Now we know that all the computations started before `barrier_wait`
# on all devices, have finished, and all the callbacks have finished
# executing.
Note that barrier_wait()
will start one
tiny computation with one tap on each of the jax.local_devices() and
will wait for all these taps to be received.
An alternative to using barrier_wait()
is to just wait for the end
of the computation, if all the callbacks are call()
:
accumulator = p[]
def host_log(arg):
# We just record the arguments in a list
accumulator.append(arg)
return 0. # return something
def device_fun(c):
y = call(host_log, x, result_shape=jax.ShapeDtypeStruct((), np.float32))
z = call(host_log, 2. * x, result_shape=jax.ShapeDtypeStruct((), np.float32))
return y + z # return something that uses both results
res1 = jax.jit(device_fun)(1.)
res2 = jax.jit(device_fun)(1.)
res1.block_until_ready()
res2.block_until_ready()
Behavior under parallelization transformations#
In presence of jax.pmap()
the code will run on multiple devices and
each device will tap its values independently.
It may be helpful to use the tap_with_device
option for id_print()
or id_tap()
, so that you see which device is sending which data:
jax.pmap(power3, devices=jax.local_devices()[:2])(np.array([3., 4.])
# device=cpu:0 what=x,x^2: (3., 9.) # from the first device
# device=cpu:1 what=x,x^2: (4., 16.) # from the second device
When using jax.pmap()
with multiple devices on multiple hosts, every
host will receive callbacks from all of its local devices, with an operand
that corresponds to each device slice. For a
call()
, the callback must return to each device only the slice of the
result that pertains to the corresponding device.
When using the experimental pjit.pjit()
the code will run on multiple
devices on different shards of the input. The current implementation of
host callbacks will ensure that a single device will collect and outfeed
the entire operand, in a single callback. The callback function is supposed
to return the entire array, which will then be sent in a single infeed to the
same device that issued the outfeed. This device is then responsible for
sending the required shards to the other devices:
with jax.sharding.Mesh(jax.local_devices()[:2], ["d"]):
pjit.pjit(power3, in_shardings=(P("d"),),
out_shardings=(P("d"),))(np.array([3., 4.]))
# device=TPU:0 what=x,x^2: ( [3., 4.],
# [9., 16.] )
Note that the collection of the operand on one device may result in OOM if the operand was sharded across devices.
When using pjit.pjit()
with multiple devices on multiple hosts, only
the host for the device 0 (w.r.t. the mesh) will receive the callback, with
the operand collected
from all participating devices on all hosts. For a call()
, the callback
must return the entire array for all devices on all hosts.
Behavior under JAX autodiff transformations#
When used under a JAX autodiff transformation, the host callback functions operate on the primal values only. Consider the following example:
def power3(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2")
return y * x
power3(3.)
# what: x,x^2 : (3., 9.)
(You can see these examples tested in host_callback_test.HostCallbackTapTest.test_tap_transforms.)
When used under jax.jvp()
there will be one callback with the primal
values only:
jax.jvp(power3, (3.,), (0.1,))
# what: x,x^2 : (3., 9.)
Similarly for jax.grad()
, we get a callback from the forward computation
only:
jax.grad(power3)(3.)
# what: x,x^2 : (3., 9.)
If you want to invoke the callback on the tangents during a jax.jvp()
,
you can use a custom_jvp. For example, you can define a function that does
nothing interesting except that its custom_jvp will print the tangents:
@jax.custom_jvp
def print_tangents(arg):
return None
@print_tangents.defjvp
def print_tangents_jvp(primals, tangents):
arg_dot, = tangents
hcb.id_print(arg_dot, what="tangents")
return primals, tangents
Then you use this function in the places where you want to tap the tangents:
def power3_with_tangents(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2")
print_tangents((x, y))
return y * x
jax.jvp(power3_with_tangents, (3.,), (0.1,))
# what: x,x^2 : (3., 9.)
# what: tangents : (0.1, 0.6)
You can do a similar thing for the cotangents during jax.grad()
. This
time you must be careful to use in the rest of the computation the values whose
cotangents you want to tap. Hence we make the print_cotangents
return
its argument:
@jax.custom_vjp
def print_cotangents(arg):
# Must return the argument for which we want the cotangent.
return arg
# f_fwd: a -> (b, residual)
def print_cotangents_fwd(arg):
return print_cotangents(arg), None
# f_bwd: (residual, CT b) -> [CT a]
def print_cotangents_bwd(residual, ct_b):
hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream)
return ct_b,
print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd)
def power3_with_cotangents(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
(x1, y1) = print_cotangents((x, y))
# Must use the output of print_cotangents
return y1 * x1
jax.grad(power3_with_cotangents)(3.)
# what: x,x^2 : (3., 9.)
# what: cotangents : (9., 3.)
If you use ad_checkpoint.checkpoint()
to rematerialize the residuals
for the backward pass, then the callbacks from the primal computation will
be called twice:
jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.)
# what: x,x^2 : (3., 9.)
# what: x,x^2 : (27., 729.)
# what: x,x^2 : (3., 9.)
The callbacks are, in order from: the primal computation of the inner power3
,
the primal computation of the outer power3
, and the rematerialization
of the residuals for the inner power3
.
Behavior under jax.vmap#
The host callback functions id_print()
and id_tap()
support the
vectorization transformation jax.vmap()
.
For jax.vmap()
the arguments to the callback are batched,
and the callback function is
passed an additional special transforms
containing a list of transformation descriptors
in the form ("batch", {"batch_dims": ...})
, where ...`
denotes the
batched dimensions for the tapped values (one entry per argument, `
None` denotes an argument that was broadcast).
jax.vmap(power3)(np.array([2., 3.])) # transforms: [(‘batch’, {‘batch_dims’: (0, 0)})] what: x,x^2 : ([2., 3.], [4., 9.])
See documentation for id_tap()
, id_print()
, and call()
.
For more usage example, see tests/host_callback_test.py.
Using call()
to call a TensorFlow function, with reverse-mode autodiff support#
Another possible use for host computation is to invoke a library written for
another framework, such as TensorFlow.
In this case it becomes interesting to support JAX autodiff for host callbacks
by deferring to the autodiff mechanism in TensorFlow,
using the jax.custom_vjp()
mechanism.
This is relatively easy to do, once one understands both the JAX custom VJP
and the TensorFlow autodiff mechanisms.
The code for how this can be done is shown in the call_tf_full_ad
function in host_callback_to_tf_test.py.
This example supports arbitrary higher-order differentiation as well.
Note that if you just want to call TensorFlow functions from JAX, you can also use the jax2tf.call_tf function.
Using call()
to call a JAX function on another device, with reverse-mode autodiff support#
It should not be surprising that we can use host computation to invoke a JAX computation on another device. The arguments are sent from the accelerator to the host, and then to the outside device on which the JAX host computation will run, and then the results are sent back to the original accelerator.
The code for how this can be done is shown in the call_jax_other_device function
in host_callback_test.py.
Low-level details and debugging#
The host callback functions will be executed for each device in the order in which the send operations were performed on the device.
The host callback functions for multiple devices may be interleaved.
The data from the devices is received by separate threads managed by the JAX
runtime (one thread per device). The runtime maintains a buffer of
configurable size (see the flag --jax_host_callback_max_queue_byte_size
).
When the buffer is full, all the receiving threads are paused
which eventually pauses the computation on devices. The runtime has one
additional thread for each device to invoke the Python user functions with the
received data. If the processing of the callbacks is slow, it may actually
lead to the runtime buffer filling up, and eventually pausing the computation
on the devices when they need to send something.
For more details on the outfeed receiver runtime mechanism see
runtime code.
In order to pause the execution until all data from computations already
started on devices has arrived and has been processed, use barrier_wait()
.
Exceptions from the user-defined callback functions are logged along with their
stack traces, but the receiving threads are not stopped. Instead the last
exception is recorded and the subsequent barrier_wait()
will
raise CallbackException
if any exception had occurred
in one of the tap functions. This exception will include the text and the
stack trace of the last exception encountered.
One further complication arises for callback functions that must return
results to the call origin device, such as call()
. This is handled
differently on CPU/GPU devices compared to TPU devices.
On CPU/GPU devices, in order to avoid the device computation
being stuck waiting for a result that will never arrive, in case of any
error during the processing of the callback (whether raised by the user-code
itself or due to a mismatch of the returned value and the expected return_shape)
we send the device a “fake” result of shape int8[12345]
.
This will make the device
computation abort because the received data is different than the one that
it expects. On CPU the runtime will crash with a distinctive error message:
`
Check failed: buffer->length() == buffer_length (12345 vs. ...)
`
On GPU, the failure is more user-friendly and will be surfaced to the Python program as:
`
RET_CHECK failure ... Mismatch between infeed source buffer shape s8[12345] ...
`
To debug the underlying cause for these messages, see the Debugging section.
On TPU devices, there is currently no shape check for infeed, so we take the safer route of not sending this fake result in case of errors. This means that the computation will hang, and no exception will be raised (but any exceptions in the callback functions will still appear in the logs).
The current implementation uses the outfeed mechanism provided by XLA. The mechanism itself is quite primitive in the sense that a receiver must know exactly the shape of each incoming packet, and how many packets are expected. This makes it hard to use for multiple kinds of data in the same computation, and it is practically impossible to use it under conditionals or in loops of non-constant iteration count. Furthermore, code that uses the outfeed mechanism directly cannot be transformed by JAX. All these limitations are addressed by the host callback functions. The tapping API introduced here makes it easy to share the outfeed mechanism for multiple purposes, while supporting all transformations.
Note that after you have used the host callback functions, you cannot
use lax.outfeed directly. You may want to stop_outfeed_receiver()
if you later need to use lax.outfeed.
Since the actual calls to your callback functions are made from the C++
receiver, it may be hard to debug the calls. In particular, the stack trace
will not include the calling code. You can use the flag
jax_host_callback_inline
(or the environment variable
JAX_HOST_CALLBACK_INLINE
) to ensure that the calls to the callbacks are
inlined. This works only if the calls are outside a staging context
(jit()
or a control-flow primitive).
The C++ receiver
is started automatically on the first call to id_tap()
. In order to stop
it properly, upon start an atexit
handler is registered to call
barrier_wait()
with the logging name “at_exit”.
There are a few environment variables that you can use to turn on logging for the C++ outfeed receiver backend.
TF_CPP_MIN_LOG_LEVEL=0
: will turn on INFO logging, needed for all below.
TF_CPP_MIN_VLOG_LEVEL=3
: will make all VLOG logging up to level 3 behave like INFO logs. This may be too much, but you will see which modules are logging relevant info, and then you can select which modules to log from.
TF_CPP_VMODULE=<module_name>=3
(the module name can be either C++ or Python, without the extension).
You should also use the --verbosity=2
flag so that you see the logs
from Python.
For example, you can try to enable logging in the host_callback
module:
TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=host_callback=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple
If you want to enable logging in lower-level implementation modules try:
TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,infeed_thunk=3,cpu_transfer_manager=3,cpu_runtime=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple
(For bazel tests use –test_arg=–vmodule=…
- Still to do:
More performance tests.
Explore implementation with outside compilation for TPU.
Explore implementation with XLA CustomCall for CPU and GPU.
API#
|
Host-callback tap primitive, like identity function with a call to |
|
Like |
|
Make a call to the host, and expect a result. |
|
Blocks the calling thread until all current outfeed is processed. |
Signals that some callback function had exceptions. |