持久化编译缓存#

JAX 提供一个可选的磁盘缓存用于编译后的程序。如果启用,JAX 会将编译后的程序副本存储在磁盘上,这可以在重复运行相同或类似任务时节省重新编译的时间。

注意:如果编译缓存不在本地文件系统上,则需要安装 etils

pip install etils

用法#

快速开始#

import jax
import jax.numpy as jnp

jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update("jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir")

@jax.jit
def f(x):
  return x + 1

x = jnp.zeros((2, 2))
f(x)

设置缓存目录#

当设置了缓存位置时,编译缓存才会被启用。这应该在第一次编译之前完成。设置位置如下:

(1) 使用环境变量

在 Shell 中,运行脚本之前

export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache"

或者在 Python 脚本的开头

import os
os.environ["JAX_COMPILATION_CACHE_DIR"] = "/tmp/jax_cache"

(2) 使用 jax.config.update()

jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")

(3) 使用 set_cache_dir()

from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir("/tmp/jax_cache")

缓存阈值#

  • jax_persistent_cache_min_compile_time_secs:只有当编译时间超过指定值时,计算结果才会被写入持久化缓存。默认值为 1.0 秒。

  • jax_persistent_cache_min_entry_size_bytes:持久化编译缓存中条目被缓存的最小大小(以字节为单位)

    • -1:禁用大小限制并阻止覆盖。

    • 保持默认值(0)以允许覆盖。覆盖通常会确保最小大小对用于缓存的文件系统是最佳的。

    • > 0:所需的实际最小大小;不允许覆盖。

请注意,函数要被缓存,必须满足这两个条件。

附加缓存#

XLA 支持额外的缓存机制,可以与 JAX 的持久化编译缓存一起启用,以进一步缩短重新编译时间。

  • jax_persistent_cache_enable_xla_caches:可能的值

    • all:启用所有 XLA 缓存功能

    • none:不启用任何额外的 XLA 缓存功能

    • xla_gpu_kernel_cache_file:仅启用内核缓存

    • xla_gpu_per_fusion_autotune_cache_dir:(默认值)仅启用自动调优缓存

Google Cloud#

在 Google Cloud 上运行时,编译缓存可以放置在 Google Cloud Storage (GCS) 存储桶中。我们推荐以下配置:

  • 在与工作负载运行相同的区域创建存储桶。

  • 在与工作负载的虚拟机 (VM) 相同的项目中创建存储桶。确保设置了权限,以便虚拟机可以写入存储桶。

  • 对于较小的工作负载,无需复制。较大的工作负载可以从复制中受益。

  • 将存储桶的默认存储类别设置为“Standard”。

  • 将软删除策略设置为最短:7 天。

  • 将对象生命周期设置为工作负载运行的预期持续时间。例如,如果工作负载预计运行 10 天,则将对象生命周期设置为 10 天。这应该覆盖在整个运行期间发生的重启。对于生命周期条件使用 age,对于操作使用 Delete。详情请参阅对象生命周期管理。如果未设置对象生命周期,缓存将继续增长,因为没有实现驱逐机制。

  • 支持所有加密策略。

假设 gs://jax-cache 是 GCS 存储桶,按如下方式设置缓存位置:

jax.config.update("jax_compilation_cache_dir", "gs://jax-cache")

工作原理#

缓存键是编译函数的签名,包含以下参数:

  • 函数执行的计算,由被哈希的 JAX 函数的非优化 HLO 捕获

  • jaxlib 版本

  • 相关的 XLA 编译标志

  • 设备配置,通常由设备数量和设备拓扑捕获。目前对于 GPU,拓扑仅包含 GPU 名称的字符串表示

  • 用于压缩编译后的可执行文件的压缩算法

  • jax._src.cache_key.custom_hook() 生成的字符串。此函数可以重新分配给用户自定义函数,以便可以更改生成的字符串。默认情况下,此函数始终返回一个空字符串。

多节点缓存#

程序首次运行(持久化缓存冷/空)时,所有进程都将编译,但只有全局通信组中 rank 为 0 的进程才会写入持久化缓存。在后续运行中,所有进程都将尝试从持久化缓存中读取,因此持久化缓存位于共享文件系统(例如:NFS)或远程存储(例如:GFS)中非常重要。如果持久化缓存位于 rank 0 的本地,那么除了 rank 0 之外的所有进程在后续运行中将因编译缓存未命中而再次编译。

在单节点上预编译多节点程序#

