jax.experimental.host_callback.id_tap#
- jax.experimental.host_callback.id_tap(tap_func, arg, *, result=None, tap_with_device=False, device_index=0, **kwargs)[source]#
Host-callback tap primitive, like identity function with a call to
tap_func
.Experimental: please give feedback, and expect changes!
id_tap
behaves semantically like the identity function but has the side-effect that a user-defined Python function is called with the runtime value of the argument.- Parameters:
tap_func – tap function to call like
tap_func(arg, transforms)
, witharg
as described below and wheretransforms
is the sequence of applied JAX transformations in the form(name, params)
. If the tap_with_device optional argument is True, then the invocation also includes the device from which the value is tapped as a keyword argument:tap_func(arg, transforms, device=dev)
.arg – the argument passed to the tap function, can be a pytree of JAX types.
result – if given, specifies the return value of
id_tap
. This value is not passed to the tap function, and in fact is not sent from the device to the host. If theresult
parameter is not specified then the return value ofid_tap
isarg
.tap_with_device – if True then the tap function is invoked with the device from which the tap originates as a keyword argument.
device_index – specifies from which device the tap function is invoked in a SPMD program. Works only when using the outfeed implementation mechanism, i.e., does not work on CPU unless –jax_host_callback_outfeed=True.
- Returns:
arg
, orresult
if given.
The order of execution is by data dependency: after all the arguments and the value of
result
if present, are computed and before the returned value is used. At least one of the returned values ofid_tap
must be used in the rest of the computation, or else this operation has no effect.Tapping works even for code executed on accelerators and even for code under JAX transformations.
For more details see the
jax.experimental.host_callback
module documentation.