jax.random packageΒΆ

JAX pseudo-random number generators (PRNGs).

Example usage:

>>> rng = jax.random.PRNGKey(seed)
>>> for i in range(num_steps):
...   rng, rng_input = jax.random.split(rng)
...   params = compiled_update(rng_input, params, next(batches))

Context:

Among other requirements, the JAX PRNG aims to: (a) ensure reproducibility, (b) parallelize well, both in terms of vectorization (generating array values) and multi-replica, multi-core computation. In particular it should not use sequencing constraints between random function calls.

The approach is based on: 1. β€œParallel random numbers: as easy as 1, 2, 3” (Salmon et al. 2011) 2. β€œSplittable pseudorandom number generators using cryptographic hashing” (Claessen et al. 2013)

See also https://github.com/google/jax/blob/master/design_notes/prng.md for the design and its motivation.

jax.random.PRNGKey(seed)[source]ΒΆ

Create a pseudo-random number generator (PRNG) key given an integer seed.

Parameters

seed (int) – a 64- or 32-bit integer used as the value of the key.

Return type

ndarray

Returns

A PRNG key, which is modeled as an array of shape (2,) and dtype uint32. The key is constructed from a 64-bit seed by effectively bit-casting to a pair of uint32 values (or from a 32-bit seed by first padding out with zeros).

jax.random.apply_round(v, rot)[source]ΒΆ
jax.random.bernoulli(key, p=0.5, shape=None)[source]ΒΆ

Sample Bernoulli random values with given shape and mean.

Parameters
  • key (ndarray) – a PRNGKey used as the random key.

  • p (ndarray) – optional, a float or array of floats for the mean of the random variables. Must be broadcast-compatible with shape. Default 0.5.

  • shape (Optional[Sequence[int]]) – optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible with p.shape. The default (None) produces a result shape equal to p.shape.

Return type

ndarray

Returns

A random array with boolean dtype and shape given by shape if shape is not None, or else p.shape.

