jax.experimental.multihost_utils
module#
Utilities for synchronizing and communication across multiple hosts.
Multihost Utils API Reference#
|
Broadcast data from a source host (host 0 by default) to all other hosts. |
|
Creates a barrier across all hosts/devices. |
|
Gather data from across processes. |
|
Verifies that all the hosts have the same tree of values. |
Converts a host local value to a globally sharded jax.Array. |
|
Converts a global jax.Array to a host local jax.Array. |