# jax.lax.custom_root#

jax.lax.custom_root(f, initial_guess, solve, tangent_solve, has_aux=False)[source]#

Differentiably solve for a roots of a function.

This is a low-level routine, mostly intended for internal use in JAX. Gradients of custom_root() are defined with respect to closed-over variables from the provided function `f` via the implicit function theorem: https://en.wikipedia.org/wiki/Implicit_function_theorem

Parameters:
• f â€“ function for which to find a root. Should accept a single argument, return a tree of arrays with the same structure as its input.

• initial_guess â€“ initial guess for a zero of f.

• solve â€“

function to solve for the roots of f. Should take two positional arguments, f and initial_guess, and return a solution with the same structure as initial_guess such that func(solution) = 0. In other words, the following is assumed to be true (but not checked):

```solution = solve(f, initial_guess)
error = f(solution)
assert all(error == 0)
```

• tangent_solve â€“

function to solve the tangent system. Should take two positional arguments, a linear function `g` (the function `f` linearized at its root) and a tree of array(s) `y` with the same structure as initial_guess, and return a solution `x` such that `g(x)=y`:

• For scalar `y`, use `lambda g, y: y / g(1.0)`.

• For vector `y`, you could use a linear solve with the Jacobian, if dimensionality of `y` is not too large: `lambda g, y: np.linalg.solve(jacobian(g)(y), y)`.

• has_aux â€“ bool indicating whether the `solve` function returns auxiliary data like solver diagnostics as a second argument.

Returns:

The result of calling solve(f, initial_guess) with gradients defined via implicit differentiation assuming `f(solve(f, initial_guess)) == 0`.