jax.debug.breakpoint#

jax.debug.breakpoint(*, backend=None, filter_frames=True, num_frames=None, ordered=False, **kwargs)[source]#

Enters a breakpoint at a point in a program.

Parameters
  • backend (Optional[str]) – The debugger backend to use. By default, picks the highest priority debugger and in the absence of other registered debuggers, falls back to the CLI debugger.

  • filter_frames (bool) – Whether or not to filter out JAX-internal stack frames from the traceback. Since some libraries, like Flax, also make user of JAX’s stack frame filtering system, this option can also affect whether stack frames from libraries are filtered.

  • num_frames (Optional[int]) – The number of frames above the current stack frame to make available for inspection in the interactive debugger.

  • ordered (bool) – A keyword only argument used to indicate whether or not the staged out computation will enforce ordering of this debug_print with respect to other ordered debug_print calls.

Returns

None.