jax.random.shuffleΒΆ

jax.random.shuffle(key, x, axis=0)[source]ΒΆ

Shuffle the elements of an array uniformly at random along an axis.

Parameters
  • key (ndarray) – a PRNGKey used as the random key.

  • x (ndarray) – the array to be shuffled.

  • axis (int) – optional, an int axis along which to shuffle (default 0).

Return type

ndarray

Returns

A shuffled version of x.