jax.scipy.linalg.cho_factor#

jax.scipy.linalg.cho_factor(a, lower=False, overwrite_a=False, check_finite=True)[source]#

Factorization for Cholesky-based linear solves

JAX implementation of scipy.linalg.cho_factor(). This function returns a result suitable for use with jax.scipy.linalg.cho_solve(). For direct Cholesky decompositions, prefer jax.scipy.linalg.cholesky().

Parameters:
  • a (ArrayLike) – input array, representing a (batched) positive-definite hermitian matrix. Must have shape (..., N, N).

  • lower (bool) – if True, compute the lower triangular Cholesky decomposition (default: False).

  • overwrite_a (bool) – unused by JAX

  • check_finite (bool) – unused by JAX

Returns:

c is an array of shape (..., N, N) representing the lower or upper cholesky decomposition of the input; lower is a boolean specifying whether this is the lower or upper decomposition.

Return type:

(c, lower)

Examples

A small real Hermitian positive-definite matrix:

>>> x = jnp.array([[2., 1.],
...                [1., 2.]])

Compute the cholesky factorization via cho_factor(), and use it to solve a linear equation via cho_solve().

>>> b = jnp.array([3., 4.])
>>> cfac = jax.scipy.linalg.cho_factor(x)
>>> y = jax.scipy.linalg.cho_solve(cfac, b)
>>> y
Array([0.6666666, 1.6666666], dtype=float32)

Check that the result is consistent:

>>> jnp.allclose(x @ y, b)
Array(True, dtype=bool)