jax.scipy.sparse.linalg.bicgstab

Contents

jax.scipy.sparse.linalg.bicgstab#

jax.scipy.sparse.linalg.bicgstab(A, b, x0=None, *, tol=1e-05, atol=0.0, maxiter=None, M=None)[source]#

Use Bi-Conjugate Gradient Stable iteration to solve Ax = b.

The numerics of JAX’s bicgstab should exact match SciPy’s bicgstab (up to numerical precision), but note that the interface is slightly different: you need to supply the linear operator A as a function instead of a sparse matrix or LinearOperator.

As with cg, derivatives of bicgstab are implemented via implicit differentiation with another bicgstab solve, rather than by differentiating through the solver. They will be accurate only if both solves converge.

Parameters:
  • A (ndarray, function, or matmul-compatible object) – 2D array or function that calculates the linear map (matrix-vector product) Ax when called like A(x) or A @ x. A can represent any general (nonsymmetric) linear operator, and function must return array(s) with the same structure and shape as its argument.

  • b (array or tree of arrays) – Right hand side of the linear system representing a single vector. Can be stored as an array or Python container of array(s) with any shape.

  • x0 (array or tree of arrays) – Starting guess for the solution. Must have the same structure as b.

  • tol (float, optional) – Tolerances for convergence, norm(residual) <= max(tol*norm(b), atol). We do not implement SciPy’s “legacy” behavior, so JAX’s tolerance will differ from SciPy unless you explicitly pass atol to SciPy’s cg.

  • atol (float, optional) – Tolerances for convergence, norm(residual) <= max(tol*norm(b), atol). We do not implement SciPy’s “legacy” behavior, so JAX’s tolerance will differ from SciPy unless you explicitly pass atol to SciPy’s cg.

  • maxiter (integer) – Maximum number of iterations. Iteration will stop after maxiter steps even if the specified tolerance has not been achieved.

  • M (ndarray, function, or matmul-compatible object) – Preconditioner for A. The preconditioner should approximate the inverse of A. Effective preconditioning dramatically improves the rate of convergence, which implies that fewer iterations are needed to reach a given error tolerance.

Returns:

  • x (array or tree of arrays) – The converged solution. Has the same structure as b.

  • info (None) – Placeholder for convergence information. In the future, JAX will report the number of iterations when convergence is not achieved, like SciPy.