jax.profiler.TraceAnnotation

jax.profiler.TraceAnnotation#

class jax.profiler.TraceAnnotation[source]#

Context manager that generates a trace event in the profiler.

The trace event spans the duration of the code enclosed by the context.

For example:

>>> x = jnp.ones((1000, 1000))
>>> with jax.profiler.TraceAnnotation("my_label"):
...   result = jnp.dot(x, x.T).block_until_ready()

This will cause a “my_label” event to show up on the trace timeline if the event occurs while the process is being traced.

__init__(self, arg0: str, /, **kwargs) None#

Attributes

is_enabled

set_metadata