jax.random.split#

jax.random.split(key, num=2)[source]#

Splits a PRNG key into num new keys by adding a leading axis.

Parameters:
Return type:

Array

Returns:

An array-like object of num new PRNG keys.