custom_root(f, initial_guess, solve, tangent_solve)¶
Differentiably solve for a roots of a function.
This is a low-level routine, mostly intended for internal use in JAX. Gradients of custom_root() are defined with respect to closed-over variables from the provided function
fvia the implicit function theorem: https://en.wikipedia.org/wiki/Implicit_function_theorem
f – function for which to find a root. Should accept a single argument, return a tree of arrays with the same structure as its input.
initial_guess – initial guess for a zero of f.
function to solve for the roots of f. Should take two positional arguments, f and initial_guess, and return a solution with the same structure as initial_guess such that func(solution) = 0. In other words, the following is assumed to be true (but not checked):
solution = solve(f, initial_guess) error = f(solution) assert all(error == 0)
function to solve the tangent system. Should take two positional arguments, a linear function
flinearized at its root) and a tree of array(s)
ywith the same structure as initial_guess, and return a solution
lambda g, y: y / g(1.0).
y, you could use a linear solve with the Jacobian, if dimensionality of
yis not too large:
lambda g, y: np.linalg.solve(jacobian(g)(y), y).
The result of calling solve(f, initial_guess) with gradients defined via implicit differentiation assuming
f(solve(f, initial_guess)) == 0.