jax.random package

JAX pseudo-random number generators (PRNGs).

The JAX PRNG system is based on “Parallel random numbers: as easy as 1, 2, 3” (Salmon et al. 2011). For details on the design and its motivation, see:

https://github.com/google/jax/blob/master/design_notes/prng.md

jax.random.PRNGKey(seed)[source]

Create a pseudo-random number generator (PRNG) key given an integer seed.

Parameters:seed – a 64- or 32-bit integer used as the value of the key.
Returns:A PRNG key, which is modeled as an array of shape (2,) and dtype uint32. The key is constructed from a 64-bit seed by effectively bit-casting to a pair of uint32 values (or from a 32-bit seed by first padding out with zeros).
jax.random.bernoulli(key, p=0.5, shape=())[source]

Sample Bernoulli random values with given shape and mean.

Parameters:
  • key – a PRNGKey used as the random key.
  • p – optional, an array-like of floating dtype broadcastable to shape for the mean of the random variables (default 0.5).
  • shape – optional, a tuple of nonnegative integers representing the shape (default scalar).
Returns:

A random array with the specified shape and boolean dtype.

jax.random.beta(key, a, b, shape=(), dtype=<class 'numpy.float64'>)[source]

Sample Bernoulli random values with given shape and mean.

Parameters:
  • key – a PRNGKey used as the random key.
  • a – an array-like broadcastable to shape and used as the shape parameter alpha of the random variables.
  • b – an array-like broadcastable to shape and used as the shape parameter beta of the random variables.
  • shape – optional, a tuple of nonnegative integers representing the shape (default scalar).
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified shape and dtype.

jax.random.cauchy(key, shape=(), dtype=<class 'numpy.float64'>)[source]

Sample Cauchy random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • shape – optional, a tuple of nonnegative integers representing the shape (default scalar).
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified shape and dtype.

jax.random.dirichlet(key, alpha, shape=(), dtype=<class 'numpy.float64'>)[source]

Sample Cauchy random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • alpha – an array-like with alpha.shape[:-1] broadcastable to shape and used as the concentration parameter of the random variables.
  • shape – optional, a tuple of nonnegative integers representing the batch shape (defaults to alpha.shape[:-1]).
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified shape and dtype.

jax.random.exponential(key, shape=(), dtype=<class 'numpy.float64'>)[source]

Sample Exponential random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • shape – optional, a tuple of nonnegative integers representing the shape (default scalar).
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified shape and dtype.

jax.random.fold_in(key, data)[source]

Folds in data to a PRNG key to form a new PRNG key.

Parameters:
  • key – a PRNGKey (an array with shape (2,) and dtype uint32).
  • data – a 32bit integer representing data to be folded in to the key.
Returns:

A new PRNGKey that is a deterministic function of the inputs and is statistically safe for producing a stream of new pseudo-random values.

jax.random.gamma(key, a, shape=(), dtype=<class 'numpy.float64'>)[source]

Sample Gamma random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • a – an array-like broadcastable to shape and used as the shape parameter of the random variables.
  • shape – optional, a tuple of nonnegative integers representing the shape (default scalar).
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified shape and dtype.

jax.random.gumbel(key, shape=(), dtype=<class 'numpy.float64'>)[source]

Sample Gumbel random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • shape – optional, a tuple of nonnegative integers representing the shape (default scalar).
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified shape and dtype.

jax.random.laplace(key, shape=(), dtype=<class 'numpy.float64'>)[source]

Sample Laplace random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • shape – optional, a tuple of nonnegative integers representing the shape (default scalar).
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified shape and dtype.

jax.random.logistic(key, shape=(), dtype=<class 'numpy.float64'>)[source]

Sample logistic random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • shape – optional, a tuple of nonnegative integers representing the shape (default scalar).
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified shape and dtype.

jax.random.normal(key, shape=(), dtype=<class 'numpy.float64'>)[source]

Sample standard normal random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • shape – a tuple of nonnegative integers representing the shape.
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified shape and dtype.

jax.random.pareto(key, b, shape=(), dtype=<class 'numpy.float64'>)[source]

Sample Pareto random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • b – an array-like broadcastable to shape and used as the shape parameter of the random variables.
  • shape – optional, a tuple of nonnegative integers representing the shape (default scalar).
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified shape and dtype.

jax.random.randint(key, shape, minval, maxval, dtype=<class 'numpy.int64'>)[source]

Sample uniform random values in [minval, maxval) with given shape/dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • shape – a tuple of nonnegative integers representing the shape.
  • minval – int or array of ints broadcast-compatible with shape, a minimum (inclusive) value for the range.
  • maxval – int or array of ints broadcast-compatible with shape, a maximum (exclusive) value for the range.
  • dtype – optional, an int dtype for the returned values (default int64 if jax_enable_x64 is true, otherwise int32).
Returns:

A random array with the specified shape and dtype.

jax.random.shuffle(key, x, axis=0)[source]

Shuffle the elements of an array uniformly at random along an axis.

Parameters:
  • key – a PRNGKey used as the random key.
  • x – the array to be shuffled.
  • axis – optional, an int axis along which to shuffle (default 0).
Returns:

A shuffled version of x.

jax.random.split(key, num=2)[source]

Splits a PRNG key into num new keys by adding a leading axis.

Parameters:
  • key – a PRNGKey (an array with shape (2,) and dtype uint32).
  • num – optional, a positive integer indicating the number of keys to produce (default 2).
Returns:

An array with shape (num, 2) and dtype uint32 representing num new keys.

jax.random.t(key, df, shape=(), dtype=<class 'numpy.float64'>)[source]

Sample Student’s t random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • df – an array-like broadcastable to shape and used as the shape parameter of the random variables.
  • shape – optional, a tuple of nonnegative integers representing the shape (default scalar).
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified shape and dtype.

jax.random.threefry_2x32(keypair, count)[source]

Apply the Threefry 2x32 hash.

Parameters:
  • keypair – a pair of 32bit unsigned integers used for the key.
  • count – an array of dtype uint32 used for the counts.
Returns:

An array of dtype uint32 with the same shape as count.

jax.random.truncated_normal(key, lower, upper, shape=(), dtype=<class 'numpy.float64'>)[source]

Sample truncated standard normal random values with given shape and dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • lower – a floating-point lower bound for truncation.
  • upper – a floating-point upper bound for truncation.
  • shape – a tuple of nonnegative integers representing the shape.
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified shape and dtype.

jax.random.uniform(key, shape=(), dtype=<class 'numpy.float64'>, minval=0.0, maxval=1.0)[source]

Sample uniform random values in [minval, maxval) with given shape/dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • shape – a tuple of nonnegative integers representing the shape.
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
  • minval – optional, a minimum (inclusive) value for the range (default 0).
  • maxval – optional, a maximum (exclusive) value for the range (default 1).
Returns:

A random array with the specified shape and dtype.