jax.experimental.multihost_utils module

jax.experimental.multihost_utils module#

Utilities for synchronizing and communication across multiple hosts.

Multihost Utils API Reference#

broadcast_one_to_all(in_tree[, is_source])

Broadcast data from a source host (host 0 by default) to all other hosts.


Creates a barrier across all hosts/devices.

process_allgather(in_tree[, tiled])

Gather data from across processes.

assert_equal(in_tree[, fail_message])

Verifies that all the hosts have the same tree of values.


Converts a host local value to a globally sharded jax.Array.


Converts a global jax.Array to a host local jax.Array.