jax.experimental.host_callback.barrier_wait#

jax.experimental.host_callback.barrier_wait(logging_name=None)[source]#

Blocks the calling thread until all current outfeed is processed.

Waits until all callbacks from computations already running on all devices have been received and processed by the Python callbacks. Raises CallbackException if there were exceptions while processing the callbacks.

This works by enqueueing a special tap computation to all devices to which we are listening for outfeed. Once all those tap computations are done, we return from barrier_wait.

Note: If any of the devices are busy and cannot accept new computations, this will deadlock.

Parameters

logging_name (Optional[str]) – an optional string that will be used in the logging statements for this invocation. See Debugging in the module documentation.