jax.random packageΒΆ
JAX pseudo-random number generators (PRNGs).
The JAX PRNG system is based on βParallel random numbers: as easy as 1, 2, 3β (Salmon et al. 2011). For details on the design and its motivation, see:
https://github.com/google/jax/blob/master/design_notes/prng.md
-
jax.random.
PRNGKey
(seed)[source]ΒΆ Create a pseudo-random number generator (PRNG) key given an integer seed.
Parameters: seed β a 64- or 32-bit integer used as the value of the key. Returns: A PRNG key, which is modeled as an array of shape (2,) and dtype uint32. The key is constructed from a 64-bit seed by effectively bit-casting to a pair of uint32 values (or from a 32-bit seed by first padding out with zeros).
-
jax.random.
bernoulli
(key, p=0.5, shape=None)[source]ΒΆ Sample Bernoulli random values with given shape and mean.
Parameters: - key β a PRNGKey used as the random key.
- p β optional, a float or array of floats for the mean of the random
variables. Must be broadcast-compatible with
shape
. Default 0.5. - shape β optional, a tuple of nonnegative integers representing the result
shape. Must be broadcast-compatible with
p.shape
. The default (None) produces a result shape equal top.shape
.
Returns: A random array with boolean dtype and shape given by
shape
ifshape
is not None, or elsep.shape
.
-
jax.random.
beta
(key, a, b, shape=None, dtype=<class 'numpy.float64'>)[source]ΒΆ Sample Bernoulli random values with given shape and mean.
Parameters: - key β a PRNGKey used as the random key.
- a β a float or array of floats broadcast-compatible with
shape
representing the first parameter βalphaβ. - b β a float or array of floats broadcast-compatible with
shape
representing the second parameter βbetaβ. - shape β optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with
a
andb
. The default (None) produces a result shape by broadcastinga
andb
. - dtype β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified dtype and shape given by
shape
ifshape
is not None, or else by broadcastinga
andb
.
-
jax.random.
categorical
(key, logits, axis=-1, shape=None)[source]ΒΆ Sample random values from categorical distributions.
Parameters: - key β a PRNGKey used as the random key.
- logits β Unnormalized log probabilities of the categorical distribution(s) to sample from, so that softmax(logits, axis) gives the corresponding probabilities.
- axis β Axis along which logits belong to the same categorical distribution.
- shape β Optional, a tuple of nonnegative integers representing the result shape.
Must be broadcast-compatible with
onp.delete(logits.shape, axis)
. The default (None) produces a result shape equal toonp.delete(logits.shape, axis)
.
Returns: A random array with int dtype and shape given by
shape
ifshape
is not None, or elseonp.delete(logits.shape, axis)
.
-
jax.random.
cauchy
(key, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ Sample Cauchy random values with given shape and float dtype.
Parameters: - key β a PRNGKey used as the random key.
- shape β optional, a tuple of nonnegative integers representing the result shape. Default ().
- dtype β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified shape and dtype.
-
jax.random.
dirichlet
(key, alpha, shape=None, dtype=<class 'numpy.float64'>)[source]ΒΆ Sample Dirichlet random values with given shape and float dtype.
Parameters: - key β a PRNGKey used as the random key.
- alpha β an array of shape
(..., n)
used as the concentration parameter of the random variables. - shape β optional, a tuple of nonnegative integers specifying the result
batch shape; that is, the prefix of the result shape excluding the last
element of value
n
. Must be broadcast-compatible withalpha.shape[:-1]
. The default (None) produces a result shape equal toalpha.shape
. - dtype β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified dtype and shape given by
shape + (alpha.shape[-1],)
ifshape
is not None, or elsealpha.shape
.
-
jax.random.
exponential
(key, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ Sample Exponential random values with given shape and float dtype.
Parameters: - key β a PRNGKey used as the random key.
- shape β optional, a tuple of nonnegative integers representing the result shape. Default ().
- dtype β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified shape and dtype.
-
jax.random.
fold_in
(key, data)[source]ΒΆ Folds in data to a PRNG key to form a new PRNG key.
Parameters: - key β a PRNGKey (an array with shape (2,) and dtype uint32).
- data β a 32bit integer representing data to be folded in to the key.
Returns: A new PRNGKey that is a deterministic function of the inputs and is statistically safe for producing a stream of new pseudo-random values.
-
jax.random.
gamma
(key, a, shape=None, dtype=<class 'numpy.float64'>)[source]ΒΆ Sample Gamma random values with given shape and float dtype.
Parameters: - key β a PRNGKey used as the random key.
- a β a float or array of floats broadcast-compatible with
shape
representing the parameter of the distribution. - shape β optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with
a
. The default (None) produces a result shape equal toa.shape
. - dtype β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified dtype and with shape given by
shape
ifshape
is not None, or else bya.shape
.
-
jax.random.
gumbel
(key, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ Sample Gumbel random values with given shape and float dtype.
Parameters: - key β a PRNGKey used as the random key.
- shape β optional, a tuple of nonnegative integers representing the result shape. Default ().
- dtype β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified shape and dtype.
-
jax.random.
laplace
(key, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ Sample Laplace random values with given shape and float dtype.
Parameters: - key β a PRNGKey used as the random key.
- shape β optional, a tuple of nonnegative integers representing the result shape. Default ().
- dtype β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified shape and dtype.
-
jax.random.
logistic
(key, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ Sample logistic random values with given shape and float dtype.
Parameters: - key β a PRNGKey used as the random key.
- shape β optional, a tuple of nonnegative integers representing the result shape. Default ().
- dtype β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified shape and dtype.
-
jax.random.
multivariate_normal
(key, mean, cov, shape=None, dtype=<class 'numpy.float64'>)[source]ΒΆ Sample multivariate normal random values with given mean and covariance.
Parameters: - key β a PRNGKey used as the random key.
- mean β a mean vector of shape
(..., n)
. - cov β a positive definite covariance matrix of shape
(..., n, n)
. The batch shape...
must be broadcast-compatible with that ofmean
. - shape β optional, a tuple of nonnegative integers specifying the result
batch shape; that is, the prefix of the result shape excluding the last
axis. Must be broadcast-compatible with
mean.shape[:-1]
andcov.shape[:-2]
. The default (None) produces a result batch shape by broadcasting together the batch shapes ofmean
andcov
. - dtype β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified dtype and shape given by
shape + mean.shape[-1:]
ifshape
is not None, or elsebroadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]
.
-
jax.random.
normal
(key, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ Sample standard normal random values with given shape and float dtype.
Parameters: - key β a PRNGKey used as the random key.
- shape β optional, a tuple of nonnegative integers representing the result shape. Default ().
- dtype β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified shape and dtype.
-
jax.random.
pareto
(key, b, shape=None, dtype=<class 'numpy.float64'>)[source]ΒΆ Sample Pareto random values with given shape and float dtype.
Parameters: - key β a PRNGKey used as the random key.
- a β a float or array of floats broadcast-compatible with
shape
representing the parameter of the distribution. - shape β optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with
b
. The default (None) produces a result shape equal tob.shape
. - dtype β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified dtype and with shape given by
shape
ifshape
is not None, or else byb.shape
.
-
jax.random.
randint
(key, shape, minval, maxval, dtype=<class 'numpy.int64'>)[source]ΒΆ Sample uniform random values in [minval, maxval) with given shape/dtype.
Parameters: - key β a PRNGKey used as the random key.
- shape β a tuple of nonnegative integers representing the shape.
- minval β int or array of ints broadcast-compatible with
shape
, a minimum (inclusive) value for the range. - maxval β int or array of ints broadcast-compatible with
shape
, a maximum (exclusive) value for the range. - dtype β optional, an int dtype for the returned values (default int64 if jax_enable_x64 is true, otherwise int32).
Returns: A random array with the specified shape and dtype.
-
jax.random.
shuffle
(key, x, axis=0)[source]ΒΆ Shuffle the elements of an array uniformly at random along an axis.
Parameters: - key β a PRNGKey used as the random key.
- x β the array to be shuffled.
- axis β optional, an int axis along which to shuffle (default 0).
Returns: A shuffled version of x.
-
jax.random.
split
(key, num=2)[source]ΒΆ Splits a PRNG key into num new keys by adding a leading axis.
Parameters: - key β a PRNGKey (an array with shape (2,) and dtype uint32).
- num β optional, a positive integer indicating the number of keys to produce (default 2).
Returns: An array with shape (num, 2) and dtype uint32 representing num new keys.
-
jax.random.
t
(key, df, shape=(), dtype=<class 'numpy.float64'>)[source]ΒΆ Sample Studentβs t random values with given shape and float dtype.
Parameters: - key β a PRNGKey used as the random key.
- df β a float or array of floats broadcast-compatible with
shape
representing the parameter of the distribution. - shape β optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with
df
. The default (None) produces a result shape equal todf.shape
. - dtype β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified dtype and with shape given by
shape
ifshape
is not None, or else bydf.shape
.
-
jax.random.
threefry_2x32
(keypair, count)[source]ΒΆ Apply the Threefry 2x32 hash.
Parameters: - keypair β a pair of 32bit unsigned integers used for the key.
- count β an array of dtype uint32 used for the counts.
Returns: An array of dtype uint32 with the same shape as count.
-
jax.random.
truncated_normal
(key, lower, upper, shape=None, dtype=<class 'numpy.float64'>)[source]ΒΆ Sample truncated standard normal random values with given shape and dtype.
Parameters: - key β a PRNGKey used as the random key.
- lower β a float or array of floats representing the lower bound for
truncation. Must be broadcast-compatible with
upper
. - upper β a float or array of floats representing the upper bound for
truncation. Must be broadcast-compatible with
lower
. - shape β optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with
lower
andupper
. The default (None) produces a result shape by broadcastinglower
andupper
. - dtype β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified dtype and shape given by
shape
ifshape
is not None, or else by broadcastinglower
andupper
.
-
jax.random.
uniform
(key, shape=(), dtype=<class 'numpy.float64'>, minval=0.0, maxval=1.0)[source]ΒΆ Sample uniform random values in [minval, maxval) with given shape/dtype.
Parameters: - key β a PRNGKey used as the random key.
- shape β optional, a tuple of nonnegative integers representing the result shape. Default ().
- dtype β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
- minval β optional, a minimum (inclusive) value for the range (default 0).
- maxval β optional, a maximum (exclusive) value for the range (default 1).
Returns: A random array with the specified shape and dtype.