jax.scipy.linalg.cho_solve

Contents

jax.scipy.linalg.cho_solve#

jax.scipy.linalg.cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True)[source]#

Solve a linear system using a Cholesky factorization

JAX implementation of scipy.linalg.cho_solve(). Uses the output of jax.scipy.linalg.cho_factor().

Parameters:
  • c_and_lower (tuple[jax.typing.ArrayLike, bool]) – (c, lower), where c is an array of shape (..., N, N) representing the lower or upper cholesky decomposition of the matrix, and lower is a boolean specifying whethe this is the lower or upper decomposition.

  • b (jax.typing.ArrayLike) – right-hand-side of linear system. Must have shape (..., N)

  • overwrite_a – unused by JAX

  • check_finite (bool) – unused by JAX

  • overwrite_b (bool)

Returns:

Array of shape (..., N) representing the solution of the linear system.

Return type:

Array

Example

A small real Hermitian positive-definite matrix:

>>> x = jnp.array([[2., 1.],
...                [1., 2.]])

Compute the cholesky factorization via cho_factor(), and use it to solve a linear equation via cho_solve().

>>> b = jnp.array([3., 4.])
>>> cfac = jax.scipy.linalg.cho_factor(x)
>>> y = jax.scipy.linalg.cho_solve(cfac, b)
>>> y
Array([0.6666666, 1.6666666], dtype=float32)

Check that the result is consistent:

>>> jnp.allclose(x @ y, b)
Array(True, dtype=bool)