jax.random.split# jax.random.split(key, num=2)[source]# Splits a PRNG key into num new keys by adding a leading axis. Parameters: key (Array | ndarray | bool_ | number | bool | int | float | complex) – a PRNG key (from key, split, fold_in). num (int | tuple[int, ...]) – optional, a positive integer (or tuple of integers) indicating the number (or shape) of keys to produce. Defaults to 2. Returns: An array-like object of num new PRNG keys. Return type: Array