Custom JVP/VJP rules for JAX-transformable functions#

This is a design document, explaining some of the thinking behind the design and implementation of jax.custom_jvp and jax.custom_vjp. For user-oriented documentation, see the tutorial notebook.

There are two ways to define differentiation rules in JAX:

  1. using jax.custom_jvp and jax.custom_vjp to define custom differentiation rules for Python functions that are already JAX-transformable; and

  2. defining new core.Primitive instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems.

This document is about #1 only.

Goals#

We want users to customize the forward- and/or reverse-mode differentiation behavior of their code. This customization

  1. should have a clear and consistent semantics in how it works and how it composes with other JAX transformations; and

  2. should be flexible in supporting use cases and workflows like in Autograd and PyTorch, including cases involving differentiation of Python control flow and workflows for NaN debugging.

As JAX developers we want to write library functions, like logit and expit, that are defined in terms of other primitives, but for the purposes of differentiation have primitive-like behavior in the sense that we want to define custom differentiation rules for them, which may be more numerically stable or performant. In particular, we don’t want to have to specify vmap or jit rules for functions like logit and expit.

As a stretch goal, we’d like to make JAX a great environment for power users looking to add custom differentiation rules for higher-order functions like fixed_point, odeint, etc.; this design doc won’t solve that problem, but we want to be confident we’re not going to preclude good solutions to that problem.

