jax.random packageΒΆ

JAX pseudo-random number generators (PRNGs).

The JAX PRNG system is based on β€œParallel random numbers: as easy as 1, 2, 3” (Salmon et al. 2011). For details on the design and its motivation, see:

https://github.com/google/jax/blob/master/design_notes/prng.md

jax.random.PRNGKey(seed)[source]ΒΆ

Create a pseudo-random number generator (PRNG) key given an integer seed.

Parameters:seed – a 64- or 32-bit integer used as the value of the key.
Returns:A PRNG key, which is modeled as an array of shape (2,) and dtype uint32. The key is constructed from a 64-bit seed by effectively bit-casting to a pair of uint32 values (or from a 32-bit seed by first padding out with zeros).
jax.random.bernoulli(key, p=0.5, shape=None)[source]ΒΆ

Sample Bernoulli random values with given shape and mean.

Parameters:
  • key – a PRNGKey used as the random key.
  • p – optional, a float or array of floats for the mean of the random variables. Must be broadcast-compatible with shape. Default 0.5.
  • shape – optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible with p.shape. The default (None) produces a result shape equal to p.shape.
Returns:

A random array with boolean dtype and shape given by shape if shape is not None, or else p.shape.

jax.random.beta(key, a, b, shape=None, dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample Bernoulli random values with given shape and mean.

Parameters:
  • key – a PRNGKey used as the random key.
  • a – a float or array of floats broadcast-compatible with shape representing the first parameter β€œalpha”.
  • b – a float or array of floats broadcast-compatible with shape representing the second parameter β€œbeta”.
  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with a and b. The default (None) produces a result shape by broadcasting a and b.
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified dtype and shape given by shape if shape is not None, or else by broadcasting a and b.

jax.random.cauchy(key, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample Cauchy random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified shape and dtype.

jax.random.dirichlet(key, alpha, shape=None, dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample Cauchy random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • alpha – an array of shape (..., n) used as the concentration parameter of the random variables.
  • shape – optional, a tuple of nonnegative integers specifying the result batch shape; that is, the prefix of the result shape excluding the last element of value n. Must be broadcast-compatible with alpha.shape[:-1]. The default (None) produces a result shape equal to alpha.shape.
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified dtype and shape given by shape + (alpha.shape[-1],) if shape is not None, or else alpha.shape.

jax.random.exponential(key, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample Exponential random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified shape and dtype.

jax.random.fold_in(key, data)[source]ΒΆ

Folds in data to a PRNG key to form a new PRNG key.

Parameters:
  • key – a PRNGKey (an array with shape (2,) and dtype uint32).
  • data – a 32bit integer representing data to be folded in to the key.
Returns:

A new PRNGKey that is a deterministic function of the inputs and is statistically safe for producing a stream of new pseudo-random values.

jax.random.gamma(key, a, shape=None, dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample Gamma random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • a – a float or array of floats broadcast-compatible with shape representing the parameter of the distribution.
  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with a. The default (None) produces a result shape equal to a.shape.
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified dtype and with shape given by shape if shape is not None, or else by a.shape.

jax.random.gumbel(key, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample Gumbel random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified shape and dtype.

jax.random.laplace(key, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample Laplace random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified shape and dtype.

jax.random.logistic(key, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample logistic random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified shape and dtype.

jax.random.multivariate_normal(key, mean, cov, shape=None, dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample multivariate normal random values with given mean and covariance.

Parameters:
  • key – a PRNGKey used as the random key.
  • mean – a mean vector of shape (..., n).
  • cov – a positive definite covariance matrix of shape (..., n, n). The batch shape ... must be broadcast-compatible with that of mean.
  • shape – optional, a tuple of nonnegative integers specifying the result batch shape; that is, the prefix of the result shape excluding the last axis. Must be broadcast-compatible with mean.shape[:-1] and cov.shape[:-2]. The default (None) produces a result batch shape by broadcasting together the batch shapes of mean and cov.
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified dtype and shape given by shape + mean.shape[-1:] if shape is not None, or else broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:].

jax.random.normal(key, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample standard normal random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified shape and dtype.

jax.random.pareto(key, b, shape=None, dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample Pareto random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • a – a float or array of floats broadcast-compatible with shape representing the parameter of the distribution.
  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with b. The default (None) produces a result shape equal to b.shape.
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified dtype and with shape given by shape if shape is not None, or else by b.shape.

jax.random.randint(key, shape, minval, maxval, dtype=<class 'numpy.int64'>)[source]ΒΆ

Sample uniform random values in [minval, maxval) with given shape/dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • shape – a tuple of nonnegative integers representing the shape.
  • minval – int or array of ints broadcast-compatible with shape, a minimum (inclusive) value for the range.
  • maxval – int or array of ints broadcast-compatible with shape, a maximum (exclusive) value for the range.
  • dtype – optional, an int dtype for the returned values (default int64 if jax_enable_x64 is true, otherwise int32).
Returns:

A random array with the specified shape and dtype.

jax.random.shuffle(key, x, axis=0)[source]ΒΆ

Shuffle the elements of an array uniformly at random along an axis.

Parameters:
  • key – a PRNGKey used as the random key.
  • x – the array to be shuffled.
  • axis – optional, an int axis along which to shuffle (default 0).
Returns:

A shuffled version of x.

jax.random.split(key, num=2)[source]ΒΆ

Splits a PRNG key into num new keys by adding a leading axis.

Parameters:
  • key – a PRNGKey (an array with shape (2,) and dtype uint32).
  • num – optional, a positive integer indicating the number of keys to produce (default 2).
Returns:

An array with shape (num, 2) and dtype uint32 representing num new keys.

jax.random.t(key, df, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample Student’s t random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • df – a float or array of floats broadcast-compatible with shape representing the parameter of the distribution.
  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with df. The default (None) produces a result shape equal to df.shape.
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified dtype and with shape given by shape if shape is not None, or else by df.shape.

jax.random.threefry_2x32(keypair, count)[source]ΒΆ

Apply the Threefry 2x32 hash.

Parameters:
  • keypair – a pair of 32bit unsigned integers used for the key.
  • count – an array of dtype uint32 used for the counts.
Returns:

An array of dtype uint32 with the same shape as count.

jax.random.truncated_normal(key, lower, upper, shape=None, dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample truncated standard normal random values with given shape and dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • lower – a float or array of floats representing the lower bound for truncation. Must be broadcast-compatible with upper.
  • upper – a float or array of floats representing the upper bound for truncation. Must be broadcast-compatible with lower.
  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with lower and upper. The default (None) produces a result shape by broadcasting lower and upper.
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified dtype and shape given by shape if shape is not None, or else by broadcasting lower and upper.

jax.random.uniform(key, shape=(), dtype=<class 'numpy.float64'>, minval=0.0, maxval=1.0)[source]ΒΆ

Sample uniform random values in [minval, maxval) with given shape/dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
  • minval – optional, a minimum (inclusive) value for the range (default 0).
  • maxval – optional, a maximum (exclusive) value for the range (default 1).
Returns:

A random array with the specified shape and dtype.