Captures a JAX device memory profile as
pprof-format protocol buffer.
A device memory profile is a snapshot of the state of memory, that describes the JAX
jax.DeviceArrayand executable objects present in memory and their allocation sites.
For more information how to use the device memory profiler, see Device Memory Profiling.
The profiling system works by instrumenting JAX on-device allocations, capturing a Python stack trace for each allocation. The instrumentation is always enabled;
device_memory_profile()provides an API to capture it.
The output of
device_memory_profile()is a binary protocol buffer that can be interpreted and visualized by the pprof tool.