jax.checkpoint#
- jax.checkpoint(fun, *, prevent_cse=True, policy=None, static_argnums=())[source]#
Make
fun
recompute internal linearization points when differentiated.The
jax.checkpoint()
decorator, aliased tojax.remat()
, provides a way to trade off computation time and memory cost in the context of automatic differentiation, especially with reverse-mode autodiff likejax.grad()
andjax.vjp()
but also withjax.linearize()
.When differentiating a function in reverse-mode, by default all the linearization points (e.g. inputs to elementwise nonlinear primitive operations) are stored when evaluating the forward pass so that they can be reused on the backward pass. This evaluation strategy can lead to a high memory cost, or even to poor performance on hardware accelerators where memory access is much more expensive than FLOPs.
An alternative evaluation strategy is for some of the linearization points to be recomputed (i.e. rematerialized) rather than stored. This approach can reduce memory usage at the cost of increased computation.
This function decorator produces a new version of
fun
which follows the rematerialization strategy rather than the default store-everything strategy. That is, it returns a new version offun
which, when differentiated, doesn’t store any of its intermediate linearization points. Instead, these linearization points are recomputed from the function’s saved inputs.See the examples below.
- Parameters
fun (
Callable
) – Function for which the autodiff evaluation strategy is to be changed from the default of storing all intermediate linearization points to recomputing them. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof.prevent_cse (
bool
) – Optional, boolean keyword-only argument indicating whether to prevent common subexpression elimination (CSE) optimizations in the HLO generated from differentiation. This CSE prevention has costs because it can foil other optimizations, and because it can incur high overheads on some backends, especially GPU. The default is True because otherwise, under ajit()
orpmap()
, CSE can defeat the purpose of this decorator. But in some settings, like when used inside ascan()
, this CSE prevention mechanism is unnecessary, in which caseprevent_cse
can be set to False.static_argnums (
Union
[int
,Tuple
[int
,...
]]) – Optional, int or sequence of ints, a keyword-only argument indicating which argument values on which to specialize for tracing and caching purposes. Specifying arguments as static can avoid ConcretizationTypeErrors when tracing, but at the cost of more retracing overheads. See the example below.policy (
Optional
[Callable
[...
,bool
]]) – Optional, callable keyword-only argument. It should be one of the attributes ofjax.checkpoint_policies
. The callable takes as input a type-level specification of a first-order primitive application and returns a boolean indicating whether the corresponding output value(s) can be saved as residuals (or instead must be recomputed in the (co)tangent computation if needed).
- Return type
- Returns
A function (callable) with the same input/output behavior as
fun
but which, when differentiated using e.g.jax.grad()
,jax.vjp()
, orjax.linearize()
, recomputes rather than stores intermediate linearization points, thus potentially saving memory at the cost of extra computation.
Here is a simple example:
>>> import jax >>> import jax.numpy as jnp
>>> @jax.checkpoint ... def g(x): ... y = jnp.sin(x) ... z = jnp.sin(y) ... return z ... >>> jax.value_and_grad(g)(2.0) (Array(0.78907233, dtype=float32, weak_type=True), Array(-0.2556391, dtype=float32, weak_type=True))
Here, the same value is produced whether or not the
jax.checkpoint()
decorator is present. When the decorator is not present, the valuesjnp.cos(2.0)
andjnp.cos(jnp.sin(2.0))
are computed on the forward pass and are stored for use in the backward pass, because they are needed on the backward pass and depend only on the primal inputs. When usingjax.checkpoint()
, the forward pass will compute only the primal outputs and only the primal inputs (2.0
) will be stored for the backward pass. At that time, the valuejnp.sin(2.0)
is recomputed, along with the valuesjnp.cos(2.0)
andjnp.cos(jnp.sin(2.0))
.While
jax.checkpoint()
controls what values are stored from the forward-pass to be used on the backward pass, the total amount of memory required to evaluate a function or its VJP depends on many additional internal details of that function. Those details include which numerical primitives are used, how they’re composed, where jit and control flow primitives like scan are used, and other factors.The
jax.checkpoint()
decorator can be applied recursively to express sophisticated autodiff rematerialization strategies. For example:>>> def recursive_checkpoint(funs): ... if len(funs) == 1: ... return funs[0] ... elif len(funs) == 2: ... f1, f2 = funs ... return lambda x: f1(f2(x)) ... else: ... f1 = recursive_checkpoint(funs[:len(funs)//2]) ... f2 = recursive_checkpoint(funs[len(funs)//2:]) ... return lambda x: f1(jax.checkpoint(f2)(x)) ...
If
fun
involves Python control flow that depends on argument values, it may be necessary to use thestatic_argnums
parameter. For example, consider a boolean flag argument:from functools import partial @partial(jax.checkpoint, static_argnums=(1,)) def foo(x, is_training): if is_training: ... else: ...
Here, the use of
static_argnums
allows theif
statement’s condition to depends on the value ofis_training
. The cost to usingstatic_argnums
is that it introduces re-tracing overheads across calls: in the example,foo
is re-traced every time it is called with a new value ofis_training
. In some situations,jax.ensure_compile_time_eval
is needed as well:@partial(jax.checkpoint, static_argnums=(1,)) def foo(x, y): with jax.ensure_compile_time_eval(): y_pos = y > 0 if y_pos: ... else: ...
As an alternative to using
static_argnums
(andjax.ensure_compile_time_eval
), it may be easier to compute some values outside thejax.checkpoint()
-decorated function and then close over them.