jax.scipy.linalg.cholesky

Contents

jax.scipy.linalg.cholesky#

jax.scipy.linalg.cholesky(a, lower=False, overwrite_a=False, check_finite=True)[source]#

Compute the Cholesky decomposition of a matrix.

JAX implementation of scipy.linalg.cholesky().

The Cholesky decomposition of a matrix A is:

\[A = U^HU = LL^H\]

where U is an upper-triangular matrix and L is a lower-triangular matrix.

Parameters:
  • a (jax.typing.ArrayLike) – input array, representing a (batched) positive-definite hermitian matrix. Must have shape (..., N, N).

  • lower (bool) – if True, compute the lower Cholesky decomposition L. if False (default), compute the upper Cholesky decomposition U.

  • overwrite_a (bool) – unused by JAX

  • check_finite (bool) – unused by JAX

Returns:

array of shape (..., N, N) representing the cholesky decomposition of the input.

Return type:

Array

Example

A small real Hermitian positive-definite matrix:

>>> x = jnp.array([[2., 1.],
...                [1., 2.]])

Upper Cholesky factorization:

>>> jax.scipy.linalg.cholesky(x)
Array([[1.4142135 , 0.70710677],
       [0.        , 1.2247449 ]], dtype=float32)

Lower Cholesky factorization:

>>> jax.scipy.linalg.cholesky(x, lower=True)
Array([[1.4142135 , 0.        ],
       [0.70710677, 1.2247449 ]], dtype=float32)

Reconstructing x from its factorization:

>>> L = jax.scipy.linalg.cholesky(x, lower=True)
>>> jnp.allclose(x, L @ L.T)
Array(True, dtype=bool)