jax.profiler.device_memory_profile

jax.profiler.device_memory_profile#

jax.profiler.device_memory_profile(backend=None)[source]#

Captures a JAX device memory profile as pprof-format protocol buffer.

A device memory profile is a snapshot of the state of memory, that describes the JAX Array and executable objects present in memory and their allocation sites.

For more information how to use the device memory profiler, see Device Memory Profiling.

The profiling system works by instrumenting JAX on-device allocations, capturing a Python stack trace for each allocation. The instrumentation is always enabled; device_memory_profile() provides an API to capture it.

The output of device_memory_profile() is a binary protocol buffer that can be interpreted and visualized by the pprof tool.

Parameters:

backend (str | None) – optional; the name of the JAX backend for which the device memory profile should be collected.

Return type:

bytes

Returns:

A byte string containing a binary pprof-format protocol buffer.