# jax.scipy.linalg.cholesky#

jax.scipy.linalg.cholesky(a, lower=False, overwrite_a=False, check_finite=True)[source]#

Compute the Cholesky decomposition of a matrix.

JAX implementation of scipy.linalg.cholesky().

The Cholesky decomposition of a matrix A is:

$A = U^HU = LL^H$

where U is an upper-triangular matrix and L is a lower-triangular matrix.

Parameters:
• a (jax.typing.ArrayLike) â€“ input array, representing a (batched) positive-definite hermitian matrix. Must have shape (..., N, N).

• lower (bool) â€“ if True, compute the lower Cholesky decomposition L. if False (default), compute the upper Cholesky decomposition U.

• overwrite_a (bool) â€“ unused by JAX

• check_finite (bool) â€“ unused by JAX

Returns:

array of shape (..., N, N) representing the cholesky decomposition of the input.

Return type:

Array

Examples

A small real Hermitian positive-definite matrix:

>>> x = jnp.array([[2., 1.],
...                [1., 2.]])


Upper Cholesky factorization:

>>> jax.scipy.linalg.cholesky(x)
Array([[1.4142135 , 0.70710677],
[0.        , 1.2247449 ]], dtype=float32)


Lower Cholesky factorization:

>>> jax.scipy.linalg.cholesky(x, lower=True)
Array([[1.4142135 , 0.        ],
[0.70710677, 1.2247449 ]], dtype=float32)


Reconstructing x from its factorization:

>>> L = jax.scipy.linalg.cholesky(x, lower=True)
>>> jnp.allclose(x, L @ L.T)
Array(True, dtype=bool)