jax.random package¶

Utilities for pseudo-random number generation.

The jax.random package provides a number of routines for deterministic generation of sequences of pseudorandom numbers.

Basic usage¶

>>> seed = 1701
>>> num_steps = 100
>>> key = jax.random.PRNGKey(seed)
>>> for i in range(num_steps):
...   key, subkey = jax.random.split(key)
...   params = compiled_update(subkey, params, next(batches))  

PRNG Keys¶

Unlike the stateful pseudorandom number generators (PRNGs) that users of NumPy and SciPy may be accustomed to, JAX random functions all require an explicit PRNG state to be passed as a first argument. The random state is described by two unsigned 32-bit integers that we call a key, usually generated by the jax.random.PRNGKey() function:

>>> from jax import random
>>> key = random.PRNGKey(0)
>>> key
DeviceArray([0, 0], dtype=uint32)

This key can then be used in any of JAX’s random number generation routines:

>>> random.uniform(key)
DeviceArray(0.41845703, dtype=float32)

Note that using a key does not modify it, so reusing the same key will lead to the same result:

>>> random.uniform(key)
DeviceArray(0.41845703, dtype=float32)

If you need a new random number, you can use jax.random.split() to generate new subkeys:

>>> key, subkey = random.split(key)
>>> random.uniform(subkey)
DeviceArray(0.10536897, dtype=float32)

Design and Context¶

Among other requirements, the JAX PRNG aims to:

  1. ensure reproducibility,

  2. parallelize well, both in terms of vectorization (generating array values) and multi-replica, multi-core computation. In particular it should not use sequencing constraints between random function calls.

The approach is based on:

  1. ‚ÄúParallel random numbers: as easy as 1, 2, 3‚ÄĚ (Salmon et al. 2011)

  2. ‚ÄúSplittable pseudorandom number generators using cryptographic hashing‚ÄĚ (Claessen et al. 2013)

See also https://github.com/google/jax/blob/main/design_notes/prng.md for the design and its motivation.

List of Available Functions¶

PRNGKey(seed)

Create a pseudo-random number generator (PRNG) key given an integer seed.

bernoulli(key[, p, shape])

Sample Bernoulli random values with given shape and mean.

beta(key, a, b[, shape, dtype])

Sample Beta random values with given shape and float dtype.

categorical(key, logits[, axis, shape])

Sample random values from categorical distributions.

cauchy(key[, shape, dtype])

Sample Cauchy random values with given shape and float dtype.

choice(key, a[, shape, replace, p])

Generates a random sample from a given 1-D array.

dirichlet(key, alpha[, shape, dtype])

Sample Dirichlet random values with given shape and float dtype.

double_sided_maxwell(key, loc, scale[, …])

Sample from a double sided Maxwell distribution.

exponential(key[, shape, dtype])

Sample Exponential random values with given shape and float dtype.

fold_in(key, data)

Folds in data to a PRNG key to form a new PRNG key.

gamma(key, a[, shape, dtype])

Sample Gamma random values with given shape and float dtype.

gumbel(key[, shape, dtype])

Sample Gumbel random values with given shape and float dtype.

laplace(key[, shape, dtype])

Sample Laplace random values with given shape and float dtype.

logistic(key[, shape, dtype])

Sample logistic random values with given shape and float dtype.

maxwell(key[, shape, dtype])

Sample from a one sided Maxwell distribution.

multivariate_normal(key, mean, cov[, shape, …])

Sample multivariate normal random values with given mean and covariance.

normal(key[, shape, dtype])

Sample standard normal random values with given shape and float dtype.

pareto(key, b[, shape, dtype])

Sample Pareto random values with given shape and float dtype.

permutation(key, x)

Permute elements of an array along its first axis or return a permuted range.

poisson(key, lam[, shape, dtype])

Sample Poisson random values with given shape and integer dtype.

rademacher(key, shape[, dtype])

Sample from a Rademacher distribution.

randint(key, shape, minval, maxval[, dtype])

Sample uniform random values in [minval, maxval) with given shape/dtype.

shuffle(key, x[, axis])

Shuffle the elements of an array uniformly at random along an axis.

split(key[, num])

Splits a PRNG key into num new keys by adding a leading axis.

t(key, df[, shape, dtype])

Sample Student’s t random values with given shape and float dtype.

truncated_normal(key, lower, upper[, shape, …])

Sample truncated standard normal random values with given shape and dtype.

uniform(key[, shape, dtype, minval, maxval])

Sample uniform random values in [minval, maxval) with given shape/dtype.

weibull_min(key, scale, concentration[, …])

Sample from a Weibull distribution.