jax.random.shuffle# jax.random.shuffle(key, x, axis=0)[source]# Shuffle the elements of an array uniformly at random along an axis. Parameters key (Union[Array, PRNGKeyArray]) – a PRNG key used as the random key. x (Union[Array, ndarray, bool_, number, bool, int, float, complex]) – the array to be shuffled. axis (int) – optional, an int axis along which to shuffle (default 0). Return type Array Returns A shuffled version of x.