# jax.random.multivariate_normal#

jax.random.multivariate_normal(key, mean, cov, shape=None, dtype=None, method='cholesky')[source]#

Sample multivariate normal random values with given mean and covariance.

The values are returned according to the probability density function:

$f(x;\mu, \Sigma) = (2\pi)^{-k/2} \det(\Sigma)^{-1}e^{-\frac{1}{2}(x - \mu)^T \Sigma^{-1} (x - \mu)}$

where $$k$$ is the dimension, $$\mu$$ is the mean (given by mean) and $$\Sigma$$ is the covariance matrix (given by cov).

Parameters:
• key (KeyArrayLike) â€“ a PRNG key used as the random key.

• mean (RealArray) â€“ a mean vector of shape (..., n).

• cov (RealArray) â€“ a positive definite covariance matrix of shape (..., n, n). The batch shape ... must be broadcast-compatible with that of mean.

• shape (Shape | None) â€“ optional, a tuple of nonnegative integers specifying the result batch shape; that is, the prefix of the result shape excluding the last axis. Must be broadcast-compatible with mean.shape[:-1] and cov.shape[:-2]. The default (None) produces a result batch shape by broadcasting together the batch shapes of mean and cov.

• dtype (DTypeLikeFloat | None) â€“ optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

• method (str) â€“ optional, a method to compute the factor of cov. Must be one of â€˜svdâ€™, â€˜eighâ€™, and â€˜choleskyâ€™. Default â€˜choleskyâ€™. For singular covariance matrices, use â€˜svdâ€™ or â€˜eighâ€™.

Return type:

Array

Returns:

A random array with the specified dtype and shape given by shape + mean.shape[-1:] if shape is not None, or else broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:].