jax.debug package#

Debugging utilities#

jax.debug.print and jax.debug.breakpoint describes how to make use of JAX’s debugging features.

callback(callback, *args[, ordered])

Calls a stageable Python callback.

print(fmt, *args[, ordered])

Prints values and works in staged out JAX functions.

breakpoint(*[, backend, filter_frames, ...])

Enters a breakpoint at a point in a program.