jax.experimental.host_callback.id_tap(tap_func, arg, *, result=None, tap_with_device=False, device_index=0, **kwargs)[source]#

Host-callback tap primitive, like identity function with a call to tap_func.

Experimental: please give feedback, and expect changes!

id_tap behaves semantically like the identity function but has the side-effect that a user-defined Python function is called with the runtime value of the argument.

  • tap_func – tap function to call like tap_func(arg, transforms), with arg as described below and where transforms is the sequence of applied JAX transformations in the form (name, params). If the tap_with_device optional argument is True, then the invocation also includes the device from which the value is tapped as a keyword argument: tap_func(arg, transforms, device=dev).

  • arg – the argument passed to the tap function, can be a pytree of JAX types.

  • result – if given, specifies the return value of id_tap. This value is not passed to the tap function, and in fact is not sent from the device to the host. If the result parameter is not specified then the return value of id_tap is arg.

  • tap_with_device – if True then the tap function is invoked with the device from which the tap originates as a keyword argument.

  • device_index – specifies from which device the tap function is invoked in a SPMD program. Works only when using the outfeed implementation mechanism, i.e., does not work on CPU unless –jax_host_callback_outfeed=True.


arg, or result if given.

The order of execution is by data dependency: after all the arguments and the value of result if present, are computed and before the returned value is used. At least one of the returned values of id_tap must be used in the rest of the computation, or else this operation has no effect.

Tapping works even for code executed on accelerators and even for code under JAX transformations.

For more details see the jax.experimental.host_callback module documentation.