jax.random.split

Contents

jax.random.split#

jax.random.split(key, num=2)[source]#

Splits a PRNG key into num new keys by adding a leading axis.

Parameters:
  • key (KeyArrayLike) – a PRNG key (from key, split, fold_in).

  • num (int | tuple[int, …]) – optional, a positive integer (or tuple of integers) indicating the number (or shape) of keys to produce. Defaults to 2.

Return type:

KeyArray

Returns:

An array-like object of num new PRNG keys.