Source code for jax._src.scipy.stats.multivariate_normal

# Copyright 2018 The JAX Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial

import numpy as np
import scipy.stats as osp_stats

from jax import lax
from jax import numpy as jnp
from jax._src.numpy.util import _wraps, promote_dtypes_inexact
from jax._src.typing import Array, ArrayLike

[docs]@_wraps(osp_stats.multivariate_normal.logpdf, update_doc=False, lax_description=""" In the JAX version, the `allow_singular` argument is not implemented. """) def logpdf(x: ArrayLike, mean: ArrayLike, cov: ArrayLike, allow_singular: None = None) -> ArrayLike: if allow_singular is not None: raise NotImplementedError("allow_singular argument of multivariate_normal.logpdf") x, mean, cov = promote_dtypes_inexact(x, mean, cov) if not mean.shape: return (-1/2 * jnp.square(x - mean) / cov - 1/2 * (jnp.log(2*np.pi) + jnp.log(cov))) else: n = mean.shape[-1] if not np.shape(cov): y = x - mean return (-1/2 * jnp.einsum('...i,...i->...', y, y) / cov - n/2 * (jnp.log(2*np.pi) + jnp.log(cov))) else: if cov.ndim < 2 or cov.shape[-2:] != (n, n): raise ValueError("multivariate_normal.logpdf got incompatible shapes") L = lax.linalg.cholesky(cov) y = jnp.vectorize( partial(lax.linalg.triangular_solve, lower=True, transpose_a=True), signature="(n,n),(n)->(n)" )(L, x - mean) return (-1/2 * jnp.einsum('...i,...i->...', y, y) - n/2 * jnp.log(2*np.pi) - jnp.log(L.diagonal(axis1=-1, axis2=-2)).sum(-1))
[docs]@_wraps(osp_stats.multivariate_normal.pdf, update_doc=False) def pdf(x: ArrayLike, mean: ArrayLike, cov: ArrayLike) -> Array: return lax.exp(logpdf(x, mean, cov))