jax.lax.pshuffle#

jax.lax.pshuffle(x, axis_name, perm)[source]#

Convenience wrapper of jax.lax.ppermute with alternate permutation encoding

If x is a pytree then the result is equivalent to mapping this function to each leaf in the tree.

Parameters:
  • x – array(s) with a mapped axis named axis_name.

  • axis_name – hashable Python object used to name a pmapped axis (see the jax.pmap() documentation for more details).

  • perm – list of ints encoding sources for the permutation to be applied to the axis named axis_name, so that the output at axis index i comes from the input at axis index perm[i]. Every integer in [0, N) should be included exactly once for axis size N.

Returns:

Array(s) with the same shape as x with slices along the axis axis_name gathered from x according to the permutation perm.