jax.random.clone

Contents

jax.random.clone#

jax.random.clone(key)[source]#

Clone a key for reuse

Outside the context of key reuse checking (see jax.experimental.key_reuse) this function operates as an identity.

Example

>>> import jax
>>> key = jax.random.key(0)
>>> data = jax.random.uniform(key)
>>> cloned_key = jax.random.clone(key)
>>> same_data = jax.random.uniform(cloned_key)
>>> assert data == same_data