The checkify transformation#

TL;DR Checkify lets you add jit-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the checkify.checkify transformation together with the assert-like checkify.check function to add runtime checks to JAX code:

from jax.experimental import checkify
import jax
import jax.numpy as jnp

def f(x, i):
  checkify.check(i >= 0, "index needs to be non-negative, got {i}", i=i)
  y = x[i]
  z = jnp.sin(y)
  return z

jittable_f = checkify.checkify(f)

err, z = jax.jit(jittable_f)(jnp.ones((5,)), -2)
print(err.get())
# >> index needs to be non-negative, got -2! (check failed at <...>:6 (f))

You can also use checkify to automatically add common checks:

errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)

err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)

err, z = checked_f(jnp.ones((5,)), -1)
err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))

err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)

err, z = checked_f(jnp.array([5, 1]), 0)
err.throw()  # if no error occurred, throw does nothing!

Functionalizing checks#

The assert-like check API by itself is not functionally pure: it can raise a Python Exception as a side-effect, just like assert. So it can’t be staged out with jit, pmap, pjit, or scan:

jax.jit(f)(jnp.ones((5,)), -1)  # checkify transformation not used
# ValueError: Cannot abstractly evaluate a checkify.check which was not functionalized.

But the checkify transformation functionalizes (or discharges) these effects. A checkify-transformed function returns an error value as a new output and remains functionally pure. That functionalization means checkify-transformed functions can be composed with staging/transforms however we like:

err, z = jax.pmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
err.throw()
"""
ValueError:
..  at mapped index 0: index needs to be non-negative! (check failed at :6 (f))
..  at mapped index 2: out-of-bounds indexing at <..>:7 (f)
"""

Why does JAX need checkify?#

Under some JAX transformations you can express runtime error checks with ordinary Python assertions, for example when only using jax.grad and jax.numpy:

def f(x):
  assert x > 0., "must be positive!"
  return jnp.log(x)

jax.grad(f)(0.)
# ValueError: "must be positive!"

But ordinary assertions don’t work inside jit, pmap, pjit, or scan. In those cases, numeric computations are staged out rather than evaluated eagerly during Python execution, and as a result numeric values aren’t available:

jax.jit(f)(0.)
# ConcretizationTypeError: "Abstract tracer value encountered ..."

JAX transformation semantics rely on functional purity, especially when composing multiple transformations, so how can we provide an error mechanism without disrupting all that? Beyond needing a new API, the situation is trickier still: XLA HLO doesn’t support assertions or throwing errors, so even if we had a JAX API which was able to stage out assertions, how would we lower these assertions to XLA?

You could imagine manually adding run-time checks to your function and plumbing out values representing errors:

def f_checked(x):
  error = x <= 0.
  result = jnp.log(x)
  return error, result

err, y = jax.jit(f_checked)(0.)
if err:
  raise ValueError("must be positive!")
# ValueError: "must be positive!"

The error is a regular value computed by the function, and the error is raised outside of f_checked. f_checked is functionally pure, so we know by construction that it’ll already work with jit, pmap, pjit, scan, and all of JAX’s transformations. The only problem is that this plumbing can be a pain!

checkify does this rewrite for you: that includes plumbing the error value through the function, rewriting checks to boolean operations and merging the result with the tracked error value, and returning the final error value as an output to the checkified function:

def f(x):
  checkify.check(x > 0., "{} must be positive!", x)  # convenient but effectful API
  return jnp.log(x)

f_checked = checkify(f)

err, x = jax.jit(f_checked)(-1.)
err.throw()
# ValueError: -1. must be positive! (check failed at <...>:2 (f))

We call this functionalizing or discharging the effect introduced by calling check. (In the “manual” example above the error value is just a boolean. checkify’s error values are conceptually similar but also track error messages and expose throw and get methods; see jax.experimental.checkify). checkify.check also allows you to add run-time values to your error message by providing them as format arguments to the error message.

You could now manually instrument your code with run-time checks, but checkify can also automatically add checks for common errors! Consider these error cases:

jnp.arange(3)[5]                # out of bounds
jnp.sin(jnp.inf)                # NaN generated
jnp.ones((5,)) / jnp.arange(5)  # division by zero

