jax.lax.pshuffle#
- jax.lax.pshuffle(x, axis_name, perm)[source]#
Convenience wrapper of jax.lax.ppermute with alternate permutation encoding
If
x
is a pytree then the result is equivalent to mapping this function to each leaf in the tree.- Parameters:
x – array(s) with a mapped axis named
axis_name
.axis_name – hashable Python object used to name a pmapped axis (see the
jax.pmap()
documentation for more details).perm – list of ints encoding sources for the permutation to be applied to the axis named
axis_name
, so that the output at axis index i comes from the input at axis index perm[i]. Every integer in [0, N) should be included exactly once for axis size N.
- Returns:
Array(s) with the same shape as
x
with slices along the axisaxis_name
gathered fromx
according to the permutationperm
.