jax.lax.all_gather#
- jax.lax.all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False)[source]#
Gather values of x across all replicas.
If
x
is a pytree then the result is equivalent to mapping this function to each leaf in the tree.This is equivalent to, but faster than, all_to_all(broadcast(x)).
- Parameters:
x – array(s) with a mapped axis named
axis_name
.axis_name – hashable Python object used to name a pmapped axis (see the
jax.pmap()
documentation for more details).axis_index_groups – optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would run all gather over the first two and last two replicas). Groups must cover all axis indices exactly once, and all groups must be the same size.
axis – a positional axis into which the chunks along
axis_name
will be concatenated.tiled – when
False
, the chunks will be stacked into a fresh positional axis at indexaxis
in the output. WhenTrue
,axis
has to refer to an existing positional dimension and the chunks will be concatenated into that dimension.
- Returns:
Array(s) representing the result of an all-gather along the axis
axis_name
. Shapes are the same asx.shape
, but:when
tiled
isFalse
, there is a new dimension equal to the size of axisaxis_name
in positionaxis
,when
tiled
isTrue
, the size of dimension in positionaxis
is multiplied by the size of axisaxis_name
.
For example, with 4 XLA devices available:
>>> x = np.arange(4) >>> y = jax.pmap(lambda x: jax.lax.all_gather(x, 'i'), axis_name='i')(x) >>> print(y) [[0 1 2 3] [0 1 2 3] [0 1 2 3] [0 1 2 3]]
An example of using axis_index_groups, groups split by even & odd device ids:
>>> x = np.arange(16).reshape(4, 4) >>> print(x) [[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10 11] [12 13 14 15]] >>> def f(x): ... return jax.lax.all_gather( ... x, 'i', axis_index_groups=[[0, 2], [3, 1]]) >>> y = jax.pmap(f, axis_name='i')(x) >>> print(y) [[[ 0 1 2 3] [ 8 9 10 11]] [[12 13 14 15] [ 4 5 6 7]] [[ 0 1 2 3] [ 8 9 10 11]] [[12 13 14 15] [ 4 5 6 7]]]