jax.checkpoint#

jax.checkpoint(fun, *, prevent_cse=True, policy=None, static_argnums=())[源代码]#

使 fun 在微分时重新计算内部线性化点。

jax.checkpoint() 装饰器(别名为 jax.remat())提供了一种在自动微分(特别是像 jax.grad()jax.vjp() 这样的反向模式自动微分,以及 jax.linearize())上下文中权衡计算时间和内存成本的方法。

在反向模式下对函数进行微分时,默认情况下,所有线性化点(例如,逐元素非线性原语操作的输入)在正向传递评估时都会被存储起来,以便在反向传递时重复使用。这种评估策略可能导致高内存成本,甚至在内存访问比浮点运算(FLOPs)昂贵得多的硬件加速器上导致性能不佳。

另一种评估策略是重新计算(即重新具体化)而不是存储一些线性化点。这种方法可以减少内存使用,但代价是增加了计算量。

此函数装饰器会生成 fun 的新版本,该版本遵循重新具体化策略而非默认的存储所有内容策略。也就是说,它返回 fun 的一个新版本,该版本在微分时不会存储任何中间线性化点。相反,这些线性化点会从函数的保存输入中重新计算。

请参阅下面的示例。

参数:
  • fun (Callable) – 要将其自动微分评估策略从默认的存储所有中间线性化点更改为重新计算的函数。其参数和返回值应为数组、标量,或(嵌套的)标准 Python 容器(元组/列表/字典)及其组合。

  • prevent_cse (bool) – 可选的布尔型仅限关键字参数,指示是否阻止微分生成的 HLO 中的公共子表达式消除(CSE)优化。阻止 CSE 会带来成本,因为它可能阻碍其他优化,并且在某些后端(特别是 GPU)上可能产生高开销。默认值为 True,因为否则在 jit()pmap() 下,CSE 可能会破坏此装饰器的目的。但在某些情况下,例如在 scan() 内部使用时,这种 CSE 预防机制是不必要的,在这种情况下,可以将 prevent_cse 设置为 False。

  • static_argnums (int | tuple[int, ...]) – 可选的整数或整数序列,一个仅限关键字的参数,指示要对哪些参数值进行专门化以用于追踪和缓存。将参数指定为静态可以避免追踪时的 ConcretizationTypeErrors,但代价是更多的重新追踪开销。请参阅下面的示例。

  • policy (Callable[..., bool] | None) – 可选的、可调用的仅限关键字参数。它应该是 jax.checkpoint_policies 的属性之一。此可调用对象以一阶原始应用程序的类型级规范作为输入,并返回一个布尔值,指示相应的输出值是否可以作为残差保存(或者如果需要,是否必须在(余)切线计算中重新计算)。

返回:

一个与 fun 具有相同输入/输出行为的函数(可调用对象),但当使用例如 jax.grad()jax.vjp()jax.linearize() 进行微分时,它会重新计算而不是存储中间线性化点,从而可能以额外的计算为代价节省内存。

返回类型:

可调用对象

这是一个简单的示例

>>> import jax
>>> import jax.numpy as jnp
>>> @jax.checkpoint
... def g(x):
...   y = jnp.sin(x)
...   z = jnp.sin(y)
...   return z
...
>>> jax.value_and_grad(g)(2.0)
(Array(0.78907233, dtype=float32, weak_type=True), Array(-0.2556391, dtype=float32, weak_type=True))

无论是否存在 jax.checkpoint() 装饰器,这里都会产生相同的值。当装饰器不存在时,jnp.cos(2.0)jnp.cos(jnp.sin(2.0)) 的值在正向传递中计算并存储以供反向传递使用,因为它们在反向传递中需要并且仅依赖于原始输入。当使用 jax.checkpoint() 时,正向传递将只计算原始输出,并且只有原始输入(2.0)会被存储以供反向传递使用。那时,jnp.sin(2.0) 的值会被重新计算,同时也会重新计算 jnp.cos(2.0)jnp.cos(jnp.sin(2.0)) 的值。

虽然 jax.checkpoint() 控制从正向传递存储哪些值以供反向传递使用,但评估函数或其 VJP 所需的总内存量取决于该函数的许多额外内部细节。这些细节包括使用了哪些数值原语、它们如何组合、在哪里使用了 jit 和 scan 等控制流原语,以及其他因素。

jax.checkpoint() 装饰器可以递归应用,以表达复杂的自动微分重新具体化策略。例如

>>> def recursive_checkpoint(funs):
...   if len(funs) == 1:
...     return funs[0]
...   elif len(funs) == 2:
...     f1, f2 = funs
...     return lambda x: f1(f2(x))
...   else:
...     f1 = recursive_checkpoint(funs[:len(funs)//2])
...     f2 = recursive_checkpoint(funs[len(funs)//2:])
...     return lambda x: f1(jax.checkpoint(f2)(x))
...

如果 fun 涉及依赖参数值的 Python 控制流,则可能需要使用 static_argnums 参数。例如,考虑一个布尔标志参数

from functools import partial

@partial(jax.checkpoint, static_argnums=(1,))
def foo(x, is_training):
  if is_training:
    ...
  else:
    ...

这里,static_argnums 的使用允许 if 语句的条件取决于 is_training 的值。使用 static_argnums 的代价是它会在多次调用中引入重新追踪开销:在示例中,每次 foo 使用 is_training 的新值调用时都会被重新追踪。在某些情况下,也需要 jax.ensure_compile_time_eval

@partial(jax.checkpoint, static_argnums=(1,))
def foo(x, y):
  with jax.ensure_compile_time_eval():
    y_pos = y > 0
  if y_pos:
    ...
  else:
    ...

作为使用 static_argnums(和 jax.ensure_compile_time_eval)的替代方案,在 jax.checkpoint() 装饰的函数外部计算某些值,然后对其进行闭包可能更容易。