jax.profiler.device_memory_profile#

jax.profiler.device_memory_profile(backend=None)[源代码]#

捕获 JAX 设备内存配置作为 pprof 格式的协议缓冲区。

设备内存配置是内存状态的快照,描述了 JAX Array 和存在于内存中的可执行对象及其分配位置。

有关如何使用设备内存配置器的更多信息,请参阅分析设备内存

该分析系统通过检测 JAX 设备上分配来实现,从而捕获每次分配的 Python 堆栈跟踪。 该检测始终启用; device_memory_profile() 提供了一个 API 来捕获它。

device_memory_profile() 的输出是一个二进制协议缓冲区,可以使用 pprof 工具进行解释和可视化。

参数:

backend ( str | None ) – 可选;应该为其收集设备内存配置文件的 JAX 后端的名称。

返回:

包含二进制 pprof 格式协议缓冲区的字节字符串。

返回类型:

bytes