jax.random.split

jax.random.split(key, num=2)[source]

Splits a PRNG key into num new keys by adding a leading axis.

Parameters
  • key (ndarray) – a PRNGKey (an array with shape (2,) and dtype uint32).

  • num (int) – optional, a positive integer indicating the number of keys to produce (default 2).

Return type

ndarray

Returns

An array with shape (num, 2) and dtype uint32 representing num new keys.