jax.scipy.stats.multivariate_normal.logpdf

Contents

jax.scipy.stats.multivariate_normal.logpdf#

jax.scipy.stats.multivariate_normal.logpdf(x, mean, cov, allow_singular=None)[source]#

Multivariate normal log probability distribution function.

JAX implementation of scipy.stats.multivariate_normal logpdf.

The multivariate normal PDF is defined as

\[f(x) = \frac{1}{(2\pi)^k\det\Sigma}\exp\left(-\frac{(x-\mu)^T\Sigma^{-1}(x-\mu)}{2} \right)\]

where \(\mu\) is the mean, \(\Sigma\) is the covariance matrix (cov), and \(k\) is the rank of \(\Sigma\).

Parameters:
Returns:

array of logpdf values.

Return type:

Array | ndarray | bool_ | number | bool | int | float | complex