jax.experimental.checkify.check_error#
- jax.experimental.checkify.check_error(error)[source]#
Raise an Exception if
error
represents a failure. Functionalized bycheckify()
.The semantics of this function are equivalent to:
>>> def check_error(err: Error) -> None: ... err.throw() # can raise ValueError
But unlike that implementation,
check_error
can be functionalized using thecheckify()
transformation.This function is similar to
check()
but with a different signature: whereascheck()
takes as arguments a boolean predicate and a new error message string, this function takes anError
value as argument. Bothcheck()
and this function raise a Python Exception on failure (a side-effect), and thus cannot be staged out byjit()
,pmap()
,scan()
, etc. Both also can be functionalized by usingcheckify()
.But unlike
check()
, this function is like a direct inverse ofcheckify()
: whereascheckify()
takes as input a function which can raise a Python Exception and produces a new function without that effect but which produces anError
value as output, thischeck_error
function can accept anError
value as input and can produce the side-effect of raising an Exception. That is, whilecheckify()
goes from functionalizable Exception effect to error value, thischeck_error
goes from error value to functionalizable Exception effect.check_error
is useful when you want to turn checks represented by anError
value (produced by functionalizingchecks
viacheckify()
) back into Python Exceptions.For example, you might want to functionalize part of your program through checkify, stage out your functionalized code through
jit()
, then re-inject your error value outside of thejit()
:>>> import jax >>> from jax.experimental import checkify >>> def f(x): ... checkify.check(x>0, "must be positive!") ... return x >>> def with_inner_jit(x): ... checked_f = checkify.checkify(f) ... # a checkified function can be jitted ... error, out = jax.jit(checked_f)(x) ... checkify.check_error(error) ... return out >>> _ = with_inner_jit(1) # no failed check >>> with_inner_jit(-1) Traceback (most recent call last): ... jax._src.JaxRuntimeError: must be positive! >>> # can re-checkify >>> error, _ = checkify.checkify(with_inner_jit)(-1)