# jax.lax.custom_rootΒΆ

jax.lax.custom_root(f, initial_guess, solve, tangent_solve)[source]ΒΆ

Differentiably solve for a roots of a function.

This is a low-level routine, mostly intended for internal use in JAX. Gradients of custom_root() are defined with respect to closed-over variables from the provided function f via the implicit function theorem: https://en.wikipedia.org/wiki/Implicit_function_theorem

Parameters
• f β function for which to find a root. Should accept a single argument, return a tree of arrays with the same structure as its input.

• initial_guess β initial guess for a zero of f.

• solve β

function to solve for the roots of f. Should take two positional arguments, f and initial_guess, and return a solution with the same structure as initial_guess such that func(solution) = 0. In other words, the following is assumed to be true (but not checked):

solution = solve(f, initial_guess)
error = f(solution)
assert all(error == 0)


• tangent_solve β

function to solve the tangent system. Should take two positional arguments, a linear function g (the function f linearized at its root) and a tree of array(s) y with the same structure as initial_guess, and return a solution x such that g(x)=y:

• For scalar y, use lambda g, y: y / g(1.0).

• For vector y, you could use a linear solve with the Jacobian, if dimensionality of y is not too large: lambda g, y: np.linalg.solve(jacobian(g)(y), y).

Returns

The result of calling solve(f, initial_guess) with gradients defined via implicit differentiation assuming f(solve(f, initial_guess)) == 0.