jax.profiler.TraceAnnotation#
- class jax.profiler.TraceAnnotation[source]#
Context manager that generates a trace event in the profiler.
The trace event spans the duration of the code enclosed by the context.
For example:
>>> x = jnp.ones((1000, 1000)) >>> with jax.profiler.TraceAnnotation("my_label"): ... result = jnp.dot(x, x.T).block_until_ready()
This will cause a “my_label” event to show up on the trace timeline if the event occurs while the process is being traced.
Attributes
is_enabled
set_metadata