Gradient checkpointing with jax.checkpoint
(jax.remat
)#
In this tutorial, you will learn how to control JAX automatic differentiation’s saved values using jax.checkpoint()
(also known as jax.remat()
), which can be particularly helpful in machine learning.
If you are new to automatic differentiation (autodiff) or need to refresh your memory, JAX has Automatic differentiation and Advanced automatic differentiation tutorials.
TL;DR Use the jax.checkpoint()
decorator (aliased as jax.remat()
) with jax.grad()
to control which intermediates are saved on the forward pass versus the recomputed intermediates on the backward pass, trading off memory and FLOPs.
If you don’t use jax.checkpoint()
, the jax.grad(f)(x)
forward pass stores Jacobian coefficients and other intermediates to use during the backward pass. These saved values are called residuals.
Note: Don’t miss the Practical notes for a discussion about how jax.checkpoint()
interacts with jax.jit()
.
import jax
import jax.numpy as jnp
def g(W, x):
y = jnp.dot(W, x)
return jnp.sin(y)
def f(W1, W2, W3, x):
x = g(W1, x)
x = g(W2, x)
x = g(W3, x)
return x
W1 = jnp.ones((5, 4))
W2 = jnp.ones((6, 5))
W3 = jnp.ones((7, 6))
x = jnp.ones(4)
# Inspect the 'residual' values to be saved on the forward pass
# if you were to evaluate `jax.grad(f)(W1, W2, W3, x)`
from jax.ad_checkpoint import print_saved_residuals
jax.ad_checkpoint.print_saved_residuals(f, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of sin from /tmp/ipykernel_1130/1801108376.py:6 (g)
f32[5] output of cos from /tmp/ipykernel_1130/1801108376.py:6 (g)
f32[6] output of sin from /tmp/ipykernel_1130/1801108376.py:6 (g)
f32[6] output of cos from /tmp/ipykernel_1130/1801108376.py:6 (g)
f32[7] output of cos from /tmp/ipykernel_1130/1801108376.py:6 (g)
By applying jax.checkpoint()
to sub-functions, as a decorator or at specific application sites, you force JAX not to save any of that sub-function’s residuals. Instead, only the inputs of a jax.checkpoint()
-decorated function might be saved, and any residuals consumed on the backward pass are re-computed from those inputs as needed:
def f2(W1, W2, W3, x):
x = jax.checkpoint(g)(W1, x)
x = jax.checkpoint(g)(W2, x)
x = jax.checkpoint(g)(W3, x)
return x
jax.ad_checkpoint.print_saved_residuals(f2, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of sin from /tmp/ipykernel_1130/1801108376.py:6 (g)
f32[6] output of sin from /tmp/ipykernel_1130/1801108376.py:6 (g)
Here, the values of two sin
applications are saved because they are arguments
in subsequent applications of the jax.checkpoint()
-decorated g
function, and
inputs to a jax.checkpoint()
-decorated function may be saved. But no values of
cos
applications are saved.
To control which values are saveable without having to edit the definition of the function to be differentiated, you can use a rematerialization policy. Here is an example that saves only the results of dot
operations with no batch dimensions (since they are often FLOP-bound, and hence worth saving rather than recomputing):
f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
jax.ad_checkpoint.print_saved_residuals(f3, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of reduce_precision from /tmp/ipykernel_1130/1801108376.py:5 (g)
f32[6] output of reduce_precision from /tmp/ipykernel_1130/1801108376.py:5 (g)
f32[7] output of reduce_precision from /tmp/ipykernel_1130/1801108376.py:5 (g)
You can also use policies to refer to intermediate values you name using jax.ad_checkpoint.checkpoint_name()
:
from jax.ad_checkpoint import checkpoint_name
def f4(W1, W2, W3, x):
x = checkpoint_name(g(W1, x), name='a')
x = checkpoint_name(g(W2, x), name='b')
x = checkpoint_name(g(W3, x), name='c')
return x
f4 = jax.checkpoint(f4, policy=jax.checkpoint_policies.save_only_these_names('a'))
jax.ad_checkpoint.print_saved_residuals(f4, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of reduce_precision from /tmp/ipykernel_1130/2296542172.py:4 (f4)
When playing around with these toy examples, you can get a closer look at what’s going on using a custom print_fwd_bwd
utility defined in this notebook:
from jax.tree_util import tree_flatten, tree_unflatten
from rich.console import Console
from rich.table import Table
import rich.text
def print_fwd_bwd(f, *args, **kwargs) -> None:
args, in_tree = tree_flatten((args, kwargs))
def f_(*args):
args, kwargs = tree_unflatten(in_tree, args)
return f(*args, **kwargs)
fwd = jax.make_jaxpr(lambda *args: jax.vjp(f_, *args))(*args).jaxpr
y, f_vjp = jax.vjp(f_, *args)
res, in_tree = tree_flatten(f_vjp)
def g_(*args):
*res, y = args
f_vjp = tree_unflatten(in_tree, res)
return f_vjp(y)
bwd = jax.make_jaxpr(g_)(*res, y).jaxpr
table = Table(show_header=False, show_lines=True, padding=(1, 2, 0, 2), box=None)
table.add_row("[bold green]forward computation:",
"[bold green]backward computation:")
table.add_row(rich.text.Text.from_ansi(str(fwd)),
rich.text.Text.from_ansi(str(bwd)))
console = Console(width=240, force_jupyter=True)
console.print(table)
def _renderable_repr(self):
return self.html
rich.jupyter.JupyterRenderable._repr_html_ = _renderable_repr
# Without using `jax.checkpoint`:
print_fwd_bwd(f, W1, W2, W3, x)
forward computation: backward computation: { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let { lambda ; a:f32[4] b:f32[5,4] c:f32[5] d:f32[5] e:f32[6,5] f:f32[6] g:f32[6] h:f32[7,6] e:f32[5] = dot_general[ i:f32[7] j:f32[7]. let dimension_numbers=(([1], [0]), ([], [])) k:f32[7] = mul j i preferred_element_type=float32 l:f32[6] = dot_general[ ] a d dimension_numbers=(([0], [0]), ([], [])) f:f32[5] = sin e preferred_element_type=float32 g:f32[5] = cos e ] k h h:f32[6] = dot_general[ m:f32[7,6] = dot_general[ dimension_numbers=(([1], [0]), ([], [])) dimension_numbers=(([], []), ([], [])) preferred_element_type=float32 preferred_element_type=float32 ] b f ] k g i:f32[6] = sin h n:f32[6] = mul l f j:f32[6] = cos h o:f32[5] = dot_general[ k:f32[7] = dot_general[ dimension_numbers=(([0], [0]), ([], [])) dimension_numbers=(([1], [0]), ([], [])) preferred_element_type=float32 preferred_element_type=float32 ] n e ] c i p:f32[6,5] = dot_general[ l:f32[7] = sin k dimension_numbers=(([], []), ([], [])) m:f32[7] = cos k preferred_element_type=float32 in (l, d, a, g, f, b, j, i, c, m) } ] n d q:f32[5] = mul o c r:f32[4] = dot_general[ dimension_numbers=(([0], [0]), ([], [])) preferred_element_type=float32 ] q b s:f32[5,4] = dot_general[ dimension_numbers=(([], []), ([], [])) preferred_element_type=float32 ] q a in (s, p, m, r) }
# Using `jax.checkpoint` with policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable:
print_fwd_bwd(f3, W1, W2, W3, x)
forward computation: backward computation: { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let { lambda ; a:f32[5] b:f32[6] c:f32[7] d:f32[5,4] e:f32[6,5] f:f32[7,6] g:f32[4] h:f32[7]. let e:f32[5] = dot_general[ i:f32[5,4] j:f32[6,5] k:f32[7,6] l:f32[4] = remat2[ dimension_numbers=(([1], [0]), ([], [])) differentiated=True preferred_element_type=float32 jaxpr={ lambda ; m:f32[5] n:f32[6] o:f32[7] p:f32[5,4] q:f32[6,5] r:f32[7,6] ] a d s:f32[4] t:f32[7]. let f:f32[5] = reduce_precision[exponent_bits=8 mantissa_bits=23] e u:f32[5] = sin m g:f32[5] = sin f v:f32[5] = cos m h:f32[6] = dot_general[ w:f32[6] = sin n dimension_numbers=(([1], [0]), ([], [])) x:f32[6] = cos n preferred_element_type=float32 y:f32[7] = cos o ] b g z:f32[7] = mul t y i:f32[6] = reduce_precision[exponent_bits=8 mantissa_bits=23] h ba:f32[6] = dot_general[ j:f32[6] = sin i dimension_numbers=(([0], [0]), ([], [])) k:f32[7] = dot_general[ preferred_element_type=float32 dimension_numbers=(([1], [0]), ([], [])) ] z r preferred_element_type=float32 bb:f32[6] = mul ba x ] c j bc:f32[5] = dot_general[ l:f32[7] = reduce_precision[exponent_bits=8 mantissa_bits=23] k dimension_numbers=(([0], [0]), ([], [])) m:f32[7] = sin l preferred_element_type=float32 in (m, f, i, l, a, b, c, d) } ] bb q bd:f32[5] = mul bc v be:f32[4] = dot_general[ dimension_numbers=(([0], [0]), ([], [])) preferred_element_type=float32 ] bd p bf:f32[5,4] = dot_general[ dimension_numbers=(([], []), ([], [])) preferred_element_type=float32 ] bd s bg:f32[6,5] = dot_general[ dimension_numbers=(([], []), ([], [])) preferred_element_type=float32 ] bb u bh:f32[7,6] = dot_general[ dimension_numbers=(([], []), ([], [])) preferred_element_type=float32 ] z w in (bf, bg, bh, be) } policy=<function dot_with_no_batch_dims_saveable at 0x7f4d0491aef0> prevent_cse=True ] a b c d e f g h in (i, j, k, l) }
Let’s think step by step#
Note: It may help to check out the Advanced automatic differentiation tutorial prior to continuing here.
jax.checkpoint
fundamentals#
In both jax.linearize()
and jax.vjp()
, there is flexibility in how and when some values are computed. Different choices can trade off memory use against FLOPs. JAX provides control over these choices with jax.checkpoint()
.
One such choice is whether to perform Jacobian coefficient computations on the forward pass, as soon as the inputs are available, or on the backward pass, just before the coefficients are needed. Consider the example of sin_vjp
:
def sin_vjp(x):
y = jnp.sin(x)
cos_x = jnp.cos(x)
return y, lambda y_bar: cos_x * y_bar
Another valid implementation would compute the value of jnp.cos(x)
on the backward pass rather than on the forward pass:
def sin_vjp2(x):
y = jnp.sin(x)
return y, lambda y_bar: jnp.cos(x) * y_bar
For this particular function, the amount of memory used by the two versions is the same, though you’ve reduced the FLOPs for the primal computation (the forward pass) and increased the FLOPs for the cotangent computation (the backward pass).
There’s another choice when it comes to function composition. Recall the VJP rule for a composition of two functions:
def f(x):
y = g(x)
z = h(y)
return z
def f_vjp(x):
y, g_vjp = jax.vjp(g, x)
z, h_vjp = jax.vjp(h, y)
def f_bwd(z_bar):
y_bar, = h_vjp(z_bar)
x_bar, = g_vjp(y_bar)
return x_bar
return z, f_bwd
An alternative is:
def f_vjp_checkpoint(x):
y = g(x)
z, h_vjp = jax.vjp(h, y)
def f_bwd2(z_bar):
y_bar, = h_vjp(z_bar)
_, g_vjp = jax.vjp(g, x)
x_bar, = g_vjp(y_bar)
return x_bar
return z, f_bwd2
Using words, this alternative implementation doesn’t compute g_vjp
, or the residual values in its closure, on the forward pass. Instead, it only computes them in the backward pass f_bwd2
. That means f_vjp_checkpoint
requires less memory: if g
and h
each required similar amounts of memory for their residuals, each much larger than x
, then the function produced by f_vjp_checkpoint(x)
requires half the memory as that of f_vjp(x)
!
The cost you pay is redundant work: in f_bwd2
you must re-evaluate g(x)
as part of jax.vjp(g, x)
just to discard its value (in the underscore variable on the line _, g_vjp = jax.vjp(g, x)
).
You can get this VJP behavior in autodiff ďż˝ without having to write VJP functions directly ďż˝ by instead using jax.checkpoint()
in an alternative definition of the original function f
:
def f_checkpoint(x):
y = jax.checkpoint(g)(x)
z = h(y)
return z
In other words, you apply jax.checkpoint()
to g
— the first stage of f
— rather than to f
itself. This way, when you evaluate jax.grad(f_checkpoint)(x)
, you’d get a computation like:
Run the forward pass of
g
, discarding residual values.Run the forward pass of
h
, saving residuals.Run the backward pass of
h
, consuming residuals from step 2.Re-run the forward pass of
g
, saving residuals.Run the backward pass of
g
, consuming residuals from step 4.
That is, by evaluating jax.grad(f_checkpoint)(x)
we’d get the same computation as:
def f_checkpoint_grad(x):
y = g(x) # step 1
_, h_vjp = jax.vjp(h)(y) # step 2
y_bar, = h_vjp(1.0) # step 3
_, g_vjp = jax.vjp(g, x) # step 4
x_bar, = g_vjp(y_bar) # step 5
return x_bar
In general, jax.checkpoint(foo)
is a new function which has the same input-output behavior as foo
, but behaves differently under autodiff, particularly under jax.linearize()
and jax.vjp()
(and their wrappers, like jax.grad()
) but not jax.jvp()
. When differentiated, only the input to a jax.checkpoint()
-differentiated function is stored on the forward pass. On the backward pass, the residuals (intermediates from foo
and its Jacobian coefficient values needed for the backward pass) are recomputed.
Notice that if f = lambda x: h(g(x))
is the function you want to differentiate (in other words, if you want to apply jax.grad(f)
) you don’t get any memory savings by applying jax.checkpoint()
to f
itself. That’s because evaluating jax.grad(jax.checkpoint(f))(x)
would lead to a computation, such as:
Run the forward pass, discarding all residuals.
Immediately re-run the forward pass, saving residuals.
Run the backward pass, consuming residuals from step 2.
In code, you’d have something like:
def f_grad_bad(x):
_ = f(x) # step 1
_, f_vjp = jax.vjp(f, x) # step 2
x_bar, = f_vjp(1.0) # step 3
return x_bar
You also wouldn’t get any memory savings by applying jax.checkpoint()
to h
, the second stage of f
. That’s because evaluating jax.grad(lambda x: jax.checkpoint(h)(g(x)))
would lead to a computation, such as:
Run the forward pass of
g
, saving residuals.Run the forward pass of
h
, discarding residuals.Immediately re-run the forward pass of
h
, saving residuals.Run the backward pass of
h
, consuming residuals from step 3.Run the backward pass of
g
, consuming residuals from step 1.
In code you’d have something like:
def f_grad_bad2(x):
y, g_vjp = jax.vjp(g, x) # step 1
z = h(y) # step 2
_, h_vjp = jax.vjp(h, y) # step 3
y_bar, = h_vjp(1.0) # step 3
x_bar, = g_vjp(y_bar) # step 5
return x_bar
Slightly more generally, if you had a chain composition of functions, such as f = lambda x: f3(f2(f1(x)))
, and were interested in evaluating jax.grad(f)
, you could say that you:
Shouldn’t apply
jax.checkpoint()
to the whole functionf
, since that wouldn’t save any memory (and will perform wasteful recomputation).Shouldn’t apply
jax.checkpoint()
to the last sub-functionf3
, since that wouldn’t save any memory (and will perform wasteful recomputation).Could apply
jax.checkpoint()
tof1
,f2
, or their compositionlambda x: f2(f1(x))
, since any of those might save memory and would express different memory/recompute tradeoffs.
Custom policies for what’s saveable#
As shown so far, using jax.checkpoint()
switches from one extreme to another:
Without
jax.checkpoint()
, JAX’s autodiff tends to compute everything possible on the forward pass and store it for the backward pass.With a
jax.checkpoint()
decorator, you instead compute as little as possible on the forward pass and recompute values as needed on the backward pass.
To operate between these two extremes, saving some things and not others, you can carefully place jax.checkpoint()
decorators on sub-functions. But that requires editing the function to be differentiated, e.g. model code, which may be inconvenient. It can also be hard to experiment with variations.
So an alternative is to use the policy
argument to jax.checkpoint()
. A policy is a callable (i.e. a function) which takes as input a type-level specification of a first order primitive application and returns a boolean indicating whether the corresponding output value(s) are allowed to be saved as residuals (or instead must be recomputed in the (co)tangent computation as needed). To write robust code, a policy should be selected from the attributes on jax.checkpoint_policies()
, like jax.checkpoint_policies.dots_with_no_batch_dims_saveable()
, since the API for writing custom policy callables is considered internal.
For example, consider this function to be differentiated:
def loss(params, x, y):
return jnp.sum((predict(params, x) - y)**2)
def predict(params, x):
*Ws, Wlast = params
for W in Ws:
x = layer(W, x)
x = jnp.dot(Wlast, x)
return x
def layer(W, x):
return jnp.sin(jnp.dot(W, x))
W1 = W2 = W3 = jnp.ones((4, 4))
params = [W1, W2, W3]
x = jnp.ones(4)
y = jnp.ones(4)
print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] output of sin from /tmp/ipykernel_1130/4230705069.py:12 (layer)
f32[4] output of cos from /tmp/ipykernel_1130/4230705069.py:12 (layer)
f32[4] output of sin from /tmp/ipykernel_1130/4230705069.py:12 (layer)
f32[4] output of cos from /tmp/ipykernel_1130/4230705069.py:12 (layer)
f32[4] output of mul from /tmp/ipykernel_1130/4230705069.py:2 (loss)
Instead of saving so many values on the forward pass, perhaps you only want to save the results of matrix multiplications with no batch dimension (since they may be FLOP- rather than memory-bound). You can do that using the policy jax.checkpoint_policies.dots_with_no_batch_dims_saveable()
:
loss_checkpoint = jax.checkpoint(loss, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
print_saved_residuals(loss_checkpoint, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] from the argument y
f32[4] output of reduce_precision from /tmp/ipykernel_1130/4230705069.py:12 (layer)
f32[4] output of reduce_precision from /tmp/ipykernel_1130/4230705069.py:12 (layer)
f32[4] output of reduce_precision from /tmp/ipykernel_1130/4230705069.py:8 (predict)
Notice also that by providing a policy, you didn’t need to edit the code defining loss
, predict
, or layer
. That is particularly convenient if you want to experiment with policies in calling code (such as a training script) without changing library code (for example, the neural network library).
Some policies can refer to values named with jax.ad_checkpoint.checkpoint_name()
:
from jax.ad_checkpoint import checkpoint_name
def predict(params, x):
*Ws, Wlast = params
for i, W in enumerate(Ws):
x = layer(W, x)
x = checkpoint_name(x, name=f'layer{i}_output')
x = jnp.dot(Wlast, x)
return x
By itself, jax.ad_checkpoint import.checkpoint_name()
is just an identity function. But because some policy functions know to look for them, you can use the names to control whether certain values output by jax.ad_checkpoint import.checkpoint_name()
are considered saveable:
print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] output of cos from /tmp/ipykernel_1130/4230705069.py:12 (layer)
f32[4] named 'layer0_output' from /tmp/ipykernel_1130/178264713.py:7 (predict)
f32[4] output of cos from /tmp/ipykernel_1130/4230705069.py:12 (layer)
f32[4] named 'layer1_output' from /tmp/ipykernel_1130/178264713.py:7 (predict)
f32[4] output of mul from /tmp/ipykernel_1130/4230705069.py:2 (loss)
loss_checkpoint2 = jax.checkpoint(loss, policy=jax.checkpoint_policies.save_any_names_but_these('layer1_output'))
print_saved_residuals(loss_checkpoint2, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] from the argument y
Another policy which refers to names is jax.checkpoint_policies.save_only_these_names
.
Some of the policies are:
everything_saveable
: The default strategy, as ifjax.checkpoint()
were not being used at all).nothing_saveable
: That is, rematerialize everything, as if a custom policy were not being used at all.dots_saveable
: Or its aliascheckpoint_dots
.dots_with_no_batch_dims_saveable
: Or its aliascheckpoint_dots_with_no_batch_dims
.save_anything_but_these_names
: Save any values except for the output ofcheckpoint_name
with any of the names given.save_any_names_but_these
: Save only named values (that is, any outputs of).checkpoint_name
: Except for those with the names given.save_only_these_names
: Save only named values, and only among the names given.
Policies only indicate what is saveable; a value is only saved if it’s actually needed by the backward pass.
Advanced: Recursive jax.checkpoint
#
By applying jax.checkpoint()
in the right way, there are many tradeoffs between memory usage and (re)computation that can be expressed. One surprising example is recursive checkpointing, where you apply jax.checkpoint()
to a function which itself calls jax.checkpoint()
-decorated functions in a way so that memory usage from the chain composition of \(D\) functions scales like \(\mathcal{O}(\log_2 D)\) rather than \(\mathcal{O}(D)\).
As a toy example, consider the chain composition of multiple jax.numpy.sin()
functions:
def chain_compose(funs):
def f(x):
for fun in funs:
x = fun(x)
return x
return f
f = chain_compose([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
In general, the number of stored residuals scales linearly with the length of the chain:
f = chain_compose([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1130/410288286.py:4 (f)
But you can apply jax.checkpoint()
recursively to improve the scaling:
def recursive_checkpoint(funs):
if len(funs) == 1:
return funs[0]
elif len(funs) == 2:
f1, f2 = funs
return lambda x: f1(f2(x))
else:
f1 = recursive_checkpoint(funs[:len(funs)//2])
f2 = recursive_checkpoint(funs[len(funs)//2:])
return lambda x: f1(jax.checkpoint(f2)(x))
f = recursive_checkpoint([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] from the argument x
f32[] output of sin from /tmp/ipykernel_1130/1943107544.py:6 (<lambda>)
f32[] output of cos from /tmp/ipykernel_1130/1943107544.py:6 (<lambda>)
f32[] output of cos from /tmp/ipykernel_1130/1943107544.py:6 (<lambda>)
f = recursive_checkpoint([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] from the argument x
f32[] output of sin from /tmp/ipykernel_1130/1943107544.py:6 (<lambda>)
f32[] output of sin from /tmp/ipykernel_1130/1943107544.py:6 (<lambda>)
f32[] output of cos from /tmp/ipykernel_1130/1943107544.py:6 (<lambda>)
f32[] output of cos from /tmp/ipykernel_1130/1943107544.py:6 (<lambda>)
The cost here, as usual, is recomputation: in particular, you end up performing \(\mathcal{O}(\log_2 D)\) times as many FLOPs:
f = chain_compose([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
forward computation: backward computation: { lambda ; a:f32[]. let { lambda ; a:f32[] b:f32[] c:f32[] d:f32[] e:f32[] f:f32[] g:f32[] h:f32[] i:f32[]. let b:f32[] = sin a j:f32[] = mul i h c:f32[] = cos a k:f32[] = mul j g d:f32[] = sin b l:f32[] = mul k f e:f32[] = cos b m:f32[] = mul l e f:f32[] = sin d n:f32[] = mul m d g:f32[] = cos d o:f32[] = mul n c h:f32[] = sin f p:f32[] = mul o b i:f32[] = cos f q:f32[] = mul p a j:f32[] = sin h in (q,) } k:f32[] = cos h l:f32[] = sin j m:f32[] = cos j n:f32[] = sin l o:f32[] = cos l p:f32[] = sin n q:f32[] = cos n in (p, c, e, g, i, k, m, o, q) }
f = recursive_checkpoint([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
forward computation: backward computation: { lambda ; a:f32[]. let { lambda ; a:f32[] b:f32[] c:f32[] d:f32[]. let b:f32[] = remat2[ e:f32[] = mul d c differentiated=False f:f32[] = mul e b jaxpr={ lambda ; c:f32[]. let d:f32[] = sin c; e:f32[] = sin d in (e,) } g:f32[] = remat2[ policy=None differentiated=True prevent_cse=True jaxpr={ lambda ; h:f32[] i:f32[]. let ] a j:f32[] = sin h f:f32[] = sin b k:f32[] = cos h g:f32[] = sin f l:f32[] = cos j h:f32[] = sin g m:f32[] = mul i l i:f32[] = sin h n:f32[] = mul m k j:f32[] = sin i in (n,) } k:f32[] = cos i policy=None l:f32[] = sin j prevent_cse=True m:f32[] = cos j ] a f in (l, g, a, k, m) } o:f32[] = remat2[ differentiated=True jaxpr={ lambda ; p:f32[] q:f32[]. let r:f32[] = sin p s:f32[] = sin r t:f32[] = sin s u:f32[] = cos s v:f32[] = cos t w:f32[] = mul q v x:f32[] = mul w u y:f32[] = remat2[ differentiated=True jaxpr={ lambda ; z:f32[] ba:f32[]. let bb:f32[] = sin z bc:f32[] = cos z bd:f32[] = cos bb be:f32[] = mul ba bd bf:f32[] = mul be bc in (bf,) } policy=None prevent_cse=True ] p x in (y,) } policy=None prevent_cse=True ] 3.0 g in (o,) }
Practical notes#
When differentiated functions are staged out to XLA for compilation — for example by applying jax.jit()
to a function which contains a jax.grad()
call — XLA will automatically optimize the computation, including decisions about when to compute or rematerialize values. As a result, jax.checkpoint()
often isn’t needed for differentiated functions under a jax.jit()
. XLA will optimize things for you.
One exception is when using staged-out control flow, like jax.lax.scan()
. Automatic compiler optimizations across multiple control flow primitives (for example, across a forward-pass scan
and the corresponding backward-pass scan
), typically aren’t aren’t as thorough. As a result, it’s often a good idea to use jax.checkpoint()
on the body function passed to jax.lax.scan()
.
For example, one common pattern in large Transformer models is to express the architecture as a jax.lax.scan()
over layers so as to reduce compilation times. That is, using a simple fully-connected network as an analogy, instead of writing something like this:
LayerParam = tuple[jnp.ndarray, jnp.ndarray] # Weights-bias pair for a layer.
ParamsList = list[LayerParam]
def net(params: ParamsList, x: jnp.ndarray):
for W, b in params:
x = jnp.maximum(jnp.dot(x, W) + b, 0.)
return x
Instead, iterate over the layer application with jax.lax.scan()
:
params = [(jnp.array([[0.5, 0.5], [1., 1.]]), jnp.array([0.5, 0.5])),
(jnp.array([[0.5, 0.5], [1., 1.]]), jnp.array([0.5, 0.5]))]
all_weights = jnp.stack([W for W, _ in params])
all_biases = jnp.stack([b for _, b in params])
def layer(x, W_b_pair):
W, b = W_b_pair
out = jnp.maximum(jnp.dot(x, W) + b, 0.)
return out, None
def net(all_weights, all_biases, x):
x, _ = jax.lax.scan(layer, x, (all_weights, all_biases))
return x
This scan-over-layers version reduces compile times, but by foiling some compiler optimizations it can lead to inefficient computation of gradients. To mitigate the issue, you can use jax.checkpoint()
on the scanned function:
from functools import partial
@partial(jax.checkpoint,
policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def layer(x, W_b_pair):
W, b = W_b_pair
out = jnp.maximum(jnp.dot(x, W) + b, 0.)
return out, None
By using jax.checkpoint()
this way, you’re manually controlling which values JAX’s autodiff saves between the forward and backward passes, and therefore not relying on XLA optimizations to choose for you.