jax.random.multivariate_normal

jax.random.multivariate_normal#

jax.random.multivariate_normal(key, mean, cov, shape=None, dtype=None, method='cholesky')[source]#

Sample multivariate normal random values with given mean and covariance.

The values are returned according to the probability density function:

\[f(x;\mu, \Sigma) = (2\pi)^{-k/2} \det(\Sigma)^{-1}e^{-\frac{1}{2}(x - \mu)^T \Sigma^{-1} (x - \mu)}\]

where \(k\) is the dimension, \(\mu\) is the mean (given by mean) and \(\Sigma\) is the covariance matrix (given by cov).

Parameters:
  • key (KeyArrayLike) – a PRNG key used as the random key.

  • mean (RealArray) – a mean vector of shape (..., n).

  • cov (RealArray) – a positive definite covariance matrix of shape (..., n, n). The batch shape ... must be broadcast-compatible with that of mean.

  • shape (Shape | None) – optional, a tuple of nonnegative integers specifying the result batch shape; that is, the prefix of the result shape excluding the last axis. Must be broadcast-compatible with mean.shape[:-1] and cov.shape[:-2]. The default (None) produces a result batch shape by broadcasting together the batch shapes of mean and cov.

  • dtype (DTypeLikeFloat | None) – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • method (str) – optional, a method to compute the factor of cov. Must be one of ‘svd’, ‘eigh’, and ‘cholesky’. Default ‘cholesky’. For singular covariance matrices, use ‘svd’ or ‘eigh’.

Return type:

Array

Returns:

A random array with the specified dtype and shape given by shape + mean.shape[-1:] if shape is not None, or else broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:].