- jax.experimental.host_callback.id_tap(tap_func, arg, *, result=None, tap_with_device=False, device_index=0, **kwargs)#
Host-callback tap primitive, like identity function with a call to
Experimental: please give feedback, and expect changes!
id_tapbehaves semantically like the identity function but has the side-effect that a user-defined Python function is called with the runtime value of the argument.
tap_func – tap function to call like
tap_func(arg, transforms), with
argas described below and where
transformsis the sequence of applied JAX transformations in the form
(name, params). If the tap_with_device optional argument is True, then the invocation also includes the device from which the value is tapped as a keyword argument:
tap_func(arg, transforms, device=dev).
arg – the argument passed to the tap function, can be a pytree of JAX types.
result – if given, specifies the return value of
id_tap. This value is not passed to the tap function, and in fact is not sent from the device to the host. If the
resultparameter is not specified then the return value of
tap_with_device – if True then the tap function is invoked with the device from which the tap originates as a keyword argument.
device_index – specifies from which device the tap function is invoked in a SPMD program. Works only when using the outfeed implementation mechanism, i.e., does not work on CPU unless –jax_host_callback_outfeed=True.
The order of execution is by data dependency: after all the arguments and the value of
resultif present, are computed and before the returned value is used. At least one of the returned values of
id_tapmust be used in the rest of the computation, or else this operation has no effect.
Tapping works even for code executed on accelerators and even for code under JAX transformations.
For more details see the