jax.debug.print#
- jax.debug.print(fmt, *args, ordered=False, **kwargs)#
Prints values and works in staged out JAX functions.
Note: This function does not work with f-strings because the formatting is done lazily.
- Parameters
fmt (
str
) – A format string, e.g."hello {x}"
, that will be used to format input arguments.*args – A list of positional arguments to be formatted.
ordered (
bool
) – A keyword only argument used to indicate whether or not the staged out computation will enforce ordering of thisjax.debug.print
w.r.t. other orderedjax.debug.print
calls.**kwargs – Additional keyword arguments to be formatted.
- Return type