jax.debug package
Contents
jax.debug package#
Debugging utilities#
jax.debug.print and jax.debug.breakpoint describes how to make use of JAX’s debugging features.
|
Calls a stageable Python callback. |
|
Prints values and works in staged out JAX functions. |
|
Enters a breakpoint at a point in a program. |