# Control autodiffâ€™s saved values with `jax.checkpoint`

(aka `jax.remat`

)#

```
import jax
import jax.numpy as jnp
```

## TL;DR#

Use the `jax.checkpoint`

decorator (aliased as `jax.remat`

) with `jax.grad`

to control which intermediates are saved on the forward pass versus recomputed on the backward pass, trading off memory and FLOPs.

**Donâ€™t miss the practical notes for a discussion about how jax.checkpoint interacts with jax.jit.**

Without using `jax.checkpoint`

, the forward pass of `jax.grad(f)(x)`

saves, for use on the backward pass, the values of Jacobian coefficients and other intermediates. We call these saved values *residuals*:

```
def g(W, x):
y = jnp.dot(W, x)
return jnp.sin(y)
def f(W1, W2, W3, x):
x = g(W1, x)
x = g(W2, x)
x = g(W3, x)
return x
W1 = jnp.ones((5, 4))
W2 = jnp.ones((6, 5))
W3 = jnp.ones((7, 6))
x = jnp.ones(4)
# Inspect the 'residual' values to be saved on the forward pass
# if we were to evaluate `jax.grad(f)(W1, W2, W3, x)`
from jax.ad_checkpoint import print_saved_residuals
jax.ad_checkpoint.print_saved_residuals(f, W1, W2, W3, x)
```

```
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
f32[5] output of cos from <ipython-input-4-f510dde58e22>:3 (g)
f32[6] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
f32[6] output of cos from <ipython-input-4-f510dde58e22>:3 (g)
f32[7] output of cos from <ipython-input-4-f510dde58e22>:3 (g)
```

By applying `jax.checkpoint`

to sub-functions, as a decorator or at specific application sites, we force JAX not to save any of that sub-functionâ€™s residuals. Instead, only the inputs of a `jax.checkpoint`

-decorated function might be saved, and any residuals consumed on the backward pass are re-computed from those inputs as needed:

```
def f2(W1, W2, W3, x):
x = jax.checkpoint(g)(W1, x)
x = jax.checkpoint(g)(W2, x)
x = jax.checkpoint(g)(W3, x)
return x
jax.ad_checkpoint.print_saved_residuals(f2, W1, W2, W3, x)
```

```
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
f32[6] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
```

Here the values of two `sin`

applications are saved because they are arguments
in subsequent applications of the `jax.checkpoint`

-decorated `g`

function, and
inputs to a `jax.checkpoint`

-decorated function may be saved. But no values of
`cos`

applications are saved.

To control which values are saveable without having to edit the definition of the function to be differentiated, you can use a rematerialization *policy*. Here is an example that saves only the results of `dot`

operations with no batch dimensions (since they are often FLOP-bound, and hence worth saving rather than recomputing):

```
f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
jax.ad_checkpoint.print_saved_residuals(f3, W1, W2, W3, x)
```

```
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)
f32[6] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)
f32[7] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)
```

You can also use policies to refer to intermediate values you name using `jax.ad_checkpoint.checkpoint_name`

:

```
from jax.ad_checkpoint import checkpoint_name
def f4(W1, W2, W3, x):
x = checkpoint_name(g(W1, x), name='a')
x = checkpoint_name(g(W2, x), name='b')
x = checkpoint_name(g(W3, x), name='c')
return x
f4 = jax.checkpoint(f4, policy=jax.checkpoint_policies.save_only_these_names('a'))
jax.ad_checkpoint.print_saved_residuals(f4, W1, W2, W3, x)
```

```
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] named 'a' from <ipython-input-7-fc0ed1c14b8d>:4 (f4)
```

When playing around with these toy examples, we can get a closer look at whatâ€™s going on using the `print_fwd_bwd`

utility definied in this notebook:

```
from jax.tree_util import tree_flatten, tree_unflatten
from rich.console import Console
from rich.table import Table
import rich.text
def print_fwd_bwd(f, *args, **kwargs) -> None:
args, in_tree = tree_flatten((args, kwargs))
def f_(*args):
args, kwargs = tree_unflatten(in_tree, args)
return f(*args, **kwargs)
fwd = jax.make_jaxpr(lambda *args: jax.vjp(f_, *args))(*args).jaxpr
y, f_vjp = jax.vjp(f_, *args)
res, in_tree = tree_flatten(f_vjp)
def g_(*args):
*res, y = args
f_vjp = tree_unflatten(in_tree, res)
return f_vjp(y)
bwd = jax.make_jaxpr(g_)(*res, y).jaxpr
table = Table(show_header=False, show_lines=True, padding=(1, 2, 0, 2), box=None)
table.add_row("[bold green]forward computation:",
"[bold green]backward computation:")
table.add_row(rich.text.Text.from_ansi(str(fwd)),
rich.text.Text.from_ansi(str(bwd)))
console = Console(width=240, force_jupyter=True)
console.print(table)
def _renderable_repr(self):
return self.html
rich.jupyter.JupyterRenderable._repr_html_ = _renderable_repr
```

