Source code for jax._src.checkify

# Copyright 2021 The JAX Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from import Sequence
import dataclasses
import functools
import itertools as it
from typing import Callable, TypeVar, Any, Union

import numpy as np

import jax.numpy as jnp
from jax import dtypes
from jax import lax

from jax._src import api
from jax._src import linear_util as lu
from jax._src import config
from jax._src import core
from jax._src import custom_derivatives
from jax._src import effects
from jax._src import pjit
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import tree_util as jtu
from jax._src.ad_util import SymbolicZero
from jax._src.api_util import flatten_fun
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.tree_util import tree_flatten
from jax._src.tree_util import tree_map
from jax._src.tree_util import tree_unflatten
from jax._src.typing import Array
from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip,
                           unzip3, weakref_lru_cache, HashableWrapper)


map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip

Bool = Union[bool, Array]
Int = Union[int, Array]
ErrorCategory = type['JaxException']
Payload = list[Union[np.ndarray, Array]]
PyTreeDef = jtu.PyTreeDef
Out = TypeVar('Out')

## Utils

def popattr(obj, attrname):
  val = getattr(obj, attrname)
  delattr(obj, attrname)
  return val

def setnewattr(obj, name, val):
  sentinel = object()
  assert getattr(obj, name, sentinel) is sentinel
  setattr(obj, name, val)

# Concrete errors

class JaxException(Exception):
  """Python exception which can contain an error message with JAX run-time info."""

  def __init__(self, traceback_info):
    self.traceback_info = traceback_info
    # TODO(lenamartens): re-enable tracebacks when they don't leak tracers.
    # self.with_traceback(self.traceback_info)

  def __init_subclass__(cls):

  def tree_flatten(self):
    return ([], self.traceback_info)

  def tree_unflatten(cls, metadata, payload):
    del payload
    return cls(metadata)

  def get_effect_type(self) -> ErrorEffect:
    raise NotImplementedError

@dataclasses.dataclass(eq=True, frozen=True)
class ErrorEffect(effects.Effect):
  error_type: type[JaxException]
  shape_dtypes: tuple[api.ShapeDtypeStruct, ...]

  def __lt__(self, other: ErrorEffect):
    shape_dtypes = lambda x: tuple((sd.shape, str(sd.dtype))  # dtype is not comparable
                                   for sd in x.shape_dtypes)
    unpack = lambda x: (str(x.error_type), shape_dtypes(x))
    return (unpack(self) < unpack(other))


class DivisionByZeroError(JaxException):

  def __str__(self):
    return 'division by zero'

  def get_effect_type(self):
    return ErrorEffect(DivisionByZeroError, ())

class NaNError(JaxException):

  def __init__(self, traceback_info, primitive_name):
    self.prim = primitive_name

  def tree_flatten(self):
    return ([], (self.traceback_info, self.prim))

  def tree_unflatten(cls, metadata, _):
    return cls(*metadata)

  def get_effect_type(self):
    return ErrorEffect(NaNError, ())

  def __str__(self):
    return f'nan generated by primitive: {self.prim}.'

class OOBError(JaxException):

  def __init__(self, traceback_info, primitive_name, operand_shape, payload):
    self.prim = primitive_name
    self.operand_shape = operand_shape
    self._payload = payload

  def tree_flatten(self):
    return ([self._payload], (self.traceback_info, self.prim, self.operand_shape))

  def tree_unflatten(cls, metadata, payload):
    return cls(*metadata, payload[0])

  def __str__(self):
    return (f'out-of-bounds indexing for array of '
            f'shape {self.operand_shape}: '
            f'index {self._payload[0]} is out of bounds for axis '
            f'{self._payload[1]} with size {self._payload[2]}. ')

  def get_effect_type(self):
    return ErrorEffect(OOBError, (api.ShapeDtypeStruct((3,), jnp.int32),))

class FailedCheckError(JaxException):

  def __init__(self, traceback_info, fmt_string, *a, **k):
    self.fmt_string = fmt_string
    self.args = a
    self.kwargs = k

  def tree_flatten(self):
    return ((self.args, self.kwargs),  # leaves
            (self.traceback_info, self.fmt_string))  # treedef

  def tree_unflatten(cls, metadata, payload):
    args, kwargs = payload
    return cls(*metadata, *args, **kwargs)

  def __str__(self):
    return (self.fmt_string.format(*self.args, **self.kwargs)
            + ' (`check` failed)')

  def get_effect_type(self):
    vals = jtu.tree_leaves((self.args, self.kwargs))
    return ErrorEffect(
        tuple(api.ShapeDtypeStruct(x.shape, x.dtype) for x in vals))

class BatchedError(JaxException):
  error_mapping: dict[tuple[int, ...], JaxException]

  def __post_init__(self):
    traceback_info = list(self.error_mapping.values())[0].traceback_info

  def __str__(self):
    return '\n'.join(f'at mapped index {", ".join(map(str, idx))}: {e}'
                     for idx, e in self.error_mapping.items())

# Error Value

