Gradient checkpointing with jax.checkpoint (jax.remat)#

In this tutorial, you will learn how to control JAX automatic differentiation’s saved values using jax.checkpoint() (also known as jax.remat()), which can be particularly helpful in machine learning.

If you are new to automatic differentiation (autodiff) or need to refresh your memory, JAX has Automatic differentiation and Advanced automatic differentiation tutorials.

TL;DR Use the jax.checkpoint() decorator (aliased as jax.remat()) with jax.grad() to control which intermediates are saved on the forward pass versus the recomputed intermediates on the backward pass, trading off memory and FLOPs.

If you don’t use jax.checkpoint(), the jax.grad(f)(x) forward pass stores Jacobian coefficients and other intermediates to use during the backward pass. These saved values are called residuals.

Note: Don’t miss the Practical notes for a discussion about how jax.checkpoint() interacts with jax.jit().

import jax
import jax.numpy as jnp

def g(W, x):
  y = jnp.dot(W, x)
  return jnp.sin(y)

def f(W1, W2, W3, x):
  x = g(W1, x)
  x = g(W2, x)
  x = g(W3, x)
  return x

W1 = jnp.ones((5, 4))
W2 = jnp.ones((6, 5))
W3 = jnp.ones((7, 6))
x = jnp.ones(4)

