jax.profiler.trace¶
-
jax.profiler.
trace
(log_dir)[source]¶ Context manager to take a profiler trace.
The trace will capture CPU, GPU, and/or TPU activity, including Python functions and JAX on-device operations.
The resulting trace can be viewed with TensorBoard. Note that TensorBoard doesn’t need to be running when collecting the trace.
Only once trace may be collected a time. A RuntimeError will be raised if a trace is started while another trace is running.
- Parameters
log_dir – The directory to save the profiler trace to (usually the TensorBoard log directory).