```
# no use of jax.checkpoint:
print_fwd_bwd(f, W1, W2, W3, x)
```

forward computation: backward computation: { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let { lambda ; a:f32[7] b:f32[6] c:f32[7,6] d:f32[6] e:f32[5] f:f32[6,5] g:f32[5] h:f32[4] e:f32[5] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a d i:f32[5,4] j:f32[7]. let f:f32[5] = sin e k:f32[7] = mul j a g:f32[5] = cos e l:f32[6] = dot_general[dimension_numbers=(([0], [0]), ([], []))] k c h:f32[6] = dot_general[dimension_numbers=(([1], [0]), ([], []))] b f m:f32[7,6] = dot_general[dimension_numbers=(([], []), ([], []))] k b i:f32[6] = sin h n:f32[6] = mul l d j:f32[6] = cos h o:f32[5] = dot_general[dimension_numbers=(([0], [0]), ([], []))] n f k:f32[7] = dot_general[dimension_numbers=(([1], [0]), ([], []))] c i p:f32[6,5] = dot_general[dimension_numbers=(([], []), ([], []))] n e l:f32[7] = sin k q:f32[5] = mul o g m:f32[7] = cos k r:f32[4] = dot_general[dimension_numbers=(([0], [0]), ([], []))] q i in (l, m, i, c, j, f, b, g, d, a) } s:f32[5,4] = dot_general[dimension_numbers=(([], []), ([], []))] q h in (s, p, m, r) }

```
# using jax.checkpoint with policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable:
print_fwd_bwd(f3, W1, W2, W3, x)
```

forward computation: backward computation: { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let { lambda ; a:f32[5] b:f32[6] c:f32[7] d:f32[5,4] e:f32[6,5] f:f32[7,6] g:f32[4] h:f32[7]. let e:f32[5] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a d i:f32[5,4] j:f32[6,5] k:f32[7,6] l:f32[4] = remat2[ f:f32[5] = sin e differentiated=True g:f32[6] = dot_general[dimension_numbers=(([1], [0]), ([], []))] b f jaxpr={ lambda ; m:f32[5] n:f32[6] o:f32[7] p:f32[5,4] q:f32[6,5] r:f32[7,6] h:f32[6] = sin g s:f32[4] t:f32[7]. let i:f32[7] = dot_general[dimension_numbers=(([1], [0]), ([], []))] c h u:f32[5] = sin m j:f32[7] = sin i v:f32[5] = cos m in (j, e, g, i, a, b, c, d) } w:f32[6] = sin n x:f32[6] = cos n y:f32[7] = cos o z:f32[7] = mul t y ba:f32[6] = dot_general[dimension_numbers=(([0], [0]), ([], []))] z r bb:f32[6] = mul ba x bc:f32[5] = dot_general[dimension_numbers=(([0], [0]), ([], []))] bb q bd:f32[5] = mul bc v be:f32[4] = dot_general[dimension_numbers=(([0], [0]), ([], []))] bd p bf:f32[5,4] = dot_general[dimension_numbers=(([], []), ([], []))] bd s bg:f32[6,5] = dot_general[dimension_numbers=(([], []), ([], []))] bb u bh:f32[7,6] = dot_general[dimension_numbers=(([], []), ([], []))] z w in (bf, bg, bh, be) } policy=<function dot_with_no_batch_dims at 0x7f5e469b1700> prevent_cse=True ] a b c d e f g h in (i, j, k, l) }

## Letâ€™s think step by step#

You might want to first (re)read the Autodiff Cookbook Part 1.

### Fundamentals of `jax.checkpoint`

#

In both `jax.linearize`

and `jax.vjp`

there is flexibilty in how and when some values are computed. Different choices can trade off memory use against FLOPs. JAX provides control over these choices with `jax.checkpoint`

.

