jax.random.permutation

jax.random.permutation(key, x)[source]

Permute elements of an array along its first axis or return a permuted range.

If x is a multi-dimensional array, it is only shuffled along its first index.

Args:n

key: a PRNGKey used as the random key. x: the array or integer range to be shuffled.

Return type

ndarray

Returns

A shuffled version of x or array range

Parameters
  • key (ndarray) –

  • x (Any) –