jax.lax.pswapaxesΒΆ

jax.lax.pswapaxes(x, axis_name, axis)[source]ΒΆ

Swap the pmapped axis axis_name with the unmapped axis axis.

If x is a pytree then the result is equivalent to mapping this function to each leaf in the tree.

The mapped axis size must be equal to the size of the unmapped axis; that is, we must have lax.psum(1, axis_name) == x.shape[axis].

This function is a special case of all_to_all where the pmapped axis of the input is placed at the position axis in the output. That is, it is equivalent to all_to_all(x, axis_name, axis, axis).

Parameters
  • x – array(s) with a mapped axis named axis_name.

  • axis_name – hashable Python object used to name a pmapped axis (see the jax.pmap() documentation for more details).

  • axis – int indicating the unmapped axis of x to map with the name axis_name.

Returns

Array(s) with the same shape as x.