jax.lax.custom_linear_solveΒΆ

jax.lax.
custom_linear_solve
(matvec, b, solve, transpose_solve=None, symmetric=False)[source]ΒΆ Perform a matrixfree linear solve with implicitly defined gradients.
This function allows for overriding or defining gradients for a linear solve directly via implicit differentiation at the solution, rather than by differentiating through the solve operation. This can sometimes be much faster or more numerically stable, or differentiating through the solve operation may not even be implemented (e.g., if
solve
useslax.while_loop
).Required invariant:
x = solve(matvec, b) # solve the linear equation assert matvec(x) == b # not checked
 Parameters
matvec β linear function to invert. Must be differentiable.
b β constant right handle side of the equation. May be any nested structure of arrays.
solve β higher level function that solves for solution to the linear equation, i.e.,
solve(matvec, x)) == x
for allx
of the same form asb
. This function need not be differentiable.transpose_solve β higher level function for solving the transpose linear equation, i.e.,
transpose_solve(vecmat, x) == x
, wherevecmat
is the transpose of the linear mapmatvec
(computed automatically with autodiff). Required for backwards mode automatic differentiation, unlesssymmetric=True
, in which casesolve
provides the default value.symmetric β bool indicating if it is safe to assume the linear map corresponds to a symmetric matrix, i.e.,
matvec == vecmat
.
 Returns
Result of
solve(matvec, b)
, with gradients defined assuming that the solutionx
satisfies the linear equationmatvec(x) == b
.