使用 jax.checkpoint (jax.remat) 进行梯度检查点#

在本教程中,您将学习如何使用 jax.checkpoint()(也称为 jax.remat())来控制 JAX 自动微分的已保存值,这在机器学习中特别有用。

如果您是自动微分(autodiff)的新手,或者需要复习一下,JAX 提供了自动微分高级自动微分教程。

要点总结:使用 jax.checkpoint() 装饰器(别名为 jax.remat())和 jax.grad() 来控制在前向传播中保存哪些中间值,以及在反向传播中重新计算哪些中间值,从而在内存和 FLOPs 之间进行权衡。

如果您不使用 jax.checkpoint(),则 jax.grad(f)(x) 的前向传播会存储雅可比系数和其他中间值,以在反向传播期间使用。这些保存的值称为残差

注意:不要错过实用提示,其中讨论了 jax.checkpoint() 如何与 jax.jit() 交互。

import jax
import jax.numpy as jnp

def g(W, x):
  y = jnp.dot(W, x)
  return jnp.sin(y)

def f(W1, W2, W3, x):
  x = g(W1, x)
  x = g(W2, x)
  x = g(W3, x)
  return x

W1 = jnp.ones((5, 4))
W2 = jnp.ones((6, 5))
W3 = jnp.ones((7, 6))
x = jnp.ones(4)

