jax.debug.print

Contents

jax.debug.print#

jax.debug.print(fmt, *args, ordered=False, **kwargs)#

Prints values and works in staged out JAX functions.

This function does not work with f-strings because formatting is delayed. So instead of jax.debug.print(f"hello {bar}"), write jax.debug.print("hello {bar}", bar=bar).

This function is a thin convenience wrapper around jax.debug.callback(). The implementation is essentially:

def debug_print(fmt: str, *args, **kwargs):
  jax.debug.callback(
      lambda *args, **kwargs: print(fmt.format(*args, **kwargs)),
      *args, **kwargs)

It may be useful to call jax.debug.callback() directly instead of this convenience wrapper. For example, to get debug printing in logs, you might use jax.debug.callback() together with logging.log.

Parameters:
  • fmt (str) – A format string, e.g. "hello {x}", that will be used to format input arguments, like str.format. See the Python docs on string formatting and format string syntax.

  • *args – A list of positional arguments to be formatted, as if passed to fmt.format.

  • ordered (bool) – A keyword only argument used to indicate whether or not the staged out computation will enforce ordering of this jax.debug.print w.r.t. other ordered jax.debug.print calls.

  • **kwargs – Additional keyword arguments to be formatted, as if passed to fmt.format.

Return type:

None