jax.profiler.save_device_memory_profile#

jax.profiler.save_device_memory_profile(filename, backend=None)[源]#

收集设备内存配置文件并将其写入文件。

save_device_memory_profile() 是一个方便的封装函数,封装了 device_memory_profile(),将其输出保存到 filename。有关更多信息,请参阅 device_memory_profile() 文档。

参数:
  • filename – 应写入配置文件的文件名。

  • backend (str | None) – 可选;应为其收集设备内存配置文件的 JAX 后端名称。

返回类型: