jax.profiler.annotate_function#

jax.profiler.annotate_function(func, name=None, **decorator_kwargs)[source]#

Decorator that generates a trace event for the execution of a function.

For example:

>>> @jax.profiler.annotate_function
... def f(x):
...   return jnp.dot(x, x.T).block_until_ready()
>>>
>>> result = f(jnp.ones((1000, 1000)))

This will cause an “f” event to show up on the trace timeline if the function execution occurs while the process is being traced by TensorBoard.

Arguments can be passed to the decorator via functools.partial().

>>> from functools import partial
>>> @partial(jax.profiler.annotate_function, name="event_name")
... def f(x):
...   return jnp.dot(x, x.T).block_until_ready()
>>> result = f(jnp.ones((1000, 1000)))
Parameters