jax.profiler.annotate_functionΒΆ

jax.profiler.annotate_function(func, name=None, **kwargs)[source]ΒΆ

Decorator that generates a trace event for the execution of a function.

For example:

>>> @jax.profiler.annotate_function
... def f(x):
...   return jnp.dot(x, x.T).block_until_ready()
>>>
>>> result = f(jnp.ones((1000, 1000)))

This will cause an β€œf” event to show up on the trace timeline if the function execution occurs while the process is being traced by TensorBoard.

Arguments can be passed to the decorator via functools.partial().

>>> from functools import partial
>>> @partial(jax.profiler.annotate_function, name="event_name")
... def f(x):
...   return jnp.dot(x, x.T).block_until_ready()
>>> result = f(jnp.ones((1000, 1000)))
Parameters