One such choice is whether to perform Jacobian coefficient computations on the forward pass, as soon as the inputs are available, or on the backward pass, just before the coefficients are needed. Consider the example of `sin_vjp`

:

```
def sin_vjp(x):
y = jnp.sin(x)
cos_x = jnp.cos(x)
return y, lambda y_bar: cos_x * y_bar
```

Another valid implementation would compute the value of `jnp.cos(x)`

on the backward pass rather than on the forward pass:

```
def sin_vjp2(x):
y = jnp.sin(x)
return y, lambda y_bar: jnp.cos(x) * y_bar
```

For this particular function, the amount of memory used by the two versions is the same, though weâ€™ve reduced the FLOPs for the primal computation (i.e. the forward pass) and increased the FLOPs for the cotangent computation (i.e. the backward pass).

Thereâ€™s another choice when it comes to function composition. Recall our VJP rule for a composition of two functions:

```
def f(x):
y = g(x)
z = h(y)
return z
def f_vjp(x):
y, g_vjp = jax.vjp(g, x)
z, h_vjp = jax.vjp(h, y)
def f_bwd(z_bar):
y_bar, = h_vjp(z_bar)
x_bar, = g_vjp(y_bar)
return x_bar
return z, f_bwd
```

An alternative is:

```
def f_vjp_checkpoint(x):
y = g(x)
z, h_vjp = jax.vjp(h, y)
def f_bwd2(z_bar):
y_bar, = h_vjp(z_bar)
_, g_vjp = jax.vjp(g, x)
x_bar, = g_vjp(y_bar)
return x_bar
return z, f_bwd2
```

In words, this alternative implementation doesnâ€™t compute `g_vjp`

, or the residual values in its closure, on the forward pass. Instead it only computes them in the backward pass `f_bwd2`

. That means `f_vjp_checkpoint`

requires less memory: if `g`

and `h`

each required similar amounts of memory for their residuals, each much larger than `x`

, then the function produced by `f_vjp_checkpoint(x)`

requires half the memory as that of `f_vjp(x)`

!

The cost we pay is redundant work: in `f_bwd2`

we must re-evaluate `g(x)`

as part of `jax.vjp(g, x)`

just to discard its value (in the underscore variable on the line `_, g_vjp = jax.vjp(g, x)`

).

We can get this VJP behavior in autodiff ďż˝ without having to write VJP functions directly ďż˝ by instead using `jax.checkpoint`

in an alternative definition of the original function `f`

:

```
def f_checkpoint(x):
y = jax.checkpoint(g)(x)
z = h(y)
return z
```

In other words, we apply `jax.checkpoint`

to `g`

, the first stage of `f`

, rather than to `f`

itself. This way, when we evaluate `jax.grad(f_checkpoint)(x)`

, weâ€™d get a computation like:

run the forward pass of

`g`

, discarding residual values;run the forward pass of

`h`

, saving residuals;run the backward pass of

`h`

, consuming residuals from step 2;re-run the forward pass of

`g`

, saving residuals;run the backward pass of

`g`

, consuming residuals from step 4.

That is, by evaluating `jax.grad(f_checkpoint)(x)`

weâ€™d get the same computation as:

```
def f_checkpoint_grad(x):
y = g(x) # step 1
_, h_vjp = jax.vjp(h)(y) # step 2
y_bar, = h_vjp(1.0) # step 3
_, g_vjp = jax.vjp(g, x) # step 4
x_bar, = g_vjp(y_bar) # step 5
return x_bar
```

In general, `jax.checkpoint(foo)`

is a new function which has the same input-output behavior as `foo`

, but behaves differently under autodiff, particularly under `jax.linearize`

and `jax.vjp`

(and their wrappers, like `jax.grad`

) but not `jax.jvp`

. When differentiated, only the input to a `jax.checkpoint`

-differentiated function is stored on the forward pass; on the backward pass, residuals (i.e. intermediates from `foo`

and its Jacobian coefficient values needed for the backward pass) are recomputed.

Notice that if `f = lambda x: h(g(x))`

is the function we want to differentiate, i.e. if we want to apply `jax.grad(f)`

, we donâ€™t get any memory savings by applying `jax.checkpoint`

to `f`

itself. Thatâ€™s because evaluating `jax.grad(jax.checkpoint(f))(x)`

would lead to a computation like:

run the forward pass, discarding all residuals;

immediately re-run the forward pass, saving residuals;

run the backward pass, consuming residuals from step 2.

