jax.random.shuffleΒΆ

jax.random.shuffle(key, x, axis=0)[source]ΒΆ

Shuffle the elements of an array uniformly at random along an axis.

Parameters
  • key (Union[Any, PRNGKeyArray]) – a PRNG key used as the random key.

  • x (Any) – the array to be shuffled.

  • axis (int) – optional, an int axis along which to shuffle (default 0).

Return type

ndarray

Returns

A shuffled version of x.