jax.scipy.sparse.linalg.cgÂ¶

jax.scipy.sparse.linalg.
cg
(A, b, x0=None, *, tol=1e05, atol=0.0, maxiter=None, M=None)[source]Â¶ Use Conjugate Gradient iteration to solve
Ax = b
.The numerics of JAXâ€™s
cg
should exact match SciPyâ€™scg
(up to numerical precision), but note that the interface is slightly different: you need to supply the linear operatorA
as a function instead of a sparse matrix orLinearOperator
.Derivatives of
cg
are implemented via implicit differentiation with anothercg
solve, rather than by differentiating through the solver. They will be accurate only if both solves converge. Parameters
A (function) â€“ Function that calculates the matrixvector product
Ax
when called likeA(x)
.A
must represent a hermitian, positive definite matrix, and must return array(s) with the same structure and shape as its argument.b (array or tree of arrays) â€“ Right hand side of the linear system representing a single vector. Can be stored as an array or Python container of array(s) with any shape.
 Returns
x (array or tree of arrays) â€“ The converged solution. Has the same structure as
b
.info (None) â€“ Placeholder for convergence information. In the future, JAX will report the number of iterations when convergence is not achieved, like SciPy.
 Other Parameters
x0 (array) â€“ Starting guess for the solution. Must have the same structure as
b
.tol, atol (float, optional) â€“ Tolerances for convergence,
norm(residual) <= max(tol*norm(b), atol)
. We do not implement SciPyâ€™s â€ślegacyâ€ť behavior, so JAXâ€™s tolerance will differ from SciPy unless you explicitly passatol
to SciPyâ€™scg
.maxiter (integer) â€“ Maximum number of iterations. Iteration will stop after maxiter steps even if the specified tolerance has not been achieved.
M (function) â€“ Preconditioner for A. The preconditioner should approximate the inverse of A. Effective preconditioning dramatically improves the rate of convergence, which implies that fewer iterations are needed to reach a given error tolerance.