jax.debug
module#
Runtime value debugging utilities#
jax.debug.print and jax.debug.breakpoint describes how to make use of JAX’s runtime value debugging features.
|
Calls a stageable Python callback. |
|
Prints values and works in staged out JAX functions. |
|
Enters a breakpoint at a point in a program. |