jax.experimental.host_callback.call

Contents

jax.experimental.host_callback.call#

jax.experimental.host_callback.call(callback_func, arg, *, result_shape=None, call_with_device=False, device_index=0, callback_flavor=CallbackFlavor.IO_CALLBACK)#

Make a call to the host, and expect a result.

Warning

The host_callback APIs are deprecated as of March 20, 2024. The functionality is subsumed by the new JAX external callbacks See google/jax#20385.

Parameters:
  • callback_func (Callable) – The Python function to invoke on the host as callback_func(arg). If the call_with_device optional argument is True, then the invocation also includes the device kwarg with the device from which the call originates: callback_func(arg, device=dev). This function must return a pytree of numpy ndarrays.

  • arg – the argument passed to the callback function, can be a pytree of JAX types.

  • result_shape – a value that describes the expected shape and dtype of the result. This can be a numeric scalar, from which a shape and dtype are obtained, or an object that has .shape and .dtype attributes. If the result of the callback is a pytree, then result_shape should also be a pytree with the same structure. In particular, result_shape can be () or None if the function does not have any results. The device code containing call is compiled with the expected result shape and dtype, and an error will be raised at runtime if the actual callback_func invocation returns a different kind of result.

  • call_with_device – if True then the callback function is invoked with the device from which the call originates as a keyword argument.

  • device_index – specifies from which device the tap function is invoked in a SPMD program. Works only when using the outfeed implementation mechanism, i.e., does not work on CPU unless –jax_host_callback_outfeed=True.

  • callback_flavor – if running with JAX_HOST_CALLBACK_LEGACY=False specifies the flavor of callback to use. See google/jax#20385.

Returns:

the result of the callback_func invocation.

For more details see the jax.experimental.host_callback module documentation.