jax.debug module#

Runtime value debugging utilities#

jax.debug.print and jax.debug.breakpoint describes how to make use of JAX’s runtime value debugging features.

callback(callback, *args[, ordered])

Calls a stageable Python callback.

print(fmt, *args[, ordered])

Prints values and works in staged out JAX functions.

breakpoint(*[, backend, filter_frames, ...])

Enters a breakpoint at a point in a program.

Sharding debugging utilities#

Functions that enable inspecting and visualizing array shardings inside (and outside) staged functions.

inspect_array_sharding(value, *, callback)

Enables inspecting array sharding inside JIT-ted functions.

visualize_array_sharding(arr, **kwargs)

Visualizes an array's sharding.

visualize_sharding(shape, sharding, *[, ...])

Visualizes a Sharding using rich.