cg(A, b, x0=None, *, tol=1e-05, atol=0.0, maxiter=None, M=None)¶
Use Conjugate Gradient iteration to solve
Ax = b.
The numerics of JAX’s
cgshould exact match SciPy’s
cg(up to numerical precision), but note that the interface is slightly different: you need to supply the linear operator
Aas a function instead of a sparse matrix or
cgare implemented via implicit differentiation with another
cgsolve, rather than by differentiating through the solver. They will be accurate only if both solves converge.
A (function) – Function that calculates the matrix-vector product
Axwhen called like
Amust represent a hermitian, positive definite matrix, and must return array(s) with the same structure and shape as its argument.
b (array or tree of arrays) – Right hand side of the linear system representing a single vector. Can be stored as an array or Python container of array(s) with any shape.
x (array or tree of arrays) – The converged solution. Has the same structure as
info (None) – Placeholder for convergence information. In the future, JAX will report the number of iterations when convergence is not achieved, like SciPy.
- Other Parameters
x0 (array) – Starting guess for the solution. Must have the same structure as
tol, atol (float, optional) – Tolerances for convergence,
norm(residual) <= max(tol*norm(b), atol). We do not implement SciPy’s “legacy” behavior, so JAX’s tolerance will differ from SciPy unless you explicitly pass
maxiter (integer) – Maximum number of iterations. Iteration will stop after maxiter steps even if the specified tolerance has not been achieved.
M (function) – Preconditioner for A. The preconditioner should approximate the inverse of A. Effective preconditioning dramatically improves the rate of convergence, which implies that fewer iterations are needed to reach a given error tolerance.