jax.profiler.start_trace

jax.profiler.start_trace(log_dir)[source]

Starts a profiler trace.

The trace will capture CPU, GPU, and/or TPU activity, including Python functions and JAX on-device operations. Use stop_trace() to end the trace and save the results to log_dir.

The resulting trace can be viewed with TensorBoard. Note that TensorBoard doesn’t need to be running when collecting the trace.

Only once trace may be collected a time. A RuntimeError will be raised if start_trace() is called while another trace is running.

Parameters

log_dir – The directory to save the profiler trace to (usually the TensorBoard log directory).