# Inspect the 'residual' values to be saved on the forward pass
# if you were to evaluate `jax.grad(f)(W1, W2, W3, x)`
from jax.ad_checkpoint import print_saved_residuals
jax.ad_checkpoint.print_saved_residuals(f, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of sin from /tmp/ipykernel_1011/1801108376.py:6 (g)
f32[5] output of cos from /tmp/ipykernel_1011/1801108376.py:6 (g)
f32[6] output of sin from /tmp/ipykernel_1011/1801108376.py:6 (g)
f32[6] output of cos from /tmp/ipykernel_1011/1801108376.py:6 (g)
f32[7] output of cos from /tmp/ipykernel_1011/1801108376.py:6 (g)

通过将 jax.checkpoint() 应用于子函数,作为装饰器或在特定应用位置,您可以强制 JAX 不保存该子函数的任何残差。相反,只可能保存 jax.checkpoint() 修饰函数的输入,并且反向传播中使用的任何残差都根据需要从这些输入重新计算。

def f2(W1, W2, W3, x):
  x = jax.checkpoint(g)(W1, x)
  x = jax.checkpoint(g)(W2, x)
  x = jax.checkpoint(g)(W3, x)
  return x

jax.ad_checkpoint.print_saved_residuals(f2, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of sin from /tmp/ipykernel_1011/1801108376.py:6 (g)
f32[6] output of sin from /tmp/ipykernel_1011/1801108376.py:6 (g)

在这里,两个 sin 应用的值被保存,因为它们是 jax.checkpoint() 修饰的 g 函数的后续应用中的参数,并且可能会保存 jax.checkpoint() 修饰的函数的输入。但是不会保存 cos 应用的值。

为了控制哪些值是可保存的,而无需编辑要微分的函数的定义,您可以使用重物化策略。这是一个示例,它仅保存没有批次维度的 dot 操作的结果(因为它们通常受 FLOP 限制,因此值得保存而不是重新计算)

f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
jax.ad_checkpoint.print_saved_residuals(f3, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of reduce_precision from /tmp/ipykernel_1011/1801108376.py:5 (g)
f32[6] output of reduce_precision from /tmp/ipykernel_1011/1801108376.py:5 (g)
f32[7] output of reduce_precision from /tmp/ipykernel_1011/1801108376.py:5 (g)

您还可以使用策略来引用您使用 jax.ad_checkpoint.checkpoint_name() 命名的中间值。

from jax.ad_checkpoint import checkpoint_name

def f4(W1, W2, W3, x):
  x = checkpoint_name(g(W1, x), name='a')
  x = checkpoint_name(g(W2, x), name='b')
  x = checkpoint_name(g(W3, x), name='c')
  return x

f4 = jax.checkpoint(f4, policy=jax.checkpoint_policies.save_only_these_names('a'))
jax.ad_checkpoint.print_saved_residuals(f4, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of reduce_precision from /tmp/ipykernel_1011/2296542172.py:4 (f4)

在尝试这些玩具示例时,您可以使用此笔记本中定义的自定义 print_fwd_bwd 实用程序更仔细地查看发生了什么。

from jax.tree_util import tree_flatten, tree_unflatten

from rich.console import Console
from rich.table import Table
import rich.text

def print_fwd_bwd(f, *args, **kwargs) -> None:
  args, in_tree = tree_flatten((args, kwargs))

  def f_(*args):
    args, kwargs = tree_unflatten(in_tree, args)
    return f(*args, **kwargs)

  fwd = jax.make_jaxpr(lambda *args: jax.vjp(f_, *args))(*args).jaxpr

  y, f_vjp = jax.vjp(f_, *args)
  res, in_tree = tree_flatten(f_vjp)

  def g_(*args):
    *res, y = args
    f_vjp = tree_unflatten(in_tree, res)
    return f_vjp(y)

  bwd = jax.make_jaxpr(g_)(*res, y).jaxpr

  table = Table(show_header=False, show_lines=True, padding=(1, 2, 0, 2), box=None)
  table.add_row("[bold green]forward computation:",
                "[bold green]backward computation:")
  table.add_row(rich.text.Text.from_ansi(str(fwd)),
                rich.text.Text.from_ansi(str(bwd)))
  console = Console(width=240, force_jupyter=True)
  console.print(table)

def _renderable_repr(self):
  return self.html
rich.jupyter.JupyterRenderable._repr_html_ = _renderable_repr
# Without using `jax.checkpoint`:
print_fwd_bwd(f, W1, W2, W3, x)
                                                                                                                                                         
  forward computation:                                         backward computation:                                                                     
                                                                                                                                                         
  { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let    { lambda ; a:f32[4] b:f32[5,4] c:f32[5] d:f32[5] e:f32[6,5] f:f32[6] g:f32[6] h:f32[7,6]  
      e:f32[5] = dot_general[                                      i:f32[7] j:f32[7]. let                                                                
        dimension_numbers=(([1], [0]), ([], []))                   k:f32[7] = mul j i                                                                    
        preferred_element_type=float32                             l:f32[6] = dot_general[                                                               
      ] a d                                                          dimension_numbers=(([0], [0]), ([], []))                                            
      f:f32[5] = sin e                                               preferred_element_type=float32                                                      
      g:f32[5] = cos e                                             ] k h                                                                                 
      h:f32[6] = dot_general[                                      m:f32[7,6] = dot_general[                                                             
        dimension_numbers=(([1], [0]), ([], []))                     dimension_numbers=(([], []), ([], []))                                              
        preferred_element_type=float32                               preferred_element_type=float32                                                      
      ] b f                                                        ] k g                                                                                 
      i:f32[6] = sin h                                             n:f32[6] = mul l f                                                                    
      j:f32[6] = cos h                                             o:f32[5] = dot_general[                                                               
      k:f32[7] = dot_general[                                        dimension_numbers=(([0], [0]), ([], []))                                            
        dimension_numbers=(([1], [0]), ([], []))                     preferred_element_type=float32                                                      
        preferred_element_type=float32                             ] n e                                                                                 
      ] c i                                                        p:f32[6,5] = dot_general[                                                             
      l:f32[7] = sin k                                               dimension_numbers=(([], []), ([], []))                                              
      m:f32[7] = cos k                                               preferred_element_type=float32                                                      
    in (l, d, a, g, f, b, j, i, c, m) }                            ] n d                                                                                 
                                                                   q:f32[5] = mul o c                                                                    
                                                                   r:f32[4] = dot_general[                                                               
                                                                     dimension_numbers=(([0], [0]), ([], []))                                            
                                                                     preferred_element_type=float32                                                      
                                                                   ] q b                                                                                 
                                                                   s:f32[5,4] = dot_general[                                                             
                                                                     dimension_numbers=(([], []), ([], []))                                              
                                                                     preferred_element_type=float32                                                      
                                                                   ] q a                                                                                 
                                                                 in (s, p, m, r) }                                                                       
# Using `jax.checkpoint` with policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable:
print_fwd_bwd(f3, W1, W2, W3, x)
                                                                                                                                                                        
  forward computation:                                                   backward computation:                                                                          
                                                                                                                                                                        
  { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let              { lambda ; a:f32[5] b:f32[6] c:f32[7] d:f32[5,4] e:f32[6,5] f:f32[7,6] g:f32[4] h:f32[7]. let  
      e:f32[5] = dot_general[                                                i:f32[5,4] j:f32[6,5] k:f32[7,6] l:f32[4] = remat2[                                        
        dimension_numbers=(([1], [0]), ([], []))                               differentiated=True                                                                      
        preferred_element_type=float32                                         jaxpr={ lambda ; m:f32[5] n:f32[6] o:f32[7] p:f32[5,4] q:f32[6,5] r:f32[7,6]             
      ] a d                                                                        s:f32[4] t:f32[7]. let                                                               
      f:f32[5] = reduce_precision[exponent_bits=8 mantissa_bits=23] e              u:f32[5] = sin m                                                                     
      g:f32[5] = sin f                                                             v:f32[5] = cos m                                                                     
      h:f32[6] = dot_general[                                                      w:f32[6] = sin n                                                                     
        dimension_numbers=(([1], [0]), ([], []))                                   x:f32[6] = cos n                                                                     
        preferred_element_type=float32                                             y:f32[7] = cos o                                                                     
      ] b g                                                                        z:f32[7] = mul t y                                                                   
      i:f32[6] = reduce_precision[exponent_bits=8 mantissa_bits=23] h              ba:f32[6] = dot_general[                                                             
      j:f32[6] = sin i                                                               dimension_numbers=(([0], [0]), ([], []))                                           
      k:f32[7] = dot_general[                                                        preferred_element_type=float32                                                     
        dimension_numbers=(([1], [0]), ([], []))                                   ] z r                                                                                
        preferred_element_type=float32                                             bb:f32[6] = mul ba x                                                                 
      ] c j                                                                        bc:f32[5] = dot_general[                                                             
      l:f32[7] = reduce_precision[exponent_bits=8 mantissa_bits=23] k                dimension_numbers=(([0], [0]), ([], []))                                           
      m:f32[7] = sin l                                                               preferred_element_type=float32                                                     
    in (m, f, i, l, a, b, c, d) }                                                  ] bb q                                                                               
                                                                                   bd:f32[5] = mul bc v                                                                 
                                                                                   be:f32[4] = dot_general[                                                             
                                                                                     dimension_numbers=(([0], [0]), ([], []))                                           
                                                                                     preferred_element_type=float32                                                     
                                                                                   ] bd p                                                                               
                                                                                   bf:f32[5,4] = dot_general[                                                           
                                                                                     dimension_numbers=(([], []), ([], []))                                             
                                                                                     preferred_element_type=float32                                                     
                                                                                   ] bd s                                                                               
                                                                                   bg:f32[6,5] = dot_general[                                                           
                                                                                     dimension_numbers=(([], []), ([], []))                                             
                                                                                     preferred_element_type=float32                                                     
                                                                                   ] bb u                                                                               
                                                                                   bh:f32[7,6] = dot_general[                                                           
                                                                                     dimension_numbers=(([], []), ([], []))                                             
                                                                                     preferred_element_type=float32                                                     
                                                                                   ] z w                                                                                
                                                                                 in (bf, bg, bh, be) }                                                                  
                                                                               policy=<function dot_with_no_batch_dims_saveable at 0x7f79ce6a7f40>                      
                                                                               prevent_cse=True                                                                         
                                                                             ] a b c d e f g h                                                                          
                                                                           in (i, j, k, l) }                                                                            

让我们逐步思考#

注意:在继续之前,查看高级自动微分教程可能会有所帮助。

jax.checkpoint 基础#

jax.linearize()jax.vjp() 中,在如何以及何时计算某些值方面具有灵活性。不同的选择可以在内存使用和 FLOP 之间进行权衡。JAX 使用 jax.checkpoint() 提供对这些选择的控制。

一种选择是在前向传播中,在输入可用后立即执行雅可比系数计算,还是在反向传播中,在需要系数之前执行雅可比系数计算。考虑 sin_vjp 的示例

def sin_vjp(x):
  y = jnp.sin(x)
  cos_x = jnp.cos(x)
  return y, lambda y_bar: cos_x * y_bar

另一个有效的实现是在反向传播而不是前向传播中计算 jnp.cos(x) 的值

def sin_vjp2(x):
  y = jnp.sin(x)
  return y, lambda y_bar: jnp.cos(x) * y_bar

对于此特定函数,两个版本使用的内存量相同,尽管您减少了原始计算(前向传播)的 FLOP,并增加了余切计算(反向传播)的 FLOP。

在函数组合方面还有另一个选择。回想一下两个函数组合的 VJP 规则

def f(x):
  y = g(x)
  z = h(y)
  return z

def f_vjp(x):
  y, g_vjp = jax.vjp(g, x)
  z, h_vjp = jax.vjp(h, y)
  def f_bwd(z_bar):
    y_bar, = h_vjp(z_bar)
    x_bar, = g_vjp(y_bar)
    return x_bar
  return z, f_bwd

另一种选择是

def f_vjp_checkpoint(x):
  y = g(x)
  z, h_vjp = jax.vjp(h, y)
  def f_bwd2(z_bar):
    y_bar, = h_vjp(z_bar)
    _, g_vjp = jax.vjp(g, x)
    x_bar, = g_vjp(y_bar)
    return x_bar
  return z, f_bwd2

用文字来说,这种替代实现不会在前向传播中计算 g_vjp 或其闭包中的残差值。相反,它只在反向传播 f_bwd2 中计算它们。这意味着 f_vjp_checkpoint 需要更少的内存:如果 gh 各自的残差都需要相似的内存量,并且都远大于 x,那么 f_vjp_checkpoint(x) 产生的函数所需的内存是 f_vjp(x) 的一半!

您付出的代价是冗余的工作:在 f_bwd2 中,您必须重新评估 g(x),作为 jax.vjp(g, x) 的一部分,只是为了丢弃它的值(在 _, g_vjp = jax.vjp(g, x) 行中的下划线变量中)。

您可以通过在原始函数 f 的替代定义中使用 jax.checkpoint(),在自动微分中获得这种 VJP 行为,而无需直接编写 VJP 函数。

def f_checkpoint(x):
  y = jax.checkpoint(g)(x)
  z = h(y)
  return z

换句话说,您将 jax.checkpoint() 应用于 gf 的第一阶段 — 而不是应用于 f 本身。这样,当您评估 jax.grad(f_checkpoint)(x) 时,您会得到类似以下的计算

  1. 运行 g 的前向传播,丢弃残差值。

  2. 运行 h 的前向传播,保存残差。

  3. 运行 h 的反向传播,消耗步骤 2 中的残差。

  4. 重新运行 g 的前向传播,保存残差。

  5. 运行 g 的反向传播,消耗步骤 4 中的残差。

也就是说,通过评估 jax.grad(f_checkpoint)(x),我们得到的计算结果与

def f_checkpoint_grad(x):
  y = g(x)                  # step 1
  _, h_vjp = jax.vjp(h)(y)  # step 2
  y_bar, = h_vjp(1.0)       # step 3
  _, g_vjp = jax.vjp(g, x)  # step 4
  x_bar, = g_vjp(y_bar)     # step 5
  return x_bar

一般来说,jax.checkpoint(foo) 是一个新函数,它具有与 foo 相同的输入-输出行为,但在自动微分下,特别是在 jax.linearize()jax.vjp() (及其包装器,如 jax.grad())下行为不同,但在 jax.jvp() 下行为相同。当被微分时,仅在正向传播中存储 jax.checkpoint()-微分函数的输入。在反向传播中,重新计算残差(来自 foo 的中间值及其反向传播所需的雅可比系数)。

请注意,如果 f = lambda x: h(g(x)) 是您要微分的函数(换句话说,如果您想应用 jax.grad(f)),则通过将 jax.checkpoint() 应用于 f 本身,您不会获得任何内存节省。这是因为评估 jax.grad(jax.checkpoint(f))(x) 会导致如下计算:

  1. 运行前向传播,丢弃所有残差。

  2. 立即重新运行前向传播,保存残差。

  3. 运行反向传播,消耗步骤 2 中的残差。

在代码中,您将有类似以下的代码

def f_grad_bad(x):
  _ = f(x)                  # step 1
  _, f_vjp = jax.vjp(f, x)  # step 2
  x_bar, = f_vjp(1.0)       # step 3
  return x_bar

通过将 jax.checkpoint() 应用于 f 的第二阶段 h,您也不会获得任何内存节省。这是因为评估 jax.grad(lambda x: jax.checkpoint(h)(g(x))) 会导致如下计算:

  1. 运行 g 的前向传播,保存残差。

  2. 运行 h 的前向传播,丢弃残差。

  3. 立即重新运行 h 的前向传播,保存残差。

  4. 运行 h 的反向传播,消耗来自步骤 3 的残差。

  5. 运行 g 的反向传播,消耗来自步骤 1 的残差。

在代码中,你可能会看到这样的代码:

def f_grad_bad2(x):
  y, g_vjp = jax.vjp(g, x)  # step 1
  z = h(y)                  # step 2
  _, h_vjp = jax.vjp(h, y)  # step 3
  y_bar, = h_vjp(1.0)       # step 3
  x_bar, = g_vjp(y_bar)     # step 5
  return x_bar

更一般地来说,如果你有一个函数链式组合,例如 f = lambda x: f3(f2(f1(x))),并且对评估 jax.grad(f) 感兴趣,你可以说你

  • 不应该将 jax.checkpoint() 应用于整个函数 f,因为这不会节省任何内存(并且会执行浪费的重复计算)。

  • 不应该将 jax.checkpoint() 应用于最后一个子函数 f3,因为这不会节省任何内存(并且会执行浪费的重复计算)。

  • 可以将 jax.checkpoint() 应用于 f1f2 或它们的组合 lambda x: f2(f1(x)),因为这些中的任何一个都可能节省内存,并且会表达不同的内存/重新计算的权衡。

可保存内容的自定义策略#

如目前所示,使用 jax.checkpoint() 会从一个极端切换到另一个极端

  • 不使用 jax.checkpoint(),JAX 的自动微分倾向于在正向传播中计算所有可能的内容,并将其存储起来以用于反向传播。

  • 使用 jax.checkpoint() 装饰器,你反而会在正向传播中计算尽可能少的内容,并在反向传播中根据需要重新计算值。

为了在这两个极端之间操作,保存一些东西,而不保存其他东西,你可以小心地将 jax.checkpoint() 装饰器放置在子函数上。但这需要编辑要微分的函数,例如模型代码,这可能不方便。尝试不同的变体也可能很困难。

因此,另一种方法是使用 jax.checkpoint()policy 参数。策略是一个可调用对象(即函数),它将一阶原始应用的类型级规范作为输入,并返回一个布尔值,指示是否允许将相应的输出值保存为残差(或者是否必须在(余)切线计算中根据需要重新计算)。为了编写健壮的代码,应该从 jax.checkpoint_policies() 上的属性中选择策略,例如 jax.checkpoint_policies.dots_with_no_batch_dims_saveable(),因为编写自定义策略可调用对象的 API 被认为是内部的。

例如,考虑这个要微分的函数

def loss(params, x, y):
  return jnp.sum((predict(params, x) - y)**2)

def predict(params, x):
  *Ws, Wlast = params
  for W in Ws:
    x = layer(W, x)
  x = jnp.dot(Wlast, x)
  return x

def layer(W, x):
  return jnp.sin(jnp.dot(W, x))
W1 = W2 = W3 = jnp.ones((4, 4))
params = [W1, W2, W3]
x = jnp.ones(4)
y = jnp.ones(4)
print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] output of sin from /tmp/ipykernel_1011/4230705069.py:12 (layer)
f32[4] output of cos from /tmp/ipykernel_1011/4230705069.py:12 (layer)
f32[4] output of sin from /tmp/ipykernel_1011/4230705069.py:12 (layer)
f32[4] output of cos from /tmp/ipykernel_1011/4230705069.py:12 (layer)
f32[4] output of mul from /tmp/ipykernel_1011/4230705069.py:2 (loss)

与其在正向传播中保存这么多值,不如你只想保存没有批处理维度的矩阵乘法的结果(因为它们可能是 FLOP 而不是内存受限的)。你可以使用策略 jax.checkpoint_policies.dots_with_no_batch_dims_saveable() 来实现这一点

loss_checkpoint = jax.checkpoint(loss, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
print_saved_residuals(loss_checkpoint, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] from the argument y
f32[4] output of reduce_precision from /tmp/ipykernel_1011/4230705069.py:12 (layer)
f32[4] output of reduce_precision from /tmp/ipykernel_1011/4230705069.py:12 (layer)
f32[4] output of reduce_precision from /tmp/ipykernel_1011/4230705069.py:8 (predict)

另请注意,通过提供策略,你不需要编辑定义 losspredictlayer 的代码。如果你想在调用代码(例如训练脚本)中尝试策略,而无需更改库代码(例如神经网络库),这将特别方便。

某些策略可以引用使用 jax.ad_checkpoint.checkpoint_name() 命名的值

from jax.ad_checkpoint import checkpoint_name

def predict(params, x):
  *Ws, Wlast = params
  for i, W in enumerate(Ws):
    x = layer(W, x)
    x = checkpoint_name(x, name=f'layer{i}_output')
  x = jnp.dot(Wlast, x)
  return x

本身,jax.ad_checkpoint import.checkpoint_name() 只是一个恒等函数。但由于某些策略函数知道要查找它们,你可以使用这些名称来控制是否将 jax.ad_checkpoint import.checkpoint_name() 输出的某些值视为可保存的

print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] output of cos from /tmp/ipykernel_1011/4230705069.py:12 (layer)
f32[4] named 'layer0_output' from /tmp/ipykernel_1011/178264713.py:7 (predict)
f32[4] output of cos from /tmp/ipykernel_1011/4230705069.py:12 (layer)
f32[4] named 'layer1_output' from /tmp/ipykernel_1011/178264713.py:7 (predict)
f32[4] output of mul from /tmp/ipykernel_1011/4230705069.py:2 (loss)
loss_checkpoint2 = jax.checkpoint(loss, policy=jax.checkpoint_policies.save_any_names_but_these('layer1_output'))
print_saved_residuals(loss_checkpoint2, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] from the argument y

另一个引用名称的策略是 jax.checkpoint_policies.save_only_these_names

策略列表#

策略如下:

  • everything_saveable(默认策略,就像根本没有使用 jax.checkpoint 一样)

  • nothing_saveable(即重新物化所有内容,就像根本没有使用自定义策略一样)

  • dots_saveable 或其别名 checkpoint_dots

  • dots_with_no_batch_dims_saveable 或其别名 checkpoint_dots_with_no_batch_dims

  • save_anything_but_these_names(保存除带有给定名称的 checkpoint_name 的输出之外的任何值)

  • save_any_names_but_these(仅保存命名值,即 checkpoint_name 的任何输出,但具有给定名称的输出除外)

  • save_only_these_names(仅保存命名值,并且仅在给定的名称中保存)

  • offload_dot_with_no_batch_dimsdots_with_no_batch_dims_saveable 相同,但卸载到 CPU 内存而不是重新计算。

  • save_and_offload_only_these_namessave_only_these_names 相同,但卸载到 CPU 内存而不是重新计算。

  • save_from_both_policies(policy_1, policy_2)(类似于逻辑 or,因此如果根据 policy_1 *或* policy_2 可以保存残差,则该残差是可保存的)

策略仅指示什么是可保存的;只有当反向传播实际需要某个值时,才会保存该值。

高级:递归 jax.checkpoint#

通过以正确的方式应用 jax.checkpoint(),可以在内存使用和(重新)计算之间表达许多权衡。一个令人惊讶的例子是*递归*检查点,其中你将 jax.checkpoint() 应用于一个函数,该函数本身以某种方式调用 jax.checkpoint() 修饰的函数,以便来自 \(D\) 个函数的链式组合的内存使用量以 \(\mathcal{O}(\log_2 D)\) 的比例缩放,而不是 \(\mathcal{O}(D)\)

作为一个玩具示例,考虑多个 jax.numpy.sin() 函数的链式组合

def chain_compose(funs):
  def f(x):
    for fun in funs:
      x = fun(x)
    return x
  return f

f = chain_compose([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] output of cos from /tmp/ipykernel_1011/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1011/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1011/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1011/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1011/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1011/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1011/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1011/410288286.py:4 (f)

通常,存储的残差的数量与链的长度成线性比例

f = chain_compose([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] output of cos from /tmp/ipykernel_1011/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1011/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1011/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1011/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1011/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1011/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1011/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1011/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1011/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1011/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1011/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1011/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1011/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1011/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1011/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1011/410288286.py:4 (f)

但你可以递归地应用 jax.checkpoint() 以改善缩放

def recursive_checkpoint(funs):
  if len(funs) == 1:
    return funs[0]
  elif len(funs) == 2:
    f1, f2 = funs
    return lambda x: f1(f2(x))
  else:
    f1 = recursive_checkpoint(funs[:len(funs)//2])
    f2 = recursive_checkpoint(funs[len(funs)//2:])
    return lambda x: f1(jax.checkpoint(f2)(x))
f = recursive_checkpoint([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] from the argument x
f32[] output of sin from /tmp/ipykernel_1011/1943107544.py:6 (<lambda>)
f32[] output of cos from /tmp/ipykernel_1011/1943107544.py:6 (<lambda>)
f32[] output of cos from /tmp/ipykernel_1011/1943107544.py:6 (<lambda>)
f = recursive_checkpoint([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] from the argument x
f32[] output of sin from /tmp/ipykernel_1011/1943107544.py:6 (<lambda>)
f32[] output of sin from /tmp/ipykernel_1011/1943107544.py:6 (<lambda>)
f32[] output of cos from /tmp/ipykernel_1011/1943107544.py:6 (<lambda>)
f32[] output of cos from /tmp/ipykernel_1011/1943107544.py:6 (<lambda>)

这里的代价,和往常一样,是重新计算:特别是,你最终执行了 \(\mathcal{O}(\log_2 D)\) 倍的 FLOP

f = chain_compose([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
                                                                                                                                 
  forward computation:                  backward computation:                                                                    
                                                                                                                                 
  { lambda ; a:f32[]. let               { lambda ; a:f32[] b:f32[] c:f32[] d:f32[] e:f32[] f:f32[] g:f32[] h:f32[] i:f32[]. let  
      b:f32[] = sin a                       j:f32[] = mul i h                                                                    
      c:f32[] = cos a                       k:f32[] = mul j g                                                                    
      d:f32[] = sin b                       l:f32[] = mul k f                                                                    
      e:f32[] = cos b                       m:f32[] = mul l e                                                                    
      f:f32[] = sin d                       n:f32[] = mul m d                                                                    
      g:f32[] = cos d                       o:f32[] = mul n c                                                                    
      h:f32[] = sin f                       p:f32[] = mul o b                                                                    
      i:f32[] = cos f                       q:f32[] = mul p a                                                                    
      j:f32[] = sin h                     in (q,) }                                                                              
      k:f32[] = cos h                                                                                                            
      l:f32[] = sin j                                                                                                            
      m:f32[] = cos j                                                                                                            
      n:f32[] = sin l                                                                                                            
      o:f32[] = cos l                                                                                                            
      p:f32[] = sin n                                                                                                            
      q:f32[] = cos n                                                                                                            
    in (p, c, e, g, i, k, m, o, q) }                                                                                             
f = recursive_checkpoint([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
                                                                                                                                        
  forward computation:                                                              backward computation:                               
                                                                                                                                        
  { lambda ; a:f32[]. let                                                           { lambda ; a:f32[] b:f32[] c:f32[] d:f32[]. let     
      b:f32[] = remat2[                                                                 e:f32[] = mul d c                               
        differentiated=False                                                            f:f32[] = mul e b                               
        jaxpr={ lambda ; c:f32[]. let d:f32[] = sin c; e:f32[] = sin d in (e,) }        g:f32[] = remat2[                               
        policy=None                                                                       differentiated=True                           
        prevent_cse=True                                                                  jaxpr={ lambda ; h:f32[] i:f32[]. let         
      ] a                                                                                     j:f32[] = sin h                           
      f:f32[] = sin b                                                                         k:f32[] = cos h                           
      g:f32[] = sin f                                                                         l:f32[] = cos j                           
      h:f32[] = sin g                                                                         m:f32[] = mul i l                         
      i:f32[] = sin h                                                                         n:f32[] = mul m k                         
      j:f32[] = sin i                                                                       in (n,) }                                   
      k:f32[] = cos i                                                                     policy=None                                   
      l:f32[] = sin j                                                                     prevent_cse=True                              
      m:f32[] = cos j                                                                   ] a f                                           
    in (l, g, a, k, m) }                                                                o:f32[] = remat2[                               
                                                                                          differentiated=True                           
                                                                                          jaxpr={ lambda ; p:f32[] q:f32[]. let         
                                                                                              r:f32[] = sin p                           
                                                                                              s:f32[] = sin r                           
                                                                                              t:f32[] = sin s                           
                                                                                              u:f32[] = cos s                           
                                                                                              v:f32[] = cos t                           
                                                                                              w:f32[] = mul q v                         
                                                                                              x:f32[] = mul w u                         
                                                                                              y:f32[] = remat2[                         
                                                                                                differentiated=True                     
                                                                                                jaxpr={ lambda ; z:f32[] ba:f32[]. let  
                                                                                                    bb:f32[] = sin z                    
                                                                                                    bc:f32[] = cos z                    
                                                                                                    bd:f32[] = cos bb                   
                                                                                                    be:f32[] = mul ba bd                
                                                                                                    bf:f32[] = mul be bc                
                                                                                                  in (bf,) }                            
                                                                                                policy=None                             
                                                                                                prevent_cse=True                        
                                                                                              ] p x                                     
                                                                                            in (y,) }                                   
                                                                                          policy=None                                   
                                                                                          prevent_cse=True                              
                                                                                        ] 3.0 g                                         
                                                                                      in (o,) }                                         

实用注意事项#

当微分函数被分段到 XLA 进行编译时 - 例如,通过将 jax.jit() 应用于包含 jax.grad() 调用的函数 - XLA 将自动优化计算,包括何时计算或重新物化值的决策。因此,**对于 jax.jit() 下的微分函数,通常不需要 jax.checkpoint()**。XLA 会为你优化一切。

一个例外是使用分段控制流时,例如 jax.lax.scan()。跨多个控制流原语(例如,跨正向传播 scan 和相应的反向传播 scan)的自动编译器优化通常不如彻底。因此,通常最好在传递给 jax.lax.scan() 的主体函数上使用 jax.checkpoint()

例如,大型 Transformer 模型中的一种常见模式是将架构表示为对层进行 jax.lax.scan(),以减少编译时间。也就是说,使用简单的全连接网络作为类比,而不是编写如下内容

LayerParam = tuple[jnp.ndarray, jnp.ndarray]  # Weights-bias pair for a layer.
ParamsList = list[LayerParam]

def net(params: ParamsList, x: jnp.ndarray):
  for W, b in params:
    x = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return x

而是使用 jax.lax.scan() 迭代层应用程序

params = [(jnp.array([[0.5, 0.5], [1., 1.]]), jnp.array([0.5, 0.5])), 
          (jnp.array([[0.5, 0.5], [1., 1.]]), jnp.array([0.5, 0.5]))]

all_weights = jnp.stack([W for W, _ in params])
all_biases = jnp.stack([b for _, b in params])

def layer(x, W_b_pair):
  W, b = W_b_pair
  out = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return out, None

def net(all_weights, all_biases, x):
  x, _ = jax.lax.scan(layer, x, (all_weights, all_biases))
  return x

这个基于层的扫描版本减少了编译时间,但通过阻止一些编译器优化,它可能会导致梯度计算效率低下。为了缓解这个问题,你可以在扫描的函数上使用 jax.checkpoint()

from functools import partial

@partial(jax.checkpoint,
         policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def layer(x, W_b_pair):
  W, b = W_b_pair
  out = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return out, None

通过这种方式使用 jax.checkpoint(),您正在手动控制 JAX 的自动微分在正向和反向传递之间保存哪些值,因此不依赖 XLA 优化来为您选择。