设备内存分析#

注意

2025 年 6 月更新:我们建议使用 XProf 分析 进行设备内存分析。在进行分析后,打开 Tensorboard 分析器的 memory_viewer 选项卡,以获得更详细和易于理解的设备内存使用情况。

JAX 设备内存分析器使我们能够探索 JAX 程序如何以及为何使用 GPU 或 TPU 内存。例如,它可以用于

  • 找出在给定时间哪些数组和可执行文件位于 GPU 内存中,或

  • 追踪内存泄漏。

安装#

JAX 设备内存分析器会产生可使用 pprof (google/pprof) 进行解释的输出。首先,请按照其 安装说明 安装 pprof。在撰写本文时,安装 pprof 需要先安装 1.16+ 版本的 GoGraphviz,然后运行

go install github.com/google/pprof@latest

这将把 pprof 安装为 $GOPATH/bin/pprof,其中 GOPATH 默认为 ~/go

注意

google/pprof 提供的 pprof 版本与作为 gperftools 包一部分分发的同名旧工具不同。 gperftools 版本的 pprof 将无法与 JAX 一起使用。

了解 JAX 程序如何使用 GPU 或 TPU 内存#

设备内存分析器的一个常见用途是找出 JAX 程序为何使用大量 GPU 或 TPU 内存,例如在尝试调试内存不足问题时。

要将设备内存配置文件保存到磁盘,请使用 jax.profiler.save_device_memory_profile()。例如,考虑以下 Python 程序

import jax
import jax.numpy as jnp
import jax.profiler

def func1(x):
  return jnp.tile(x, 10) * 0.5

def func2(x):
  y = func1(x)
  return y, jnp.tile(x, 10) + 1

x = jax.random.normal(jax.random.key(42), (1000, 1000))
y, z = func2(x)

z.block_until_ready()

jax.profiler.save_device_memory_profile("memory.prof")

如果我们首先运行上面的程序,然后执行

pprof --web memory.prof

pprof 会打开一个网页浏览器,其中包含调用图格式的设备内存配置文件的以下可视化效果。

Device memory profiling example

调用图是 Python 堆栈在分配每个活动缓冲区时所处状态的可视化。例如,在本例中,可视化显示 func2 及其被调用函数分配了 76.30MB,其中 38.15MB 是在 func1func2 的调用中分配的。有关如何解释调用图可视化的更多信息,请参阅 pprof 文档

通过 jax.jit() 编译的函数对设备内存分析器来说是不可见的。也就是说,在 jit 编译的函数内部分配的任何内存都将被归因于该函数整体。

在示例中,调用 block_until_ready() 是为了确保在收集设备内存配置文件之前 func2 已完成。有关更多详细信息,请参阅 异步调度

调试内存泄漏#

我们还可以使用 JAX 设备内存分析器通过使用 pprof 可视化在不同时间拍摄的两个设备内存配置文件之间的内存使用量变化来追踪内存泄漏。例如,考虑以下程序,该程序将 JAX 数组累积到一个不断增长的 Python 列表中。

import jax
import jax.numpy as jnp
import jax.profiler

def afunction():
  return jax.random.normal(jax.random.key(77), (1000000,))

z = afunction()

def anotherfunc():
  arrays = []
  for i in range(1, 10):
    x = jax.random.normal(jax.random.key(42), (i, 10000))
    arrays.append(x)
    x.block_until_ready()
    jax.profiler.save_device_memory_profile(f"memory{i}.prof")

anotherfunc()

如果我们仅仅可视化执行结束时的设备内存配置文件 (memory9.prof),可能不会明显看出 anotherfunc 中循环的每次迭代都会累积更多的设备内存分配。

pprof --web memory9.prof

Device memory profile at end of execution

afunction 中,大量但固定的分配会主导整个配置文件,但不会随时间增长。

通过使用 pprof--diff_base 功能 来可视化循环迭代中的内存使用量变化,我们可以确定程序的内存使用量为何随时间增加。

pprof --web --diff_base memory1.prof memory9.prof

Device memory profile at end of execution

可视化显示,内存增长可以归因于 anotherfunc 中对 normal 的调用。