持久化编译缓存#
JAX 拥有一个可选的已编译程序的磁盘缓存。启用后,JAX 会将已编译程序的副本存储在磁盘上,这可以节省重复运行时相同的或相似任务的重新编译时间。
注意:如果编译缓存不在本地文件系统上,则需要安装 etils。
pip install etils
用法#
快速入门#
import jax
import jax.numpy as jnp
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update("jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir")
@jax.jit
def f(x):
return x + 1
x = jnp.zeros((2, 2))
f(x)
设置缓存目录#
当设置了 缓存位置 时,编译缓存即被启用。这应该在第一次编译之前完成。按以下方式设置位置:
(1) 使用环境变量
在 shell 中,运行脚本之前
export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache"
或在 Python 脚本的顶部
import os
os.environ["JAX_COMPILATION_CACHE_DIR"] = "/tmp/jax_cache"
(2) 使用 jax.config.update()
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
(3) 使用 set_cache_dir()
from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir("/tmp/jax_cache")
缓存阈值#
jax_persistent_cache_min_compile_time_secs: 只有当编译时间超过指定值时,计算才会被写入持久化缓存。默认值为 1.0 秒。jax_persistent_cache_min_entry_size_bytes: 将在持久化编译缓存中缓存的条目的最小尺寸(以字节为单位)。-1: 禁用尺寸限制并防止覆盖。将其保留为默认值(
0)以允许覆盖。覆盖通常会确保最小尺寸对于用于缓存的文件系统来说是最优的。> 0: 期望的实际最小尺寸;无覆盖。
请注意,函数必须满足这两个条件才会被缓存。
其他缓存#
XLA 支持额外的缓存机制,可以与 JAX 的持久化编译缓存一起启用,以进一步缩短重新编译时间。
jax_persistent_cache_enable_xla_caches: 可能的值all: 启用所有 XLA 缓存功能none: 不启用任何额外的 XLA 缓存功能xla_gpu_kernel_cache_file: 仅启用内核缓存xla_gpu_per_fusion_autotune_cache_dir: (默认值)仅启用自动调优缓存
Google Cloud#
在 Google Cloud 上运行时,可以将编译缓存放置在 Google Cloud Storage (GCS) 存储桶中。我们推荐以下配置:
在与工作负载运行的区域相同的区域创建存储桶。
在与工作负载的 VM 相同的项目中创建存储桶。确保设置了权限,以便 VM 可以写入存储桶。
对于较小的工作负载,无需复制。较大的工作负载可以从复制中受益。
将存储桶的默认存储类别设置为“Standard”。
将软删除策略设置为最短:7 天。
将对象生命周期设置为工作负载运行的预期持续时间。例如,如果工作负载预计运行 10 天,则将对象生命周期设置为 10 天。这应该可以涵盖整个运行期间发生的重启。使用
age作为生命周期条件,使用Delete作为操作。有关详细信息,请参阅 对象生命周期管理。如果未设置对象生命周期,缓存将继续增长,因为没有实现逐出机制。所有加密策略均受支持。
假设 gs://jax-cache 是 GCS 存储桶,则按以下方式设置缓存位置:
jax.config.update("jax_compilation_cache_dir", "gs://jax-cache")
工作原理#
缓存键是已编译函数的签名,包含以下参数:
函数执行的计算,通过待哈希的 JAX 函数的非优化 HLO 捕获。
jaxlib 版本。
相关的 XLA 编译标志。
设备配置,通常通过设备数量和设备拓扑捕获。目前对于 GPU,拓扑仅包含 GPU 名称的字符串表示。
用于压缩已编译可执行文件的压缩算法。
由
jax._src.cache_key.custom_hook()生成的字符串。此函数可以重新分配给用户定义的函数,以便修改生成的字符串。默认情况下,此函数始终返回一个空字符串。
多节点缓存#
第一次运行程序时(持久化缓存是冷的/空的),所有进程都将编译,但只有全局通信组中的 rank 0 进程才会写入持久化缓存。在后续运行中,所有进程都将尝试从持久化缓存读取,因此持久化缓存必须位于共享文件系统(例如 NFS)或远程存储(例如 GFS)上。如果持久化缓存是 rank 0 本地缓存,那么在后续运行中,除 rank 0 之外的所有进程都将因缓存未命中而再次编译。
在单节点上预编译多节点程序#
JAX 可以使用单节点上的已编译程序填充多节点编译缓存。在单节点上准备缓存有助于减少集群上的昂贵编译时间。要在单节点上编译和运行多节点程序,用户可以使用 jax_mock_gpu_topology 配置选项创建假的远程设备。
例如,下面的代码片段指示 JAX 模拟一个拥有四个节点(每个节点运行八个进程,每个进程连接到一个 GPU)的集群。
jax.config.update("jax_mock_gpu_topology", "4x8x1")
使用此配置填充缓存后,用户可以在四个节点、每个节点八个进程、每个进程一个 GPU 的情况下运行程序而无需重新编译。
重要提示
运行模拟程序的进程必须具有与将使用缓存的节点相同的 GPU 数量和相同的 GPU 型号。例如,模拟拓扑
8x4x2必须在具有两个 GPU 的进程中运行。在运行带有模拟拓扑的程序时,与其他节点的通信结果是未定义的,因此在模拟环境中运行的 JAX 程序的输出很可能是不正确的。
记录缓存活动#
检查持久化编译缓存的确切活动对于调试可能很有帮助。以下是一些入门建议:
用户可以通过在脚本顶部添加以下内容来启用相关源文件的日志记录:
import os
os.environ["JAX_DEBUG_LOG_MODULES"] = "jax._src.compiler,jax._src.lru_cache"
或者,您可以通过以下方式更改全局 jax 日志级别:
import os
os.environ["JAX_LOGGING_LEVEL"] = "DEBUG"
# or locally with
jax.config.update("jax_logging_level", "DEBUG")
检查缓存未命中#
为了检查和理解为什么会发生缓存未命中,JAX 提供了一个配置标志,该标志可以启用所有缓存未命中(包括持久化编译缓存未命中)的日志记录及其解释。虽然目前这仅适用于跟踪缓存未命中,但最终目标是解释所有缓存未命中。可以通过设置以下配置来启用此功能:
jax.config.update("jax_explain_cache_misses", True)
陷阱#
目前已发现几个陷阱:
目前,持久化缓存不适用于包含主机回调的函数。在这种情况下,会完全避免缓存。
这是因为 HLO 包含指向回调的指针,并且每次运行都会发生变化,即使计算和计算基础设施完全相同。
目前,持久化缓存不适用于使用实现自身
custom_partitioning的原语的函数。函数 HLO 包含指向
custom_partitioning回调的指针,导致相同计算在不同运行中产生不同的缓存键。在这种情况下,缓存仍然会进行,但每次都会生成不同的键,使缓存无效。
规避 custom_partitioning#
如前所述,编译缓存不适用于由实现 custom_partitioning 的原语组成的函数。但是,可以使用 shard_map 来规避那些实现 custom_partitioning 的原语,并使编译缓存按预期工作。
假设我们有一个函数 F,它使用实现 custom_partitioning 的原语 LayerNorm 来执行层归一化后跟矩阵乘法。
import jax
def F(x1, x2, gamma, beta):
ln_out = LayerNorm(x1, gamma, beta)
return ln_out @ x2
如果我们只是在没有 shard_map 的情况下编译此函数,则每次运行相同代码时,layernorm_matmul_without_shard_map 的缓存键都会不同。
layernorm_matmul_without_shard_map = jax.jit(F, in_shardings=(...), out_sharding=(...))(x1, x2, gamma, beta)
但是,如果我们使用 shard_map 包装层归一化原语并定义执行相同计算的函数 G,那么即使 LayerNorm 实现 custom_partitioning,layernorm_matmul_with_shard_map 的缓存键每次也会相同。
import jax
def G(x1, x2, gamma, beta, mesh, ispecs, ospecs):
ln_out = jax.shard_map(LayerNorm, mesh=mesh, in_specs=ispecs, out_specs=ospecs, check_vma=False)(x1, x2, gamma, beta)
return ln_out @ x2
ispecs = jax.sharding.PartitionSpec(...)
ospecs = jax.sharding.PartitionSpec(...)
mesh = jax.sharding.Mesh(...)
layernorm_matmul_with_shard_map = jax.jit(G, static_argnames=['mesh', 'ispecs', 'ospecs'])(x1, x2, gamma, beta, mesh, ispecs, ospecs)
请注意,实现 custom_partitioning 的原语必须用 shard_map 包装才能实现此规避。仅仅用 shard_map 包装外部函数 F 是不够的。