jax.profiler module#

Tracing and time profiling#

Profiling JAX programs describes how to make use of JAX’s tracing and time profiling features.


Starts a profiler server on port port.


Starts a profiler trace.


Stops the currently-running profiler trace.


Context manager to take a profiler trace.

annotate_function(func[, name])

Decorator that generates a trace event for the execution of a function.


Context manager that generates a trace event in the profiler.

StepTraceAnnotation(name, **kwargs)

Context manager that generates a step trace event in the profiler.

Device memory profiling#

See Device Memory Profiling for an introduction to JAX’s device memory profiling features.


Captures a JAX device memory profile as pprof-format protocol buffer.

save_device_memory_profile(filename[, backend])

Collects a device memory profile and writes it to a file.

Deprecated functions#

trace_function(*args, **kwargs)

TraceContext(*args, **kwargs)

StepTraceContext(*args, **kwargs)