jax.experimental.host_callback.call#

jax.experimental.host_callback.call(callback_func, arg, *, result_shape=None, call_with_device=False)[source]#

Make a call to the host, and expect a result.

Experimental: please give feedback, and expect changes!

Parameters
  • callback_func (Callable) – The Python function to invoke on the host as callback_func(arg). If the call_with_device optional argument is True, then the invocation also includes the device kwarg with the device from which the call originates: callback_func(arg, device=dev). This function must return a pytree of numpy ndarrays.

  • arg – the argument passed to the callback function, can be a pytree of JAX types.

  • result_shape – a value that describes the expected shape and dtype of the result. This can be a numeric scalar, from which a shape and dtype are obtained, or an object that has .shape and .dtype attributes. If the result of the callback is a pytree, then result_shape should also be a pytree with the same structure. In particular, result_shape can be () or None if the function does not have any results. The device code containing call is compiled with the expected result shape and dtype, and an error will be raised at runtime if the actual callback_func invocation returns a different kind of result.

  • call_with_device – if True then the callback function is invoked with the device from which the call originates as a keyword argument.

Returns

the result of the callback_func invocation.

For more details see the module documentation.