jax.profiler.save_device_memory_profile(filename, backend=None)[source]#

Collects a device memory profile and writes it to a file.

save_device_memory_profile() is a convenience wrapper around device_memory_profile() that saves its output to a filename. See the device_memory_profile() documentation for more information.

  • filename – the filename to which the profile should be written.

  • backend (Optional[str]) – optional; the name of the JAX backend for which the device memory profile should be collected.

Return type