jax.experimental.checkify.check

Contents

jax.experimental.checkify.check#

jax.experimental.checkify.check(pred, msg, *fmt_args, **fmt_kwargs)[source]#

Check a predicate, add an error with msg if predicate is False.

This is an effectful operation, and can’t be staged (jitted/scanned/…). Before staging a function with checks, checkify() it!

Parameters:
  • pred (Union[bool, Array]) – if False, a FailedCheckError error is added.

  • msg (str) – error message if error is added. Can be a format string.

  • fmt_args – Positional and keyword formatting arguments for msg, eg.: check(.., "check failed on values {} and {named_arg}", x, named_arg=y) Note that these arguments can be traced values allowing you to add run-time values to the error message. Note that tracking these run-time arrays will increase your memory usage, even if no error happens.

  • fmt_kwargs – Positional and keyword formatting arguments for msg, eg.: check(.., "check failed on values {} and {named_arg}", x, named_arg=y) Note that these arguments can be traced values allowing you to add run-time values to the error message. Note that tracking these run-time arrays will increase your memory usage, even if no error happens.

Return type:

None

For example:

>>> import jax
>>> import jax.numpy as jnp
>>> from jax.experimental import checkify
>>> def f(x):
...   checkify.check(x>0, "{x} needs to be positive!", x=x)
...   return 1/x
>>> checked_f = checkify.checkify(f)
>>> err, out = jax.jit(checked_f)(-3.)
>>> err.throw()  
Traceback (most recent call last):
  ...
jax._src.checkify.JaxRuntimeError: -3. needs to be positive!