jax.experimental.multihost_utils module

jax.experimental.multihost_utils module#

Utilities for synchronizing and communication across multiple hosts.

Multihost Utils API Reference#

broadcast_one_to_all(in_tree[, is_source])

Broadcast data from a source host (host 0 by default) to all other hosts.

sync_global_devices(name)

Creates a barrier across all hosts/devices.

process_allgather(in_tree[, tiled])

Gather data from across processes.

assert_equal(in_tree[, fail_message])

Verifies that all the hosts have the same tree of values.

host_local_array_to_global_array(...)

Converts a host local value to a globally sharded jax.Array.

global_array_to_host_local_array(...)

Converts a global jax.Array to a host local jax.Array.