That is, our primary goals are

  1. solve the vmap-removes-custom-jvp semantics problem (#1249), and

  2. allow Python in custom VJPs, e.g. to debug NaNs (#1275).

Secondary goals are 3. clean up and simplify user experience (symbolic zeros, kwargs, etc) 4. make progress towards a world where users can easily add fixed_point, odeint, root, etc.

Overall, we want to close #116, #1097, #1249, #1275, #1366, #1723, #1670, #1875, #1938, and replace the custom_transforms machinery (from #636, #818, and others).

Non-goals#

Here are objectives we’re not aiming to achieve:

  1. The custom_transforms machinery aimed to provide a transformation-generic mechanism for customizing behavior, in principle (though never really used in practice) allowing users to customize rules for any transformation while somehow inheriting the “transparent” behavior for others. We are instead only going to solve the customization problem for differentiation (JVP and VJP, separately). Differentiation is the only case actually requested, and by specializing to differentiation we can reduce complexity and improve flexibility. To control all rules one can just write a primitive.

  2. We’re not going to prioritize mathematical aesthetics over flexibility and clarity on the user side, and simplicity on the implementation side. In particular, while the custom VJP signature a -> (b, CT b --o CT a) is mathematically pleasing, if it’s hard to implement in a Python mechanism because of the closure in the return type, we’re fine doing something that handles residuals more explicitly.

  3. Serialization support, of the form where the staged-out serialized program representation can be loaded and further JAX-transformed as opposed to just evaluated, is currently out of scope for these custom JVP/VJP transformation rules. Serialization may be useful not only for researchers who want to save some representation of their computation (and transform it after loading it), but also for future considerations like having jaxpr transformations implemented outside Python, or having jaxprs as an MLIR dialect. By defining this as a non-goal for the purpose of this design, we have fewer constraints on where we can stash Python callables.

Main problem descriptions#

The vmap-removes-custom-jvp semantics problem#

The vmap-removes-custom-jvp semantics problem is that vmap does not compose properly with differentiation of functions with custom_transforms rules:

# old custom_transforms api to be replaced
@jax.custom_transforms
def f(x):
  return 2. * x

# f_vjp :: a -> (b, CT b --o CT a)
def f_vjp(x):
  return f(x), lambda g: 3. * x  # 3 instead of 2

jax.defvjp_all(f, f_vjp)

grad(f)(1.)  # 3.
vmap(grad(f))(np.ones(4))  # [3., 3., 3., 3.]
grad(lambda x: vmap(f)(x).sum())(np.ones(4))  # [2., 2., 2., 2.]

The last grad-of-vmap line has an unexpected result! In general, applying vmap, or really any non-differentiation transformation, has the effect of removing the custom differentiation rule. (Applying jvp causes a failure when a custom VJP rule is defined.)

The problem exists because transformations are like rewrites, and the vmap transformation effectively rewrites the function to no longer call the newly-introduced primitive for which there is a custom rule (and hence grad then doesn’t produce the custom rule’s result). In more detail, the custom_transforms machinery sets things up so that evaluating f(x) applies the function

{ lambda  ; ; a.
  let b = f_primitive a
  in [b] }

where f_primitive is a new primitive (introduced for every custom_transforms function and in fact for every call of the function) to which the custom VJP rule is associated. When we evaluate grad(f)(x), the differentiation machinery encounters f_primitive and processes it with the custom rule.

However, because f_primitive is transparent to vmap, in the sense that vmap operates on (effectively by inlining) the definition of f_primitive, the function vmap(f) is effectively

{ lambda  ; ; a.
  let b = mul 2. a
  in [b] }

In words, vmap rewrites the function in terms of its underlying primitives and their transformation rules, removing f_primitive entirely.

More generally, because vmap(f) has semantics defined in terms of calls to f, it is semantically inconsistent to remove the custom derivative rule. That is, since we define

vmap(f)(xs) == np.stack([f(x) for x in xs])

we must have

jvp(vmap(f))(xs) == jvp(lambda xs: np.stack([f(x) for x in xs]))

yet this property is not observed when f has a custom derivative rule defined, as the custom derivative rule is used in the right-hand version but not the left-hand one.

This issue isn’t specific to vmap; it applies to all transformations for which the semantics of transforming a function f are defined in terms of calls to the function f, rather than rewriting it into another function. The mask transformation also falls into this class. Differentiation transforms and the hypothetical all-unary-functions-become-cosine transform are not in this class.

(The interaction between additional custom rules, like custom vmap rules, is likely to get even more complex, suggesting the problem framing of custom_transforms is too broad.)

The Python flexibility problem#

In JAX, as in Autograd and PyTorch but not TF1, differentiation of a Python function is performed while the function is being executed and traced. This behavior delights users for a few reasons.

First and most importantly, it enables pdb-based workflows, e.g. for inspecting numerics or catching NaNs. That is, users can employ the standard Python debugger and other Python-native tools to debug their code, even being able to inspect runtime values to understand numerical behavior on examples and to catch fundamentally runtime errors like NaNs. In fact, just while working on the PR corresponding to this design, especially on the odeint primitive, I used runtime value inspection to debug issues many times, increasing my confidence that this is a key user workflow in Python. One especially handy trick, which I’ve used in both JAX and Autograd many times, is the ability to insert a debugger breakpoint in a custom VJP rule to enter a debugger at a specific point in the backward pass.

Second, it allows differentiation of Python native control flow. We’re not sure how often this is used in practice in finalized software artifacts, but when users first poke around JAX or Autograd they’re often impressed by this freedom. There’s a reason we include it at the top of our JAX and Autograd READMEs, slide decks, and demos. Ceding this capability would be a step backward from Autograd. We want JAX to have the best automatic differentiation.

However, the custom_transforms machinery does not provide this Python-support flexibility. That is, because it’s implemented in terms of up-front jaxpr formation from the Python code for both the user function and custom differentiation rules, code like this leads to an abstract value tracing error:

# old custom_transforms api to be replaced
@jax.custom_transforms
def f(x):
  if x > 0:
    return x
  else:
    return 0.

def f_vjp(x):
  return ...

jax.defvjp_all(f, f_vjp)

grad(f)(1.)  # Error!

Solution idea#

The main idea is that dougalm@ already solved these problems with core.call. That is, we can frame the task of specifying a custom JVP rule for a user function in terms of a new Python-level call primitive (not to be added to the jaxpr language; see below). This new call primitive has a user Python function associated with it just like core.call, but additionally has a second Python callable representing the JVP rule. Let’s refer to this new call primitive as custom_jvp_call.

Transformations like vmap interact with custom_jvp_call as with core.call: they effectively pass right through it and are applied to the underlying Python callables. Schematically, writing in terms of curried versions of the primitives for convenience, analogously to how vmap interacts with core.call by applying to the function to be called:

vmap(call(f)) == call(vmap(f))

for the new primitive custom_jvp_call we simply apply vmap to the two functions it entails:

vmap(custom_jvp_call(f, f_jvp)) == custom_jvp_call(vmap(f), vmap(f_jvp))

This behavior means we’ve solved the vmap-removes-custom-jvp semantics problem.

The jvp transformation interacts as one might expect: it just calls f_jvp,

jvp(call(f)) == call(jvp(f))

jvp(custom_jvp_call(f, f_jvp)) == f_jvp

Because custom_jvp_call acts like core.call (and not like xla.xla_call) in that it doesn’t raise the abstraction level of its inputs (because it’s not delaying anything or staging anything out), it means we’ve solved the Python flexibility problem: there are no constraints on the user Python function (above the usual functional programming constraints required by jvp or vjp).

What about evaluation and compilation? These are two ways to “exit” the JAX system, in the sense that no additional transformations can be applied after these steps. As a result, their rules are trivial:

eval(call(f)) == eval(f)
jit(call(f)) == hlo_call(jit(f))

eval(custom_jvp_call(f, f_jvp)) == eval(f)
jit(custom_jvp_call(f, f_jvp)) == hlo_call(jit(f))

In words, if a JVP rule hasn’t already rewritten custom_jvp_call(f, f_jvp) into f_jvp, when we get to the point of evaluation with eval or staging out to XLA with jit, differentiation is never going to be applied, so we just ignore f_jvp and behave just like core.call. However, due to the wrinkle discussed next, the partial eval rule for custom_jvp_call must be a bit more complex, since partial evaluation isn’t just used to stage out to XLA with jit.

The only remaining wrinkle has to do with “initial-style” jaxpr-forming primitives, like lax.scan, and their transformation rules. These represent a different kind of “staging out to a jaxpr” than that for compilation because we can perform additional transformations on the staged-out jaxpr. That is, when lax.scan forms a jaxpr, it does not exit the transformation system, since when we apply a jvp or vmap to a lax.scan we need to apply it to the function represented by the jaxpr.

Another way to state the wrinkle is that initial-style primitives like lax.scan rely on the ability to round-trip to a jaxpr and back to a Python callable while preserving semantics. That must mean preserving custom differentiation rule semantics too.

The solution is to use a bit of dynamic scoping: when we’re staging out to a jaxpr for an initial-style primitive, like those in lax_control_flow.py, we set a bit on the global trace state. When that bit is set, instead of using the final-style custom_jvp_call primitive, we use an initial-style custom_jvp_call_jaxpr primitive, and trace the functions f and f_jvp to jaxprs up-front to make initial-style processing easier. The custom_jvp_call_jaxpr primitive is otherwise similar to the final-style version.

(Footnote: while morally we form jaxprs for both f and f_jvp before binding custom_jvp_call_jaxpr, we need to delay the formation of the jaxpr of f_jvp because it may call the custom-JVP function and thus eager processing would lead to an infinite recursion. We delay that jaxpr formation in a thunk.)

If we gave up on the Python flexibility problem, we could get away with only having custom_jvp_call_jaxpr and not having the separate Python-level primitive custom_jvp_call.

API#

The custom JVP for an a -> b function is specified with an (a, Ta) -> (b, T b) function:

# f :: a -> b
@jax.custom_jvp
def f(x):
  return np.sin(x)

# f_jvp :: (a, T a) -> (b, T b)
def f_jvp(primals, tangents):
  x, = primals
  t, = tangents
  return f(x), np.cos(x) * t

f.defjvp(f_jvp)

(Interesting autodiff aside: for the rule to apply to higher-order differentiation, one must call f in the body of f_jvp; that precludes some kinds of work sharing between the internals of f and the tangent calculation.)

The custom VJP for an a -> b function is specified with an a -> (b, c) forward pass function paired with a (c, CT b) -> CT a backward pass function:

# f :: a -> b
@jax.custom_vjp
def f(x):
  return np.sin(x)

# f_fwd :: a -> (b, c)
def f_fwd(x):
  return f(x), np.cos(x)

# f_bwd :: (c, CT b) -> CT a
def f_bwd(cos_x, g):
  return (cos_x * g,)

f.defvjp(f_fwd, f_bwd)

The signature a -> (b, CT b --o CT a) is more aesthetically pleasing, but supporting it would make the implementation more complex and might require compromising expressibility desiderata. The basic reason that Python callables are opaque (unless we trace them to a jaxpr eagerly, which places expressiveness constraints), and in this case we may be returning a callable with vmap tracers inside its closure that we need to know about during the forward pass.

We could add convenience wrappers, for example to define the JVP rule for a single argument at a time (like we do internally for primitives). But because this proposal is complicated enough as it is, I decided against convenience layers; let’s keep things minimal for now.

There are some other bells and whistles to the API:

  • Inputs and output types a, b, and c can be arbitrary pytrees of jaxtypes.

  • Passing arguments by name (keyword arguments) is supported when they can be resolved to positions using the inspect module. This is a bit of an experiment with Python 3’s improved ability to programmatically inspect argument signatures. I believe it is sound but not complete, which is a fine place to be. (See also #2069.)

  • Arguments can be marked non-differentiable using nondiff_argnums, and as with jit’s static_argnums these arguments don’t have to be JAX types. We need to set a convention for how these arguments are passed to the rules. For a primal function with type signature (d, a) -> b where d represents the non-differentiable type, the JVP rule’s signature is (a, T a, d) -> T b and the VJP rule’s reverse component signature is (d, c, CT b) -> CT a. That is, the non-differentiable arguments are passed in order after primals and tangents for a custom JVP rule, and passed in order preceding the residuals in a custom VJP rule’s reverse function.

Implementation notes#

  • Updated jax.experimental.odeint

    • Since odeint is a pretty complex user of a custom VJP rule, in addition to just updating it to work at all, I wanted to revise it to be a canonical user of the new custom VJP API as a way to test that the API was a good one.

    • Along the way I made other improvements to the odeint implementation:

      • remove raveling/unraveling boilerplate

      • make use of lax.scan to remove the index-update logic

      • speed up by 20+% on the simple pendulum benchmark

  • Added a custom bind method on each transform for the custom derivative call primitives, custom_jvp_call and custom_vjp_call. It’s like core.call_bind, except we don’t process env traces: those are just errors.

  • Added custom_lin primitive, which gets staged out into linear jaxprs to be transposed when using a custom VJP rule.

    • Because our reverse-mode autodiff is decomposed into linearization, partial evaluation, and transposition, our custom VJP rules are processed in two separate steps: one during linearization and one during transposition.

    • The linearization step, i.e. the JVP rule for custom_vjp_call, applies custom_lin to the tangent values; custom_lin carries with it the user’s custom backward-pass function, and as a primitive it only has a transpose rule.

    • This mechanism is described more in #636.

  • To prevent