jax.closure_convert#
- jax.closure_convert(fun, *example_args)[source]#
Closure conversion utility, for use with higher-order custom derivatives.
To define custom derivatives such as with
jax.custom_vjp(f)
, the target functionf
must take, as formal arguments, all values involved in differentiation. Iff
is a higher-order function, in that it accepts as an argument a Python functiong
, then values stored away ing
’s closure will not be visible to the custom derivative rules, and attempts at AD involving these values will fail. One way around this is to convert the closure by extracting these values, and to pass them as explicit formal arguments across the custom derivative boundary. This utility carries out that conversion. More precisely, it closure-converts the functionfun
specialized to the types of the arguments given inexample_args
.When we refer here to “values in the closure” of
fun
, we do not mean the values that are captured by Python directly whenfun
is defined (e.g. the Python objects infun.__closure__
, if the attribute exists). Rather, we mean values encountered during the execution offun
onexample_args
that determine its output. This may include, for instance, arrays captured transitively in Python closures, i.e. in the Python closure of functions called byfun
, the closures of the functions that they call, and so forth.The function
fun
must be a pure function.Example usage:
def minimize(objective_fn, x0): converted_fn, aux_args = closure_convert(objective_fn, x0) return _minimize(converted_fn, x0, *aux_args) @partial(custom_vjp, nondiff_argnums=(0,)) def _minimize(objective_fn, x0, *args): z = objective_fn(x0, *args) # ... find minimizer x_opt ... return x_opt def fwd(objective_fn, x0, *args): y = _minimize(objective_fn, x0, *args) return y, (y, args) def rev(objective_fn, res, g): y, args = res y_bar = g # ... custom reverse-mode AD ... return x0_bar, *args_bars _minimize.defvjp(fwd, rev)
- Parameters:
fun (Callable) – Python callable to be converted. Must be a pure function.
example_args – Arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e., pytrees) thereof, used to determine the types of the formal arguments to
fun
. This type-specialized form offun
is the function that will be closure converted.
- Returns:
A pair comprising (i) a Python callable, accepting the same arguments as
fun
followed by arguments corresponding to the values hoisted from its closure, and (ii) a list of values hoisted from the closure.- Return type: