jax.profiler module

Tracing and time profiling

Profiling JAX programs describes how to make use of JAX’s tracing and time profiling features.

start_server(port)

Starts a profiler server on port port.

start_trace(log_dir)

Starts a profiler trace.

stop_trace()

Stops the currently-running profiler trace.

trace(log_dir)

Context manager to take a profiler trace.

annotate_function(func[, name])

Decorator that generates a trace event for the execution of a function.

TraceAnnotation

Context manager that generates a trace event in the profiler.

StepTraceAnnotation(name, **kwargs)

Context manager that generates a step trace event in the profiler.

Device memory profiling

See Device Memory Profiling for an introduction to JAX’s device memory profiling features.

device_memory_profile([backend])

Captures a JAX device memory profile as pprof-format protocol buffer.

save_device_memory_profile(filename[, backend])

Collects a device memory profile and writes it to a file.

Deprecated functions

trace_function(*args, **kwargs)

TraceContext(*args, **kwargs)

StepTraceContext(*args, **kwargs)