jax.random module#

Utilities for pseudo-random number generation.

The jax.random package provides a number of routines for deterministic generation of sequences of pseudorandom numbers.

Basic usage#

>>> seed = 1701
>>> num_steps = 100
>>> key = jax.random.PRNGKey(seed)
>>> for i in range(num_steps):
...   key, subkey = jax.random.split(key)
...   params = compiled_update(subkey, params, next(batches))  

PRNG Keys#

Unlike the stateful pseudorandom number generators (PRNGs) that users of NumPy and SciPy may be accustomed to, JAX random functions all require an explicit PRNG state to be passed as a first argument. The random state is described by two unsigned 32-bit integers that we call a key, usually generated by the jax.random.PRNGKey() function:

>>> from jax import random
>>> key = random.PRNGKey(0)
>>> key
Array([0, 0], dtype=uint32)

This key can then be used in any of JAX’s random number generation routines:

>>> random.uniform(key)
Array(0.41845703, dtype=float32)

Note that using a key does not modify it, so reusing the same key will lead to the same result:

>>> random.uniform(key)
Array(0.41845703, dtype=float32)

If you need a new random number, you can use jax.random.split() to generate new subkeys:

>>> key, subkey = random.split(key)
>>> random.uniform(subkey)
Array(0.10536897, dtype=float32)

Advanced#

Design and Context#

TLDR: JAX PRNG = Threefry counter PRNG + a functional array-oriented splitting model

See docs/jep/263-prng.md for more details.

To summarize, among other requirements, the JAX PRNG aims to:

  1. ensure reproducibility,

  2. parallelize well, both in terms of vectorization (generating array values) and multi-replica, multi-core computation. In particular it should not use sequencing constraints between random function calls.

Advanced RNG configuration#

JAX provides several PRNG implementations (controlled by the jax_default_prng_impl flag).

  • default A counter-based PRNG built around the Threefry hash function.

  • experimental A PRNG that thinly wraps the XLA Random Bit Generator (RBG) algorithm. See TF doc.

    • ‚Äúrbg‚ÄĚ uses ThreeFry for splitting, and XLA RBG for data generation.

    • ‚Äúunsafe_rbg‚ÄĚ exists only for demonstration purposes, using RBG both for splitting (using an untested made up algorithm) and generating.

    The random streams generated by these experimental implementations haven’t been subject to any empirical randomness testing (e.g. Big Crush). The random bits generated may change between JAX versions.

The possible reasons not use the default RNG are:

  1. it may be slow to compile (specifically for Google Cloud TPUs)

  2. it’s slower to execute on TPUs

  3. it doesn’t support efficient automatic sharding / partitioning

Here is a short summary:

Property

Threefry

Threefry*

rbg

unsafe_rbg

rbg**

unsafe_rbg**

Fastest on TPU

‚úÖ

‚úÖ

‚úÖ

‚úÖ

efficiently shardable (w/ pjit)

‚úÖ

‚úÖ

‚úÖ

identical across shardings

‚úÖ

‚úÖ

‚úÖ

‚úÖ

identical across CPU/GPU/TPU

‚úÖ

‚úÖ

identical across JAX/XLA versions

‚úÖ

‚úÖ

(*): with jax_threefry_partitionable=1 set (**): with XLA_FLAGS=‚Äďxla_tpu_spmd_rng_bit_generator_unsafe=1 set

The difference between ‚Äúrbg‚ÄĚ and ‚Äúunsafe_rbg‚ÄĚ is that while ‚Äúrbg‚ÄĚ uses a less robust/studied hash function for random value generation (but not for jax.random.split or jax.random.fold_in), ‚Äúunsafe_rbg‚ÄĚ additionally uses less robust hash functions for jax.random.split and jax.random.fold_in. Therefore less safe in the sense that the quality of random streams it generates from different keys is less well understood.

For more about jax_threefry_partitionable, see https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers

List of Available Functions#

PRNGKey(seed)

Create a pseudo-random number generator (PRNG) key given an integer seed.

ball(key, d[, p, shape, dtype])

Sample uniformly from the unit Lp ball.

bernoulli(key[, p, shape])

Sample Bernoulli random values with given shape and mean.

beta(key, a, b[, shape, dtype])

Sample Beta random values with given shape and float dtype.

bits(key[, shape, dtype])

Sample uniform bits in the form of unsigned integers.

categorical(key, logits[, axis, shape])

Sample random values from categorical distributions.

cauchy(key[, shape, dtype])

Sample Cauchy random values with given shape and float dtype.

chisquare(key, df[, shape, dtype])

Sample Chisquare random values with given shape and float dtype.

choice(key, a[, shape, replace, p, axis])

Generates a random sample from a given array.

dirichlet(key, alpha[, shape, dtype])

Sample Dirichlet random values with given shape and float dtype.

double_sided_maxwell(key, loc, scale[, ...])

Sample from a double sided Maxwell distribution.

exponential(key[, shape, dtype])

Sample Exponential random values with given shape and float dtype.

f(key, dfnum, dfden[, shape, dtype])

Sample F-distribution random values with given shape and float dtype.

fold_in(key, data)

Folds in data to a PRNG key to form a new PRNG key.

gamma(key, a[, shape, dtype])

Sample Gamma random values with given shape and float dtype.

generalized_normal(key, p[, shape, dtype])

Sample from the generalized normal distribution.

geometric(key, p[, shape, dtype])

Sample Geometric random values with given shape and float dtype.

gumbel(key[, shape, dtype])

Sample Gumbel random values with given shape and float dtype.

laplace(key[, shape, dtype])

Sample Laplace random values with given shape and float dtype.

loggamma(key, a[, shape, dtype])

Sample log-gamma random values with given shape and float dtype.

logistic(key[, shape, dtype])

Sample logistic random values with given shape and float dtype.

maxwell(key[, shape, dtype])

Sample from a one sided Maxwell distribution.

multivariate_normal(key, mean, cov[, shape, ...])

Sample multivariate normal random values with given mean and covariance.

normal(key[, shape, dtype])

Sample standard normal random values with given shape and float dtype.

orthogonal(key, n[, shape, dtype])

Sample uniformly from the orthogonal group O(n).

pareto(key, b[, shape, dtype])

Sample Pareto random values with given shape and float dtype.

permutation(key, x[, axis, independent])

Returns a randomly permuted array or range.

poisson(key, lam[, shape, dtype])

Sample Poisson random values with given shape and integer dtype.

rademacher(key, shape[, dtype])

Sample from a Rademacher distribution.

randint(key, shape, minval, maxval[, dtype])

Sample uniform random values in [minval, maxval) with given shape/dtype.

rayleigh(key, scale[, shape, dtype])

Sample Rayleigh random values with given shape and float dtype.

shuffle(key, x[, axis])

Shuffle the elements of an array uniformly at random along an axis.

split(key[, num])

Splits a PRNG key into num new keys by adding a leading axis.

t(key, df[, shape, dtype])

Sample Student's t random values with given shape and float dtype.

truncated_normal(key, lower, upper[, shape, ...])

Sample truncated standard normal random values with given shape and dtype.

uniform(key[, shape, dtype, minval, maxval])

Sample uniform random values in [minval, maxval) with given shape/dtype.

wald(key, mean[, shape, dtype])

Sample Wald random values with given shape and float dtype.

weibull_min(key, scale, concentration[, ...])

Sample from a Weibull distribution.