jax.profiler module#

Tracing and time profiling#

Profiling JAX programs describes how to make use of JAX’s tracing and time profiling features.


Starts the profiler server on port port.

start_trace(log_dir[, create_perfetto_link, ...])

Starts a profiler trace.


Stops the currently-running profiler trace.

trace(log_dir[, create_perfetto_link, ...])

Context manager to take a profiler trace.

annotate_function(func[, name])

Decorator that generates a trace event for the execution of a function.


Context manager that generates a trace event in the profiler.

StepTraceAnnotation(name, **kwargs)

Context manager that generates a step trace event in the profiler.

Device memory profiling#

See Device Memory Profiling for an introduction to JAX’s device memory profiling features.


Captures a JAX device memory profile as pprof-format protocol buffer.

save_device_memory_profile(filename[, backend])

Collects a device memory profile and writes it to a file.