jax.experimental.host_callback.call#
- jax.experimental.host_callback.call(callback_func, arg, *, result_shape=None, call_with_device=False, device_index=0, callback_flavor=CallbackFlavor.IO_CALLBACK)#
Make a call to the host, and expect a result.
Warning
The host_callback APIs are deprecated as of March 20, 2024. The functionality is subsumed by the new JAX external callbacks See google/jax#20385.
- Parameters:
callback_func (
Callable
) – The Python function to invoke on the host ascallback_func(arg)
. If thecall_with_device
optional argument is True, then the invocation also includes thedevice
kwarg with the device from which the call originates:callback_func(arg, device=dev)
. This function must return a pytree of numpy ndarrays.arg – the argument passed to the callback function, can be a pytree of JAX types.
result_shape – a value that describes the expected shape and dtype of the result. This can be a numeric scalar, from which a shape and dtype are obtained, or an object that has
.shape
and.dtype
attributes. If the result of the callback is a pytree, thenresult_shape
should also be a pytree with the same structure. In particular,result_shape
can be () or None if the function does not have any results. The device code containingcall
is compiled with the expected result shape and dtype, and an error will be raised at runtime if the actualcallback_func
invocation returns a different kind of result.call_with_device – if True then the callback function is invoked with the device from which the call originates as a keyword argument.
device_index – specifies from which device the tap function is invoked in a SPMD program. Works only when using the outfeed implementation mechanism, i.e., does not work on CPU unless –jax_host_callback_outfeed=True.
callback_flavor – if running with JAX_HOST_CALLBACK_LEGACY=False specifies the flavor of callback to use. See google/jax#20385.
- Returns:
the result of the
callback_func
invocation.
For more details see the
jax.experimental.host_callback
module documentation.