That is, in code weâ€™d have something like:

```
def f_grad_bad(x):
_ = f(x) # step 1
_, f_vjp = jax.vjp(f, x) # step 2
x_bar, = f_vjp(1.0) # step 3
return x_bar
```

We also wouldnâ€™t get any memory savings by applying `jax.checkpoint`

to `h`

, the second stage of `f`

. Thatâ€™s because evaluating `jax.grad(lambda x: jax.checkpoint(h)(g(x)))`

would lead to a computation like:

run the forward pass of

`g`

, saving residuals;run the forward pass of

`h`

, discarding residuals;immediately re-run the forward pass of

`h`

, saving residuals;run the backward pass of

`h`

, consuming residuals from step 3;run the backward pass of

`g`

, consuming residuals from step 1.

That is, in code weâ€™d have something like:

```
def f_grad_bad2(x):
y, g_vjp = jax.vjp(g, x) # step 1
z = h(y) # step 2
_, h_vjp = jax.vjp(h, y) # step 3
y_bar, = h_vjp(1.0) # step 3
x_bar, = g_vjp(y_bar) # step 5
return x_bar
```

Slightly more generally, if we had a chain composition of functions, like `f = lambda x: f3(f2(f1(x)))`

, and we were interested in evaluating `jax.grad(f)`

, we could say that:

we shouldnâ€™t apply

`jax.checkpoint`

to the whole function`f`

, since that wouldnâ€™t save any memory (and will perform wasteful recomputation);we shouldnâ€™t apply

`jax.checkpoint`

to the last sub-function`f3`

, since that wouldnâ€™t save any memory (and will perform wasteful recomputation);we could apply

`jax.checkpoint`

to`f1`

,`f2`

, or their composition`lambda x: f2(f1(x))`

, since any of those might save memory and would express different memory/recompute tradeoffs.

### Custom policies for whatâ€™s saveable#

As shown so far, using `jax.checkpoint`

switches from one extreme to another:

without

`jax.checkpoint`

, JAXâ€™s autodiff tends to compute everything possible on the forward pass and store it for the backward pass;with a

`jax.checkpoint`

decorator, we instead compute as little as possible on the forward pass and recompute values as needed on the backward pass.

To operate between these two extremes, saving some things and not others, we can carefully place `jax.checkpoint`

decorators on sub-functions. But that requires editing the function to be differentiated, e.g. model code, which may be inconvenient. It can also be hard to experiment with variations.

So an alternative is to use the `policy`

argument to `jax.checkpoint`

. A policy is a callable (i.e. a function) which takes as input a type-level specification of a first order primitive application and returns a boolean indicating whether the corresponding output value(s) are allowed to be saved as residuals (or instead must be recomputed in the (co)tangent computation as needed). To write robust code, a policy should be selected from the attributes on `jax.checkpoint_policies`

, like `jax.checkpoint_policies.dots_with_no_batch_dims_saveable`

, since the API for writing custom policy callables is considered internal.

For example, consider this function to be differentiated:

```
def loss(params, x, y):
return jnp.sum((predict(params, x) - y)**2)
def predict(params, x):
*Ws, Wlast = params
for W in Ws:
x = layer(W, x)
x = jnp.dot(Wlast, x)
return x
def layer(W, x):
return jnp.sin(jnp.dot(W, x))
```

```
W1 = W2 = W3 = jnp.ones((4, 4))
params = [W1, W2, W3]
x = jnp.ones(4)
y = jnp.ones(4)
```

```
print_saved_residuals(loss, params, x, y)
```

```
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] output of sin from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of sin from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of mul from <ipython-input-18-3808b5023c3d>:2 (loss)
```

Instead of saving so many values on the forward pass, perhaps we only want to save the results of matrix multiplications with no batch dimension (since they may be FLOP- rather than memory-bound). We can do that using the policy `jax.checkpoint_policies.dots_with_no_batch_dims_saveable`

:

```
loss_checkpoint = jax.checkpoint(loss, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
print_saved_residuals(loss_checkpoint, params, x, y)
```

```
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] from the argument 'y'
f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:8 (predict)
```

Notice also that by providing a policy, we didnâ€™t need to edit the code defining `loss`

, `predict`

, or `layer`

. That is particularly convenient if we want to experiment with policies in calling code (e.g. a training script) without changing library code (e.g. the neural network library).

Some policies can refer to values named with `jax.ad_checkpoint.checkpoint_name`

:

