jax.debug
module#
Runtime value debugging utilities#
Compiled prints and breakpoints describes how to make use of JAX’s runtime value debugging features.
|
Calls a stageable Python callback. |
|
Prints values and works in staged out JAX functions. |
|
Enters a breakpoint at a point in a program. |