jax.remat / jax.checkpoint 变更:你需要知道什么#
目录#
发生了什么?#
从 #11830 开始,我们正在切换到 jax.checkpoint() 的新实现,也称为 jax.remat()(这两个名称是彼此的别名)。**对于大多数代码,不会有任何改变。** 但在某些边缘情况下可能会观察到一些差异;请参阅 升级后可能出现哪些问题?
如何暂时禁用此变更并恢复到旧行为?#
如果您在此变更后遇到问题,**在 jax==0.3.16 版本之前**,可以通过以下任一方式将 jax_new_checkpoint 配置选项设置为 False 来禁用新实现:
设置 shell 环境变量
JAX_NEW_CHECKPOINT=0;执行
jax.config.update('jax_new_checkpoint', False);如果您使用
absl解析标志,请传递--jax_new_checkpoint=False选项。
如果您需要恢复到旧实现,**请在 GitHub issue 上联系我们**,以便我们能够为您在新实现上工作。
从 jax==0.3.17 版本开始,jax_new_checkpoint 配置选项不再可用。如果您遇到问题,请在 issue tracker 上联系我们,以便我们帮助解决!
我们为什么要这样做?#
在撰写本文时,JAX 有两个并行的 jax.checkpoint 实现。新实现已被 Pax 和 Flaxformer/T5X 等用户选择性地使用了数月。但它并非默认开启。
我们希望将新实现设置为默认开启,然后删除旧实现。使用新实现并移除旧实现将为用户带来多项优势。
用户可自定义的重构策略#
新实现的主要优势在于一个与 policy 参数相对应的新功能。其思想是为自动微分前向传播过程中哪些中间值需要保存(而不是重构)提供精确的用户控制。通过对内存使用与重新计算之间的权衡进行控制,用户可以获得显著的性能提升,尤其是在大型模型和我们的 LLM MLPerf 提交中!
该功能的完整文档即将发布,但这里有一个简短的示例
from functools import partial
import jax
def apply_layer(W, x):
return jnp.sin(jnp.dot(W, x))
@partial(jax.checkpoint, policy=jax.checkpoint_policies.checkpoint_dots)
def predict(params, x):
for W in params[:-1]:
x = apply_layer(W, x)
return jnp.dot(params[-1], x)
通过在此处应用带有 policy=jax.checkpoint_policies.checkpoint_dots 的 jax.checkpoint,我们确保在前向传播过程中只允许保存矩阵乘法的结果。从 cos 应用中计算出的雅可比系数以及计算它们所需的 sin 应用的值不会从前向传播中保存,而是在后向传播中重新计算。(此类策略在 TPU 上可能非常有效,因为逐元素计算几乎是免费的,但保存矩阵单元的结果是有价值的。)
能够重构常量,而不仅仅是依赖于参数的操作#
旧的 jax.checkpoint 实现实际上无法重构没有数据依赖于被装饰函数参数的操作。考虑这个简单的例子
@jax.checkpoint
def f(x):
a = some_function(jnp.arange(10_000_000)) # `a` does not depend on `x`
return a * x
旧的 jax.checkpoint 实现被迫保存 a 的值,这可能需要大量内存。新的 jax.checkpoint 实现可以重构 a 的值,而不是保存它。
在某些情况下显著减少 Python 开销#
新的 jax.checkpoint 在某些情况下会显著减少 Python 开销。简单的开销基准测试显示速度提高了 10 倍。这些开销仅在使用 eager op-by-op 执行时出现,因此在通常的 jax.checkpoint 在 jax.jit 或类似情况下使用时,速度提升并不显著。但仍然很棒!
通过简化内部机制来启用新的 JAX 功能#
此变更还为未来带来了巨大的用户收益,例如自定义批处理规则(vmap 对应于 custom_vjp)以及 custom_vjp 的前向可微分升级。它还显著降低了 JAX 代码库某些部分的复杂性,这总体上有利于可维护性和 bug 修复。
升级后可能出现哪些问题?#
无害的数值变化#
由于新实现可以重构更多的计算,包括可能的大常量,因此某些代码可能会出现细微的数值变化。任何数值变化的幅度都应该在我们期望的编译器优化(如浮点运算重排序)的范围内。但一些过于严格的测试容差可能需要稍微放宽。
已移除 concrete=True 选项。#
旧的 jax.checkpoint 实现有一个布尔型 concrete 选项,它允许在具体 Python 值上进行跟踪(而不是延迟所有计算并仅在抽象值上跟踪)。该选项很少使用,并且在使用的场景中存在更简单的替代方案。因此,我们在新的 jax.checkpoint 中移除了该选项。
例如,在 Google 代码中,concrete=True 最常见的用途是支持传递诸如 is_training 这样的参数
@partial(jax.checkpoint, concrete=True) # OLD jax.checkpoint API
def foo(x, is_training):
if is_training:
return g(x)
else:
return h(x)
使用新的 jax.checkpoint 实现,我们可以使用 static\_argnums 选项来实现相同的功能
@partial(jax.checkpoint, static_argnums=(1,)) # NEW jax.checkpoint API
def foo(x, is_training):
if is_training:
...
如果 jax.numpy 操作需要应用于静态参数,并且其数值结果在 Python 跟踪期间计算而不是延迟计算,我们可以使用 static_argnums 和 jax.ensure_compile_time_eval()。但似乎不太可能需要这样做!