jax.experimental.checkify.check_error

Contents

jax.experimental.checkify.check_error#

jax.experimental.checkify.check_error(error)[source]#

Raise an Exception if error represents a failure. Functionalized by checkify().

The semantics of this function are equivalent to:

>>> def check_error(err: Error) -> None:
...   err.throw()  # can raise ValueError

But unlike that implementation, check_error can be functionalized using the checkify() transformation.

This function is similar to check() but with a different signature: whereas check() takes as arguments a boolean predicate and a new error message string, this function takes an Error value as argument. Both check() and this function raise a Python Exception on failure (a side-effect), and thus cannot be staged out by jit(), pmap(), scan(), etc. Both also can be functionalized by using checkify().

But unlike check(), this function is like a direct inverse of checkify(): whereas checkify() takes as input a function which can raise a Python Exception and produces a new function without that effect but which produces an Error value as output, this check_error function can accept an Error value as input and can produce the side-effect of raising an Exception. That is, while checkify() goes from functionalizable Exception effect to error value, this check_error goes from error value to functionalizable Exception effect.

check_error is useful when you want to turn checks represented by an Error value (produced by functionalizing checks via checkify()) back into Python Exceptions.

Parameters:

error (Error) – Error to check.

Return type:

None

For example, you might want to functionalize part of your program through checkify, stage out your functionalized code through jit(), then re-inject your error value outside of the jit():

>>> import jax
>>> from jax.experimental import checkify
>>> def f(x):
...   checkify.check(x>0, "must be positive!")
...   return x
>>> def with_inner_jit(x):
...   checked_f = checkify.checkify(f)
...   # a checkified function can be jitted
...   error, out = jax.jit(checked_f)(x)
...   checkify.check_error(error)
...   return out
>>> _ = with_inner_jit(1)  # no failed check
>>> with_inner_jit(-1)  
Traceback (most recent call last):
  ...
jax._src.JaxRuntimeError: must be positive!
>>> # can re-checkify
>>> error, _ = checkify.checkify(with_inner_jit)(-1)