jax.profiler module

Tracing and time profiling

Profiling JAX programs describes how to make use of JAX’s tracing and time profiling features.


Starts a profiler server on port port.

trace_function(func[, name])

Decorator that generates a trace event for the execution of a function.


Context manager generates a trace event in the profiler.

Device memory profiling

See Device Memory Profiling for an introduction to JAX’s device memory profiling features.


Captures a JAX device memory profile as pprof-format protocol buffer.

save_device_memory_profile(filename[, backend])

Collects a device memory profile and writes it to a file.