jax.experimental.checkify.checkify(f, errors=frozenset({<class 'jax._src.checkify.FailedCheckError'>}))[source]#

Functionalize check calls in fun, and optionally add run-time error checks.

Run-time errors are either user-added check() assertions, or automatically added checks like NaN checks, depending on the errors argument.

The returned function will return an Error object err along with the output of the original function. err.get() will either return None (if no error occurred) or a string containing an error message. This error message will correspond to the first error which occurred. err.throw() will raise a ValueError with the error message if an error occurred.

By default only user-added check() assertions are enabled. You can enable automatic checks through the errors argument.

The automatic check sets which can be enabled, and when an error is generated:
  • user_checks: a check() evaluated to False.

  • nan_checks: a floating-point operation generated a NaN value as output.

  • div_checks: a division by zero.

  • index_checks: an index was out-of-bounds.

Multiple categories can be enabled together by passing in an error Set (eg. errors=nan_checks). Multiple sets can be re-combined (eg. errors=float_checks|user_checks)

  • fun – Callable which can contain user checks (see check()).

  • errors (frozenset[type[JaxException]]) – A set of ErrorCategory values which defines the set of enabled checks. By default only explicit checks are enabled (user_checks). You can also for example enable NAN and DIV errors by passing the float_checks set, or for example combine multiple sets through set operations (float_checks | user_checks)

  • f (Callable[[...], Out])


A function which accepts the same arguments as fun and returns as output a pair where the first element is an Error value, representing the first failed check(), and the second element is the original output of fun.

Return type:

Callable[[…], tuple[Error, Out]]

For example:

>>> import jax
>>> import jax.numpy as jnp
>>> from jax.experimental import checkify
>>> @jax.jit
... def f(x):
...   y = jnp.sin(x)
...   return x+y
>>> err, out = checkify.checkify(f, errors=checkify.float_checks)(jnp.inf)
>>> err.throw()  
Traceback (most recent call last):
jax._src.checkify.JaxRuntimeError: nan generated by primitive: sin