pswapaxes(x, axis_name, axis)¶
Swap the pmapped axis
axis_namewith the unmapped axis
xis a pytree then the result is equivalent to mapping this function to each leaf in the tree.
The mapped axis size must be equal to the size of the unmapped axis; that is, we must have
lax.psum(1, axis_name) == x.shape[axis].
This function is a special case of
all_to_allwhere the pmapped axis of the input is placed at the position
axisin the output. That is, it is equivalent to
all_to_all(x, axis_name, axis, axis).
x – array(s) with a mapped axis named
axis_name – hashable Python object used to name a pmapped axis (see the
jax.pmap()documentation for more details).
axis – int indicating the unmapped axis of
xto map with the name
Array(s) with the same shape as