jax.profiler.device_memory_profile#
- jax.profiler.device_memory_profile(backend=None)[source]#
Captures a JAX device memory profile as
pprof
-format protocol buffer.A device memory profile is a snapshot of the state of memory, that describes the JAX
Array
and executable objects present in memory and their allocation sites.For more information how to use the device memory profiler, see Profiling device memory.
The profiling system works by instrumenting JAX on-device allocations, capturing a Python stack trace for each allocation. The instrumentation is always enabled;
device_memory_profile()
provides an API to capture it.The output of
device_memory_profile()
is a binary protocol buffer that can be interpreted and visualized by the pprof tool.