jax.random.multivariate_normal#
- jax.random.multivariate_normal(key, mean, cov, shape=None, dtype=None, method='cholesky')[source]#
Sample multivariate normal random values with given mean and covariance.
The values are returned according to the probability density function:
\[f(x;\mu, \Sigma) = (2\pi)^{-k/2} \det(\Sigma)^{-1}e^{-\frac{1}{2}(x - \mu)^T \Sigma^{-1} (x - \mu)}\]where \(k\) is the dimension, \(\mu\) is the mean (given by
mean
) and \(\Sigma\) is the covariance matrix (given bycov
).- Parameters
key (
Union
[Array
,PRNGKeyArray
]) – a PRNG key used as the random key.mean (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) – a mean vector of shape(..., n)
.cov (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) – a positive definite covariance matrix of shape(..., n, n)
. The batch shape...
must be broadcast-compatible with that ofmean
.shape (
Optional
[Sequence
[int
]]) – optional, a tuple of nonnegative integers specifying the result batch shape; that is, the prefix of the result shape excluding the last axis. Must be broadcast-compatible withmean.shape[:-1]
andcov.shape[:-2]
. The default (None) produces a result batch shape by broadcasting together the batch shapes ofmean
andcov
.dtype (
Union
[Any
,str
,dtype
,SupportsDType
,None
]) – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).method (
str
) – optional, a method to compute the factor ofcov
. Must be one of ‘svd’, ‘eigh’, and ‘cholesky’. Default ‘cholesky’. For singular covariance matrices, use ‘svd’ or ‘eigh’.
- Return type
- Returns
A random array with the specified dtype and shape given by
shape + mean.shape[-1:]
ifshape
is not None, or elsebroadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]
.