jax.experimental.multihost_utils.broadcast_one_to_all#
- jax.experimental.multihost_utils.broadcast_one_to_all(in_tree, is_source=None)[source]#
Broadcast data from a source host (host 0 by default) to all other hosts.
- Parameters:
in_tree (Any) – pytree of arrays - each array must have the same shape across the hosts.
is_source (bool | None) – optional bool denoting whether the caller is the source. Only ‘source host’ will contribute the data for the broadcast. If None, then host 0 is used.
- Return type:
Any
- Returns:
A pytree matching in_tree where the leaves now all contain the data from the first host.