- jax.lax.custom_linear_solve(matvec, b, solve, transpose_solve=None, symmetric=False, has_aux=False)#
Perform a matrix-free linear solve with implicitly defined gradients.
This function allows for overriding or defining gradients for a linear solve directly via implicit differentiation at the solution, rather than by differentiating through the solve operation. This can sometimes be much faster or more numerically stable, or differentiating through the solve operation may not even be implemented (e.g., if
x = solve(matvec, b) # solve the linear equation assert matvec(x) == b # not checked
matvec – linear function to invert. Must be differentiable.
b – constant right handle side of the equation. May be any nested structure of arrays.
solve – higher level function that solves for solution to the linear equation, i.e.,
solve(matvec, x) == xfor all
xof the same form as
b. This function need not be differentiable.
transpose_solve – higher level function for solving the transpose linear equation, i.e.,
transpose_solve(vecmat, x) == x, where
vecmatis the transpose of the linear map
matvec(computed automatically with autodiff). Required for backwards mode automatic differentiation, unless
symmetric=True, in which case
solveprovides the default value.
symmetric – bool indicating if it is safe to assume the linear map corresponds to a symmetric matrix, i.e.,
matvec == vecmat.
has_aux – bool indicating whether the
transpose_solvefunctions return auxiliary data like solver diagnostics as a second argument.
- Result of
solve(matvec, b), with gradients defined assuming that the
xsatisfies the linear equation
matvec(x) == b.
- Result of