jax.debug.print and jax.debug.breakpoint#

The jax.debug package offers some useful tools for inspecting values inside of JIT-ted functions.

Debugging with jax.debug.print and other debugging callbacks#

TL;DR Use jax.debug.print() to print values to stdout in jax.jit-,jax.pmap-, and pjit-decorated functions:

import jax
import jax.numpy as jnp

@jax.jit
def f(x):
  jax.debug.print("🤯 {x} 🤯", x=x)
  y = jnp.sin(x)
  jax.debug.print("🤯 {y} 🤯", y=y)
  return y
  
f(2.)
# Prints:
# 🤯 2.0 🤯
# 🤯 0.9092974662780762 🤯

With some transformations, like jax.grad and jax.vmap, you can use Python’s builtin print function to print out numerical values. But print won’t work with jax.jit or jax.pmap because those transformations delay numerical evaluation. So use jax.debug.print instead!

Semantically, jax.debug.print is roughly equivalent to the following Python function

def debug.print(fmt: str, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
  print(fmt.format(*args, **kwargs))

except that it can be staged out and transformed by JAX. See the API reference for more details.

Why “debug” print?#

In the name of debugging, jax.debug.print can reveal information about how computations are evaluated:

xs = jnp.arange(3.)

def f(x):
  jax.debug.print("x: {}", x)
  y = np.sin(x)
  jax.debug.print("y: {}", y)
  return y
jax.vmap(f)(xs)
# Prints: x: 0.0
#         x: 1.0
#         x: 2.0
#         y: 0.0
#         y: 0.841471
#         y: 0.9092974
jax.lax.map(f, xs)
# Prints: x: 0.0
#         y: 0.0
#         x: 1.0
#         y: 0.841471
#         x: 2.0
#         y: 0.9092974

Notice that the printed results are in different orders!

By revealing these inner-workings, the output of jax.debug.print doesn’t respect JAX’s usual semantics guarantees, like that jax.vmap(f)(xs) and jax.lax.map(f, xs) compute the same thing (in different ways). Yet these evaluation order details are exactly what we might want to see when debugging!

So use jax.debug.print for debugging, and not when semantics guarantees are important.

More examples of jax.debug.print#

In addition to the above examples using jit and vmap, here are a few more to have in mind.

Printing under jax.pmap#

When jax.pmap-ed, jax.debug.prints might be reordered!

xs = jnp.arange(2.)

def f(x):
  jax.debug.print("x: {}", x)
  return x
jax.pmap(f)(xs)
# Prints: x: 1.0
#         x: 0.0
# OR
# Prints: x: 1.0
#         x: 0.0

Printing under jax.grad#

Under a jax.grad, jax.debug.prints will only print on the forward pass:

def f(x):
  jax.debug.print("x: {}", x)
  return x * 2.

jax.grad(f)(1.)
# Prints: x: 1.0

This behavior is similar to how Python’s builtin print works under a jax.grad. But by using jax.debug.print here, the behavior is the same even if the caller applies a jax.jit.

To print on the backward pass, just use a jax.custom_vjp:

@jax.custom_vjp
def print_grad(x):
  return x

def print_grad_fwd(x):
  return x, None

def print_grad_bwd(_, x_grad):
  jax.debug.print("x_grad: {}", x_grad)
  return (x_grad,)

print_grad.defvjp(print_grad_fwd, print_grad_bwd)


def f(x):
  x = print_grad(x)
  return x * 2.
jax.grad(f)(1.)
# Prints: x_grad: 2.0

Printing in other transformations#

jax.debug.print also works in other transformations like xmap and pjit.

More control with jax.debug.callback#

In fact, jax.debug.print is a thin convenience wrapper around jax.debug.callback, which can be used directly for greater control over string formatting, or even the kind of output.

Semantically, jax.debug.callback is roughly equivalent to the following Python function

def callback(fun: Callable, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
  fun(*args, **kwargs)
  return None

As with jax.debug.print, these callbacks should only be used for debugging output, like printing or plotting. Printing and plotting are pretty harmless, but if you use it for anything else its behavior might surprise you under transformations. For example, it’s not safe to use jax.debug.callback for timing operations, since callbacks might reordered and asynchronous (see below).

Sharp bits#

Like most JAX APIs, jax.debug.print can cut you if you’re not careful.

Ordering of printed results#

When distinct calls to jax.debug.print involve arguments which don’t depend on one another, they might be reordered when staged out, e.g. by jax.jit:

@jax.jit
def f(x, y):
  jax.debug.print("x: {}", x)
  jax.debug.print("y: {}", y)
  return x + y

f(2., 3.)
# Prints: x: 2.0
#         y: 3.0
# OR
# Prints: y: 3.0
#         x: 2.0

Why? Under the hood, the compiler gets a functional representation of the staged-out computation, where the imperative order of the Python function is lost and only data dependence remains. This change is invisible to users with functionally pure code, but in the presence of side-effects like printing, it’s noticeable.

To preserve the original order of jax.debug.prints as written in your Python function, you can use jax.debug.print(..., ordered=True), which will ensure the relative order of prints is preserved. But using ordered=True will raise an error under jax.pmap and other JAX transformations involving parallelism, since ordering can’t be guaranteed under parallel execution.

Asynchronous callbacks#

Depending on the backend, jax.debug.prints may happen asynchronously, i.e. not in your main program thread. This means that values could be printed to your screen even after your JAX function has returned a value.

@jax.jit
def f(x):
  jax.debug.print("x: {}")
  return x
f(2.).block_until_ready()
# <do something else>
# Prints: x: 2.

To block on the jax.debug.prints in a function, you can call jax.effects_barrier(), which will wait until any remaining side-effects in the function have completed as well:

@jax.jit
def f(x):
  jax.debug.print("x: {}")
  return x
f(2.).block_until_ready()
jax.effects_barrier()
# Prints: x: 2.
# <do something else>

Performance impacts#

Unnecessary materialization#

While jax.debug.print was designed to have a minimal performance footprint, it can interfere with compiler optimizations and potentially affect the memory profile of your JAX programs.

def f(w, b, x):
  logits = w.dot(x) + b
  jax.debug.print("logits: {}", logits)
  return jax.nn.relu(logits)

In this example, we are printing intermediate values in between a linear layer and the activation function. Compilers like XLA can perform fusion optimizations, which might avoid materializing logits in memory. But when we use jax.debug.print on logits, we are forcing those intermediates to be materialized, potentially slowing down the program and increasing memory usage.

Furthermore, when using jax.debug.print with jax.pjit, a global synchronization occurs that will materialize values on a single device.

Callback overhead#

jax.debug.print inherently incurs communication between an accelerator and its host. The underlying mechanism differs from backend to backend (e.g. GPU vs TPU) but in all cases, we’ll need to copy the printed values from device to host. In the CPU case, this overhead is smaller.

Furthermore, when using jax.debug.print with jax.pjit, a global synchronization occurs that adds some overhead.

Strengths and limitations of jax.debug.print#

Strengths#

  • Print debugging is simple and intuitive

  • jax.debug.callback can be used for other innocuous side-effects

Limitations#

  • Adding print statements is a manual process

  • Can have performance impacts

Interactive inspection with jax.debug.breakpoint()#

TL;DR Use jax.debug.breakpoint() to pause the execution of your JAX program to inspect values:

@jax.jit
def f(x):
  y, z = jnp.sin(x), jnp.cos(x)
  jax.debug.breakpoint()
  return y * z
f(2.) # ==> Pauses during execution!

JAX debugger

jax.debug.breakpoint() is actually just an application of jax.debug.callback(...) that captures information about the call stack. It has the same transformation behaviors as jax.debug.print as a result (e.g. vmap-ing jax.debug.breakpoint() unrolls it across the mapped axis).

Usage#

Calling jax.debug.breakpoint() in a compiled JAX function will pause your program when it hits the breakpoint. You’ll be presented with a pdb-like prompt that allows you to inspect the values in the call stack. Unlike pdb, you will not be able to step through the execution, but you are allowed to resume it.

Debugger commands:

  • help - prints out available commands

  • p - evaluates an expression and prints its result

  • pp - evaluates an expression and pretty-prints its result

  • u(p) - go up a stack frame

  • d(own) - go down a stack frame

  • w(here)/bt - print out a backtrace

  • l(ist) - print out code context

  • c(ont(inue)) - resumes the execution of the program

  • q(uit)/exit - exits the program (does not work on TPU)

Examples#

Usage with jax.lax.cond#

When combined with jax.lax.cond, the debugger can become a useful tool for detecting nans or infs.

def breakpoint_if_nonfinite(x):
  is_finite = jnp.isfinite(x).all()
  def true_fn(x):
    pass
  def false_fn(x):
    jax.debug.breakpoint()
  lax.cond(has_nan, true_fn, false_fn, x)
 
@jax.jit
def f(x, y):
  z = x / y
  breakpoint_if_nonfinite(z)
  return z
f(2., 0.) # ==> Pauses during execution!

Sharp bits#

Because jax.debug.breakpoint is a just an application of jax.debug.callback, it has the same sharp bits as jax.debug.print, with a few more caveats:

  • jax.debug.breakpoint materializes even more intermediates than jax.debug.print because it forces materialization of all values in the call stack

  • jax.debug.breakpoint has more runtime overhead than a jax.debug.print because it has to potentially copy all the intermediate values in a JAX program from device to host.

Strengths and limitations of jax.debug.breakpoint()#

Strengths#

  • Simple, intuitive and (somewhat) standard

  • Can inspect many values at the same time, up and down the call stack

Limitations#

  • Need to potentially use many breakpoints pinpoint the source of an error

  • Materializes many intermediates