jax.lax.all_gather#

jax.lax.all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False)[source]#

Gather values of x across all replicas.

If x is a pytree then the result is equivalent to mapping this function to each leaf in the tree.

This is equivalent to, but faster than, all_to_all(broadcast(x)).

Parameters
  • x – array(s) with a mapped axis named axis_name.

  • axis_name – hashable Python object used to name a pmapped axis (see the jax.pmap() documentation for more details).

  • axis_index_groups – optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would run all gather over the first two and last two replicas). Groups must cover all axis indices exactly once, and all groups must be the same size.

  • axis – a positional axis into which the chunks along axis_name will be concatenated.

  • tiled – when False, the chunks will be stacked into a fresh positional axis at index axis in the output. When True, axis has to refer to an existing positional dimension and the chunks will be concatenated into that dimension.

Returns

Array(s) representing the result of an all-gather along the axis axis_name. Shapes are the same as x.shape, but:

  • when tiled is False, there is a new dimension equal to the size of axis axis_name in position axis,

  • when tiled is True, the size of dimension in position axis is multiplied by the size of axis axis_name.

For example, with 4 XLA devices available:

>>> x = np.arange(4)
>>> y = jax.pmap(lambda x: jax.lax.all_gather(x, 'i'), axis_name='i')(x)
>>> print(y)
[[0 1 2 3]
 [0 1 2 3]
 [0 1 2 3]
 [0 1 2 3]]

An example of using axis_index_groups, groups split by even & odd device ids:

>>> x = np.arange(16).reshape(4, 4)
>>> print(x)
  [[ 0  1  2  3]
   [ 4  5  6  7]
   [ 8  9 10 11]
   [12 13 14 15]]
>>> def f(x):
...   return jax.lax.all_gather(
...       x, 'i', axis_index_groups=[[0, 2], [3, 1]])
>>> y = jax.pmap(f, axis_name='i')(x)
>>> print(y)
[[[ 0  1  2  3]
  [ 8  9 10 11]]
 [[12 13 14 15]
  [ 4  5  6  7]]
 [[ 0  1  2  3]
  [ 8  9 10 11]]
 [[12 13 14 15]
  [ 4  5  6  7]]]