jax.scipy.linalg.cho_solve#
- jax.scipy.linalg.cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True)[source]#
Solve a linear system using a Cholesky factorization
JAX implementation of
scipy.linalg.cho_solve()
. Uses the output ofjax.scipy.linalg.cho_factor()
.- Parameters:
c_and_lower (tuple[ArrayLike, bool]) –
(c, lower)
, wherec
is an array of shape(..., N, N)
representing the lower or upper cholesky decomposition of the matrix, andlower
is a boolean specifying whether this is the lower or upper decomposition.b (ArrayLike) – right-hand-side of linear system. Must have shape
(..., N)
overwrite_a – unused by JAX
check_finite (bool) – unused by JAX
overwrite_b (bool)
- Returns:
Array of shape
(..., N)
representing the solution of the linear system.- Return type:
Examples
A small real Hermitian positive-definite matrix:
>>> x = jnp.array([[2., 1.], ... [1., 2.]])
Compute the cholesky factorization via
cho_factor()
, and use it to solve a linear equation viacho_solve()
.>>> b = jnp.array([3., 4.]) >>> cfac = jax.scipy.linalg.cho_factor(x) >>> y = jax.scipy.linalg.cho_solve(cfac, b) >>> y Array([0.6666666, 1.6666666], dtype=float32)
Check that the result is consistent:
>>> jnp.allclose(x @ y, b) Array(True, dtype=bool)