jax.lax.custom_root

Contents

jax.lax.custom_root#

jax.lax.custom_root(f, initial_guess, solve, tangent_solve, has_aux=False)[source]#

Differentiably solve for the roots of a function.

This is a low-level routine, mostly intended for internal use in JAX. Gradients of custom_root() are defined with respect to closed-over variables from the provided function f via the implicit function theorem: https://en.wikipedia.org/wiki/Implicit_function_theorem

Parameters:
  • f – function for which to find a root. Should accept a single argument, return a tree of arrays with the same structure as its input.

  • initial_guess – initial guess for a zero of f.

  • solve –

    function to solve for the roots of f. Should take two positional arguments, f and initial_guess, and return a solution with the same structure as initial_guess such that func(solution) = 0. In other words, the following is assumed to be true (but not checked):

    solution = solve(f, initial_guess)
    error = f(solution)
    assert all(error == 0)
    

  • tangent_solve –

    function to solve the tangent system. Should take two positional arguments, a linear function g (the function f linearized at its root) and a tree of array(s) y with the same structure as initial_guess, and return a solution x such that g(x)=y:

    • For scalar y, use lambda g, y: y / g(1.0).

    • For vector y, you could use a linear solve with the Jacobian, if dimensionality of y is not too large: lambda g, y: np.linalg.solve(jacobian(g)(y), y).

  • has_aux – bool indicating whether the solve function returns auxiliary data like solver diagnostics as a second argument.

Returns:

The result of calling solve(f, initial_guess) with gradients defined via implicit differentiation assuming f(solve(f, initial_guess)) == 0.