jax.experimental.checkify.checkify#
- jax.experimental.checkify.checkify(f, errors=frozenset({<class 'jax._src.checkify.FailedCheckError'>}))[source]#
Functionalize check calls in fun, and optionally add run-time error checks.
Run-time errors are either user-added
check()
assertions, or automatically added checks like NaN checks, depending on theerrors
argument.The returned function will return an Error object err along with the output of the original function.
err.get()
will either returnNone
(if no error occurred) or a string containing an error message. This error message will correspond to the first error which occurred.err.throw()
will raise a ValueError with the error message if an error occurred.By default only user-added
check()
assertions are enabled. You can enable automatic checks through theerrors
argument.- The automatic check sets which can be enabled, and when an error is generated:
user_checks
: acheck()
evaluated to False.nan_checks
: a floating-point operation generated a NaN value as output.div_checks
: a division by zero.index_checks
: an index was out-of-bounds.
Multiple categories can be enabled together by passing in an error Set (eg.
errors=nan_checks
). Multiple sets can be re-combined (eg.errors=float_checks|user_checks
)- Parameters:
fun – Callable which can contain user checks (see
check()
).errors (
frozenset
[type
[JaxException]]) – A set of ErrorCategory values which defines the set of enabled checks. By default only explicitchecks
are enabled (user_checks
). You can also for example enable NAN and DIV errors by passing thefloat_checks
set, or for example combine multiple sets through set operations (float_checks | user_checks
)
- Return type:
- Returns:
A function which accepts the same arguments as
fun
and returns as output a pair where the first element is anError
value, representing the first failedcheck()
, and the second element is the original output offun
.
For example:
>>> import jax >>> import jax.numpy as jnp >>> from jax.experimental import checkify >>> >>> @jax.jit ... def f(x): ... y = jnp.sin(x) ... return x+y >>> err, out = checkify.checkify(f, errors=checkify.float_checks)(jnp.inf) >>> err.throw() Traceback (most recent call last): ... jax._src.checkify.JaxRuntimeError: nan generated by primitive: sin