持久编译缓存#
JAX 为已编译程序提供可选的磁盘缓存。如果启用,JAX 会将已编译程序的副本存储在磁盘上,这可以在重复运行相同或相似任务时节省重新编译时间。
注意:如果编译缓存不在本地文件系统上,则需要安装 etils。
pip install etils
用法#
快速开始#
import jax
import jax.numpy as jnp
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update("jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir")
@jax.jit
def f(x):
return x + 1
x = jnp.zeros((2, 2))
f(x)
设置缓存目录#
当设置了缓存位置时,启用编译缓存。这应该在第一次编译之前完成。按如下方式设置位置
(1) 使用环境变量
在 shell 中,在运行脚本之前
export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache"
或在 Python 脚本的顶部
import os
os.environ["JAX_COMPILATION_CACHE_DIR"] = "/tmp/jax_cache"
(2) 使用 jax.config.update()
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
(3) 使用 set_cache_dir()
from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir("/tmp/jax_cache")
缓存阈值#
jax_persistent_cache_min_compile_time_secs
:只有当编译时间长于指定值时,才会将计算写入持久缓存。默认值为 1.0 秒。jax_persistent_cache_min_entry_size_bytes
:持久编译缓存中缓存的条目的最小大小(以字节为单位)。-1
:禁用大小限制并阻止覆盖。保持默认值 (
0
) 以允许覆盖。覆盖通常会确保最小大小对于用于缓存的文件系统是最佳的。> 0
:所需的实际最小大小;不进行覆盖。
请注意,要缓存函数,需要满足这两个条件。
额外的缓存#
XLA 支持额外的缓存机制,可以与 JAX 的持久编译缓存一起启用,以进一步缩短重新编译时间。
jax_persistent_cache_enable_xla_caches
:可能的值all
:启用所有 XLA 缓存功能none
:不启用任何额外的 XLA 缓存功能xla_gpu_kernel_cache_file
:仅启用内核缓存xla_gpu_per_fusion_autotune_cache_dir
:(默认值)仅启用自动调优缓存
Google 云#
在 Google 云上运行时,可以将编译缓存放置在 Google Cloud Storage (GCS) 存储桶中。我们建议以下配置
在与工作负载运行区域相同的区域中创建存储桶。
在与工作负载 VM 所在的同一项目中创建存储桶。确保设置权限,以便 VM 可以写入存储桶。
较小的工作负载不需要复制。较大的工作负载可以从复制中受益。
对存储桶的默认存储类别使用“标准”。
将软删除策略设置为最短:7 天。
将对象生命周期设置为工作负载运行的预期持续时间。例如,如果预期工作负载运行 10 天,则将对象生命周期设置为 10 天。这应该涵盖整个运行期间发生的重启。对生命周期条件使用
age
,对操作使用Delete
。有关详细信息,请参阅对象生命周期管理。如果未设置对象生命周期,则缓存将继续增长,因为没有实现驱逐机制。支持所有加密策略。
假设 gs://jax-cache
是 GCS 存储桶,则按如下方式设置缓存位置
jax.config.update("jax_compilation_cache_dir", "gs://jax-cache")
工作原理#
缓存键是已编译函数的签名,其中包含以下参数
由被哈希的 JAX 函数的非优化 HLO 捕获的函数执行的计算
jaxlib 版本
相关的 XLA 编译标志
设备配置通常通过设备的数量和设备的拓扑结构来捕获。目前,对于 GPU,拓扑结构仅包含 GPU 名称的字符串表示形式
用于压缩已编译可执行文件的压缩算法
由
jax._src.cache_key.custom_hook()
生成的字符串。可以将此函数重新分配给用户定义的函数,以便可以更改生成的字符串。默认情况下,此函数始终返回一个空字符串。
在多个节点上缓存#
首次运行程序时(持久缓存为冷/空),所有进程都将进行编译,但只有全局通信组中等级为 0 的进程才会写入持久缓存。在后续运行中,所有进程都将尝试从持久缓存中读取,因此持久缓存必须位于共享文件系统(例如:NFS)或远程存储(例如:GFS)中,这一点非常重要。如果持久缓存是等级 0 的本地缓存,则由于编译缓存未命中,除等级 0 之外的所有进程将在后续运行中再次编译。
记录缓存活动#
检查持久编译缓存的具体工作方式对于调试很有帮助。以下是一些关于如何开始的建议。
用户可以通过放置以下代码来启用相关源文件的日志记录
import os
os.environ["JAX_DEBUG_LOG_MODULES"] = "jax._src.compiler,jax._src.lru_cache"
在脚本顶部。或者,您可以使用以下命令更改全局 jax 日志级别
import os
os.environ["JAX_LOGGING_LEVEL"] = "DEBUG"
# or locally with
jax.config.update("jax_logging_level", "DEBUG")
检查缓存未命中#
为了检查和了解为什么会出现缓存未命中,JAX 包含一个配置标志,该标志可以启用对所有缓存未命中(包括持久编译缓存未命中)及其解释的日志记录。尽管目前这仅针对跟踪缓存未命中实现,但最终目标是解释所有缓存未命中。可以通过设置以下配置来启用此功能。
jax.config.update("jax_explain_cache_misses", True)
陷阱#
目前已经发现了一些陷阱
目前,持久缓存不适用于具有主机回调的函数。在这种情况下,完全避免缓存。
这是因为 HLO 包含一个指向回调的指针,即使计算和计算基础结构完全相同,指针也会在运行之间发生变化。
目前,持久缓存不适用于使用实现自己的 custom_partitioning 的原语的函数。
该函数的 HLO 包含一个指向 custom_partitioning 回调的指针,并且会导致同一计算在不同运行中产生不同的缓存键。
在这种情况下,缓存仍然会继续,但每次都会生成不同的键,从而使缓存失效。
绕过 custom_partitioning
#
如前所述,编译缓存不适用于由实现 custom_partitioning
的原语组成的函数。但是,可以对那些实现 custom_partitioning
的原语使用 shard_map 来规避 custom_partitioning
,并使编译缓存按预期工作
假设我们有一个函数 F
,它使用实现 custom_partitioning
的原语 LayerNorm
实现层归一化,然后进行矩阵乘法
import jax
def F(x1, x2, gamma, beta):
ln_out = LayerNorm(x1, gamma, beta)
return ln_out @ x2
如果我们只是编译此函数而不使用 shard_map,则 layernorm_matmul_without_shard_map
的缓存键每次运行相同的代码时都会不同
layernorm_matmul_without_shard_map = jax.jit(F, in_shardings=(...), out_sharding=(...))(x1, x2, gamma, beta)
但是,如果我们将层归一化原语包装在 shard_map 中,并定义一个执行相同计算的函数 G,则尽管 LayerNorm
实现了 custom_partitioning
,layernorm_matmul_with_shard_map
的缓存键每次都将相同
import jax
from jax.experimental.shard_map import shard_map
def G(x1, x2, gamma, beta, mesh, ispecs, ospecs):
ln_out = shard_map(LayerNorm, mesh, in_specs=ispecs, out_specs=ospecs, check_rep=False)(x1, x2, gamma, beta)
return ln_out @ x2
ispecs = jax.sharding.PartitionSpec(...)
ospecs = jax.sharding.PartitionSpec(...)
mesh = jax.sharding.Mesh(...)
layernorm_matmul_with_shard_map = jax.jit(G, static_argnames=['mesh', 'ispecs', 'ospecs'])(x1, x2, gamma, beta, mesh, ispecs, ospecs)
请注意,必须将实现 custom_partitioning
的原语包装在 shard_map 中才能进行此操作。将外部函数 F
包装在 shard_map 中是不够的。