Source code for jax._src.lax.control_flow.solves

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for the custom linear solve and utilities."""
import collections
from functools import partial
import operator

import jax
from jax.tree_util import (tree_flatten, treedef_children, tree_leaves,
                           tree_unflatten, treedef_tuple)
from jax._src import ad_util
from jax._src import core
from jax._src import linear_util as lu
from jax._src.core import raise_to_shaped
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.traceback_util import api_boundary
from jax._src.util import split_list, safe_map
import numpy as np

from jax._src.lax.control_flow.common import (
    _abstractify,
    _check_tree,
    _initial_style_jaxpr,
    )

_map = safe_map

_RootTuple = collections.namedtuple('_RootTuple', 'f, solve, l_and_s')


def _split_root_args(args, const_lengths):
  params_list = split_list(args, list(const_lengths))
  return _RootTuple(*params_list[:-1]), params_list[-1]


[docs] @api_boundary def custom_root(f, initial_guess, solve, tangent_solve, has_aux=False): """Differentiably solve for the roots of a function. This is a low-level routine, mostly intended for internal use in JAX. Gradients of custom_root() are defined with respect to closed-over variables from the provided function ``f`` via the implicit function theorem: https://en.wikipedia.org/wiki/Implicit_function_theorem Args: f: function for which to find a root. Should accept a single argument, return a tree of arrays with the same structure as its input. initial_guess: initial guess for a zero of f. solve: function to solve for the roots of f. Should take two positional arguments, f and initial_guess, and return a solution with the same structure as initial_guess such that func(solution) = 0. In other words, the following is assumed to be true (but not checked):: solution = solve(f, initial_guess) error = f(solution) assert all(error == 0) tangent_solve: function to solve the tangent system. Should take two positional arguments, a linear function ``g`` (the function ``f`` linearized at its root) and a tree of array(s) ``y`` with the same structure as initial_guess, and return a solution ``x`` such that ``g(x)=y``: - For scalar ``y``, use ``lambda g, y: y / g(1.0)``. - For vector ``y``, you could use a linear solve with the Jacobian, if dimensionality of ``y`` is not too large: ``lambda g, y: np.linalg.solve(jacobian(g)(y), y)``. has_aux: bool indicating whether the ``solve`` function returns auxiliary data like solver diagnostics as a second argument. Returns: The result of calling solve(f, initial_guess) with gradients defined via implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``. """ guess_flat, in_args_tree = tree_flatten((initial_guess,)) guess_avals = tuple(_map(_abstractify, guess_flat)) f_jaxpr, f_consts, out_tree = _initial_style_jaxpr( f, in_args_tree, guess_avals) in_tree, = treedef_children(in_args_tree) _check_tree("f", "initial_guess", out_tree, in_tree, False) solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr( partial(solve, f), in_args_tree, guess_avals) _check_tree("solve", "initial_guess", solution_tree, in_tree, has_aux) def linearize_and_solve(x, b): unchecked_zeros, f_jvp = jax.linearize(f, x) return tangent_solve(f_jvp, b) l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr( linearize_and_solve, treedef_tuple((in_tree,) * 2), guess_avals * 2) _check_tree("tangent_solve", "x", out_tree, in_tree, False) all_consts = [f_consts, solve_consts, l_and_s_consts] const_lengths = _RootTuple(*_map(len, all_consts)) jaxprs = _RootTuple(f_jaxpr, solve_jaxpr, l_and_s_jaxpr) solution_flat = _custom_root( const_lengths, jaxprs, *(_flatten(all_consts) + guess_flat)) return tree_unflatten(solution_tree, solution_flat)
@partial(jax.custom_jvp, nondiff_argnums=(0, 1)) def _custom_root(const_lengths, jaxprs, *args): params, initial_guess = _split_root_args(args, const_lengths) solution = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + initial_guess)) return solution @_custom_root.defjvp def _root_jvp(const_lengths, jaxprs, primals, tangents): params, _ = _split_root_args(primals, const_lengths) sol = _custom_root(const_lengths, jaxprs, *primals) f_out_vals = len(jaxprs.f.out_avals) solution, aux = split_list(sol, [f_out_vals]) params_dot, _ = _split_root_args(tangents, const_lengths) # F(m, u) = 0 # system of equations in u, parameterized by m # # solution is u*(m) defined in a neighborhood # F(m, u*(m)) = 0 # satisfied in a neighborhood # # βˆ‚_0 F(m, u*(m)) + βˆ‚_1 F(m, u*(m)) βˆ‚ u*(m) = 0 # implied by line above # βˆ‚ u*(m) = - (βˆ‚_1 F(m, u*(m)))^{-1} βˆ‚_0 F(m, u*(m)) # rearrange # # βˆ‚ u*(m)[v] = - (βˆ‚_1 F(m, u*(m)))^{-1} [βˆ‚_0 F(m, u*(m))[v]] # jvp f = core.jaxpr_as_fun(jaxprs.f) linearize_and_solve = partial( core.jaxpr_as_fun(jaxprs.l_and_s), *params.l_and_s) f_at_solution = lambda *params: f(*params, *solution) _, rhs = ad.jvp(lu.wrap_init(f_at_solution)).call_wrapped( params.f, params_dot.f) solution_dot = _map( operator.neg, linearize_and_solve(*solution, *rhs)) # append aux, create symbolic zero tangents for the aux values solution += aux solution_dot += _map(ad_util.zeros_like_jaxval, aux) return solution, solution_dot class _LinearSolveTuple(collections.namedtuple( '_LinearSolveTuple', 'matvec, vecmat, solve, transpose_solve')): def transpose(self): return type(self)(self.vecmat, self.matvec, self.transpose_solve, self.solve) def _split_linear_solve_args(args, const_lengths): params_list = split_list(args, list(const_lengths)) return _LinearSolveTuple(*params_list[:-1]), params_list[-1] def _transpose_one_output(linear_fun, primals): transpose_fun = jax.linear_transpose(linear_fun, primals) def transposed_fun(x): (y,) = transpose_fun(x) return y return transposed_fun def _flatten(args): return [x for arg in args for x in arg] def _check_shapes(func_name, expected_name, actual, expected): actual_shapes = _map(np.shape, tree_leaves(actual)) expected_shapes = _map(np.shape, tree_leaves(expected)) if actual_shapes != expected_shapes: raise ValueError( f"{func_name}() output shapes must match {expected_name}, " f"got {actual_shapes} and {expected_shapes}")
[docs] @api_boundary def custom_linear_solve( matvec, b, solve, transpose_solve=None, symmetric=False, has_aux=False): """Perform a matrix-free linear solve with implicitly defined gradients. This function allows for overriding or defining gradients for a linear solve directly via implicit differentiation at the solution, rather than by differentiating *through* the solve operation. This can sometimes be much faster or more numerically stable, or differentiating through the solve operation may not even be implemented (e.g., if ``solve`` uses ``lax.while_loop``). Required invariant:: x = solve(matvec, b) # solve the linear equation assert matvec(x) == b # not checked Args: matvec: linear function to invert. Must be differentiable. b: constant right handle side of the equation. May be any nested structure of arrays. solve: higher level function that solves for solution to the linear equation, i.e., ``solve(matvec, x) == x`` for all ``x`` of the same form as ``b``. This function need not be differentiable. transpose_solve: higher level function for solving the transpose linear equation, i.e., ``transpose_solve(vecmat, x) == x``, where ``vecmat`` is the transpose of the linear map ``matvec`` (computed automatically with autodiff). Required for backwards mode automatic differentiation, unless ``symmetric=True``, in which case ``solve`` provides the default value. symmetric: bool indicating if it is safe to assume the linear map corresponds to a symmetric matrix, i.e., ``matvec == vecmat``. has_aux: bool indicating whether the ``solve`` and ``transpose_solve`` functions return auxiliary data like solver diagnostics as a second argument. Returns: Result of ``solve(matvec, b)``, with gradients defined assuming that the solution ``x`` satisfies the linear equation ``matvec(x) == b``. """ if transpose_solve is None and symmetric: transpose_solve = solve b_flat, in_args_tree = tree_flatten((b,)) b_avals = tuple(_map(_abstractify, b_flat)) tree, = treedef_children(in_args_tree) def _shape_checked(fun, name, has_aux): def f(x): y = fun(x) _check_shapes(name, "b", y, b_flat) return y def f_aux(x): y, aux = fun(x) _check_shapes(name, "b", y, b_flat) return y, aux return f_aux if has_aux else f # no auxiliary data assumed for matvec matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr( _shape_checked(matvec, "matvec", False), in_args_tree, b_avals, 'custom_linear_solve') _check_tree("matvec", "b", out_tree, tree, False) solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr( _shape_checked(partial(solve, matvec), "solve", has_aux), in_args_tree, b_avals, 'custom_linear_solve') _check_tree("solve", "b", out_tree, tree, has_aux) if transpose_solve is None: vecmat_jaxpr = tr_solve_jaxpr = None vecmat_consts = tr_solve_consts = [] else: if symmetric: vecmat = matvec vecmat_jaxpr = matvec_jaxpr vecmat_consts = matvec_consts else: vecmat = _transpose_one_output(matvec, b) vecmat_jaxpr, vecmat_consts, out_tree = _initial_style_jaxpr( vecmat, in_args_tree, b_avals, 'custom_linear_solve') assert out_tree == tree tr_solve_jaxpr, tr_solve_consts, out_tree = _initial_style_jaxpr( _shape_checked(partial(transpose_solve, vecmat), "transpose_solve", has_aux), in_args_tree, b_avals, 'custom_linear_solve') _check_tree("transpose_solve", "b", out_tree, tree, has_aux) all_consts = [matvec_consts, vecmat_consts, solve_consts, tr_solve_consts] const_lengths = _LinearSolveTuple(*_map(len, all_consts)) jaxprs = _LinearSolveTuple( matvec_jaxpr, vecmat_jaxpr, solve_jaxpr, tr_solve_jaxpr) out_flat = linear_solve_p.bind( *(_flatten(all_consts) + b_flat), const_lengths=const_lengths, jaxprs=jaxprs) return tree_unflatten(out_tree, out_flat)
def _linear_solve_abstract_eval(*args, const_lengths, jaxprs): args_to_raise = args[sum(const_lengths):] # raise aux_args to shaped arrays as well if present # number of aux args is the difference in out_avals # of solve and matvec (since they map to the same vector space) num_aux = len(jaxprs.solve.out_avals) - len(jaxprs.matvec.out_avals) if num_aux > 0: args_to_raise += tuple(jaxprs.solve.out_avals[-num_aux:]) return _map(raise_to_shaped, args_to_raise) def _custom_linear_solve_impl(*args, const_lengths, jaxprs): params, b = _split_linear_solve_args(args, const_lengths) x = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + b)) return x def _tangent_linear_map(func, params, params_dot, *x): """Compute the tangent of a linear map. Assuming ``func(*params, *x)`` is linear in ``x`` and computes ``A @ x``, this function computes ``βˆ‚A @ x``. """ assert any(type(p) is not ad_util.Zero for p in params_dot) zeros = _map(ad_util.Zero.from_value, x) _, out_tangent = ad.jvp(lu.wrap_init(func)).call_wrapped( params + list(x), params_dot + zeros) return out_tangent def _custom_linear_solve_jvp(primals, tangents, const_lengths, jaxprs): # A x - b = 0 # βˆ‚A x + A βˆ‚x - βˆ‚b = 0 # βˆ‚x = A^{-1} (βˆ‚b - βˆ‚A x) kwargs = dict(const_lengths=const_lengths, jaxprs=jaxprs) x = linear_solve_p.bind(*primals, **kwargs) params, _ = _split_linear_solve_args(primals, const_lengths) params_dot, b_dot = _split_linear_solve_args(tangents, const_lengths) num_x_leaves = len(b_dot) # x is a flat tree with possible aux values appended # since x_tree == b_tree == b_dot_tree, we can cut off # aux values with len info provided by b_dot tree here x_leaves, _ = split_list(x, [num_x_leaves]) if all(type(p) is ad_util.Zero for p in params_dot.matvec): # no need to evaluate matvec_tangents rhs = b_dot else: matvec_tangents = _tangent_linear_map( core.jaxpr_as_fun(jaxprs.matvec), params.matvec, params_dot.matvec, *x_leaves) rhs = _map(ad.add_tangents, b_dot, _map(operator.neg, matvec_tangents)) x_dot = linear_solve_p.bind(*(_flatten(params) + rhs), **kwargs) # split into x tangents and aux tangents (these become zero) dx_leaves, daux_leaves = split_list(x_dot, [num_x_leaves]) daux_leaves = _map(ad_util.Zero.from_value, daux_leaves) x_dot = dx_leaves + daux_leaves return x, x_dot def _linear_solve_transpose_rule(cotangent, *primals, const_lengths, jaxprs): if jaxprs.transpose_solve is None: raise TypeError('transpose_solve required for backwards mode automatic ' 'differentiation of custom_linear_solve') params, b = _split_linear_solve_args(primals, const_lengths) # split off symbolic zeros in the cotangent if present x_cotangent, _ = split_list(cotangent, [len(b)]) assert all(ad.is_undefined_primal(x) for x in b) cotangent_b_full = linear_solve_p.bind( *(_flatten(params.transpose()) + x_cotangent), const_lengths=const_lengths.transpose(), jaxprs=jaxprs.transpose()) # drop aux values in cotangent computation cotangent_b, _ = split_list(cotangent_b_full, [len(b)]) return [None] * sum(const_lengths) + cotangent_b def _linear_solve_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, dims, const_lengths, jaxprs): orig_bat = [d is not batching.not_mapped for d in dims] params, b = _split_linear_solve_args(args, const_lengths) params_dims, b_dims = _split_linear_solve_args(dims, const_lengths) params_bat, orig_b_bat = _split_linear_solve_args(orig_bat, const_lengths) (matvec, vecmat, solve, solve_t) = jaxprs (matvec_bat, vecmat_bat, solve_bat, solve_t_bat) = params_bat # number of operator out avals is assumed to be the same for matvec/vecmat num_operator_out_avals = len(matvec.out_avals) num_aux = len(solve.out_avals) - num_operator_out_avals # Fixpoint computation of which parts of x and b are batched; we need to # ensure this is consistent between all four jaxprs b_bat = orig_b_bat x_bat = [False] * len(solve.out_avals) for i in range(1 + len(orig_b_bat) + len(solve.out_avals)): # Apply vecmat and solve -> new batched parts of x solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr( solve, axis_size, solve_bat + b_bat, instantiate=x_bat, axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) if vecmat is None: vecmat_jaxpr_batched = None x_bat_out = solve_x_bat else: vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr( vecmat, axis_size, vecmat_bat + b_bat, instantiate=b_bat, axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) # batch all aux data by default x_bat_out = _map(operator.or_, vecmat_x_bat + [True] * num_aux, solve_x_bat) # keep a slice of only the linear operator part of solve's avals x_bat_noaux = x_bat_out[:num_operator_out_avals] # Apply matvec and solve_t -> new batched parts of b matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr( matvec, axis_size, matvec_bat + x_bat_noaux, instantiate=b_bat, axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) if solve_t is None: solve_t_jaxpr_batched = None b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat) else: solve_t_jaxpr_batched, solve_t_b_aux_bat = batching.batch_jaxpr( solve_t, axis_size, solve_t_bat + x_bat_noaux, instantiate=x_bat_out, axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) assert len(solve_t_b_aux_bat) == len(orig_b_bat) + num_aux solve_t_b_bat, _ = split_list(solve_t_b_aux_bat, [len(orig_b_bat)]) b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat, solve_t_b_bat, orig_b_bat) if x_bat_out == x_bat and b_bat_out == b_bat: break else: x_bat = x_bat_out b_bat = b_bat_out else: assert False, "Fixedpoint not reached" batched_jaxprs = _LinearSolveTuple(matvec_jaxpr_batched, vecmat_jaxpr_batched, solve_jaxpr_batched, solve_t_jaxpr_batched) # Move batched axes to the front new_params = [ batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 else x for x, d in zip(_flatten(params), _flatten(params_dims)) ] # Broadcast out b if necessary new_b = [ batching.broadcast(x, axis_size, 0) if now_bat and not was_bat else batching.moveaxis(x, d, 0) if now_bat and d != 0 else x for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat) ] outs = linear_solve_p.bind( *(new_params + new_b), const_lengths=const_lengths, jaxprs=batched_jaxprs) out_dims = [0 if batched else batching.not_mapped for batched in solve_x_bat] return outs, out_dims linear_solve_p = core.AxisPrimitive('custom_linear_solve') linear_solve_p.multiple_results = True linear_solve_p.def_impl(_custom_linear_solve_impl) linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval) ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp xla.register_initial_style_primitive(linear_solve_p) mlir.register_lowering( linear_solve_p, mlir.lower_fun(_custom_linear_solve_impl, multiple_results=True)) ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule batching.axis_primitive_batchers[linear_solve_p] = partial(_linear_solve_batching_rule, None) batching.spmd_axis_primitive_batchers[linear_solve_p] = _linear_solve_batching_rule