# jax.lax.custom_linear_solve#

jax.lax.custom_linear_solve(matvec, b, solve, transpose_solve=None, symmetric=False, has_aux=False)[source]#

Perform a matrix-free linear solve with implicitly defined gradients.

This function allows for overriding or defining gradients for a linear solve directly via implicit differentiation at the solution, rather than by differentiating through the solve operation. This can sometimes be much faster or more numerically stable, or differentiating through the solve operation may not even be implemented (e.g., if `solve` uses `lax.while_loop`).

Required invariant:

```x = solve(matvec, b)  # solve the linear equation
assert matvec(x) == b  # not checked
```
Parameters:
• matvec â€“ linear function to invert. Must be differentiable.

• b â€“ constant right handle side of the equation. May be any nested structure of arrays.

• solve â€“ higher level function that solves for solution to the linear equation, i.e., `solve(matvec, x) == x` for all `x` of the same form as `b`. This function need not be differentiable.

• transpose_solve â€“ higher level function for solving the transpose linear equation, i.e., `transpose_solve(vecmat, x) == x`, where `vecmat` is the transpose of the linear map `matvec` (computed automatically with autodiff). Required for backwards mode automatic differentiation, unless `symmetric=True`, in which case `solve` provides the default value.

• symmetric â€“ bool indicating if it is safe to assume the linear map corresponds to a symmetric matrix, i.e., `matvec == vecmat`.

• has_aux â€“ bool indicating whether the `solve` and `transpose_solve` functions return auxiliary data like solver diagnostics as a second argument.

Returns:

Result of `solve(matvec, b)`, with gradients defined assuming that the

solution `x` satisfies the linear equation `matvec(x) == b`.