[docs] @jtu.register_pytree_node_class @dataclasses.dataclass(frozen=True) class Error: _pred: dict[ErrorEffect, Bool] _code: dict[ErrorEffect, Int] _metadata: dict[Int, PyTreeDef] # mapping of code to JaxException treedef. _payload: dict[ErrorEffect, Payload] def get(self) -> str | None: """Returns error message if error happened, None if no error happened.""" exp = self.get_exception() if exp is not None: return str(exp) return None def get_exception(self) -> JaxException | None: """Returns Python exception if error happened, None if no error happened.""" if any(map(np.shape, self._pred.values())): return self._get_batched_exception() else: min_code = None cur_effect = None for error_effect, code in self._code.items(): if self._pred[error_effect]: if min_code is None or code < min_code: min_code = code cur_effect = error_effect if cur_effect is not None: return tree_unflatten(self._metadata[int(min_code)], # type: ignore self._payload[cur_effect]) return None def throw(self): _check_error(self) def __str__(self): return f'Error({self.get()})' # Internal helpers def _get_batched_exception(self) -> BatchedError | None: shape = np.shape(list(self._pred.values())[0]) error_mapping = {} for idx in np.ndindex(*shape): min_code = None cur_effect = None for error_effect, code in self._code.items(): if self._pred[error_effect][idx]: # type: ignore if min_code is None or code[idx] < min_code: min_code = code[idx] # type: ignore cur_effect = error_effect if cur_effect is not None: payload = tree_map(lambda x, i=idx: x[i], self._payload[cur_effect]) jax_error = tree_unflatten(self._metadata[int(min_code)], payload) # type: ignore error_mapping[idx] = jax_error if error_mapping: return BatchedError(error_mapping) else: return None def _update(self, effect_type: ErrorEffect, pred, code, metadata, payload): new_errs = {**self._pred, **{effect_type: pred}} # type: ignore new_codes = {**self._code, **{effect_type: code}} # type: ignore new_payload = {**self._payload, **{effect_type: payload}} # type: ignore new_metadata = {**self._metadata, **metadata} return Error(new_errs, new_codes, new_metadata, new_payload) def _add_placeholder_effects(self, effects: set[ErrorEffect]): """Fill out Error with `effects` and np.ones arrays of their payloads.""" new_err = self._pred.copy() new_code = self._code.copy() new_payload = self._payload.copy() for effect in effects: if effect not in self._pred.keys(): new_err[effect] = False new_payload[effect] = list( tree_map(lambda a: jnp.ones(a.shape, a.dtype), effect.shape_dtypes)) # The error value associated with this effect will never become True, so # we don't need to set a meaningful code. new_code[effect] = -1 return Error(new_err, new_code, self._metadata, new_payload) def _replace(self, *args, **kwargs): return dataclasses.replace(self, *args, **kwargs) # PyTree methods def tree_flatten(self): return ((self._pred, self._code, self._payload), (self._metadata)) @classmethod def tree_unflatten(cls, metadata, data): pred, code, payload = data return cls(pred, code, metadata, payload)
init_error = Error({}, {}, {}, {}) # value used as initial (empty) error. next_code = it.count(1).__next__ # globally unique ids, could be uuid4 def assert_func(error: Error, pred: Bool, new_error: JaxException) -> Error: code = next_code() effect_type = new_error.get_effect_type() new_payload, new_metadata = tree_flatten(new_error) return update_error(error, pred, code, {code: new_metadata}, new_payload, effect_type) def update_error(error, pred, code, metadata, payload, effect_type): err_of_type = error._pred.get(effect_type, False) out_err = err_of_type | pred out_code =, error._code.get(effect_type, -1), code) cur_payload = error._payload.get(effect_type, None) if cur_payload is not None: out_payload = tree_map(functools.partial(, err_of_type), cur_payload, payload) else: out_payload = payload return error._update(effect_type, out_err, out_code, metadata, out_payload) ## Checkify transformation for plumbing functional error values. @lu.transformation_with_aux def _flatten_and_get_error_metadata_thunk(*invals): error, out = yield invals, {} out_vals, out_tree = jtu.tree_flatten((error, out)) yield out_vals, (out_tree, set(error._pred.keys())) def default_checkify_rule(primitive: core.Primitive, error: Error, enabled_errors, *invals: core.Value, **params: Any) -> tuple[Error, Sequence[core.Value]]: """Default rule for primitives in `checkify` interpreter.""" if 'call_jaxpr' not in params: # Default non-HOP case: just call primitive and don't update error. return error, primitive.bind(*invals, **params) # Code below handles call- and map-primitives, by recursively calling # checkify_jaxpr. err_vals, err_tree = jtu.tree_flatten(error) num_error_vals = len(err_vals) if 'donated_invars' in params: params = dict(params, donated_invars=(*[False]*num_error_vals, *params['donated_invars'])) # call_jaxpr handling call_jaxpr = params.pop('call_jaxpr') if isinstance(call_jaxpr, core.ClosedJaxpr): # handle closed_call_p jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts else: jaxpr, consts = call_jaxpr, () consts_ = tuple(HashableWrapper(c) for c in consts) partial_checkify = lu.hashable_partial(lu.wrap_init( checkify_jaxpr_flat_hashable), jaxpr, consts_, enabled_errors, err_tree) partial_checkify, metadata = _flatten_and_get_error_metadata_thunk( partial_checkify) # map-specific params handling. if isinstance(primitive, core.MapPrimitive): # Update `in_axes` and `out_axes_thunk` params for map primitive. out_val_axes = params.pop('out_axes') @as_hashable_function(closure=out_val_axes) def out_axes_thunk(): out_err_num = metadata()[0].num_leaves - len(out_val_axes) return (*(0,)*out_err_num, *out_val_axes) params = dict(params, in_axes=(*(None,)*num_error_vals, *params['in_axes']), out_axes_thunk=out_axes_thunk) all_vals = primitive.bind(partial_checkify, *err_vals, *invals, **params) out_tree, _ = metadata() error, out_vals = tree_unflatten(out_tree, all_vals) if isinstance(primitive, core.MapPrimitive): error = _reduce_any_error(error) return error, out_vals def get_shaped_aval(val): return core.raise_to_shaped(core.get_aval(val)) def checkify_jaxpr(jaxpr: core.ClosedJaxpr, enabled_errors, error: Error, *args) -> tuple[Error, list[core.Value]]: err_vals, err_tree = jtu.tree_flatten(error) return checkify_jaxpr_flat(jaxpr.jaxpr, jaxpr.consts, enabled_errors, err_tree, *err_vals, *args) def checkify_jaxpr_flat(jaxpr: core.Jaxpr, consts: Sequence[core.Value], enabled_errors, err_tree: PyTreeDef, *args: core.Value) -> tuple[Error, list[Any]]: env: dict[core.Var, Any] = {} err_vals, in_args = split_list(args, [err_tree.num_leaves]) error = jtu.tree_unflatten(err_tree, err_vals) last_used = core.last_used(jaxpr) def read_env(var: core.Atom): if isinstance(var, core.Literal): return var.val return env[var] def write_env(var: core.Var, val: Any): env[var] = val map(write_env, jaxpr.constvars, consts) map(write_env, jaxpr.invars, in_args) # interpreter loop for eqn in jaxpr.eqns: invals = map(read_env, eqn.invars) checkify_rule = error_checks.get( eqn.primitive, functools.partial(default_checkify_rule, eqn.primitive)) name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack with source_info_util.user_context(eqn.source_info.traceback, name_stack=name_stack): error, outvals = checkify_rule(error, enabled_errors, *invals, **eqn.params) if eqn.primitive.multiple_results: map(write_env, eqn.outvars, outvals) else: write_env(eqn.outvars[0], outvals) core.clean_up_dead_vars(eqn, env, last_used) return error, map(read_env, jaxpr.outvars) def checkify_jaxpr_flat_hashable(jaxpr, hashable_consts, enabled_errors, err_tree, *args): consts = tuple(c.x for c in hashable_consts) return checkify_jaxpr_flat(jaxpr, consts, enabled_errors, err_tree, *args) @lu.transformation_with_aux def flatten_fun_output(*args): ans = yield args, {} yield tree_flatten(ans) def _reduce_any_error(error: Error): out_error = init_error for error_effect in error._pred.keys(): errs, codes, payloads = (error._pred[error_effect], error._code[error_effect], error._payload[error_effect]) reduced_idx = jnp.argsort(errs)[-1] pred, code, payload = tree_map(lambda x, idx=reduced_idx: x[idx], (errs, codes, payloads)) out_error = out_error._update(error_effect, pred, code, {}, payload) out_error = out_error._replace(_metadata=error._metadata) return out_error ## check_p primitive check_p = core.Primitive('check') check_p.multiple_results = True # zero results def _pp_check(eqn, context, settings) -> core.pp.Doc: annotation = (source_info_util.summarize(eqn.source_info) if settings.source_info else None) name_stack_annotation = (f'[{eqn.source_info.name_stack}]' if settings.name_stack else None) trimmed_params = sorted((k, v) for (k, v) in eqn.params.items() if k != "err_tree") rhs = [core.pp.text(, annotation=name_stack_annotation), core.pp_kv_pairs(trimmed_params, context, settings), core.pp.text(" ") + core.pp_vars(eqn.invars, context)] return core.pp.concat([core.pp.text("", annotation), *rhs]) core.pp_eqn_rules[check_p] = _pp_check # TODO(lenamartens): inherit from Exception instead of ValueError.
[docs] class JaxRuntimeError(ValueError): pass
@check_p.def_impl def check_impl(*args, err_tree, debug): if debug: # NOOP (check will only trigger when discharged) return [] error = tree_unflatten(err_tree, args) exc = error.get_exception() if exc: filtered_tb = traceback_util.filter_traceback( exc.traceback_info.as_python_traceback()) exc.with_traceback(filtered_tb) raise JaxRuntimeError(str(exc)) from exc return [] @check_p.def_effectful_abstract_eval def check_abstract_eval(*args, err_tree, debug): del debug return [], set(tree_unflatten(err_tree, args)._pred.keys()) # TODO(lenamartens) add in-depth error explanation to link to in module docs. functionalization_error = ValueError( 'Cannot abstractly evaluate a checkify.check which was not' ' functionalized. This probably means you tried to stage' ' (jit/scan/pmap/...) a `check` without functionalizing it' ' through `checkify.checkify`.' ) def check_lowering_rule(ctx, *args, err_tree, debug): if debug: # NOOP (check will only trigger when discharged) return [] if not config.xla_runtime_errors.value: raise functionalization_error out_op, _, _ = mlir.emit_python_callback( ctx, callback=functools.partial(python_err, err_tree), token=None, operands=args, operand_avals=list(ctx.avals_in), result_avals=list(ctx.avals_out), has_side_effect=True) return out_op def check_lowering_rule_unsupported(*a, debug, **k): if debug: return [] raise functionalization_error def python_err(err_tree, *args): error = tree_unflatten(err_tree, args) _check_error(error) return [] mlir.register_lowering(check_p, check_lowering_rule_unsupported, platform='tpu') mlir.register_lowering(check_p, check_lowering_rule, platform='cpu') mlir.register_lowering(check_p, check_lowering_rule, platform='gpu') def check_batching_rule(batched_args, batch_dims, *, err_tree, debug): size = next(x.shape[dim] for x, dim in zip(batched_args, batch_dims) if dim is not batching.not_mapped) batched_args = (batching.bdim_at_front(a, d, size) for a, d in zip(batched_args, batch_dims)) err = tree_unflatten(err_tree, batched_args) _check_error(err, debug=debug) return [], [] batching.primitive_batchers[check_p] = check_batching_rule def check_jvp_rule(primals, _, *, err_tree, debug): # Check primals, discard tangents. check_p.bind(*primals, err_tree=err_tree, debug=debug) return [], [] ad.primitive_jvps[check_p] = check_jvp_rule ## checkify rules ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error) error_checks: dict[core.Primitive, ErrorCheckRule] = {} def get_traceback(): return source_info_util.current().traceback def nan_error_check(prim, error, enabled_errors, *in_vals, **params): out = prim.bind(*in_vals, **params) err = check_nans(prim, error, enabled_errors, out) return err, out def check_nans(prim, error, enabled_errors, out): if NaNError not in enabled_errors: return error def isnan(x): if jnp.issubdtype(x.dtype, dtypes.prng_key): return False return jnp.any(jnp.isnan(x)) any_nans = (jnp.any(jnp.array([isnan(x) for x in out])) if prim.multiple_results else isnan(out)) return assert_func(error, any_nans, NaNError(get_traceback(), # All primitives which can generate a NaN. nan_primitives = [lax.acos_p, lax.acosh_p, lax.add_p, lax.asin_p, lax.asinh_p, lax.atan2_p, lax.atan_p, lax.atanh_p, lax.bessel_i0e_p, lax.bessel_i1e_p, lax.cbrt_p, lax.conv_general_dilated_p, lax.cos_p, lax.cosh_p, lax.cumlogsumexp_p, lax.cummax_p, lax.cummin_p, lax.cumprod_p, lax.cumsum_p, lax.digamma_p, lax.dot_general_p, lax.erf_inv_p, lax.erf_p, lax.erfc_p, lax.exp_p, lax.expm1_p, lax.fft_p, lax.igamma_grad_a_p, lax.igamma_p, lax.igammac_p, lax.integer_pow_p, lax.lgamma_p, lax.linear_solve_p, lax.log1p_p, lax.log_p, lax.logistic_p, lax.mul_p, lax.pad_p, lax.pow_p, lax.psum_p, lax.random_gamma_grad_p, lax.reduce_p, lax.reduce_prod_p, lax.reduce_sum_p, lax.reduce_window_p, lax.reduce_window_sum_p, lax.regularized_incomplete_beta_p, lax.rem_p, lax.rng_uniform_p, lax.rsqrt_p, lax.sin_p, lax.sinh_p, lax.sqrt_p, lax.sub_p, lax.tan_p, lax.tanh_p] for _prim in nan_primitives: error_checks[_prim] = functools.partial(nan_error_check, _prim) def dynamic_slice_error_check(error, enabled_errors, operand, *start_indices, slice_sizes): out = lax.dynamic_slice_p.bind(operand, *start_indices, slice_sizes=slice_sizes) if OOBError not in enabled_errors: return error, out operand_dims = np.array(operand.shape) slice_sizes = np.array(slice_sizes) start_indices = jnp.array(start_indices) oob_mask = (start_indices < 0) | (start_indices + slice_sizes > operand_dims) payload = oob_payload(oob_mask, start_indices, range(operand.ndim), operand.shape) error = assert_func(error, jnp.any(oob_mask), OOBError(get_traceback(), "dynamic_slice", operand.shape, payload)) return error, out error_checks[lax.dynamic_slice_p] = dynamic_slice_error_check def dynamic_update_slice_error_check(error, enabled_errors, operand, update, *start_indices): out = lax.dynamic_update_slice_p.bind(operand, update, *start_indices) if OOBError not in enabled_errors: return error, out operand_dims = np.array(operand.shape) update_dims = np.array(update.shape) start_indices = jnp.array(start_indices) oob_mask = (start_indices < 0) | (start_indices + update_dims > operand_dims) payload = oob_payload(oob_mask, start_indices, range(operand.ndim), operand.shape) error = assert_func(error, jnp.any(oob_mask), OOBError(get_traceback(), "dynamic_update_slice", operand.shape, payload)) return error, out error_checks[lax.dynamic_update_slice_p] = dynamic_update_slice_error_check def gather_error_check(error, enabled_errors, operand, start_indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): out = lax.gather_p.bind( operand, start_indices, dimension_numbers=dimension_numbers, slice_sizes=slice_sizes, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value) if OOBError not in enabled_errors: return error, out # compare to OOB masking logic in lax._gather_translation_rule dnums = dimension_numbers operand_dims = np.array(operand.shape) num_batch_dims = len(start_indices.shape) - 1 upper_bound = operand_dims[np.array(dnums.start_index_map)] upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)] upper_bound = jnp.expand_dims(upper_bound, axis=tuple(range(num_batch_dims))) oob_mask = (start_indices < 0) | (start_indices > upper_bound.astype(start_indices.dtype)) payload = oob_payload(oob_mask, start_indices, dnums.start_index_map, operand.shape) error = assert_func(error, jnp.any(oob_mask), OOBError(get_traceback(), "gather", operand.shape, payload)) return error, out error_checks[lax.gather_p] = gather_error_check def div_error_check(error, enabled_errors, x, y): """Checks for division by zero and NaN.""" if DivisionByZeroError in enabled_errors: any_zero = jnp.any(jnp.equal(y, 0)) error = assert_func(error, any_zero, DivisionByZeroError(get_traceback())) return nan_error_check(lax.div_p, error, enabled_errors, x, y) error_checks[lax.div_p] = div_error_check def oob_payload(oob_mask, indices, dims_map, operand_shape): # Get first OOB index, axis and axis size so it can be added to the error msg. flat_idx = jnp.argmin(jnp.logical_not(oob_mask)) multi_idx = jnp.unravel_index(flat_idx, indices.shape) oob_axis = jnp.array(dims_map)[multi_idx[-1]] oob_axis_size = jnp.array(operand_shape)[oob_axis] oob_index = jnp.ravel(indices)[flat_idx] payload = jnp.array([oob_index, oob_axis, oob_axis_size], dtype=jnp.int32) return payload def scatter_oob(operand, indices, updates, dnums): # Ref: see clamping code used in scatter_translation_rule slice_sizes = [] pos = 0 for i in range(len(operand.shape)): if i in dnums.inserted_window_dims: slice_sizes.append(1) else: slice_sizes.append(updates.shape[dnums.update_window_dims[pos]]) pos += 1 upper_bound = np.array([operand.shape[i] - slice_sizes[i] for i in dnums.scatter_dims_to_operand_dims], np.int64) upper_bound = np.minimum(upper_bound, np.iinfo(indices.dtype).max) upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape, (len(indices.shape) - 1,)) lower_oob = jnp.less(indices, 0) upper_oob = jnp.greater(indices, upper_bound.astype(indices.dtype)) oob_mask = jnp.logical_or(lower_oob, upper_oob) payload = oob_payload(oob_mask, indices, dnums.scatter_dims_to_operand_dims, operand.shape) return jnp.any(oob_mask), payload def scatter_error_check(prim, error, enabled_errors, operand, indices, updates, *, update_jaxpr, update_consts, dimension_numbers, indices_are_sorted, unique_indices, mode): """Checks if indices are within bounds and update does not generate NaN.""" out = prim.bind( operand, indices, updates, update_jaxpr=update_jaxpr, update_consts=update_consts, dimension_numbers=dimension_numbers, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) if OOBError not in enabled_errors: return error, out out_of_bounds, payload = scatter_oob(operand, indices, updates, dimension_numbers) oob_error = OOBError(get_traceback(),, operand.shape, payload) error = assert_func(error, out_of_bounds, oob_error) error = check_nans(prim, error, enabled_errors, out) return error, out error_checks[lax.scatter_p] = functools.partial(scatter_error_check, lax.scatter_p) error_checks[lax.scatter_add_p] = functools.partial(scatter_error_check, lax.scatter_add_p) error_checks[lax.scatter_mul_p] = functools.partial(scatter_error_check, lax.scatter_mul_p) error_checks[lax.scatter_min_p] = functools.partial(scatter_error_check, lax.scatter_min_p) error_checks[lax.scatter_max_p] = functools.partial(scatter_error_check, lax.scatter_max_p) # HOP error check rules @weakref_lru_cache def jaxpr_to_checkify_jaxpr( jaxpr: core.ClosedJaxpr, enabled_errors, err_tree: PyTreeDef, *flat_err_and_in_vals) -> tuple[core.ClosedJaxpr, PyTreeDef, set[ErrorEffect]]: checkify_jaxpr_partial = functools.partial(checkify_jaxpr_flat, jaxpr.jaxpr, jaxpr.consts, enabled_errors, err_tree) fun = lu.wrap_init(checkify_jaxpr_partial) fun, metadata = _flatten_and_get_error_metadata_thunk(fun) new_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, flat_err_and_in_vals) checked_jaxpr = core.ClosedJaxpr(new_jaxpr, consts) out_tree, error_effects = metadata() return checked_jaxpr, out_tree, error_effects def cond_error_check(error: Error, enabled_errors, index, *ops, branches, linear): # Get the error-effects out of all branches so the cond can be called with # a merged error with all these effects. err_vals, err_tree = jtu.tree_flatten(error) in_avals = map(get_shaped_aval, [*err_vals, *ops]) def get_error_effects_from_jaxpr(jxpr): _, _, effects = jaxpr_to_checkify_jaxpr(jxpr, enabled_errors, err_tree, *in_avals) return effects effects = [get_error_effects_from_jaxpr(jxpr) for jxpr in branches] merged_error = error._add_placeholder_effects(set().union(*effects)) err_vals, err_tree = jtu.tree_flatten(merged_error) new_linear = (*[False] * len(err_vals), *linear) # Update branch jaxprs to be checkified jaxprs. in_avals = map(get_shaped_aval, [*err_vals, *ops]) new_branches, out_trees, _ = unzip3( jaxpr_to_checkify_jaxpr( jxpr, enabled_errors, err_tree, *in_avals) for jxpr in branches) err_and_outs = lax.cond_p.bind( index, *err_vals, *ops, branches=tuple(new_branches), linear=new_linear) # we need to merge metadata across out_trees (a tuple) err0, out = tree_unflatten(out_trees[0], err_and_outs) merged_metadata = err0._metadata for tr in out_trees[1:]: err, _ = tree_unflatten(tr, err_and_outs) merged_metadata = {**merged_metadata, **err._metadata} return err0._replace(_metadata=merged_metadata), out error_checks[lax.cond_p] = cond_error_check def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr, num_consts, num_carry, linear, unroll): consts, carry, xs = split_list(in_flat, [num_consts, num_carry]) xs_mapped = [core.mapped_aval(length, 0, get_shaped_aval(val)) for val in xs] # Query body effects to create a merged error containing all effects (such # that in and out carried error are of the same type). err_vals, err_tree = jtu.tree_flatten(error) new_in_aval = map(get_shaped_aval, [*err_vals, *consts, *carry]) + xs_mapped _, _, effects = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors, err_tree, *new_in_aval) merged_error = error._add_placeholder_effects(effects) err_vals, err_tree = jtu.tree_flatten(merged_error) # Create checked-jaxpr, with the needed pre-processing on the inputs. new_in_aval = map(get_shaped_aval, [*err_vals, *consts, *carry]) + xs_mapped checked_jaxpr_, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors, err_tree, *new_in_aval) tomove = ([False] * len(err_vals) + [True] * len(consts) + [False] * (len(carry) + len(xs))) checked_jaxpr = pe.move_binders_to_front(checked_jaxpr_, tomove) new_in_flat = [*consts, *err_vals, *carry, *xs] new_linear = (*[False] * len(err_vals), *linear) err_and_out = lax.scan_p.bind( *new_in_flat, reverse=reverse, length=length, jaxpr=checked_jaxpr, num_consts=len(consts), num_carry=len(carry)+len(err_vals), linear=new_linear, unroll=unroll) err, out = tree_unflatten(out_tree, err_and_out) return err, out error_checks[lax.scan_p] = scan_error_check def checkify_while_body_jaxpr( cond_jaxpr: core.ClosedJaxpr, body_jaxpr: core.ClosedJaxpr, enabled_errors, error: Error, c_consts_num: int) -> tuple[core.ClosedJaxpr, PyTreeDef, set[ErrorEffect]]: cond_f = core.jaxpr_as_fun(cond_jaxpr) body_f = core.jaxpr_as_fun(body_jaxpr) def new_body_f(*c_consts_and_vals): c_consts, vals = split_list(c_consts_and_vals, [c_consts_num]) out = body_f(*vals) # This checks if the next cond application will error _ = cond_f(*c_consts, *out) return out new_body_f_ = lu.wrap_init(new_body_f) c_consts_avals = cond_jaxpr.in_avals[:c_consts_num] jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(new_body_f_, [*c_consts_avals, *body_jaxpr.in_avals]) closed_jaxpr = pe.close_jaxpr(jaxpr) err_vals, err_tree = jtu.tree_flatten(error) err_vals = map(get_shaped_aval, err_vals) flat_err_and_in_vals = [*err_vals, *c_consts_avals, *body_jaxpr.in_avals] jaxpr, out_tree, error_effects = jaxpr_to_checkify_jaxpr( closed_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals) return jaxpr, out_tree, error_effects @weakref_lru_cache def ignore_error_output_jaxpr(jaxpr, num_error_vals: int): """Constructs a checked jaxpr which does not output its error value.""" consts = jaxpr.consts jaxpr = jaxpr.jaxpr new_jaxpr = jaxpr.replace(outvars=jaxpr.outvars[num_error_vals:]) return core.ClosedJaxpr(new_jaxpr, consts) def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr): if cond_jaxpr.out_avals[0].shape: # TODO(lenamartens, sharadmv): support batched while. raise ValueError('Checkify does not support batched while-loops ' '(checkify-of-vmap-of-while). \nHint: if possible, move ' 'the vmap to the outer level to get ' 'vmap-of-checkify-of-while.') c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts]) # Check if the first cond application will error. error, _ = checkify_jaxpr(cond_jaxpr, enabled_errors, error, *c_consts, *carry) _, _, error_effects = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, enabled_errors, error, cond_nconsts) # merged error! error = error._add_placeholder_effects(error_effects) err_vals, err_tree = jtu.tree_flatten(error) checked_body_jaxpr_, body_out_tree, _ = checkify_while_body_jaxpr( cond_jaxpr, body_jaxpr, enabled_errors, error, cond_nconsts) num_error_vals = len(err_vals) to_move = ([False] * num_error_vals + [True] * cond_nconsts + [True] * body_nconsts + [False] * len(carry)) checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move) cond_in_flat = [*err_vals, *c_consts, *carry] cond_in_flat = map(get_shaped_aval, cond_in_flat) checked_cond_jaxpr, _, _ = jaxpr_to_checkify_jaxpr(cond_jaxpr, enabled_errors, err_tree, *cond_in_flat) compat_cond_jaxpr_ = ignore_error_output_jaxpr(checked_cond_jaxpr, num_error_vals) to_move = [False] * num_error_vals + [True] * cond_nconsts + [False] * len(carry) compat_cond_jaxpr = pe.move_binders_to_front(compat_cond_jaxpr_, to_move) new_in_flat = [*c_consts, *c_consts, *b_consts, *err_vals, *carry] all_out_vals = lax.while_p.bind( *new_in_flat, cond_nconsts=cond_nconsts, cond_jaxpr=compat_cond_jaxpr, body_nconsts=cond_nconsts+body_nconsts, body_jaxpr=checked_body_jaxpr) # body_out_tree will have all the metadata of cond because it executes a cond! error, out = tree_unflatten(body_out_tree, all_out_vals) return error, out error_checks[lax.while_p] = while_loop_error_check def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, in_shardings, out_shardings, resource_env, donated_invars, name, inline, keep_unused): # jaxpr to checked_jaxpr err_vals, err_tree = jtu.tree_flatten(error) new_vals_in = [*err_vals, *vals_in] in_avals = tuple(map(get_shaped_aval, new_vals_in)) checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors, err_tree, *in_avals) # Update pjit params to account for extra error values. num_error_vals = len(err_vals) num_out_error_vals = out_tree.num_leaves - len(out_shardings) sharding = sharding_impls.UNSPECIFIED new_in_shardings = (*[sharding] * num_error_vals, *in_shardings) new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings) new_donated_invars = (*[False] * num_error_vals, *donated_invars) err_and_out = pjit.pjit_p.bind( *new_vals_in, jaxpr=checked_jaxpr, in_shardings=new_in_shardings, out_shardings=new_out_shardings, resource_env=resource_env, donated_invars=new_donated_invars, name=name, inline=inline, keep_unused=keep_unused, ) return tree_unflatten(out_tree, err_and_out) error_checks[pjit.pjit_p] = pjit_error_check def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts, jvp_jaxpr_thunk, call_jaxpr, **params): # The types to have in mind are: # jvp : (a -> b) -> (a, T a) -> (b, T b) # checkify : (a -> b) -> a -> Err b # jvp-of-checkify : (a -> b) -> (a, T a) -> (Err b, T (Err b)) # where because Err is a pytree, we necessarily have T (Err b) = Err' (T b) # where the other Err' components are trivial (of float0 dtype). # Semantically, we don't add checks to the JVP rule. To check the result of a # JVP rule, one must instead use checkify-of-jvp. Thus this implementation # just forwards the input error and code (and trivial tangents) to the output. err_vals, err_tree = jtu.tree_flatten(in_err) partial_checkify = lu.wrap_init( functools.partial(checkify_jaxpr_flat, call_jaxpr.jaxpr, call_jaxpr.consts, enabled_errors, err_tree)) partial_checkify, f_metadata = _flatten_and_get_error_metadata_thunk( partial_checkify) jvp = lift_jvp(err_tree.num_leaves, num_consts, jvp_jaxpr_thunk) jvp, jvp_out_tree = flatten_fun_output(jvp) all_outs = custom_derivatives.custom_jvp_call_p.bind( partial_checkify, jvp, *err_vals, *in_vals, **params) fst, out_metadata = lu.merge_linear_aux(f_metadata, jvp_out_tree) if fst: err_and_out_tree, _ = out_metadata out_err, out_vals = tree_unflatten(err_and_out_tree, all_outs) else: err_vals, out_vals = split_list(all_outs, [len(err_vals)]) # forward input error to output out_err = jtu.tree_unflatten(err_tree, err_vals) return out_err, out_vals error_checks[custom_derivatives.custom_jvp_call_p] = custom_jvp_call_rule # Compared to custom_derivatives.lift_jvp, we're handling the extra inputs and # outputs that checkify adds (just forwarding the error data's primal and # tangent components). The jaxpr in jvp_jaxpr_thunk doesn't expect those. # TODO(mattjj): can we simplify this, or dedup with custom_derivatives.lift_jvp? # Adding another layer of lu.transformation was tricky, though maybe doable. def lift_jvp(num_errs, num_consts, jvp_jaxpr_thunk): @lu.wrap_init def jvp(*xs): n, ragged = divmod(len(xs), 2) assert not ragged primals, tangents = xs[num_consts+num_errs:n], xs[n+num_consts+num_errs:] zeros = [type(t) is SymbolicZero for t in tangents] jvp_jaxpr, jvp_consts, out_zeros = jvp_jaxpr_thunk(*zeros) nonzero_tangents = [t for t in tangents if type(t) is not SymbolicZero] out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents) out_primals, nz_out_tangents = split_list(out, [len(out_zeros)]) nz_out_tangents_ = iter(nz_out_tangents) out_tangents = [SymbolicZero(core.get_aval(p).at_least_vspace()) if z else next(nz_out_tangents_) for p, z in zip(out_primals, out_zeros)] assert next(nz_out_tangents_, None) is None primal_errs = xs[num_consts:num_consts+num_errs] tangent_errs = xs[n+num_consts:n+num_consts+num_errs] return [*primal_errs, *out_primals, *tangent_errs, *out_tangents] return jvp def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr, fwd_jaxpr_thunk, num_consts, bwd, out_trees, symbolic_zeros): err_vals, err_tree = jtu.tree_flatten(in_err) num_errs = err_tree.num_leaves checkified_fun = lu.wrap_init( functools.partial(checkify_jaxpr_flat, fun_jaxpr.jaxpr, fun_jaxpr.consts, enabled_errors, err_tree)) checkified_fun, fun_metadata = _flatten_and_get_error_metadata_thunk( checkified_fun) @lu.wrap_init def checkified_fwd(*args): # TODO(lenamartens, sharadmv): why not checkify here? xs, zeros = args[::2], args[1::2] xs, zeros = xs[num_errs:], zeros[num_errs:] fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk(*zeros) xs_without_consts = xs[num_consts:] return core.eval_jaxpr(fwd_jaxpr, fwd_consts, *xs_without_consts) bwd_ = lambda *args: (*(None,)*num_errs, *bwd(*args)) checkified_fwd, fwd_out_tree = flatten_fun_output(checkified_fwd) all_outs = custom_derivatives.custom_vjp_call_p.bind( checkified_fun, checkified_fwd, bwd_, *err_vals, *in_vals, out_trees=out_trees, symbolic_zeros=symbolic_zeros) fst, out_metadata = lu.merge_linear_aux(fun_metadata, fwd_out_tree) if fst: err_and_out_tree, _ = out_metadata out_err, out_vals = tree_unflatten(err_and_out_tree, all_outs) else: out_err, out_vals = in_err, all_outs return out_err, out_vals error_checks[custom_derivatives.custom_vjp_call_jaxpr_p] = custom_vjp_call_jaxpr_rule def check_discharge_rule(error, enabled_errors, *args, err_tree, debug): del debug new_error = tree_unflatten(err_tree, args) # Split up new_error into error to be functionalized if it's included in # enabled_errors (=discharged_error) and an error to be defunctionalized if # it's not included (=recharged_error) discharged_error = error recharged_error = init_error for error_effect in new_error._pred.keys(): pred = new_error._pred[error_effect] code = new_error._code[error_effect] payload = new_error._payload[error_effect] if error_effect.error_type in enabled_errors: discharged_error = update_error(discharged_error, pred, code, {}, payload, error_effect) else: recharged_error = update_error(recharged_error, pred, code, {}, payload, error_effect) discharged_error = discharged_error._replace( _metadata={**new_error._metadata, **discharged_error._metadata}) recharged_error = recharged_error._replace(_metadata=new_error._metadata) # TODO(lenamartens): we actually need to recharge, but this would be a # breaking API change so leaving for a follow-up. # check_error(recharged_error) return discharged_error, [] error_checks[check_p] = check_discharge_rule ## checkify public api user_checks = frozenset({FailedCheckError}) nan_checks = frozenset({NaNError}) index_checks = frozenset({OOBError}) div_checks = frozenset({DivisionByZeroError}) float_checks = nan_checks | div_checks automatic_checks = float_checks | index_checks all_checks = automatic_checks | user_checks
[docs] def checkify(f: Callable[..., Out], errors: frozenset[ErrorCategory] = user_checks ) -> Callable[..., tuple[Error, Out]]: """Functionalize `check` calls in `fun`, and optionally add run-time error checks. Run-time errors are either user-added :func:`~check` assertions, or automatically added checks like NaN checks, depending on the ``errors`` argument. The returned function will return an Error object `err` along with the output of the original function. ``err.get()`` will either return ``None`` (if no error occurred) or a string containing an error message. This error message will correspond to the first error which occurred. ``err.throw()`` will raise a ValueError with the error message if an error occurred. By default only user-added :func:`~check` assertions are enabled. You can enable automatic checks through the ``errors`` argument. The automatic check sets which can be enabled, and when an error is generated: - ``user_checks``: a :func:`~check` evaluated to False. - ``nan_checks``: a floating-point operation generated a NaN value as output. - ``div_checks``: a division by zero. - ``index_checks``: an index was out-of-bounds. Multiple categories can be enabled together by passing in an error `Set` (eg. ``errors=nan_checks``). Multiple sets can be re-combined (eg. ``errors=float_checks|user_checks``) Args: fun: Callable which can contain user checks (see :func:`~check`). errors: A set of ErrorCategory values which defines the set of enabled checks. By default only explicit ``checks`` are enabled (``user_checks``). You can also for example enable NAN and DIV errors by passing the ``float_checks`` set, or for example combine multiple sets through set operations (``float_checks | user_checks``) Returns: A function which accepts the same arguments as ``fun`` and returns as output a pair where the first element is an ``Error`` value, representing the first failed :func:`~check`, and the second element is the original output of ``fun``. For example: >>> import jax >>> import jax.numpy as jnp >>> from jax.experimental import checkify >>> >>> @jax.jit ... def f(x): ... y = jnp.sin(x) ... return x+y >>> err, out = checkify.checkify(f, errors=checkify.float_checks)(jnp.inf) >>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... jax._src.checkify.JaxRuntimeError: nan generated by primitive: sin """ @traceback_util.api_boundary def checked_fun(*args, **kwargs): # close over all arguments so they're not turned into abstract values. in_tree = jtu.tree_structure(((), {})) closed_f = lambda: f(*args, **kwargs) # stage: fun_, out_tree = flatten_fun(lu.wrap_init(closed_f), in_tree) debug = pe.debug_info(closed_f, in_tree, out_tree, False, 'checkify') jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, (), debug) jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_)) # checkify: error, out_flat = checkify_jaxpr(jaxpr, errors, init_error, *consts) return error, jtu.tree_unflatten(out_tree(), out_flat) return checked_fun
[docs] def check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None: """Check a predicate, add an error with msg if predicate is False. This is an effectful operation, and can't be staged (jitted/scanned/...). Before staging a function with checks, :func:`~checkify` it! Args: pred: if False, a FailedCheckError error is added. msg: error message if error is added. Can be a format string. fmt_args, fmt_kwargs: Positional and keyword formatting arguments for `msg`, eg.: ``check(.., "check failed on values {} and {named_arg}", x, named_arg=y)`` Note that these arguments can be traced values allowing you to add run-time values to the error message. Note that tracking these run-time arrays will increase your memory usage, even if no error happens. For example: >>> import jax >>> import jax.numpy as jnp >>> from jax.experimental import checkify >>> def f(x): ... checkify.check(x>0, "{x} needs to be positive!", x=x) ... return 1/x >>> checked_f = checkify.checkify(f) >>> err, out = jax.jit(checked_f)(-3.) >>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... jax._src.checkify.JaxRuntimeError: -3. needs to be positive! """ _check(pred, msg, False, *fmt_args, **fmt_kwargs)
def _check(pred, msg, debug, *fmt_args, **fmt_kwargs): if not is_scalar_pred(pred): prim_name = 'debug_check' if debug else 'check' raise TypeError(f'{prim_name} takes a scalar pred as argument, got {pred}') for arg in jtu.tree_leaves((fmt_args, fmt_kwargs)): if not isinstance(arg, (Array, np.ndarray)): raise TypeError('Formatting arguments to checkify.check need to be ' 'PyTrees of arrays, but got ' f'{arg!r} of type {type(arg)}.') new_error = FailedCheckError(get_traceback(), msg, *fmt_args, **fmt_kwargs) error = assert_func(init_error, jnp.logical_not(pred), new_error) _check_error(error, debug=debug) def _check_error(error, *, debug=False): if any(map(np.shape, error._pred.values())): error = _reduce_any_error(error) err_args, tree_def = tree_flatten(error) return check_p.bind(*err_args, err_tree=tree_def, debug=debug) def is_scalar_pred(pred) -> bool: return (isinstance(pred, bool) or isinstance(pred, Array) and pred.shape == () and pred.dtype == jnp.dtype('bool')) def debug_check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None: """Check a predicate when running under checkify, otherwise is a no-op. A `debug_check` will only be run if it is transformed by :func:`~checkify`, otherwise the check will be dropped. Args: pred: if False, a FailedCheckError error is added. msg: error message if error is added. fmt_args, fmt_kwargs: Positional and keyword formatting arguments for `msg`, eg.: ``debug_check(.., "check failed on values {} and {named}", x, named=y)`` Note that these arguments can be traced values allowing you to add run-time values to the error message. Note that tracking these run-time arrays will increase your memory usage, even if no error happens. For example: >>> import jax >>> import jax.numpy as jnp >>> from jax.experimental import checkify >>> def f(x): ... checkify.debug_check(x!=0, "cannot be zero!") ... return x >>> _ = f(0) # running without checkify means no debug_check is run. >>> checked_f = checkify.checkify(f) >>> err, out = jax.jit(checked_f)(0) # running with checkify runs debug_check. >>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... jax._src.checkify.JaxRuntimeError: cannot be zero! """ _check(pred, msg, True, *fmt_args, **fmt_kwargs)
[docs] def check_error(error: Error) -> None: """Raise an Exception if ``error`` represents a failure. Functionalized by :func:`~checkify`. The semantics of this function are equivalent to: >>> def check_error(err: Error) -> None: ... err.throw() # can raise ValueError But unlike that implementation, ``check_error`` can be functionalized using the :func:`~checkify` transformation. This function is similar to :func:`~check` but with a different signature: whereas :func:`~check` takes as arguments a boolean predicate and a new error message string, this function takes an ``Error`` value as argument. Both :func:`~check` and this function raise a Python Exception on failure (a side-effect), and thus cannot be staged out by :func:`~jax.jit`, :func:`~jax.pmap`, :func:`~jax.lax.scan`, etc. Both also can be functionalized by using :func:`~checkify`. But unlike :func:`~check`, this function is like a direct inverse of :func:`~checkify`: whereas :func:`~checkify` takes as input a function which can raise a Python Exception and produces a new function without that effect but which produces an ``Error`` value as output, this ``check_error`` function can accept an ``Error`` value as input and can produce the side-effect of raising an Exception. That is, while :func:`~checkify` goes from functionalizable Exception effect to error value, this ``check_error`` goes from error value to functionalizable Exception effect. ``check_error`` is useful when you want to turn checks represented by an ``Error`` value (produced by functionalizing ``checks`` via :func:`~checkify`) back into Python Exceptions. Args: error: Error to check. For example, you might want to functionalize part of your program through checkify, stage out your functionalized code through :func:`~jax.jit`, then re-inject your error value outside of the :func:`~jax.jit`: >>> import jax >>> from jax.experimental import checkify >>> def f(x): ... checkify.check(x>0, "must be positive!") ... return x >>> def with_inner_jit(x): ... checked_f = checkify.checkify(f) ... # a checkified function can be jitted ... error, out = jax.jit(checked_f)(x) ... checkify.check_error(error) ... return out >>> _ = with_inner_jit(1) # no failed check >>> with_inner_jit(-1) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... jax._src.JaxRuntimeError: must be positive! >>> # can re-checkify >>> error, _ = checkify.checkify(with_inner_jit)(-1) """ if not isinstance(error, Error): raise ValueError('check_error takes an Error as argument, ' f'got type {type(error)} instead.') _check_error(error, debug=False)