jax.experimental.multihost_utils.broadcast_one_to_all

jax.experimental.multihost_utils.broadcast_one_to_all#

jax.experimental.multihost_utils.broadcast_one_to_all(in_tree, is_source=None)[source]#

Broadcast data from a source host (host 0 by default) to all other hosts.

Parameters:
  • in_tree (Any) – pytree of arrays - each array must have the same shape across the hosts.

  • is_source (bool | None) – optional bool denoting whether the caller is the source. Only ‘source host’ will contribute the data for the broadcast. If None, then host 0 is used.

Return type:

Any

Returns:

A pytree matching in_tree where the leaves now all contain the data from the first host.