# jax.numpy.linalg.cholesky#

jax.numpy.linalg.cholesky(a, *, upper=False)[source]#

Compute the Cholesky decomposition of a matrix.

JAX implementation of numpy.linalg.cholesky().

The Cholesky decomposition of a matrix A is:

$A = U^HU$

or

$A = LL^H$

where U is an upper-triangular matrix and L is a lower-triangular matrix, and $$X^H$$ is the Hermitian transpose of X.

Parameters:
• a (jax.typing.ArrayLike) â€“ input array, representing a (batched) positive-definite hermitian matrix. Must have shape (..., N, N).

• upper (bool) â€“ if True, compute the upper Cholesky decomposition L. if False (default), compute the lower Cholesky decomposition U.

Returns:

array of shape (..., N, N) representing the Cholesky decomposition of the input. If the input is not Hermitian positive-definite, The result will contain NaN entries.

Return type:

Array

Example

A small real Hermitian positive-definite matrix:

>>> x = jnp.array([[2., 1.],
...                [1., 2.]])


Lower Cholesky factorization:

>>> jnp.linalg.cholesky(x)
Array([[1.4142135 , 0.        ],
[0.70710677, 1.2247449 ]], dtype=float32)


Upper Cholesky factorization:

>>> jnp.linalg.cholesky(x, upper=True)
Array([[1.4142135 , 0.70710677],
[0.        , 1.2247449 ]], dtype=float32)


Reconstructing x from its factorization:

>>> L = jnp.linalg.cholesky(x)
>>> jnp.allclose(x, L @ L.T)
Array(True, dtype=bool)