- jax.lax.ppermute(x, axis_name, perm)#
Perform a collective permutation according to the permutation
xis a pytree then the result is equivalent to mapping this function to each leaf in the tree.
This function is an analog of the CollectivePermute HLO.
x – array(s) with a mapped axis named
axis_name – hashable Python object used to name a pmapped axis (see the
jax.pmap()documentation for more details).
perm – list of pairs of ints, representing
(source_index, destination_index)pairs that encode how the mapped axis named
axis_nameshould be shuffled. The integer values are treated as indices into the mapped axis
axis_name. Any two pairs should not have the same source index or the same destination index. For each index of the axis
axis_namethat does not correspond to a destination index in
perm, the corresponding values in the result are filled with zeros of the appropriate type.
Array(s) with the same shape as
xwith slices along the axis
xaccording to the permutation