jax.random.multivariate_normal#
- jax.random.multivariate_normal(key, mean, cov, shape=None, dtype=None, method='cholesky')[source]#
Sample multivariate normal random values with given mean and covariance.
The values are returned according to the probability density function:
\[f(x;\mu, \Sigma) = (2\pi)^{-k/2} \det(\Sigma)^{-1}e^{-\frac{1}{2}(x - \mu)^T \Sigma^{-1} (x - \mu)}\]where \(k\) is the dimension, \(\mu\) is the mean (given by
mean
) and \(\Sigma\) is the covariance matrix (given bycov
).- Parameters:
key (KeyArrayLike) – a PRNG key used as the random key.
mean (RealArray) – a mean vector of shape
(..., n)
.cov (RealArray) – a positive definite covariance matrix of shape
(..., n, n)
. The batch shape...
must be broadcast-compatible with that ofmean
.shape (Shape | None | None) – optional, a tuple of nonnegative integers specifying the result batch shape; that is, the prefix of the result shape excluding the last axis. Must be broadcast-compatible with
mean.shape[:-1]
andcov.shape[:-2]
. The default (None) produces a result batch shape by broadcasting together the batch shapes ofmean
andcov
.dtype (DTypeLikeFloat | None | None) – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
method (str) – optional, a method to compute the factor of
cov
. Must be one of ‘svd’, ‘eigh’, and ‘cholesky’. Default ‘cholesky’. For singular covariance matrices, use ‘svd’ or ‘eigh’.
- Returns:
A random array with the specified dtype and shape given by
shape + mean.shape[-1:]
ifshape
is not None, or elsebroadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]
.- Return type: