jax.experimental.host_callback.call#
- jax.experimental.host_callback.call(callback_func, arg, *, result_shape=None, call_with_device=False, device_index=0)[source]#
Make a call to the host, and expect a result.
Experimental: please give feedback, and expect changes!
- Parameters
callback_func (
Callable
) – The Python function to invoke on the host ascallback_func(arg)
. If thecall_with_device
optional argument is True, then the invocation also includes thedevice
kwarg with the device from which the call originates:callback_func(arg, device=dev)
. This function must return a pytree of numpy ndarrays.arg – the argument passed to the callback function, can be a pytree of JAX types.
result_shape – a value that describes the expected shape and dtype of the result. This can be a numeric scalar, from which a shape and dtype are obtained, or an object that has
.shape
and.dtype
attributes. If the result of the callback is a pytree, thenresult_shape
should also be a pytree with the same structure. In particular,result_shape
can be () or None if the function does not have any results. The device code containingcall
is compiled with the expected result shape and dtype, and an error will be raised at runtime if the actualcallback_func
invocation returns a different kind of result.call_with_device – if True then the callback function is invoked with the device from which the call originates as a keyword argument.
device_index – specifies from which device the tap function is invoked in a SPMD program. Works only when using the outfeed implementation mechanism, i.e., does not work on CPU unless –jax_host_callback_outfeed=True.
- Returns
the result of the
callback_func
invocation.
For more details see the
jax.experimental.host_callback
module documentation.