jax.experimental.multihost_utils.process_allgather#
- jax.experimental.multihost_utils.process_allgather(in_tree, tiled=False)[source]#
Gather data from across processes.
- Parameters:
- Return type:
- Returns:
- Pytrees of numpy arrays.
If the input is a non-fully addressable jax.Array, then the data is fully replicated.
If the input is numpy array or fully addressable jax.Array, then the output shape is dependent on the tiled argument. If its False, then the output will be stacked else concatenated.
If the input is a scalar, then the output will be stacked.