jax.scipy.linalg.cholesky#
- jax.scipy.linalg.cholesky(a, lower=False, overwrite_a=False, check_finite=True)[source]#
Compute the Cholesky decomposition of a matrix.
JAX implementation of
scipy.linalg.cholesky()
.The Cholesky decomposition of a matrix A is:
\[A = U^HU = LL^H\]where U is an upper-triangular matrix and L is a lower-triangular matrix.
- Parameters:
a (ArrayLike) – input array, representing a (batched) positive-definite hermitian matrix. Must have shape
(..., N, N)
.lower (bool) – if True, compute the lower Cholesky decomposition L. if False (default), compute the upper Cholesky decomposition U.
overwrite_a (bool) – unused by JAX
check_finite (bool) – unused by JAX
- Returns:
array of shape
(..., N, N)
representing the cholesky decomposition of the input.- Return type:
See also
jax.numpy.linalg.cholesky()
: NumPy-stype Cholesky APIjax.lax.linalg.cholesky()
: XLA-style Cholesky API
Examples
A small real Hermitian positive-definite matrix:
>>> x = jnp.array([[2., 1.], ... [1., 2.]])
Upper Cholesky factorization:
>>> jax.scipy.linalg.cholesky(x) Array([[1.4142135 , 0.70710677], [0. , 1.2247449 ]], dtype=float32)
Lower Cholesky factorization:
>>> jax.scipy.linalg.cholesky(x, lower=True) Array([[1.4142135 , 0. ], [0.70710677, 1.2247449 ]], dtype=float32)
Reconstructing
x
from its factorization:>>> L = jax.scipy.linalg.cholesky(x, lower=True) >>> jnp.allclose(x, L @ L.T) Array(True, dtype=bool)