By default checkify only discharges checkify.checks, and won’t do anything to catch errors like the above. But if you ask it to, checkify will also instrument your code with checks automatically.

def f(x, i):
  y = x[i]        # i could be out of bounds.
  z = jnp.sin(y)  # z could become NaN
  return z

errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)

err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)

err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)

The API for selecting which automatic checks to enable is based on Sets. See jax.experimental.checkify for more details.

checkify under JAX transformations.#

As demonstrated in the examples above, a checkified function can be happily jitted. Here’s a few more examples of checkify with other JAX transformations. Note that checkified functions are functionally pure, and should trivially compose with all JAX transformations!

jit#

You can safely add jax.jit to a checkified function, or checkify a jitted function, both will work.

def f(x, i):
  return x[i]

checkify_of_jit = checkify.checkify(jax.jit(f))
jit_of_checkify = jax.jit(checkify.checkify(f))
err, _ =  checkify_of_jit(jnp.ones((5,)), 100)
err.get()
# out-of-bounds indexing at <..>:2 (f)
err, _ = jit_of_checkify(jnp.ones((5,)), 100)
# out-of-bounds indexing at <..>:2 (f)

vmap/pmap#

You can vmap and pmap checkified functions (or checkify mapped functions). Mapping a checkified function will give you a mapped error, which can contain different errors for every element of the mapped dimension.

def f(x, i):
  checkify.check(i >= 0, "index needs to be non-negative!")
  return x[i]

checked_f = checkify.checkify(f, errors=checkify.all_checks)
errs, out = jax.vmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
errs.throw()
"""
ValueError:
  at mapped index 0: index needs to be non-negative! (check failed at <...>:2 (f))
  at mapped index 2: out-of-bounds indexing at <...>:3 (f)
"""

However, a checkify-of-vmap will produce a single (unmapped) error!

@jax.vmap
def f(x, i):
  checkify.check(i >= 0, "index needs to be non-negative!")
  return x[i]

checked_f = checkify.checkify(f, errors=checkify.all_checks)
err, out = checked_f(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
err.throw()
# ValueError: index needs to be non-negative! (check failed at <...>:2 (f))

pjit#

pjit of a checkified function just works, you only need to specify an additional out_axis_resources of None for the error value output.

def f(x):
  return x / x

f = checkify.checkify(f, errors=checkify.float_checks)
f = pjit(
  f,
  in_shardings=PartitionSpec('x', None),
  out_shardings=(None, PartitionSpec('x', None)))

with jax.sharding.Mesh(mesh.devices, mesh.axis_names):
 err, data = f(input_data)
err.throw()
# ValueError: divided by zero at <...>:4 (f)

grad#

Your gradient computation will also be instrumented if you checkify-of-grad:

def f(x):
 return x / (1 + jnp.sqrt(x))

grad_f = jax.grad(f)

err, _ = checkify.checkify(grad_f, errors=checkify.nan_checks)(0.)
print(err.get())
>> nan generated by primitive mul at <...>:3 (f)

Note that there’s no multiply in f, but there is a multiply in its gradient computation (and this is where the NaN is generated!). So use checkify-of-grad to add automatic checks to both forward and backward pass operations.

checkify.checks will only be applied to the primal value of your function. If you want to use a check on a gradient value, use a custom_vjp:

@jax.custom_vjp
def assert_gradient_negative(x):
 return x

def fwd(x):
 return assert_gradient_negative(x), None

def bwd(_, grad):
 checkify.check(grad < 0, "gradient needs to be negative!")
 return (grad,)

assert_gradient_negative.defvjp(fwd, bwd)

jax.grad(assert_gradient_negative)(-1.)
# ValueError: gradient needs to be negative!

Strengths and limitations of jax.experimental.checkify#

Strengths#

  • You can use it everywhere (errors are “just values” and behave intuitively under transformations like other values)

  • Automatic instrumentation: you don’t need to make local modifications to your code. Instead, checkify can instrument all of it!

Limitations#

  • Adding a lot of runtime checks can be expensive (eg. adding a NaN check to every primitive will add a lot of operations to your computation)

  • Requires threading error values out of functions and manually throwing the error. If the error is not explicitly thrown, you might miss out on errors!

  • Throwing an error value will materialize that error value on the host, meaning it’s a blocking operation which defeats JAX’s async run-ahead.