class jax.profiler.TraceAnnotation[source]#

Context manager that generates a trace event in the profiler.

The trace event spans the duration of the code enclosed by the context.

For example:

>>> x = jnp.ones((1000, 1000))
>>> with jax.profiler.TraceAnnotation("my_label"):
...   result = jnp.dot(x, x.T).block_until_ready()

This will cause a “my_label” event to show up on the trace timeline if the event occurs while the process is being traced.

__init__(self: jaxlib.xla_extension.profiler.TraceMe, arg0: str, **kwargs) None#


__init__(self, arg0, **kwargs)


set_metadata(self, **kwargs)