jax.lax.all_to_all#
- jax.lax.all_to_all(x, axis_name, split_axis, concat_axis, *, axis_index_groups=None, tiled=False)[source]#
Materialize the mapped axis and map a different axis.
If
x
is a pytree then the result is equivalent to mapping this function to each leaf in the tree.In the output, the input mapped axis
axis_name
is materialized at the logical axis positionconcat_axis
, and the input unmapped axis at positionsplit_axis
is mapped with the nameaxis_name
.The group size of the mapped axis size must be equal to the size of the unmapped axis; that is, we must have
lax.psum(1, axis_name, axis_index_groups=axis_index_groups) == x.shape[axis]
. By default, whenaxis_index_groups=None
, this encompasses all the devices.- Parameters:
x – array(s) with a mapped axis named
axis_name
.axis_name – hashable Python object used to name a pmapped axis (see the
jax.pmap()
documentation for more details).split_axis – int indicating the unmapped axis of
x
to map with the nameaxis_name
.concat_axis – int indicating the position in the output to materialize the mapped axis of the input with the name
axis_name
.axis_index_groups – optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would run all_to_all over the first two and last two replicas). Groups must cover all axis indices exactly once, and all groups must be the same size.
tiled – when True, all_to_all will divide split_axis into chunks and concatenate them along concat_axis. In particular, no dimensions are added or removed. False by default.
- Returns:
When tiled is False, array(s) with shape given by the expression:
np.insert(np.delete(x.shape, split_axis), concat_axis, axis_size)
where
axis_size
is the size of the mapped axis namedaxis_name
in the inputx
, i.e.axis_size = lax.psum(1, axis_name)
.Otherwise array with shape similar to the input shape, except with split_axis divided by axis size and concat_axis multiplied by axis size.