jax.random.split# jax.random.split(key, num=2)[source]# Splits a PRNG key into num new keys by adding a leading axis. Parameters: key (Union[Array, ndarray, bool_, number, bool, int, float, complex]) – a PRNG key (from PRNGKey, split, fold_in). num (Union[int, tuple[int, ...]]) – optional, a positive integer (or tuple of integers) indicating the number (or shape) of keys to produce. Defaults to 2. Return type: Array Returns: An array-like object of num new PRNG keys.