jax.debug.print#

jax.debug.print(fmt, *args, ordered=False, **kwargs)#

Prints values and works in staged out JAX functions.

Parameters
  • fmt (str) – A format string, e.g. "hello {x}", that will be used to format input arguments.

  • *args – A list of positional arguments to be formatted.

  • ordered (bool) – A keyword only argument used to indicate whether or not the staged out computation will enforce ordering of this debug_print w.r.t. other ordered debug_print calls.

  • **kwargs – Additional keyword arguments to be formatted.

Return type

None