jax.remat
/ jax.checkpoint
更改:您需要了解的内容#
目录#
发生了什么?#
截至 #11830,我们正在切换 jax.checkpoint()
(又名 jax.remat()
,这两个名称是彼此的别名)的新实现。 对于大多数代码,不会有任何更改。 但在边缘情况下可能会有一些可观察到的差异;请参阅升级后可能出现哪些问题?
如何禁用此更改,并暂时恢复旧的行为?#
如果您对此更改有疑问,在 jax==0.3.16
版本中,可以通过将 jax_new_checkpoint
配置选项设置为 False 来关闭新实现,可以通过以下任何一种方式
设置 shell 环境变量
JAX_NEW_CHECKPOINT=0
;执行
jax.config.update('jax_new_checkpoint', False)
;如果您使用
absl
解析标志,请传递--jax_new_checkpoint=False
选项。
如果您需要恢复到旧的实现,请在 GitHub 问题上联系我们,以便我们可以让新的实现为您工作。
截至 jax==0.3.17
,jax_new_checkpoint
配置选项不再可用。 如果您有问题,请在 问题跟踪器 上联系我们,以便我们帮助修复它!
我们为什么要这样做?#
在撰写本文时,JAX 有两个并行的 jax.checkpoint
实现。 新的实现已经使用了几个月(例如,Pax 和 Flaxformer/T5X),并且是可选的。 但它尚未默认启用。
我们希望将新的实现切换为默认启用,然后删除旧的实现。 使用新的实现并删除旧的实现,可以为用户带来以下好处。
用户可自定义的重物化策略#
新实现的主要优势是对应于 policy
参数的新功能。 这个想法是让用户精确控制在自动微分的前向传递期间保存(与重物化)哪些中间变量。 通过控制内存使用与重新计算之间的权衡,用户可以获得显着的性能提升,尤其是在大型模型和我们的 LLM MLPerf 提交中!
此功能的完整文档仍在制定中,但这是一个快速示例
from functools import partial
import jax
def apply_layer(W, x):
return jnp.sin(jnp.dot(W, x))
@partial(jax.checkpoint, policy=jax.checkpoint_policies.checkpoint_dots)
def predict(params, x):
for W in params[:-1]:
x = apply_layer(W, x)
return jnp.dot(params[-1], x)
通过在此处应用带有 policy=jax.checkpoint_policies.checkpoint_dots
的 jax.checkpoint
,我们确保在前向传递期间仅允许保存矩阵乘法的结果。 来自 cos
应用的 Jacobian 系数值,以及计算它们所需的 sin
应用的值,不会从前向传递中保存,而是在后向传递期间重新计算。 (像这样的策略在 TPU 上可能有效,在 TPU 上,逐元素计算实际上是免费的,但矩阵单元的结果值得保存。)
能够重物化常量,而不仅仅是具有数据依赖于参数的操作#
旧的 jax.checkpoint
实现实际上无法重物化不依赖于修饰函数参数的数据的计算。 考虑这个玩具示例
@jax.checkpoint
def f(x):
a = some_function(jnp.arange(10_000_000)) # `a` does not depend on `x`
return a * x
旧的 jax.checkpoint
实现被迫保存 a
的值,这可能需要大量内存。 新的 jax.checkpoint
实现可以重物化而不是保存 a
的值。
在某些情况下,Python 开销显着降低#
在某些情况下,新的 jax.checkpoint
产生的 Python 开销显着降低。 简单的开销基准测试 速度提高了 10 倍。 这些开销仅在 eager op-by-op 执行中出现,因此在使用 jax.jit
或类似的 jax.checkpoint
的常见情况下,速度提升并不相关。 但仍然,不错!
通过简化内部结构启用新的 JAX 功能#
此更改还解锁了未来更大的用户利益,例如自定义批处理规则(vmap
类似于 custom_vjp
)和 custom_vjp
的前向可微分升级。 它还显着降低了 JAX 代码库某些部分的复杂性,这通常有利于可维护性和错误修复。
升级后可能出现哪些问题?#
无害的数值更改#
由于新的实现可以重物化更多计算,包括可能很大的常量的计算,因此某些代码可能会看到小的数值更改。 任何数值更改的幅度都应在我们期望的编译器优化更改范围内,例如浮点运算的重新排序。 但是一些过于严格的测试容差可能需要稍微放宽。
concrete=True
选项已删除。#
旧的 jax.checkpoint
实现有一个布尔值 concrete
选项,该选项允许在具体的 Python 值上进行跟踪(而不是延迟所有计算并仅在抽象值上进行跟踪)。 该选项很少使用,并且在使用该选项的情况下,有更简单的替代方案。 因此,我们在新的 jax.checkpoint
中删除了该选项。
例如,Google 代码中 concrete=True
最常见的用途是支持传递像 is_training
这样的参数
@partial(jax.checkpoint, concrete=True) # OLD jax.checkpoint API
def foo(x, is_training):
if is_training:
return g(x)
else:
return h(x)
使用新的 jax.checkpoint
实现,我们可以使用 static\_argnums
选项完成相同的操作
@partial(jax.checkpoint, static_argnums=(1,)) # NEW jax.checkpoint API
def foo(x, is_training):
if is_training:
...
如果需要在静态参数上执行 jax.numpy
操作,并在 Python 跟踪期间而不是延迟计算其数值结果,我们可以将 static_argnums
与 jax.ensure_compile_time_eval()
一起使用。 但您似乎不太可能需要这个!