jax.random.shuffle#

jax.random.shuffle(key, x, axis=0)[source]#

Shuffle the elements of an array uniformly at random along an axis.

Parameters
Return type

Array

Returns

A shuffled version of x.