- jax.lax.pshuffle(x, axis_name, perm)#
Convenience wrapper of jax.lax.ppermute with alternate permutation encoding
xis a pytree then the result is equivalent to mapping this function to each leaf in the tree.
x – array(s) with a mapped axis named
axis_name – hashable Python object used to name a pmapped axis (see the
jax.pmap()documentation for more details).
perm – list of of ints encoding sources for the permutation to be applied to the axis named
axis_name, so that the output at axis index i comes from the input at axis index perm[i]. Every integer in [0, N) should be included exactly once for axis size N.
Array(s) with the same shape as
xwith slices along the axis
xaccording to the permutation