jax.random packageΒΆ
JAX pseudo-random number generators (PRNGs).
Example usage:
>>> rng = jax.random.PRNGKey(seed)
>>> for i in range(num_steps):
... rng, rng_input = jax.random.split(rng)
... params = compiled_update(rng_input, params, next(batches))
Context:
Among other requirements, the JAX PRNG aims to: (a) ensure reproducibility, (b) parallelize well, both in terms of vectorization (generating array values) and multi-replica, multi-core computation. In particular it should not use sequencing constraints between random function calls.
The approach is based on: 1. βParallel random numbers: as easy as 1, 2, 3β (Salmon et al. 2011) 2. βSplittable pseudorandom number generators using cryptographic hashingβ (Claessen et al. 2013)
See also https://github.com/google/jax/blob/master/design_notes/prng.md for the design and its motivation.
-
jax.random.
PRNGKey
(seed)[source]ΒΆ Create a pseudo-random number generator (PRNG) key given an integer seed.
- Parameters
seed (
int
) β a 64- or 32-bit integer used as the value of the key.- Return type
ndarray
- Returns
A PRNG key, which is modeled as an array of shape (2,) and dtype uint32. The key is constructed from a 64-bit seed by effectively bit-casting to a pair of uint32 values (or from a 32-bit seed by first padding out with zeros).
-
jax.random.
bernoulli
(key, p=0.5, shape=None)[source]ΒΆ Sample Bernoulli random values with given shape and mean.
- Parameters
key (
ndarray
) β a PRNGKey used as the random key.p (
ndarray
) β optional, a float or array of floats for the mean of the random variables. Must be broadcast-compatible withshape
. Default 0.5.shape (
Optional
[Sequence
[int
]]) β optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible withp.shape
. The default (None) produces a result shape equal top.shape
.
- Return type
ndarray
- Returns
A random array with boolean dtype and shape given by
shape
ifshape
is not None, or elsep.shape
.
-
jax.random.
beta
(key, a, b, shape=None, dtype=<class 'numpy.float64'>)[source]ΒΆ Sample Beta random values with given shape and float dtype.
- Parameters
key (
ndarray
) β a PRNGKey used as the random key.a (
Union
[float
,ndarray
]) β a float or array of floats broadcast-compatible withshape
representing the first parameter βalphaβ.b (
Union
[float
,ndarray
]) β a float or array of floats broadcast-compatible withshape
representing the second parameter βbetaβ.shape (
Optional
[Sequence
[int
]]) β optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible witha
andb
. The default (None) produces a result shape by broadcastinga
andb
.dtype (
dtype
) β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
- Return type
ndarray
- Returns
A random array with the specified dtype and shape given by
shape
ifshape
is not None, or else by broadcastinga
andb
.
-
jax.random.
categorical
(key, logits, axis=-1, shape=None)[source]ΒΆ Sample random values from categorical distributions.
- Parameters
key β a PRNGKey used as the random key.
logits β Unnormalized log probabilities of the categorical distribution(s) to sample from, so that softmax(logits, axis) gives the corresponding probabilities.
axis β Axis along which logits belong to the same categorical distribution.
shape β Optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible with
np.delete(logits.shape, axis)
. The default (None) produces a result shape equal tonp.delete(logits.shape, axis)
.
- Returns
A random array with int dtype and shape given by
shape
ifshape
is not None, or elsenp.delete(logits.shape, axis)
.
-
jax.random.
cauchy
(key, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ Sample Cauchy random values with given shape and float dtype.
- Parameters
key β a PRNGKey used as the random key.
shape β optional, a tuple of nonnegative integers representing the result shape. Default ().
dtype β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
- Returns
A random array with the specified shape and dtype.
-
jax.random.
choice
(key, a, shape=(), replace=True, p=None)[source]ΒΆ Generates a random sample from a given 1-D array.
- Parameters
key β a PRNGKey used as the random key.
a β 1D array or int. If an ndarray, a random sample is generated from its elements. If an int, the random sample is generated as if a were arange(a).
shape β tuple of ints, optional. Output shape. If the given shape is, e.g.,
(m, n)
, thenm * n
samples are drawn. Default is (), in which case a single value is returned.replace β boolean. Whether the sample is with or without replacement. default is True.
p β 1-D array-like, The probabilities associated with each entry in a. If not given the sample assumes a uniform distribution over all entries in a.
- Returns
An array of shape shape containing samples from a.
-
jax.random.
dirichlet
(key, alpha, shape=None, dtype=<class 'numpy.float64'>)[source]ΒΆ Sample Dirichlet random values with given shape and float dtype.
- Parameters
key β a PRNGKey used as the random key.
alpha β an array of shape
(..., n)
used as the concentration parameter of the random variables.shape β optional, a tuple of nonnegative integers specifying the result batch shape; that is, the prefix of the result shape excluding the last element of value
n
. Must be broadcast-compatible withalpha.shape[:-1]
. The default (None) produces a result shape equal toalpha.shape
.dtype β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
- Returns
A random array with the specified dtype and shape given by
shape + (alpha.shape[-1],)
ifshape
is not None, or elsealpha.shape
.
-
jax.random.
double_sided_maxwell
(key, loc, scale, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ Sample from a double sided Maxwell distribution.
- Samples using:
loc + scale* sgn(U-0.5)* one_sided_maxwell U~Unif;
- Parameters
key β a PRNGKey key.
loc β The location parameter of the distribution.
scale β The scale parameter of the distribution.
shape β The shape added to the parameters loc and scale broadcastable shape.
dtype β The type used for samples.
- Returns
A jnp.array of samples.
-
jax.random.
exponential
(key, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ Sample Exponential random values with given shape and float dtype.
- Parameters
key β a PRNGKey used as the random key.
shape β optional, a tuple of nonnegative integers representing the result shape. Default ().
dtype β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
- Returns
A random array with the specified shape and dtype.
-
jax.random.
fold_in
(key, data)[source]ΒΆ Folds in data to a PRNG key to form a new PRNG key.
- Parameters
key β a PRNGKey (an array with shape (2,) and dtype uint32).
data β a 32bit integer representing data to be folded in to the key.
- Returns
A new PRNGKey that is a deterministic function of the inputs and is statistically safe for producing a stream of new pseudo-random values.
-
jax.random.
gamma
(key, a, shape=None, dtype=<class 'numpy.float64'>)[source]ΒΆ Sample Gamma random values with given shape and float dtype.
- Parameters
key β a PRNGKey used as the random key.
a β a float or array of floats broadcast-compatible with
shape
representing the parameter of the distribution.shape β optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with
a
. The default (None) produces a result shape equal toa.shape
.dtype β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
- Returns
A random array with the specified dtype and with shape given by
shape
ifshape
is not None, or else bya.shape
.
-
jax.random.
gumbel
(key, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ Sample Gumbel random values with given shape and float dtype.
- Parameters
key β a PRNGKey used as the random key.
shape β optional, a tuple of nonnegative integers representing the result shape. Default ().
dtype β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
- Returns
A random array with the specified shape and dtype.
-
jax.random.
laplace
(key, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ Sample Laplace random values with given shape and float dtype.
- Parameters
key β a PRNGKey used as the random key.
shape β optional, a tuple of nonnegative integers representing the result shape. Default ().
dtype β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
- Returns
A random array with the specified shape and dtype.
-
jax.random.
logistic
(key, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ Sample logistic random values with given shape and float dtype.
- Parameters
key β a PRNGKey used as the random key.
shape β optional, a tuple of nonnegative integers representing the result shape. Default ().
dtype β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
- Returns
A random array with the specified shape and dtype.
-
jax.random.
maxwell
(key, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ Sample from a one sided Maxwell distribution.
The scipy counterpart is scipy.stats.maxwell.
- Parameters
key β a PRNGKey key.
shape β The shape of the returned samples.
dtype β The type used for samples.
- Returns
A jnp.array of samples, of shape shape.
-
jax.random.
multivariate_normal
(key, mean, cov, shape=None, dtype=<class 'numpy.float64'>)[source]ΒΆ Sample multivariate normal random values with given mean and covariance.
- Parameters
key (
ndarray
) β a PRNGKey used as the random key.mean (
ndarray
) β a mean vector of shape(..., n)
.cov (
ndarray
) β a positive definite covariance matrix of shape(..., n, n)
. The batch shape...
must be broadcast-compatible with that ofmean
.shape (
Optional
[Sequence
[int
]]) β optional, a tuple of nonnegative integers specifying the result batch shape; that is, the prefix of the result shape excluding the last axis. Must be broadcast-compatible withmean.shape[:-1]
andcov.shape[:-2]
. The default (None) produces a result batch shape by broadcasting together the batch shapes ofmean
andcov
.dtype (
dtype
) β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
- Return type
ndarray
- Returns
A random array with the specified dtype and shape given by
shape + mean.shape[-1:]
ifshape
is not None, or elsebroadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]
.
-
jax.random.
normal
(key, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ Sample standard normal random values with given shape and float dtype.
- Parameters
- Return type
ndarray
- Returns
A random array with the specified shape and dtype.
-
jax.random.
pareto
(key, b, shape=None, dtype=<class 'numpy.float64'>)[source]ΒΆ Sample Pareto random values with given shape and float dtype.
- Parameters
key β a PRNGKey used as the random key.
a β a float or array of floats broadcast-compatible with
shape
representing the parameter of the distribution.shape β optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with
b
. The default (None) produces a result shape equal tob.shape
.dtype β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
- Returns
A random array with the specified dtype and with shape given by
shape
ifshape
is not None, or else byb.shape
.
-
jax.random.
permutation
(key, x)[source]ΒΆ Permute elements of an array along its first axis or return a permuted range.
If x is a multi-dimensional array, it is only shuffled along its first index.
- Args:n
key: a PRNGKey used as the random key. x: the array or integer range to be shuffled.
- Returns
A shuffled version of x or array range
-
jax.random.
poisson
(key, lam, shape=(), dtype=<class 'numpy.int64'>)[source]ΒΆ Sample Poisson random values with given shape and integer dtype.
- Parameters
key β a PRNGKey used as the random key.
lam β rate parameter (mean of the distribution), must be >= 0.
shape β optional, a tuple of nonnegative integers representing the result shape. Default ().
dtype β optional, a integer dtype for the returned values (default int64 if jax_enable_x64 is true, otherwise int32).
- Returns
A random array with the specified shape and dtype.
-
jax.random.
rademacher
(key, shape, dtype=<class 'numpy.int64'>)[source]ΒΆ Sample from a Rademacher distribution.
- Parameters
key β a PRNGKey key.
shape β The shape of the returned samples.
dtype β The type used for samples.
- Returns
A jnp.array of samples, of shape shape. Each element in the output has a 50% change of being 1 or -1.
-
jax.random.
randint
(key, shape, minval, maxval, dtype=<class 'numpy.int64'>)[source]ΒΆ Sample uniform random values in [minval, maxval) with given shape/dtype.
- Parameters
key (
ndarray
) β a PRNGKey used as the random key.shape (
Sequence
[int
]) β a tuple of nonnegative integers representing the shape.minval (
Union
[int
,ndarray
]) β int or array of ints broadcast-compatible withshape
, a minimum (inclusive) value for the range.maxval (
Union
[int
,ndarray
]) β int or array of ints broadcast-compatible withshape
, a maximum (exclusive) value for the range.dtype (
dtype
) β optional, an int dtype for the returned values (default int64 if jax_enable_x64 is true, otherwise int32).
- Returns
A random array with the specified shape and dtype.
-
jax.random.
rotate_left
(x, d)ΒΆ
-
jax.random.
shuffle
(key, x, axis=0)[source]ΒΆ Shuffle the elements of an array uniformly at random along an axis.
- Parameters
key (
ndarray
) β a PRNGKey used as the random key.x (
ndarray
) β the array to be shuffled.axis (
int
) β optional, an int axis along which to shuffle (default 0).
- Return type
ndarray
- Returns
A shuffled version of x.
-
jax.random.
split
(key, num=2)[source]ΒΆ Splits a PRNG key into num new keys by adding a leading axis.
- Parameters
key (
ndarray
) β a PRNGKey (an array with shape (2,) and dtype uint32).num (
int
) β optional, a positive integer indicating the number of keys to produce (default 2).
- Return type
ndarray
- Returns
An array with shape (num, 2) and dtype uint32 representing num new keys.
-
jax.random.
t
(key, df, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ Sample Studentβs t random values with given shape and float dtype.
- Parameters
key β a PRNGKey used as the random key.
df β a float or array of floats broadcast-compatible with
shape
representing the parameter of the distribution.shape β optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with
df
. The default (None) produces a result shape equal todf.shape
.dtype β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
- Returns
A random array with the specified dtype and with shape given by
shape
ifshape
is not None, or else bydf.shape
.
-
jax.random.
threefry_2x32
(keypair, count)[source]ΒΆ Apply the Threefry 2x32 hash.
- Parameters
keypair β a pair of 32bit unsigned integers used for the key.
count β an array of dtype uint32 used for the counts.
- Returns
An array of dtype uint32 with the same shape as count.
-
jax.random.
truncated_normal
(key, lower, upper, shape=None, dtype=<class 'numpy.float64'>)[source]ΒΆ Sample truncated standard normal random values with given shape and dtype.
- Parameters
key (
ndarray
) β a PRNGKey used as the random key.lower (
Union
[float
,ndarray
]) β a float or array of floats representing the lower bound for truncation. Must be broadcast-compatible withupper
.upper (
Union
[float
,ndarray
]) β a float or array of floats representing the upper bound for truncation. Must be broadcast-compatible withlower
.shape (
Optional
[Sequence
[int
]]) β optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible withlower
andupper
. The default (None) produces a result shape by broadcastinglower
andupper
.dtype (
dtype
) β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
- Return type
ndarray
- Returns
A random array with the specified dtype and shape given by
shape
ifshape
is not None, or else by broadcastinglower
andupper
.
-
jax.random.
uniform
(key, shape=(), dtype=<class 'numpy.float64'>, minval=0.0, maxval=1.0)[source]ΒΆ Sample uniform random values in [minval, maxval) with given shape/dtype.
- Parameters
key (
ndarray
) β a PRNGKey used as the random key.shape (
Sequence
[int
]) β optional, a tuple of nonnegative integers representing the result shape. Default ().dtype (
dtype
) β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).minval (
Union
[float
,ndarray
]) β optional, a minimum (inclusive) value broadcast-compatible with shape for the range (default 0).maxval (
Union
[float
,ndarray
]) β optional, a maximum (exclusive) value broadcast-compatible with shape for the range (default 1).
- Return type
ndarray
- Returns
A random array with the specified shape and dtype.
-
jax.random.
weibull_min
(key, scale, concentration, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ Sample from a Weibull distribution.
The scipy counterpart is scipy.stats.weibull_min.
- Parameters
key β a PRNGKey key.
scale β The scale parameter of the distribution.
concentration β The concentration parameter of the distribution.
shape β The shape added to the parameters loc and scale broadcastable shape.
dtype β The type used for samples.
- Returns
A jnp.array of samples.