Advanced JAX Tutorials
Collects a device memory profile and writes it to a file.
save_device_memory_profile() is a convenience wrapper around device_memory_profile()
that saves its output to a filename. See the
device_memory_profile() documentation for more information.
filename – the filename to which the profile should be written.
backend (Optional[str]) – optional; the name of the JAX backend for which the device memory
profile should be collected.