GPU 内存分配#
当首次运行 JAX 操作时,JAX 将预分配 75% 的 GPU 总内存。 预分配最大限度地减少了分配开销和内存碎片,但有时会导致内存不足 (OOM) 错误。如果您的 JAX 进程因 OOM 失败,可以使用以下环境变量来覆盖默认行为
XLA_PYTHON_CLIENT_PREALLOCATE=false
这会禁用预分配行为。JAX 将改为根据需要分配 GPU 内存,从而可能减少总体内存使用量。但是,此行为更容易导致 GPU 内存碎片,这意味着使用大部分可用 GPU 内存的 JAX 程序在禁用预分配的情况下可能会出现 OOM。
XLA_PYTHON_CLIENT_MEM_FRACTION=.XX
如果启用了预分配,这将使 JAX 预分配 XX% 的 GPU 总内存,而不是默认的 75%。降低预分配量可以修复 JAX 程序启动时发生的 OOM。
XLA_PYTHON_CLIENT_ALLOCATOR=platform
这使得 JAX 按需分配所需的确切内存,并释放不再需要的内存(请注意,这是唯一会释放 GPU 内存而不是重用它的配置)。这非常慢,因此不建议用于一般用途,但可能适用于以最小可能的 GPU 内存占用量运行或调试 OOM 失败。
OOM 失败的常见原因#
- 并发运行多个 JAX 进程。
可以使用
XLA_PYTHON_CLIENT_MEM_FRACTION
为每个进程提供适当的内存量,或者设置XLA_PYTHON_CLIENT_PREALLOCATE=false
。- 并发运行 JAX 和 GPU TensorFlow。
TensorFlow 也默认进行预分配,因此这类似于并发运行多个 JAX 进程。
一种解决方案是使用仅 CPU 的 TensorFlow(例如,如果您仅使用 TF 进行数据加载)。您可以使用命令
tf.config.experimental.set_visible_devices([], "GPU")
阻止 TensorFlow 使用 GPU或者,使用
XLA_PYTHON_CLIENT_MEM_FRACTION
或XLA_PYTHON_CLIENT_PREALLOCATE
。还有类似的选项可以配置 TensorFlow 的 GPU 内存分配(TF1 中的 gpu_memory_fraction 和 allow_growth,应在传递给tf.Session
的tf.ConfigProto
中设置)。有关 TF2,请参阅 使用 GPU:限制 GPU 内存增长)。- 在显示 GPU 上运行 JAX。
使用
XLA_PYTHON_CLIENT_MEM_FRACTION
或XLA_PYTHON_CLIENT_PREALLOCATE
。- 禁用重物化 HLO 传递
有时禁用自动重物化 HLO 传递有利于避免编译器做出较差的重物化选择。可以通过分别设置
jax.config.update('enable_remat_opt_pass', True)
或jax.config.update('enable_remat_opt_pass', False)
来启用/禁用该传递。启用或禁用自动重物化传递会在计算和内存之间产生不同的权衡。但请注意,该算法是基本的,并且通常可以通过禁用自动重物化传递并使用 jax.remat API 手动执行来获得更好的计算和内存之间的权衡
实验性功能#
此处的特性是实验性的,必须谨慎尝试。
TF_GPU_ALLOCATOR=cuda_malloc_async
这会将 XLA 自己的 BFC 内存分配器替换为 cudaMallocAsync。这将删除大的固定预分配,并使用不断增长的内存池。预期的好处是不需要设置 XLA_PYTHON_CLIENT_MEM_FRACTION。
风险是
内存碎片化有所不同,因此,如果您接近限制,则由于碎片化而导致的精确 OOM 情况将有所不同。
分配时间不会在开始时全部支付,而是在需要增加内存池时产生。因此,您可能会在开始时遇到较少的速度稳定性,并且对于基准测试,忽略前几次迭代将更加重要。
可以通过预先分配一个重要的块来减轻风险,并且仍然可以获得不断增长的内存池的好处。这可以使用 TF_CUDA_MALLOC_ASYNC_SUPPORTED_PREALLOC=N 完成。如果 N 为 -1,它将预分配与默认分配量相同的量。否则,它是您要预分配的大小(以字节为单位)。