JAX 可以在单个节点上使用多节点编译程序填充编译缓存。在单个节点上准备缓存有助于减少集群上昂贵的编译时间。要在单个节点上编译和运行多节点程序,用户可以使用 jax_mock_gpu_topology 配置选项创建伪远程设备。

例如,下面的代码片段指示 JAX 模拟一个包含四个节点的集群,每个节点运行八个进程,每个进程连接到一个 GPU。

jax.config.update("jax_mock_gpu_topology", "4x8x1")

使用此配置填充缓存后,用户可以在四个节点上运行程序,无需重新编译,每个节点八个进程,每个进程一个 GPU。

重要注意事项

  • 运行模拟程序的进程必须具有与将使用缓存的节点相同数量的 GPU 和相同的 GPU 模型。例如,模拟拓扑 8x4x2 必须在具有两个 GPU 的进程中运行。

  • 当运行具有模拟拓扑的程序时,与其他节点通信的结果是未定义的,因此在模拟环境中运行的 JAX 程序的输出很可能是不正确的。

日志缓存活动#

检查持久化编译缓存的具体情况对调试很有帮助。这里有一些开始的建议。

用户可以通过放置以下内容来启用相关源文件的日志记录:

import os
os.environ["JAX_DEBUG_LOG_MODULES"] = "jax._src.compiler,jax._src.lru_cache"

在脚本的开头。或者,您可以使用以下命令更改全局 JAX 日志级别:

import os
os.environ["JAX_LOGGING_LEVEL"] = "DEBUG"
# or locally with
jax.config.update("jax_logging_level", "DEBUG")

检查缓存未命中#

为了检查和理解为什么会出现缓存未命中,JAX 包含一个配置标志,可以启用所有缓存未命中(包括持久化编译缓存未命中)及其解释的日志记录。尽管目前这只实现了追踪缓存未命中,但最终目标是解释所有缓存未命中。可以通过设置以下配置来启用此功能。

jax.config.update("jax_explain_cache_misses", True)

常见问题#

目前已发现以下几个常见问题:

  • 目前,持久化缓存不适用于具有主机回调的函数。在这种情况下,缓存被完全避免。

    • 这是因为 HLO 包含指向回调的指针,即使计算和计算基础设施完全相同,它也会在每次运行中发生变化。

  • 目前,持久化缓存不适用于使用实现了自定义分区 (custom_partitioning) 的原语的函数。

    • 函数的 HLO 包含指向 custom_partitioning 回调的指针,这导致在不同运行中相同的计算产生不同的缓存键。

    • 在这种情况下,缓存仍然进行,但每次都会生成不同的键,使得缓存失效。

规避 custom_partitioning 问题#

如前所述,编译缓存不适用于由实现了 custom_partitioning 的原语组成的函数。但是,可以使用 shard_map 来规避那些实现了 custom_partitioning 的原语,并使编译缓存按预期工作。

假设我们有一个函数 F,它使用实现了 custom_partitioning 的原语 LayerNorm 来实现层归一化 (layernorm) 和矩阵乘法。

import jax

def F(x1, x2, gamma, beta):
   ln_out = LayerNorm(x1, gamma, beta)
   return ln_out @ x2

如果我们只是在没有 shard_map 的情况下编译此函数,每次运行相同的代码时,layernorm_matmul_without_shard_map 的缓存键都会不同。

layernorm_matmul_without_shard_map = jax.jit(F, in_shardings=(...), out_sharding=(...))(x1, x2, gamma, beta)

然而,如果我们将 layernorm 原语封装在 shard_map 中,并定义一个执行相同计算的函数 G,那么尽管 LayerNorm 实现了 custom_partitioninglayernorm_matmul_with_shard_map 的缓存键每次都将相同。

import jax

def G(x1, x2, gamma, beta, mesh, ispecs, ospecs):
   ln_out = jax.shard_map(LayerNorm, mesh=mesh, in_specs=ispecs, out_specs=ospecs, check_vma=False)(x1, x2, gamma, beta)
   return ln_out @ x2

ispecs = jax.sharding.PartitionSpec(...)
ospecs = jax.sharding.PartitionSpec(...)
mesh = jax.sharding.Mesh(...)
layernorm_matmul_with_shard_map = jax.jit(G, static_argnames=['mesh', 'ispecs', 'ospecs'])(x1, x2, gamma, beta, mesh, ispecs, ospecs)

请注意,实现 custom_partitioning 的原语必须封装在 shard_map 中才能实现此规避方法。仅仅将外部函数 F 封装在 shard_map 中是不够的。