jax.scipy.sparse.linalg.cg(A, b, x0=None, *, tol=1e-05, atol=0.0, maxiter=None, M=None)[source]

Use Conjugate Gradient iteration to solve Ax = b.

The numerics of JAX’s cg should exact match SciPy’s cg (up to numerical precision), but note that the interface is slightly different: you need to supply the linear operator A as a function instead of a sparse matrix or LinearOperator.

Derivatives of cg are implemented via implicit differentiation with another cg solve, rather than by differentiating through the solver. They will be accurate only if both solves converge.

  • A (function) – Function that calculates the matrix-vector product Ax when called like A(x). A must represent a hermitian, positive definite matrix, and must return array(s) with the same structure and shape as its argument.

  • b (array or tree of arrays) – Right hand side of the linear system representing a single vector. Can be stored as an array or Python container of array(s) with any shape.


  • x (array or tree of arrays) – The converged solution. Has the same structure as b.

  • info (None) – Placeholder for convergence information. In the future, JAX will report the number of iterations when convergence is not achieved, like SciPy.

Other Parameters
  • x0 (array) – Starting guess for the solution. Must have the same structure as b.

  • tol, atol (float, optional) – Tolerances for convergence, norm(residual) <= max(tol*norm(b), atol). We do not implement SciPy’s “legacy” behavior, so JAX’s tolerance will differ from SciPy unless you explicitly pass atol to SciPy’s cg.

  • maxiter (integer) – Maximum number of iterations. Iteration will stop after maxiter steps even if the specified tolerance has not been achieved.

  • M (function) – Preconditioner for A. The preconditioner should approximate the inverse of A. Effective preconditioning dramatically improves the rate of convergence, which implies that fewer iterations are needed to reach a given error tolerance.