- jax.lax.pswapaxes(x, axis_name, axis, *, axis_index_groups=None)#
Swap the pmapped axis
axis_namewith the unmapped axis
xis a pytree then the result is equivalent to mapping this function to each leaf in the tree.
The group size of the mapped axis size must be equal to the size of the unmapped axis; that is, we must have
lax.psum(1, axis_name, axis_index_groups=axis_index_groups) == x.shape[axis]. By default, when
axis_index_groups=None, this encompasses all the devices.
This function is a special case of
all_to_allwhere the pmapped axis of the input is placed at the position
axisin the output. That is, it is equivalent to
all_to_all(x, axis_name, axis, axis).
x – array(s) with a mapped axis named
axis_name – hashable Python object used to name a pmapped axis (see the
jax.pmap()documentation for more details).
axis – int indicating the unmapped axis of
xto map with the name
axis_index_groups – optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would run pswapaxes over the first two and last two replicas). Groups must cover all axis indices exactly once, and all groups must be the same size.
Array(s) with the same shape as