jax.experimental.key_reuse module

jax.experimental.key_reuse module#

Experimental Key Reuse Checking#

This module contains experimental functionality for detecting re-use of random keys within JAX programs. It is under active development and the APIs here are likely to change.

Key reuse checking can be enabled on jit-compiled functions using the jax.enable_key_reuse_checks() configuration:

>>> import jax
>>> @jax.jit
... def f(key):
...   return jax.random.uniform(key) + jax.random.normal(key)
...
>>> key = jax.random.key(0)
>>> with jax.enable_key_reuse_checks():
...   f(key)  
Traceback (most recent call last):
 ...
KeyReuseError: In random_bits, key values a are already consumed.

This flag can also be set globally if you wish to enagle key reuse checks in every JIT-compiled function.

API#

reuse_key(key)

Explicitly mark a key as unconsumed.

KeyReuseError