jax.random.beta(key, a, b, shape=None, dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample Beta random values with given shape and float dtype.

Parameters
  • key (ndarray) – a PRNGKey used as the random key.

  • a (Union[float, ndarray]) – a float or array of floats broadcast-compatible with shape representing the first parameter β€œalpha”.

  • b (Union[float, ndarray]) – a float or array of floats broadcast-compatible with shape representing the second parameter β€œbeta”.

  • shape (Optional[Sequence[int]]) – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with a and b. The default (None) produces a result shape by broadcasting a and b.

  • dtype (dtype) – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Return type

ndarray

Returns

A random array with the specified dtype and shape given by shape if shape is not None, or else by broadcasting a and b.

jax.random.categorical(key, logits, axis=-1, shape=None)[source]ΒΆ

Sample random values from categorical distributions.

Parameters
  • key – a PRNGKey used as the random key.

  • logits – Unnormalized log probabilities of the categorical distribution(s) to sample from, so that softmax(logits, axis) gives the corresponding probabilities.

  • axis – Axis along which logits belong to the same categorical distribution.

  • shape – Optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible with np.delete(logits.shape, axis). The default (None) produces a result shape equal to np.delete(logits.shape, axis).

Returns

A random array with int dtype and shape given by shape if shape is not None, or else np.delete(logits.shape, axis).

jax.random.cauchy(key, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample Cauchy random values with given shape and float dtype.

Parameters
  • key – a PRNGKey used as the random key.

  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns

A random array with the specified shape and dtype.

jax.random.choice(key, a, shape=(), replace=True, p=None)[source]ΒΆ

Generates a random sample from a given 1-D array.

Parameters
  • key – a PRNGKey used as the random key.

  • a – 1D array or int. If an ndarray, a random sample is generated from its elements. If an int, the random sample is generated as if a were arange(a).

  • shape – tuple of ints, optional. Output shape. If the given shape is, e.g., (m, n), then m * n samples are drawn. Default is (), in which case a single value is returned.

  • replace – boolean. Whether the sample is with or without replacement. default is True.

  • p – 1-D array-like, The probabilities associated with each entry in a. If not given the sample assumes a uniform distribution over all entries in a.

Returns

An array of shape shape containing samples from a.

jax.random.dirichlet(key, alpha, shape=None, dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample Dirichlet random values with given shape and float dtype.

Parameters
  • key – a PRNGKey used as the random key.

  • alpha – an array of shape (..., n) used as the concentration parameter of the random variables.

  • shape – optional, a tuple of nonnegative integers specifying the result batch shape; that is, the prefix of the result shape excluding the last element of value n. Must be broadcast-compatible with alpha.shape[:-1]. The default (None) produces a result shape equal to alpha.shape.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns

A random array with the specified dtype and shape given by shape + (alpha.shape[-1],) if shape is not None, or else alpha.shape.

jax.random.double_sided_maxwell(key, loc, scale, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample from a double sided Maxwell distribution.

Samples using:

loc + scale* sgn(U-0.5)* one_sided_maxwell U~Unif;

Parameters
  • key – a PRNGKey key.

  • loc – The location parameter of the distribution.

  • scale – The scale parameter of the distribution.

  • shape – The shape added to the parameters loc and scale broadcastable shape.

  • dtype – The type used for samples.

Returns

A jnp.array of samples.

jax.random.exponential(key, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample Exponential random values with given shape and float dtype.

Parameters
  • key – a PRNGKey used as the random key.

  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns

A random array with the specified shape and dtype.

jax.random.fold_in(key, data)[source]ΒΆ

Folds in data to a PRNG key to form a new PRNG key.

Parameters
  • key – a PRNGKey (an array with shape (2,) and dtype uint32).

  • data – a 32bit integer representing data to be folded in to the key.

Returns

A new PRNGKey that is a deterministic function of the inputs and is statistically safe for producing a stream of new pseudo-random values.

jax.random.gamma(key, a, shape=None, dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample Gamma random values with given shape and float dtype.

Parameters
  • key – a PRNGKey used as the random key.

  • a – a float or array of floats broadcast-compatible with shape representing the parameter of the distribution.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with a. The default (None) produces a result shape equal to a.shape.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns

A random array with the specified dtype and with shape given by shape if shape is not None, or else by a.shape.

jax.random.gumbel(key, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample Gumbel random values with given shape and float dtype.

Parameters
  • key – a PRNGKey used as the random key.

  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns

A random array with the specified shape and dtype.

jax.random.laplace(key, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample Laplace random values with given shape and float dtype.

Parameters
  • key – a PRNGKey used as the random key.

  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns

A random array with the specified shape and dtype.

jax.random.logistic(key, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample logistic random values with given shape and float dtype.

Parameters
  • key – a PRNGKey used as the random key.

  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns

A random array with the specified shape and dtype.

jax.random.maxwell(key, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample from a one sided Maxwell distribution.

The scipy counterpart is scipy.stats.maxwell.

Parameters
  • key – a PRNGKey key.

  • shape – The shape of the returned samples.

  • dtype – The type used for samples.

Returns

A jnp.array of samples, of shape shape.

jax.random.multivariate_normal(key, mean, cov, shape=None, dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample multivariate normal random values with given mean and covariance.

Parameters
  • key (ndarray) – a PRNGKey used as the random key.

  • mean (ndarray) – a mean vector of shape (..., n).

  • cov (ndarray) – a positive definite covariance matrix of shape (..., n, n). The batch shape ... must be broadcast-compatible with that of mean.

  • shape (Optional[Sequence[int]]) – optional, a tuple of nonnegative integers specifying the result batch shape; that is, the prefix of the result shape excluding the last axis. Must be broadcast-compatible with mean.shape[:-1] and cov.shape[:-2]. The default (None) produces a result batch shape by broadcasting together the batch shapes of mean and cov.

  • dtype (dtype) – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Return type

ndarray

Returns

A random array with the specified dtype and shape given by shape + mean.shape[-1:] if shape is not None, or else broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:].

jax.random.normal(key, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample standard normal random values with given shape and float dtype.

Parameters
  • key (ndarray) – a PRNGKey used as the random key.

  • shape (Sequence[int]) – optional, a tuple of nonnegative integers representing the result shape. Default ().

  • dtype (dtype) – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Return type

ndarray

Returns

A random array with the specified shape and dtype.

jax.random.pareto(key, b, shape=None, dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample Pareto random values with given shape and float dtype.

Parameters
  • key – a PRNGKey used as the random key.

  • a – a float or array of floats broadcast-compatible with shape representing the parameter of the distribution.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with b. The default (None) produces a result shape equal to b.shape.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns

A random array with the specified dtype and with shape given by shape if shape is not None, or else by b.shape.

jax.random.permutation(key, x)[source]ΒΆ

Permute elements of an array along its first axis or return a permuted range.

If x is a multi-dimensional array, it is only shuffled along its first index.

Args:n

key: a PRNGKey used as the random key. x: the array or integer range to be shuffled.

Returns

A shuffled version of x or array range

jax.random.poisson(key, lam, shape=(), dtype=<class 'numpy.int64'>)[source]ΒΆ

Sample Poisson random values with given shape and integer dtype.

Parameters
  • key – a PRNGKey used as the random key.

  • lam – rate parameter (mean of the distribution), must be >= 0.

  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().

  • dtype – optional, a integer dtype for the returned values (default int64 if jax_enable_x64 is true, otherwise int32).

Returns

A random array with the specified shape and dtype.

jax.random.rademacher(key, shape, dtype=<class 'numpy.int64'>)[source]ΒΆ

Sample from a Rademacher distribution.

Parameters
  • key – a PRNGKey key.

  • shape – The shape of the returned samples.

  • dtype – The type used for samples.

Returns

A jnp.array of samples, of shape shape. Each element in the output has a 50% change of being 1 or -1.

jax.random.randint(key, shape, minval, maxval, dtype=<class 'numpy.int64'>)[source]ΒΆ

Sample uniform random values in [minval, maxval) with given shape/dtype.

Parameters
  • key (ndarray) – a PRNGKey used as the random key.

  • shape (Sequence[int]) – a tuple of nonnegative integers representing the shape.

  • minval (Union[int, ndarray]) – int or array of ints broadcast-compatible with shape, a minimum (inclusive) value for the range.

  • maxval (Union[int, ndarray]) – int or array of ints broadcast-compatible with shape, a maximum (exclusive) value for the range.

  • dtype (dtype) – optional, an int dtype for the returned values (default int64 if jax_enable_x64 is true, otherwise int32).

Returns

A random array with the specified shape and dtype.

jax.random.rolled_loop_step(i, state)[source]ΒΆ
jax.random.rotate_left(x, d)ΒΆ
jax.random.rotate_list(xs)[source]ΒΆ
jax.random.shuffle(key, x, axis=0)[source]ΒΆ

Shuffle the elements of an array uniformly at random along an axis.

Parameters
  • key (ndarray) – a PRNGKey used as the random key.

  • x (ndarray) – the array to be shuffled.

  • axis (int) – optional, an int axis along which to shuffle (default 0).

Return type

ndarray

Returns

A shuffled version of x.

jax.random.split(key, num=2)[source]ΒΆ

Splits a PRNG key into num new keys by adding a leading axis.

Parameters
  • key (ndarray) – a PRNGKey (an array with shape (2,) and dtype uint32).

  • num (int) – optional, a positive integer indicating the number of keys to produce (default 2).

Return type

ndarray

Returns

An array with shape (num, 2) and dtype uint32 representing num new keys.

jax.random.t(key, df, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample Student’s t random values with given shape and float dtype.

Parameters
  • key – a PRNGKey used as the random key.

  • df – a float or array of floats broadcast-compatible with shape representing the parameter of the distribution.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with df. The default (None) produces a result shape equal to df.shape.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns

A random array with the specified dtype and with shape given by shape if shape is not None, or else by df.shape.

jax.random.threefry_2x32(keypair, count)[source]ΒΆ

Apply the Threefry 2x32 hash.

Parameters
  • keypair – a pair of 32bit unsigned integers used for the key.

  • count – an array of dtype uint32 used for the counts.

Returns

An array of dtype uint32 with the same shape as count.

jax.random.truncated_normal(key, lower, upper, shape=None, dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample truncated standard normal random values with given shape and dtype.

Parameters
  • key (ndarray) – a PRNGKey used as the random key.

  • lower (Union[float, ndarray]) – a float or array of floats representing the lower bound for truncation. Must be broadcast-compatible with upper.

  • upper (Union[float, ndarray]) – a float or array of floats representing the upper bound for truncation. Must be broadcast-compatible with lower.

  • shape (Optional[Sequence[int]]) – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with lower and upper. The default (None) produces a result shape by broadcasting lower and upper.

  • dtype (dtype) – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Return type

ndarray

Returns

A random array with the specified dtype and shape given by shape if shape is not None, or else by broadcasting lower and upper. Returns values in the open interval (lower, upper).

jax.random.uniform(key, shape=(), dtype=<class 'numpy.float64'>, minval=0.0, maxval=1.0)[source]ΒΆ

Sample uniform random values in [minval, maxval) with given shape/dtype.

Parameters
  • key (ndarray) – a PRNGKey used as the random key.

  • shape (Sequence[int]) – optional, a tuple of nonnegative integers representing the result shape. Default ().

  • dtype (dtype) – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • minval (Union[float, ndarray]) – optional, a minimum (inclusive) value broadcast-compatible with shape for the range (default 0).

  • maxval (Union[float, ndarray]) – optional, a maximum (exclusive) value broadcast-compatible with shape for the range (default 1).

Return type

ndarray

Returns

A random array with the specified shape and dtype.

jax.random.weibull_min(key, scale, concentration, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ

Sample from a Weibull distribution.

The scipy counterpart is scipy.stats.weibull_min.

Parameters
  • key – a PRNGKey key.

  • scale – The scale parameter of the distribution.

  • concentration – The concentration parameter of the distribution.

  • shape – The shape added to the parameters loc and scale broadcastable shape.

  • dtype – The type used for samples.

Returns

A jnp.array of samples.