jax.profiler.stop_trace

jax.profiler.stop_trace()[source]

Stops the currently-running profiler trace.

The trace will be saved to the log_dir passed to the corresponding start_trace() call. Raises a RuntimeError if a trace hasn’t been started.