jax.profiler.StepTraceContext

class jax.profiler.StepTraceContext(*args, **kwargs)[source]
__init__(self: jaxlib.xla_extension.profiler.TraceMe, arg0: str, **kwargs)None[source]

Methods

__init__(self, arg0, **kwargs)

is_enabled()

set_metadata(self, **kwargs)