# Inspect the 'residual' values to be saved on the forward pass
# if you were to evaluate `jax.grad(f)(W1, W2, W3, x)`
from jax.ad_checkpoint import print_saved_residuals
jax.ad_checkpoint.print_saved_residuals(f, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of sin from /tmp/ipykernel_1130/1801108376.py:6 (g)
f32[5] output of cos from /tmp/ipykernel_1130/1801108376.py:6 (g)
f32[6] output of sin from /tmp/ipykernel_1130/1801108376.py:6 (g)
f32[6] output of cos from /tmp/ipykernel_1130/1801108376.py:6 (g)
f32[7] output of cos from /tmp/ipykernel_1130/1801108376.py:6 (g)

By applying jax.checkpoint() to sub-functions, as a decorator or at specific application sites, you force JAX not to save any of that sub-function’s residuals. Instead, only the inputs of a jax.checkpoint()-decorated function might be saved, and any residuals consumed on the backward pass are re-computed from those inputs as needed:

def f2(W1, W2, W3, x):
  x = jax.checkpoint(g)(W1, x)
  x = jax.checkpoint(g)(W2, x)
  x = jax.checkpoint(g)(W3, x)
  return x

jax.ad_checkpoint.print_saved_residuals(f2, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of sin from /tmp/ipykernel_1130/1801108376.py:6 (g)
f32[6] output of sin from /tmp/ipykernel_1130/1801108376.py:6 (g)

Here, the values of two sin applications are saved because they are arguments in subsequent applications of the jax.checkpoint()-decorated g function, and inputs to a jax.checkpoint()-decorated function may be saved. But no values of cos applications are saved.

To control which values are saveable without having to edit the definition of the function to be differentiated, you can use a rematerialization policy. Here is an example that saves only the results of dot operations with no batch dimensions (since they are often FLOP-bound, and hence worth saving rather than recomputing):

f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
jax.ad_checkpoint.print_saved_residuals(f3, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of reduce_precision from /tmp/ipykernel_1130/1801108376.py:5 (g)
f32[6] output of reduce_precision from /tmp/ipykernel_1130/1801108376.py:5 (g)
f32[7] output of reduce_precision from /tmp/ipykernel_1130/1801108376.py:5 (g)

You can also use policies to refer to intermediate values you name using jax.ad_checkpoint.checkpoint_name():

from jax.ad_checkpoint import checkpoint_name

def f4(W1, W2, W3, x):
  x = checkpoint_name(g(W1, x), name='a')
  x = checkpoint_name(g(W2, x), name='b')
  x = checkpoint_name(g(W3, x), name='c')
  return x

f4 = jax.checkpoint(f4, policy=jax.checkpoint_policies.save_only_these_names('a'))
jax.ad_checkpoint.print_saved_residuals(f4, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of reduce_precision from /tmp/ipykernel_1130/2296542172.py:4 (f4)

When playing around with these toy examples, you can get a closer look at what’s going on using a custom print_fwd_bwd utility defined in this notebook:

from jax.tree_util import tree_flatten, tree_unflatten

from rich.console import Console
from rich.table import Table
import rich.text

def print_fwd_bwd(f, *args, **kwargs) -> None:
  args, in_tree = tree_flatten((args, kwargs))

  def f_(*args):
    args, kwargs = tree_unflatten(in_tree, args)
    return f(*args, **kwargs)

  fwd = jax.make_jaxpr(lambda *args: jax.vjp(f_, *args))(*args).jaxpr

  y, f_vjp = jax.vjp(f_, *args)
  res, in_tree = tree_flatten(f_vjp)

  def g_(*args):
    *res, y = args
    f_vjp = tree_unflatten(in_tree, res)
    return f_vjp(y)

  bwd = jax.make_jaxpr(g_)(*res, y).jaxpr

  table = Table(show_header=False, show_lines=True, padding=(1, 2, 0, 2), box=None)
  table.add_row("[bold green]forward computation:",
                "[bold green]backward computation:")
  table.add_row(rich.text.Text.from_ansi(str(fwd)),
                rich.text.Text.from_ansi(str(bwd)))
  console = Console(width=240, force_jupyter=True)
  console.print(table)

def _renderable_repr(self):
  return self.html
rich.jupyter.JupyterRenderable._repr_html_ = _renderable_repr
# Without using `jax.checkpoint`:
print_fwd_bwd(f, W1, W2, W3, x)
                                                                                                                                                         
  forward computation:                                         backward computation:                                                                     
                                                                                                                                                         
  { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let    { lambda ; a:f32[4] b:f32[5,4] c:f32[5] d:f32[5] e:f32[6,5] f:f32[6] g:f32[6] h:f32[7,6]  
      e:f32[5] = dot_general[                                      i:f32[7] j:f32[7]. let                                                                
        dimension_numbers=(([1], [0]), ([], []))                   k:f32[7] = mul j i                                                                    
        preferred_element_type=float32                             l:f32[6] = dot_general[                                                               
      ] a d                                                          dimension_numbers=(([0], [0]), ([], []))                                            
      f:f32[5] = sin e                                               preferred_element_type=float32                                                      
      g:f32[5] = cos e                                             ] k h                                                                                 
      h:f32[6] = dot_general[                                      m:f32[7,6] = dot_general[                                                             
        dimension_numbers=(([1], [0]), ([], []))                     dimension_numbers=(([], []), ([], []))                                              
        preferred_element_type=float32                               preferred_element_type=float32                                                      
      ] b f                                                        ] k g                                                                                 
      i:f32[6] = sin h                                             n:f32[6] = mul l f                                                                    
      j:f32[6] = cos h                                             o:f32[5] = dot_general[                                                               
      k:f32[7] = dot_general[                                        dimension_numbers=(([0], [0]), ([], []))                                            
        dimension_numbers=(([1], [0]), ([], []))                     preferred_element_type=float32                                                      
        preferred_element_type=float32                             ] n e                                                                                 
      ] c i                                                        p:f32[6,5] = dot_general[                                                             
      l:f32[7] = sin k                                               dimension_numbers=(([], []), ([], []))                                              
      m:f32[7] = cos k                                               preferred_element_type=float32                                                      
    in (l, d, a, g, f, b, j, i, c, m) }                            ] n d                                                                                 
                                                                   q:f32[5] = mul o c                                                                    
                                                                   r:f32[4] = dot_general[                                                               
                                                                     dimension_numbers=(([0], [0]), ([], []))                                            
                                                                     preferred_element_type=float32                                                      
                                                                   ] q b                                                                                 
                                                                   s:f32[5,4] = dot_general[                                                             
                                                                     dimension_numbers=(([], []), ([], []))                                              
                                                                     preferred_element_type=float32                                                      
                                                                   ] q a                                                                                 
                                                                 in (s, p, m, r) }                                                                       
# Using `jax.checkpoint` with policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable:
print_fwd_bwd(f3, W1, W2, W3, x)
                                                                                                                                                                        
  forward computation:                                                   backward computation:                                                                          
                                                                                                                                                                        
  { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let              { lambda ; a:f32[5] b:f32[6] c:f32[7] d:f32[5,4] e:f32[6,5] f:f32[7,6] g:f32[4] h:f32[7]. let  
      e:f32[5] = dot_general[                                                i:f32[5,4] j:f32[6,5] k:f32[7,6] l:f32[4] = remat2[                                        
        dimension_numbers=(([1], [0]), ([], []))                               differentiated=True                                                                      
        preferred_element_type=float32                                         jaxpr={ lambda ; m:f32[5] n:f32[6] o:f32[7] p:f32[5,4] q:f32[6,5] r:f32[7,6]             
      ] a d                                                                        s:f32[4] t:f32[7]. let                                                               
      f:f32[5] = reduce_precision[exponent_bits=8 mantissa_bits=23] e              u:f32[5] = sin m                                                                     
      g:f32[5] = sin f                                                             v:f32[5] = cos m                                                                     
      h:f32[6] = dot_general[                                                      w:f32[6] = sin n                                                                     
        dimension_numbers=(([1], [0]), ([], []))                                   x:f32[6] = cos n                                                                     
        preferred_element_type=float32                                             y:f32[7] = cos o                                                                     
      ] b g                                                                        z:f32[7] = mul t y                                                                   
      i:f32[6] = reduce_precision[exponent_bits=8 mantissa_bits=23] h              ba:f32[6] = dot_general[                                                             
      j:f32[6] = sin i                                                               dimension_numbers=(([0], [0]), ([], []))                                           
      k:f32[7] = dot_general[                                                        preferred_element_type=float32                                                     
        dimension_numbers=(([1], [0]), ([], []))                                   ] z r                                                                                
        preferred_element_type=float32                                             bb:f32[6] = mul ba x                                                                 
      ] c j                                                                        bc:f32[5] = dot_general[                                                             
      l:f32[7] = reduce_precision[exponent_bits=8 mantissa_bits=23] k                dimension_numbers=(([0], [0]), ([], []))                                           
      m:f32[7] = sin l                                                               preferred_element_type=float32                                                     
    in (m, f, i, l, a, b, c, d) }                                                  ] bb q                                                                               
                                                                                   bd:f32[5] = mul bc v                                                                 
                                                                                   be:f32[4] = dot_general[                                                             
                                                                                     dimension_numbers=(([0], [0]), ([], []))                                           
                                                                                     preferred_element_type=float32                                                     
                                                                                   ] bd p                                                                               
                                                                                   bf:f32[5,4] = dot_general[                                                           
                                                                                     dimension_numbers=(([], []), ([], []))                                             
                                                                                     preferred_element_type=float32                                                     
                                                                                   ] bd s                                                                               
                                                                                   bg:f32[6,5] = dot_general[                                                           
                                                                                     dimension_numbers=(([], []), ([], []))                                             
                                                                                     preferred_element_type=float32                                                     
                                                                                   ] bb u                                                                               
                                                                                   bh:f32[7,6] = dot_general[                                                           
                                                                                     dimension_numbers=(([], []), ([], []))                                             
                                                                                     preferred_element_type=float32                                                     
                                                                                   ] z w                                                                                
                                                                                 in (bf, bg, bh, be) }                                                                  
                                                                               policy=<function dot_with_no_batch_dims_saveable at 0x7f4d0491aef0>                      
                                                                               prevent_cse=True                                                                         
                                                                             ] a b c d e f g h                                                                          
                                                                           in (i, j, k, l) }                                                                            

Let’s think step by step#

Note: It may help to check out the Advanced automatic differentiation tutorial prior to continuing here.

jax.checkpoint fundamentals#

In both jax.linearize() and jax.vjp(), there is flexibility in how and when some values are computed. Different choices can trade off memory use against FLOPs. JAX provides control over these choices with jax.checkpoint().

One such choice is whether to perform Jacobian coefficient computations on the forward pass, as soon as the inputs are available, or on the backward pass, just before the coefficients are needed. Consider the example of sin_vjp:

def sin_vjp(x):
  y = jnp.sin(x)
  cos_x = jnp.cos(x)
  return y, lambda y_bar: cos_x * y_bar

Another valid implementation would compute the value of jnp.cos(x) on the backward pass rather than on the forward pass:

def sin_vjp2(x):
  y = jnp.sin(x)
  return y, lambda y_bar: jnp.cos(x) * y_bar

For this particular function, the amount of memory used by the two versions is the same, though you’ve reduced the FLOPs for the primal computation (the forward pass) and increased the FLOPs for the cotangent computation (the backward pass).

There’s another choice when it comes to function composition. Recall the VJP rule for a composition of two functions:

def f(x):
  y = g(x)
  z = h(y)
  return z

def f_vjp(x):
  y, g_vjp = jax.vjp(g, x)
  z, h_vjp = jax.vjp(h, y)
  def f_bwd(z_bar):
    y_bar, = h_vjp(z_bar)
    x_bar, = g_vjp(y_bar)
    return x_bar
  return z, f_bwd

An alternative is:

def f_vjp_checkpoint(x):
  y = g(x)
  z, h_vjp = jax.vjp(h, y)
  def f_bwd2(z_bar):
    y_bar, = h_vjp(z_bar)
    _, g_vjp = jax.vjp(g, x)
    x_bar, = g_vjp(y_bar)
    return x_bar
  return z, f_bwd2

Using words, this alternative implementation doesn’t compute g_vjp, or the residual values in its closure, on the forward pass. Instead, it only computes them in the backward pass f_bwd2. That means f_vjp_checkpoint requires less memory: if g and h each required similar amounts of memory for their residuals, each much larger than x, then the function produced by f_vjp_checkpoint(x) requires half the memory as that of f_vjp(x)!

The cost you pay is redundant work: in f_bwd2 you must re-evaluate g(x) as part of jax.vjp(g, x) just to discard its value (in the underscore variable on the line _, g_vjp = jax.vjp(g, x)).

You can get this VJP behavior in autodiff ďż˝ without having to write VJP functions directly ďż˝ by instead using jax.checkpoint() in an alternative definition of the original function f:

def f_checkpoint(x):
  y = jax.checkpoint(g)(x)
  z = h(y)
  return z

In other words, you apply jax.checkpoint() to g — the first stage of f — rather than to f itself. This way, when you evaluate jax.grad(f_checkpoint)(x), you’d get a computation like:

  1. Run the forward pass of g, discarding residual values.

  2. Run the forward pass of h, saving residuals.

  3. Run the backward pass of h, consuming residuals from step 2.

  4. Re-run the forward pass of g, saving residuals.

  5. Run the backward pass of g, consuming residuals from step 4.

That is, by evaluating jax.grad(f_checkpoint)(x) we’d get the same computation as:

def f_checkpoint_grad(x):
  y = g(x)                  # step 1
  _, h_vjp = jax.vjp(h)(y)  # step 2
  y_bar, = h_vjp(1.0)       # step 3
  _, g_vjp = jax.vjp(g, x)  # step 4
  x_bar, = g_vjp(y_bar)     # step 5
  return x_bar

In general, jax.checkpoint(foo) is a new function which has the same input-output behavior as foo, but behaves differently under autodiff, particularly under jax.linearize() and jax.vjp() (and their wrappers, like jax.grad()) but not jax.jvp(). When differentiated, only the input to a jax.checkpoint()-differentiated function is stored on the forward pass. On the backward pass, the residuals (intermediates from foo and its Jacobian coefficient values needed for the backward pass) are recomputed.

Notice that if f = lambda x: h(g(x)) is the function you want to differentiate (in other words, if you want to apply jax.grad(f)) you don’t get any memory savings by applying jax.checkpoint() to f itself. That’s because evaluating jax.grad(jax.checkpoint(f))(x) would lead to a computation, such as:

  1. Run the forward pass, discarding all residuals.

  2. Immediately re-run the forward pass, saving residuals.

  3. Run the backward pass, consuming residuals from step 2.

In code, you’d have something like:

def f_grad_bad(x):
  _ = f(x)                  # step 1
  _, f_vjp = jax.vjp(f, x)  # step 2
  x_bar, = f_vjp(1.0)       # step 3
  return x_bar

You also wouldn’t get any memory savings by applying jax.checkpoint() to h, the second stage of f. That’s because evaluating jax.grad(lambda x: jax.checkpoint(h)(g(x))) would lead to a computation, such as:

  1. Run the forward pass of g, saving residuals.

  2. Run the forward pass of h, discarding residuals.

  3. Immediately re-run the forward pass of h, saving residuals.

  4. Run the backward pass of h, consuming residuals from step 3.

  5. Run the backward pass of g, consuming residuals from step 1.

In code you’d have something like:

def f_grad_bad2(x):
  y, g_vjp = jax.vjp(g, x)  # step 1
  z = h(y)                  # step 2
  _, h_vjp = jax.vjp(h, y)  # step 3
  y_bar, = h_vjp(1.0)       # step 3
  x_bar, = g_vjp(y_bar)     # step 5
  return x_bar

Slightly more generally, if you had a chain composition of functions, such as f = lambda x: f3(f2(f1(x))), and were interested in evaluating jax.grad(f), you could say that you:

  • Shouldn’t apply jax.checkpoint() to the whole function f, since that wouldn’t save any memory (and will perform wasteful recomputation).

  • Shouldn’t apply jax.checkpoint() to the last sub-function f3, since that wouldn’t save any memory (and will perform wasteful recomputation).

  • Could apply jax.checkpoint() to f1, f2, or their composition lambda x: f2(f1(x)), since any of those might save memory and would express different memory/recompute tradeoffs.

Custom policies for what’s saveable#

As shown so far, using jax.checkpoint() switches from one extreme to another:

  • Without jax.checkpoint(), JAX’s autodiff tends to compute everything possible on the forward pass and store it for the backward pass.

  • With a jax.checkpoint() decorator, you instead compute as little as possible on the forward pass and recompute values as needed on the backward pass.

To operate between these two extremes, saving some things and not others, you can carefully place jax.checkpoint() decorators on sub-functions. But that requires editing the function to be differentiated, e.g. model code, which may be inconvenient. It can also be hard to experiment with variations.

So an alternative is to use the policy argument to jax.checkpoint(). A policy is a callable (i.e. a function) which takes as input a type-level specification of a first order primitive application and returns a boolean indicating whether the corresponding output value(s) are allowed to be saved as residuals (or instead must be recomputed in the (co)tangent computation as needed). To write robust code, a policy should be selected from the attributes on jax.checkpoint_policies(), like jax.checkpoint_policies.dots_with_no_batch_dims_saveable(), since the API for writing custom policy callables is considered internal.

For example, consider this function to be differentiated:

def loss(params, x, y):
  return jnp.sum((predict(params, x) - y)**2)

def predict(params, x):
  *Ws, Wlast = params
  for W in Ws:
    x = layer(W, x)
  x = jnp.dot(Wlast, x)
  return x

def layer(W, x):
  return jnp.sin(jnp.dot(W, x))
W1 = W2 = W3 = jnp.ones((4, 4))
params = [W1, W2, W3]
x = jnp.ones(4)
y = jnp.ones(4)
print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] output of sin from /tmp/ipykernel_1130/4230705069.py:12 (layer)
f32[4] output of cos from /tmp/ipykernel_1130/4230705069.py:12 (layer)
f32[4] output of sin from /tmp/ipykernel_1130/4230705069.py:12 (layer)
f32[4] output of cos from /tmp/ipykernel_1130/4230705069.py:12 (layer)
f32[4] output of mul from /tmp/ipykernel_1130/4230705069.py:2 (loss)

Instead of saving so many values on the forward pass, perhaps you only want to save the results of matrix multiplications with no batch dimension (since they may be FLOP- rather than memory-bound). You can do that using the policy jax.checkpoint_policies.dots_with_no_batch_dims_saveable():

loss_checkpoint = jax.checkpoint(loss, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
print_saved_residuals(loss_checkpoint, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] from the argument y
f32[4] output of reduce_precision from /tmp/ipykernel_1130/4230705069.py:12 (layer)
f32[4] output of reduce_precision from /tmp/ipykernel_1130/4230705069.py:12 (layer)
f32[4] output of reduce_precision from /tmp/ipykernel_1130/4230705069.py:8 (predict)

Notice also that by providing a policy, you didn’t need to edit the code defining loss, predict, or layer. That is particularly convenient if you want to experiment with policies in calling code (such as a training script) without changing library code (for example, the neural network library).

Some policies can refer to values named with jax.ad_checkpoint.checkpoint_name():

from jax.ad_checkpoint import checkpoint_name

def predict(params, x):
  *Ws, Wlast = params
  for i, W in enumerate(Ws):
    x = layer(W, x)
    x = checkpoint_name(x, name=f'layer{i}_output')
  x = jnp.dot(Wlast, x)
  return x

By itself, jax.ad_checkpoint import.checkpoint_name() is just an identity function. But because some policy functions know to look for them, you can use the names to control whether certain values output by jax.ad_checkpoint import.checkpoint_name() are considered saveable:

print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] output of cos from /tmp/ipykernel_1130/4230705069.py:12 (layer)
f32[4] named 'layer0_output' from /tmp/ipykernel_1130/178264713.py:7 (predict)
f32[4] output of cos from /tmp/ipykernel_1130/4230705069.py:12 (layer)
f32[4] named 'layer1_output' from /tmp/ipykernel_1130/178264713.py:7 (predict)
f32[4] output of mul from /tmp/ipykernel_1130/4230705069.py:2 (loss)
loss_checkpoint2 = jax.checkpoint(loss, policy=jax.checkpoint_policies.save_any_names_but_these('layer1_output'))
print_saved_residuals(loss_checkpoint2, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] from the argument y

Another policy which refers to names is jax.checkpoint_policies.save_only_these_names.

Some of the policies are:

  • everything_saveable: The default strategy, as if jax.checkpoint() were not being used at all).

  • nothing_saveable: That is, rematerialize everything, as if a custom policy were not being used at all.

  • dots_saveable: Or its alias checkpoint_dots.

  • dots_with_no_batch_dims_saveable: Or its alias checkpoint_dots_with_no_batch_dims.

  • save_anything_but_these_names: Save any values except for the output of checkpoint_name with any of the names given.

  • save_any_names_but_these: Save only named values (that is, any outputs of). checkpoint_name: Except for those with the names given.

  • save_only_these_names: Save only named values, and only among the names given.

Policies only indicate what is saveable; a value is only saved if it’s actually needed by the backward pass.

Advanced: Recursive jax.checkpoint#

By applying jax.checkpoint() in the right way, there are many tradeoffs between memory usage and (re)computation that can be expressed. One surprising example is recursive checkpointing, where you apply jax.checkpoint() to a function which itself calls jax.checkpoint()-decorated functions in a way so that memory usage from the chain composition of \(D\) functions scales like \(\mathcal{O}(\log_2 D)\) rather than \(\mathcal{O}(D)\).

As a toy example, consider the chain composition of multiple jax.numpy.sin() functions:

def chain_compose(funs):
  def f(x):
    for fun in funs:
      x = fun(x)
    return x
  return f

f = chain_compose([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)

In general, the number of stored residuals scales linearly with the length of the chain:

f = chain_compose([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)

But you can apply jax.checkpoint() recursively to improve the scaling:

def recursive_checkpoint(funs):
  if len(funs) == 1:
    return funs[0]
  elif len(funs) == 2:
    f1, f2 = funs
    return lambda x: f1(f2(x))
  else:
    f1 = recursive_checkpoint(funs[:len(funs)//2])
    f2 = recursive_checkpoint(funs[len(funs)//2:])
    return lambda x: f1(jax.checkpoint(f2)(x))
f = recursive_checkpoint([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] from the argument x
f32[] output of sin from /tmp/ipykernel_1130/1943107544.py:6 (<lambda>)
f32[] output of cos from /tmp/ipykernel_1130/1943107544.py:6 (<lambda>)
f32[] output of cos from /tmp/ipykernel_1130/1943107544.py:6 (<lambda>)
f = recursive_checkpoint([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] from the argument x
f32[] output of sin from /tmp/ipykernel_1130/1943107544.py:6 (<lambda>)
f32[] output of sin from /tmp/ipykernel_1130/1943107544.py:6 (<lambda>)
f32[] output of cos from /tmp/ipykernel_1130/1943107544.py:6 (<lambda>)
f32[] output of cos from /tmp/ipykernel_1130/1943107544.py:6 (<lambda>)

The cost here, as usual, is recomputation: in particular, you end up performing \(\mathcal{O}(\log_2 D)\) times as many FLOPs:

f = chain_compose([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
                                                                                                                                 
  forward computation:                  backward computation:                                                                    
                                                                                                                                 
  { lambda ; a:f32[]. let               { lambda ; a:f32[] b:f32[] c:f32[] d:f32[] e:f32[] f:f32[] g:f32[] h:f32[] i:f32[]. let  
      b:f32[] = sin a                       j:f32[] = mul i h                                                                    
      c:f32[] = cos a                       k:f32[] = mul j g                                                                    
      d:f32[] = sin b                       l:f32[] = mul k f                                                                    
      e:f32[] = cos b                       m:f32[] = mul l e                                                                    
      f:f32[] = sin d                       n:f32[] = mul m d                                                                    
      g:f32[] = cos d                       o:f32[] = mul n c                                                                    
      h:f32[] = sin f                       p:f32[] = mul o b                                                                    
      i:f32[] = cos f                       q:f32[] = mul p a                                                                    
      j:f32[] = sin h                     in (q,) }                                                                              
      k:f32[] = cos h                                                                                                            
      l:f32[] = sin j                                                                                                            
      m:f32[] = cos j                                                                                                            
      n:f32[] = sin l                                                                                                            
      o:f32[] = cos l                                                                                                            
      p:f32[] = sin n                                                                                                            
      q:f32[] = cos n                                                                                                            
    in (p, c, e, g, i, k, m, o, q) }                                                                                             
f = recursive_checkpoint([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
                                                                                                                                        
  forward computation:                                                              backward computation:                               
                                                                                                                                        
  { lambda ; a:f32[]. let                                                           { lambda ; a:f32[] b:f32[] c:f32[] d:f32[]. let     
      b:f32[] = remat2[                                                                 e:f32[] = mul d c                               
        differentiated=False                                                            f:f32[] = mul e b                               
        jaxpr={ lambda ; c:f32[]. let d:f32[] = sin c; e:f32[] = sin d in (e,) }        g:f32[] = remat2[                               
        policy=None                                                                       differentiated=True                           
        prevent_cse=True                                                                  jaxpr={ lambda ; h:f32[] i:f32[]. let         
      ] a                                                                                     j:f32[] = sin h                           
      f:f32[] = sin b                                                                         k:f32[] = cos h                           
      g:f32[] = sin f                                                                         l:f32[] = cos j                           
      h:f32[] = sin g                                                                         m:f32[] = mul i l                         
      i:f32[] = sin h                                                                         n:f32[] = mul m k                         
      j:f32[] = sin i                                                                       in (n,) }                                   
      k:f32[] = cos i                                                                     policy=None                                   
      l:f32[] = sin j                                                                     prevent_cse=True                              
      m:f32[] = cos j                                                                   ] a f                                           
    in (l, g, a, k, m) }                                                                o:f32[] = remat2[                               
                                                                                          differentiated=True                           
                                                                                          jaxpr={ lambda ; p:f32[] q:f32[]. let         
                                                                                              r:f32[] = sin p                           
                                                                                              s:f32[] = sin r                           
                                                                                              t:f32[] = sin s                           
                                                                                              u:f32[] = cos s                           
                                                                                              v:f32[] = cos t                           
                                                                                              w:f32[] = mul q v                         
                                                                                              x:f32[] = mul w u                         
                                                                                              y:f32[] = remat2[                         
                                                                                                differentiated=True                     
                                                                                                jaxpr={ lambda ; z:f32[] ba:f32[]. let  
                                                                                                    bb:f32[] = sin z                    
                                                                                                    bc:f32[] = cos z                    
                                                                                                    bd:f32[] = cos bb                   
                                                                                                    be:f32[] = mul ba bd                
                                                                                                    bf:f32[] = mul be bc                
                                                                                                  in (bf,) }                            
                                                                                                policy=None                             
                                                                                                prevent_cse=True                        
                                                                                              ] p x                                     
                                                                                            in (y,) }                                   
                                                                                          policy=None                                   
                                                                                          prevent_cse=True                              
                                                                                        ] 3.0 g                                         
                                                                                      in (o,) }                                         

Practical notes#

When differentiated functions are staged out to XLA for compilation — for example by applying jax.jit() to a function which contains a jax.grad() call — XLA will automatically optimize the computation, including decisions about when to compute or rematerialize values. As a result, jax.checkpoint() often isn’t needed for differentiated functions under a jax.jit(). XLA will optimize things for you.

One exception is when using staged-out control flow, like jax.lax.scan(). Automatic compiler optimizations across multiple control flow primitives (for example, across a forward-pass scan and the corresponding backward-pass scan), typically aren’t aren’t as thorough. As a result, it’s often a good idea to use jax.checkpoint() on the body function passed to jax.lax.scan().

For example, one common pattern in large Transformer models is to express the architecture as a jax.lax.scan() over layers so as to reduce compilation times. That is, using a simple fully-connected network as an analogy, instead of writing something like this:

LayerParam = tuple[jnp.ndarray, jnp.ndarray]  # Weights-bias pair for a layer.
ParamsList = list[LayerParam]

def net(params: ParamsList, x: jnp.ndarray):
  for W, b in params:
    x = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return x

Instead, iterate over the layer application with jax.lax.scan():

params = [(jnp.array([[0.5, 0.5], [1., 1.]]), jnp.array([0.5, 0.5])), 
          (jnp.array([[0.5, 0.5], [1., 1.]]), jnp.array([0.5, 0.5]))]

all_weights = jnp.stack([W for W, _ in params])
all_biases = jnp.stack([b for _, b in params])

def layer(x, W_b_pair):
  W, b = W_b_pair
  out = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return out, None

def net(all_weights, all_biases, x):
  x, _ = jax.lax.scan(layer, x, (all_weights, all_biases))
  return x

This scan-over-layers version reduces compile times, but by foiling some compiler optimizations it can lead to inefficient computation of gradients. To mitigate the issue, you can use jax.checkpoint() on the scanned function:

from functools import partial

@partial(jax.checkpoint,
         policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def layer(x, W_b_pair):
  W, b = W_b_pair
  out = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return out, None

By using jax.checkpoint() this way, you’re manually controlling which values JAX’s autodiff saves between the forward and backward passes, and therefore not relying on XLA optimizations to choose for you.