GPU 内存分配#

当第一次 JAX 操作运行时,JAX 将预分配 75% 的 GPU 总内存。 预分配可以最大限度地减少分配开销和内存碎片,但有时可能会导致内存不足 (OOM) 错误。如果您的 JAX 进程因 OOM 而失败,可以使用以下环境变量来覆盖默认行为

XLA_PYTHON_CLIENT_PREALLOCATE=false

这会禁用预分配行为。JAX 将根据需要分配 GPU 内存,从而可能降低整体内存使用量。然而,这种行为更容易导致 GPU 内存碎片化,这意味着禁用了预分配的 JAX 程序在使用大部分可用 GPU 内存时可能会发生 OOM。

XLA_PYTHON_CLIENT_MEM_FRACTION=.XX

如果启用了预分配,这将使 JAX 预分配总 GPU 内存的 XX%,而不是默认的 75%。降低预分配量可以解决 JAX 程序启动时发生的 OOM 问题。

XLA_PYTHON_CLIENT_ALLOCATOR=platform

这使得 JAX 能够按需精确分配所需内存,并释放不再需要的内存(请注意,这是唯一会释放 GPU 内存而不是重用它的配置)。这种方式非常慢,因此不建议普遍使用,但对于以最小的 GPU 内存占用运行或调试 OOM 故障可能会很有用。

常见的 OOM 故障原因#

同时运行多个 JAX 进程。

要么使用 XLA_PYTHON_CLIENT_MEM_FRACTION 为每个进程分配适当的内存量,要么设置 XLA_PYTHON_CLIENT_PREALLOCATE=false

同时运行 JAX 和 GPU TensorFlow。

TensorFlow 也默认预分配内存,因此这类似于同时运行多个 JAX 进程。

一种解决方案是仅使用 CPU 版本的 TensorFlow(例如,如果您只使用 TF 进行数据加载)。您可以使用命令 tf.config.experimental.set_visible_devices([], "GPU") 阻止 TensorFlow 使用 GPU。

或者,使用 XLA_PYTHON_CLIENT_MEM_FRACTIONXLA_PYTHON_CLIENT_PREALLOCATE。还有类似选项可以配置 TensorFlow 的 GPU 内存分配(TF1 中的 gpu_memory_fractionallow_growth,应在传递给 tf.Sessiontf.ConfigProto 中设置。TF2 请参阅 使用 GPU:限制 GPU 内存增长)。

在显示 GPU 上运行 JAX。

使用 XLA_PYTHON_CLIENT_MEM_FRACTIONXLA_PYTHON_CLIENT_PREALLOCATE

禁用重物化 HLO 传递

有时,禁用自动重物化 HLO 传递有利于避免编译器做出糟糕的重物化选择。通过分别设置 jax.config.update('jax_compiler_enable_remat_pass', True)jax.config.update('jax_compiler_enable_remat_pass', False) 可以启用/禁用该传递。启用或禁用自动重物化传递会在计算和内存之间产生不同的权衡。然而请注意,该算法是基本的,通常可以通过禁用自动重物化传递并使用 jax.remat API 手动执行来获得计算和内存之间更好的权衡。

实验性功能#

此处的功能是实验性的,必须谨慎尝试。

TF_GPU_ALLOCATOR=cuda_malloc_async

这会将 XLA 自己的 BFC 内存分配器替换为 cudaMallocAsync。这将取消大型固定预分配,并使用一个可增长的内存池。预期的好处是无需设置 XLA_PYTHON_CLIENT_MEM_FRACTION

风险包括

  • 内存碎片化情况不同,因此如果接近内存限制,因碎片化导致的具体 OOM 情况也会不同。

  • 分配时间不会在开始时全部支付,而是在内存池需要增加时产生。因此,您可能会在开始时体验到较差的速度稳定性,对于基准测试而言,忽略前几次迭代会更加重要。

通过预分配一个显著的内存块,仍然可以获得可增长内存池的好处,从而降低风险。这可以通过 TF_CUDA_MALLOC_ASYNC_SUPPORTED_PREALLOC=N 来实现。如果 N 为 -1,它将预分配与默认分配量相同的大小。否则,它就是您希望预分配的字节大小。