```
from jax.ad_checkpoint import checkpoint_name
def predict(params, x):
*Ws, Wlast = params
for i, W in enumerate(Ws):
x = layer(W, x)
x = checkpoint_name(x, name=f'layer{i}_output')
x = jnp.dot(Wlast, x)
return x
```

By itself, `checkpoint_name`

is just an identity function. But because some policy functions know to look for them, we can use the names to control whether certain values output by `checkpoint_name`

are considered saveable:

```
print_saved_residuals(loss, params, x, y)
```

```
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] named 'layer0_output' from <ipython-input-22-e48aedf368ad>:7 (predict)
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] named 'layer1_output' from <ipython-input-22-e48aedf368ad>:7 (predict)
f32[4] output of mul from <ipython-input-18-3808b5023c3d>:2 (loss)
```

```
loss_checkpoint2 = jax.checkpoint(loss, policy=jax.checkpoint_policies.save_any_names_but_these('layer1_output'))
print_saved_residuals(loss_checkpoint2, params, x, y)
```

```
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] from the argument 'y'
```

Another policy which refers to names is `jax.checkpoint_policies.save_only_these_names`

.

Some of the policies are:

`everything_saveable`

(the default strategy, as if`jax.checkpoint`

were not being used at all)`nothing_saveable`

(i.e. rematerialize everything, as if a custom policy were not being used at all)`dots_saveable`

or its alias`checkpoint_dots`

`dots_with_no_batch_dims_saveable`

or its alias`checkpoint_dots_with_no_batch_dims`

`save_anything_but_these_names`

(save any values except for the output of`checkpoint_name`

with any of the names given)`save_any_names_but_these`

(save only named values, i.e. any outputs of`checkpoint_name`

, except for those with the names given)`save_only_these_names`

(save only named values, and only among the names given)

Policies only indicate what is saveable; a value is only saved if itâ€™s actually needed by the backward pass.

### Advanced: recursive `jax.checkpoint`

#

By applying `jax.checkpoint`

in the right way, there are many tradeoffs between memory usage and (re)computation that can be expressed. One surprising example is *recursive* checkpointing, where we apply `jax.checkpoint`

to a function which itself calls `jax.checkpoint`

-decorated functions in a way so that memory usage from the chain composition of \(D\) functions scales like \(\mathcal{O}(\log_2 D)\) rather than \(\mathcal{O}(D)\).

As a toy example, consider the chain composition of multiple `jnp.sin`

functions:

```
def chain_compose(funs):
def f(x):
for fun in funs:
x = fun(x)
return x
return f
f = chain_compose([jnp.sin] * 8)
print_saved_residuals(f, 3.)
```

```
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
```

In general, the number of stored residuals scales linearly with the length of the chain:

```
f = chain_compose([jnp.sin] * 16)
print_saved_residuals(f, 3.)
```

```
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
```

But we can apply `jax.checkpoint`

recursively to improve the scaling:

```
def recursive_checkpoint(funs):
if len(funs) == 1:
return funs[0]
elif len(funs) == 2:
f1, f2 = funs
return lambda x: f1(f2(x))
else:
f1 = recursive_checkpoint(funs[:len(funs)//2])
f2 = recursive_checkpoint(funs[len(funs)//2:])
return lambda x: f1(jax.checkpoint(f2)(x))
```

```
f = recursive_checkpoint([jnp.sin] * 8)
print_saved_residuals(f, 3.)
```

```
f32[] from the argument 'x'
f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
```

```
f = recursive_checkpoint([jnp.sin] * 16)
print_saved_residuals(f, 3.)
```

```
f32[] from the argument 'x'
f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
```

The cost here, as usual, is recomputation: in particular, we end up performing \(\mathcal{O}(\log_2 D)\) times as many FLOPs:

```
f = chain_compose([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
```

forward computation: backward computation: { lambda ; a:f32[]. let { lambda ; a:f32[] b:f32[] c:f32[] d:f32[] e:f32[] f:f32[] g:f32[] h:f32[] i:f32[]. let b:f32[] = sin a j:f32[] = mul i a c:f32[] = cos a k:f32[] = mul j b d:f32[] = sin b l:f32[] = mul k c e:f32[] = cos b m:f32[] = mul l d f:f32[] = sin d n:f32[] = mul m e g:f32[] = cos d o:f32[] = mul n f h:f32[] = sin f p:f32[] = mul o g i:f32[] = cos f q:f32[] = mul p h j:f32[] = sin h in (q,) } k:f32[] = cos h l:f32[] = sin j m:f32[] = cos j n:f32[] = sin l o:f32[] = cos l p:f32[] = sin n q:f32[] = cos n in (p, q, o, m, k, i, g, e, c) }

