- jax.experimental.host_callback.call(callback_func, arg, *, result_shape=None, call_with_device=False, device_index=0)[source]#
Make a call to the host, and expect a result.
Experimental: please give feedback, and expect changes!
Callable) – The Python function to invoke on the host as
callback_func(arg). If the
call_with_deviceoptional argument is True, then the invocation also includes the
devicekwarg with the device from which the call originates:
callback_func(arg, device=dev). This function must return a pytree of numpy ndarrays.
arg – the argument passed to the callback function, can be a pytree of JAX types.
result_shape – a value that describes the expected shape and dtype of the result. This can be a numeric scalar, from which a shape and dtype are obtained, or an object that has
.dtypeattributes. If the result of the callback is a pytree, then
result_shapeshould also be a pytree with the same structure. In particular,
result_shapecan be () or None if the function does not have any results. The device code containing
callis compiled with the expected result shape and dtype, and an error will be raised at runtime if the actual
callback_funcinvocation returns a different kind of result.
call_with_device – if True then the callback function is invoked with the device from which the call originates as a keyword argument.
device_index – specifies from which device the tap function is invoked in a SPMD program. Works only when using the outfeed implementation mechanism, i.e., does not work on CPU unless –jax_host_callback_outfeed=True.
the result of the
For more details see the