jax.remat / jax.checkpoint 变更:您需要了解的内容#

目录#

发生了什么?#

#11830 起,我们正在启用 jax.checkpoint() 的新实现,也称 jax.remat()(这两个名称是彼此的别名)。对于大多数代码,将不会有任何更改。 但在边缘情况下可能会出现一些可观察到的差异;请参阅 升级后可能出现哪些问题?

如何禁用此更改并暂时恢复旧行为?#

如果您在使用此更改时遇到问题,**在 jax==0.3.16 版本之前**,可以通过以下任何一种方式将 `jax_new_checkpoint` 配置选项设置为 `False` 来关闭新实现:

  1. 设置 shell 环境变量 JAX_NEW_CHECKPOINT=0

  2. 执行 jax.config.update('jax_new_checkpoint', False)

  3. 如果您使用 absl 解析标志,请传递 --jax_new_checkpoint=False 选项。

如果您需要恢复到旧实现,**请在 GitHub issue 上联系我们**,以便我们能让新实现为您工作。

jax==0.3.17 起,jax_new_checkpoint 配置选项不再可用。如果您遇到问题,请在问题跟踪器上联系我们,以便我们能帮助解决!

我们为什么要这样做?#

在撰写本文时,JAX 有两个并行的 jax.checkpoint 实现。新的实现已经以选择加入(opt-in)的方式使用了数月(例如,由 Pax 和 Flaxformer/T5X 使用)。但它还没有默认启用。

我们希望将新实现设为默认启用,然后删除旧实现。使用新实现并移除旧实现,将为用户带来多项好处。

用户可定制的重物化策略#

新实现的主要优点是一个与 policy 参数对应的新功能。其思想是让用户精确控制在自动微分的前向传播过程中哪些中间值被保存(而不是被重物化)。通过对内存使用与重新计算之间的权衡进行控制,用户可以获得显著的性能提升,尤其是在大型模型和我们的 LLM MLPerf 提交中!

此功能的完整文档仍在准备中,但这里有一个快速示例:

from functools import partial
import jax

def apply_layer(W, x):
  return jnp.sin(jnp.dot(W, x))

@partial(jax.checkpoint, policy=jax.checkpoint_policies.checkpoint_dots)
def predict(params, x):
  for W in params[:-1]:
    x = apply_layer(W, x)
  return jnp.dot(params[-1], x)

在这里应用带有 policy=jax.checkpoint_policies.checkpoint_dotsjax.checkpoint,我们确保在前向传播过程中只允许保存矩阵乘法的结果。cos 应用的雅可比系数(Jacobian coefficient)值,以及计算它们所需的 sin 应用值,都不会从前向传播中保存,而是在反向传播期间重新计算。(像这样的策略在 TPU 上可能很有效,因为逐元素计算实际上是免费的,但矩阵单元的结果值得保存。)

能够重物化常量,而不仅仅是依赖于参数数据的操作#

旧的 jax.checkpoint 实现实际上无法重物化那些与被装饰函数的参数没有数据依赖关系的计算。考虑这个简单的例子:

@jax.checkpoint
def f(x):
  a = some_function(jnp.arange(10_000_000))  # `a` does not depend on `x`
  return a * x

旧的 jax.checkpoint 实现被迫保存 a 的值,这可能需要大量内存。新的 jax.checkpoint 实现可以重物化 a 的值,而不是保存它。

在某些情况下显著减少 Python 开销#

新的 jax.checkpoint 在某些情况下会显著减少 Python 开销。简单的开销基准测试速度提高了 10 倍。这些开销只出现在即时(eager)逐操作执行中,所以在 jax.jit 或类似机制下使用 jax.checkpoint 的常见情况下,这些速度提升并不相关。但仍然很棒!

通过简化内部机制启用新的 JAX 功能#

这一更改也为未来用户带来了巨大的好处,例如自定义批处理规则(vmapcustom_vjp 模拟)以及 custom_vjp 的前向可微分升级。它还显著降低了 JAX 代码库部分内容的复杂性,这将有利于整体的可维护性和错误修复。

升级后可能出现哪些问题?#

无害的数值变化#

由于新实现可以重物化更多的计算,包括那些可能较大的常量,因此某些代码可能会出现小的数值变化。任何数值变化的幅度应在我们预期编译器优化更改(例如浮点运算的重新排序)所产生的范围内。但一些过于严格的测试容差可能需要稍微放宽。

`concrete=True` 选项已被移除。#

旧的 jax.checkpoint 实现有一个布尔型的 `concrete` 选项,它允许对具体的 Python 值进行跟踪(而不是延迟所有计算并仅对抽象值进行跟踪)。该选项很少使用,并且在它被使用的情况下,有更简单的替代方案。因此,我们在新的 jax.checkpoint 中移除了该选项。

例如,在 Google 代码中,concrete=True 的主要常见用途是支持传递像 is_training 这样的参数:

@partial(jax.checkpoint, concrete=True)  # OLD jax.checkpoint API
def foo(x, is_training):
  if is_training:
    return g(x)
  else:
    return h(x)

使用新的 jax.checkpoint 实现,我们可以通过 static\_argnums 选项实现相同的功能:

@partial(jax.checkpoint, static_argnums=(1,))  # NEW jax.checkpoint API
def foo(x, is_training):
  if is_training:
    ...

如果 jax.numpy 操作需要对静态参数执行,并且它们的数值结果在 Python 跟踪期间而不是延迟计算,我们可以将 static_argnumsjax.ensure_compile_time_eval() 一起使用。但这似乎不太可能需要!