jax.lax.all_to_all(x, axis_name, split_axis, concat_axis)[source]ΒΆ

Materialize the mapped axis and map a different axis.

If x is a pytree then the result is equivalent to mapping this function to each leaf in the tree.

In the output, the input mapped axis axis_name is materialized at the logical axis position concat_axis, and the input unmapped axis at position split_axis is mapped with the name axis_name.

The input mapped axis size must be equal to the size of the axis to be mapped; that is, we must have lax.psum(1, axis_name) == x.shape[split_axis].

  • x – array(s) with a mapped axis named axis_name.

  • axis_name – hashable Python object used to name a pmapped axis (see the jax.pmap() documentation for more details).

  • split_axis – int indicating the unmapped axis of x to map with the name axis_name.

  • concat_axis – int indicating the position in the output to materialize the mapped axis of the input with the name axis_name.


Array(s) with shape given by the expression:

np.insert(np.delete(x.shape, split_axis), concat_axis, axis_size)

where axis_size is the size of the mapped axis named axis_name in the input x, i.e. axis_size = lax.psum(1, axis_name).