- jax.lax.all_to_all(x, axis_name, split_axis, concat_axis, *, axis_index_groups=None, tiled=False)#
Materialize the mapped axis and map a different axis.
xis a pytree then the result is equivalent to mapping this function to each leaf in the tree.
In the output, the input mapped axis
axis_nameis materialized at the logical axis position
concat_axis, and the input unmapped axis at position
split_axisis mapped with the name
The group size of the mapped axis size must be equal to the size of the unmapped axis; that is, we must have
lax.psum(1, axis_name, axis_index_groups=axis_index_groups) == x.shape[axis]. By default, when
axis_index_groups=None, this encompasses all the devices.
x – array(s) with a mapped axis named
axis_name – hashable Python object used to name a pmapped axis (see the
jax.pmap()documentation for more details).
split_axis – int indicating the unmapped axis of
xto map with the name
concat_axis – int indicating the position in the output to materialize the mapped axis of the input with the name
axis_index_groups – optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would run all_to_all over the first two and last two replicas). Groups must cover all axis indices exactly once, and all groups must be the same size.
tiled – when True, all_to_all will divide split_axis into chunks and concatenate them along concat_axis. In particular, no dimensions are added or removed. False by default.
When tiled is False, array(s) with shape given by the expression:
np.insert(np.delete(x.shape, split_axis), concat_axis, axis_size)
axis_sizeis the size of the mapped axis named
axis_namein the input
axis_size = lax.psum(1, axis_name).
Otherwise array with shape similar to the input shape, except with split_axis divided by axis size and concat_axis multiplied by axis size.