jax.random.split#
- jax.random.split(key, num=2)[source]#
Splits a PRNG key into num new keys by adding a leading axis.
- Parameters:
key (KeyArrayLike) – a PRNG key (from
key
,split
,fold_in
).num (int | tuple[int, …]) – optional, a positive integer (or tuple of integers) indicating the number (or shape) of keys to produce. Defaults to 2.
- Return type:
KeyArray
- Returns:
An array-like object of num new PRNG keys.