jax.experimental.multihost_utils.process_allgather

jax.experimental.multihost_utils.process_allgather#

jax.experimental.multihost_utils.process_allgather(in_tree, tiled=False)[source]#

Gather data from across processes.

Parameters:
  • in_tree (Any) – pytree of arrays - each array _must_ have the same shape across the hosts.

  • tiled (bool) – Whether to stack or concat the output. Defaults to False i.e. stack into a new positional axis at index 0.

Return type:

Any

Returns:

Pytrees of numpy arrays.
  • If the input is a non-fully addressable jax.Array, then the data is fully replicated.

  • If the input is numpy array or fully addressable jax.Array, then the output shape is dependent on the tiled argument. If its False, then the output will be stacked else concatenated.

  • If the input is a scalar, then the output will be stacked.