jax.debug.breakpoint

Contents

jax.debug.breakpoint#

jax.debug.breakpoint(*, backend=None, filter_frames=True, num_frames=None, ordered=False, token=None, **kwargs)[source]#

Enters a breakpoint at a point in a program.

Parameters:
  • backend (str | None) – The debugger backend to use. By default, picks the highest priority debugger and in the absence of other registered debuggers, falls back to the CLI debugger.

  • filter_frames (bool) – Whether or not to filter out JAX-internal stack frames from the traceback. Since some libraries, like Flax, also make user of JAX’s stack frame filtering system, this option can also affect whether stack frames from libraries are filtered.

  • num_frames (int | None) – The number of frames above the current stack frame to make available for inspection in the interactive debugger.

  • ordered (bool) – A keyword only argument used to indicate whether or not the staged out computation will enforce ordering of this jax.debug.breakpoint with respect to other ordered jax.debug.breakpoint and jax.debug.print calls.

  • token – A keyword only argument; an alternative to ordered. If used then a JAX array (or pytree of JAX arrays) should be passed, and the breakpoint will be run once its value is computed. This is returned unchanged, and should be passed back to the computation. If the return value is unused in the later computation, then the whole computation will be pruned and this breakpoint will not be run.

Returns:

If token is passed, then its value is returned unchanged. Otherwise, returns None.