jax.scipy.stats.multivariate_normal.pdf

Contents

jax.scipy.stats.multivariate_normal.pdf#

jax.scipy.stats.multivariate_normal.pdf(x, mean, cov)[source]#

Multivariate normal probability distribution function.

JAX implementation of scipy.stats.multivariate_normal pdf.

The multivariate normal PDF is defined as

\[f(x) = \frac{1}{(2\pi)^k\det\Sigma}\exp\left(-\frac{(x-\mu)^T\Sigma^{-1}(x-\mu)}{2} \right)\]

where \(\mu\) is the mean, \(\Sigma\) is the covarance matrix (cov), and \(k\) is the rank of \(\Sigma\).

Parameters:
Returns:

array of pdf values.

Return type:

Array