```
f = recursive_checkpoint([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
```

forward computation: backward computation: { lambda ; a:f32[]. let { lambda ; a:f32[] b:f32[] c:f32[] d:f32[]. let b:f32[] = remat2[ e:f32[] = mul d a differentiated=False f:f32[] = mul e b jaxpr={ lambda ; c:f32[]. let d:f32[] = sin c; e:f32[] = sin d in (e,) } g:f32[] = remat2[ policy=None differentiated=True prevent_cse=True jaxpr={ lambda ; h:f32[] i:f32[]. let ] a j:f32[] = sin h f:f32[] = sin b k:f32[] = cos h g:f32[] = sin f l:f32[] = cos j h:f32[] = sin g m:f32[] = mul i l i:f32[] = sin h n:f32[] = mul m k j:f32[] = sin i in (n,) } k:f32[] = cos i policy=None l:f32[] = sin j prevent_cse=True m:f32[] = cos j ] c f in (l, m, k, g, a) } o:f32[] = remat2[ differentiated=True jaxpr={ lambda ; p:f32[] q:f32[]. let r:f32[] = sin p s:f32[] = sin r t:f32[] = sin s u:f32[] = cos s v:f32[] = cos t w:f32[] = mul q v x:f32[] = mul w u y:f32[] = remat2[ differentiated=True jaxpr={ lambda ; z:f32[] ba:f32[]. let bb:f32[] = sin z bc:f32[] = cos z bd:f32[] = cos bb be:f32[] = mul ba bd bf:f32[] = mul be bc in (bf,) } policy=None prevent_cse=True ] p x in (y,) } policy=None prevent_cse=True ] 3.0 g in (o,) }

## Practical notes#

When differentiated functions are staged out to XLA for compilation, for example by applying `jax.jit`

to a function which contains a `jax.grad`

call, XLA will automatically optimize the computation, including decisions about when to compute or rematerialize values. As a result, ** jax.checkpoint often isnâ€™t needed for differentiated functions under a jax.jit**. XLA will optimize things for you.

One exception is when using staged-out control flow, like `jax.lax.scan`

. Automatic compiler optimizations across multiple control flow primitives, e.g. across a forward-pass `scan`

and the corresponding backward-pass `scan`

, typically arenâ€™t arenâ€™t as thorough. As a result, itâ€™s often a good idea to use `jax.checkpoint`

on the body function passed to `jax.lax.scan`

.

For example, one common pattern in large Transformer models is to express the architecture as a `jax.lax.scan`

over layers so as to reduce compilation times. That is, using a simple fully-connected network as an analogy, instead of writing something like this:

```
from typing import Tuple, List
LayerParam = Tuple[jnp.ndarray, jnp.ndarray] # weights, bias pair for a layer
ParamsList = List[LayerParam]
def net(params: ParamsList, x: jnp.ndarray):
for W, b in params:
x = jnp.maximum(jnp.dot(x, W) + b, 0.)
return x
```

We would instead iterate over the layer application with `jax.lax.scan`

:

```
StackedWeights = jnp.ndarray # all weight matrices stacked together
StackedBiases = jnp.ndarray # all bias vectors stacked together
all_weights = jnp.stack([W for W, _ in params])
all_biases = jnp.stack([b for _, b in params])
def layer(x, W_b_pair):
W, b = W_b_pair
out = jnp.maximum(jnp.dot(x, W) + b, 0.)
return out, None
def net(all_weights, all_biases, x):
x, _ = jax.lax.scan(layer, x, (all_weights, all_biases))
return x
```

This scan-over-layers version reduces compile times, but by foiling some compiler optimizations it can lead to inefficient computation of gradients. To mitigate the issue, we would use `jax.checkpoint`

on the scanned function:

```
from functools import partial
@partial(jax.checkpoint,
policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def layer(x, W_b_pair):
W, b = W_b_pair
out = jnp.maximum(jnp.dot(x, W) + b, 0.)
return out, None
```

By using `jax.checkpoint`

this way, weâ€™re manually controlling which values JAXâ€™s autodiff saves between the forward and backward passes, and hence not relying on XLA optimizations to choose for us.