使用 jax.checkpoint(又名 jax.remat)控制自动微分的保存值#
import jax
import jax.numpy as jnp
概述#
使用 jax.grad 的 jax.checkpoint 装饰器(别名为 jax.remat)来控制在正向传播中保存哪些中间值,以及在反向传播中重新计算哪些值,从而在内存和浮点运算次数之间进行权衡。
不要错过实用说明,其中讨论了 jax.checkpoint 与 jax.jit 的交互方式。
如果不使用 jax.checkpoint,jax.grad(f)(x) 的正向传播会保存雅可比系数和其他中间值,供反向传播使用。我们将这些保存的值称为残差。
def g(W, x):
y = jnp.dot(W, x)
return jnp.sin(y)
def f(W1, W2, W3, x):
x = g(W1, x)
x = g(W2, x)
x = g(W3, x)
return x
W1 = jnp.ones((5, 4))
W2 = jnp.ones((6, 5))
W3 = jnp.ones((7, 6))
x = jnp.ones(4)
# Inspect the 'residual' values to be saved on the forward pass
# if we were to evaluate `jax.grad(f)(W1, W2, W3, x)`
from jax.ad_checkpoint import print_saved_residuals
jax.ad_checkpoint.print_saved_residuals(f, W1, W2, W3, x)
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
f32[5] output of cos from <ipython-input-4-f510dde58e22>:3 (g)
f32[6] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
f32[6] output of cos from <ipython-input-4-f510dde58e22>:3 (g)
f32[7] output of cos from <ipython-input-4-f510dde58e22>:3 (g)
通过将 jax.checkpoint 应用于子函数,作为装饰器或在特定应用点,我们迫使 JAX 不保存该子函数的任何残差。相反,只有 jax.checkpoint 装饰函数的输入可能会被保存,并且在反向传播中消耗的任何残差都会根据需要从这些输入重新计算。
def f2(W1, W2, W3, x):
x = jax.checkpoint(g)(W1, x)
x = jax.checkpoint(g)(W2, x)
x = jax.checkpoint(g)(W3, x)
return x
jax.ad_checkpoint.print_saved_residuals(f2, W1, W2, W3, x)
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
f32[6] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
在这里,两个 sin 应用的值被保存,因为它们是 jax.checkpoint 装饰的 g 函数后续应用的参数,并且 jax.checkpoint 装饰函数的输入可能会被保存。但 cos 应用的任何值都不会被保存。
要控制哪些值可以保存,而无需编辑要微分的函数的定义,您可以使用重构策略。以下示例保存了仅与没有批处理维度的 dot 操作相关的结果(因为它们通常是浮点运算密集型的,因此值得保存而不是重新计算)。
f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
jax.ad_checkpoint.print_saved_residuals(f3, W1, W2, W3, x)
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)
f32[6] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)
f32[7] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)
您还可以使用策略引用使用 jax.ad_checkpoint.checkpoint_name 命名的中间值。
from jax.ad_checkpoint import checkpoint_name
def f4(W1, W2, W3, x):
x = checkpoint_name(g(W1, x), name='a')
x = checkpoint_name(g(W2, x), name='b')
x = checkpoint_name(g(W3, x), name='c')
return x
f4 = jax.checkpoint(f4, policy=jax.checkpoint_policies.save_only_these_names('a'))
jax.ad_checkpoint.print_saved_residuals(f4, W1, W2, W3, x)
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] named 'a' from <ipython-input-7-fc0ed1c14b8d>:4 (f4)
在把玩这些示例时,我们可以使用此笔记本中定义的 print_fwd_bwd 工具来更详细地了解正在发生的情况。
from jax.tree_util import tree_flatten, tree_unflatten
from rich.console import Console
from rich.table import Table
import rich.text
def print_fwd_bwd(f, *args, **kwargs) -> None:
args, in_tree = tree_flatten((args, kwargs))
def f_(*args):
args, kwargs = tree_unflatten(in_tree, args)
return f(*args, **kwargs)
fwd = jax.make_jaxpr(lambda *args: jax.vjp(f_, *args))(*args).jaxpr
y, f_vjp = jax.vjp(f_, *args)
res, in_tree = tree_flatten(f_vjp)
def g_(*args):
*res, y = args
f_vjp = tree_unflatten(in_tree, res)
return f_vjp(y)
bwd = jax.make_jaxpr(g_)(*res, y).jaxpr
table = Table(show_header=False, show_lines=True, padding=(1, 2, 0, 2), box=None)
table.add_row("[bold green]forward computation:",
"[bold green]backward computation:")
table.add_row(rich.text.Text.from_ansi(str(fwd)),
rich.text.Text.from_ansi(str(bwd)))
console = Console(width=240, force_jupyter=True)
console.print(table)
def _renderable_repr(self):
return self.html
rich.jupyter.JupyterRenderable._repr_html_ = _renderable_repr
# no use of jax.checkpoint:
print_fwd_bwd(f, W1, W2, W3, x)
forward computation: backward computation: { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let { lambda ; a:f32[7] b:f32[6] c:f32[7,6] d:f32[6] e:f32[5] f:f32[6,5] g:f32[5] h:f32[4] e:f32[5] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a d i:f32[5,4] j:f32[7]. let f:f32[5] = sin e k:f32[7] = mul j a g:f32[5] = cos e l:f32[6] = dot_general[dimension_numbers=(([0], [0]), ([], []))] k c h:f32[6] = dot_general[dimension_numbers=(([1], [0]), ([], []))] b f m:f32[7,6] = dot_general[dimension_numbers=(([], []), ([], []))] k b i:f32[6] = sin h n:f32[6] = mul l d j:f32[6] = cos h o:f32[5] = dot_general[dimension_numbers=(([0], [0]), ([], []))] n f k:f32[7] = dot_general[dimension_numbers=(([1], [0]), ([], []))] c i p:f32[6,5] = dot_general[dimension_numbers=(([], []), ([], []))] n e l:f32[7] = sin k q:f32[5] = mul o g m:f32[7] = cos k r:f32[4] = dot_general[dimension_numbers=(([0], [0]), ([], []))] q i in (l, m, i, c, j, f, b, g, d, a) } s:f32[5,4] = dot_general[dimension_numbers=(([], []), ([], []))] q h in (s, p, m, r) }
# using jax.checkpoint with policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable:
print_fwd_bwd(f3, W1, W2, W3, x)
forward computation: backward computation: { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let { lambda ; a:f32[5] b:f32[6] c:f32[7] d:f32[5,4] e:f32[6,5] f:f32[7,6] g:f32[4] h:f32[7]. let e:f32[5] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a d i:f32[5,4] j:f32[6,5] k:f32[7,6] l:f32[4] = remat2[ f:f32[5] = sin e differentiated=True g:f32[6] = dot_general[dimension_numbers=(([1], [0]), ([], []))] b f jaxpr={ lambda ; m:f32[5] n:f32[6] o:f32[7] p:f32[5,4] q:f32[6,5] r:f32[7,6] h:f32[6] = sin g s:f32[4] t:f32[7]. let i:f32[7] = dot_general[dimension_numbers=(([1], [0]), ([], []))] c h u:f32[5] = sin m j:f32[7] = sin i v:f32[5] = cos m in (j, e, g, i, a, b, c, d) } w:f32[6] = sin n x:f32[6] = cos n y:f32[7] = cos o z:f32[7] = mul t y ba:f32[6] = dot_general[dimension_numbers=(([0], [0]), ([], []))] z r bb:f32[6] = mul ba x bc:f32[5] = dot_general[dimension_numbers=(([0], [0]), ([], []))] bb q bd:f32[5] = mul bc v be:f32[4] = dot_general[dimension_numbers=(([0], [0]), ([], []))] bd p bf:f32[5,4] = dot_general[dimension_numbers=(([], []), ([], []))] bd s bg:f32[6,5] = dot_general[dimension_numbers=(([], []), ([], []))] bb u bh:f32[7,6] = dot_general[dimension_numbers=(([], []), ([], []))] z w in (bf, bg, bh, be) } policy=<function dot_with_no_batch_dims at 0x7f5e469b1700> prevent_cse=True ] a b c d e f g h in (i, j, k, l) }
让我们一步一步来思考#
您可能想先(重新)阅读自动微分食谱第一部分。
jax.checkpoint 的基础知识#
在 jax.linearize 和 jax.vjp 中,计算某些值的方式和时间都具有灵活性。不同的选择可以在内存使用与浮点运算次数之间进行权衡。JAX 通过 jax.checkpoint 控制这些选择。
其中一种选择是,是在正向传播中,一旦输入可用,就执行雅可比系数的计算,还是在反向传播中,在需要系数之前才执行。考虑 sin_vjp 的例子。
def sin_vjp(x):
y = jnp.sin(x)
cos_x = jnp.cos(x)
return y, lambda y_bar: cos_x * y_bar
另一个有效的实现是在反向传播而非正向传播中计算 jnp.cos(x) 的值。
def sin_vjp2(x):
y = jnp.sin(x)
return y, lambda y_bar: jnp.cos(x) * y_bar
对于这个特定的函数,两个版本的内存使用量是相同的,尽管我们减少了原始计算(即正向传播)的浮点运算次数,并增加了余切计算(即反向传播)的浮点运算次数。
在函数组合方面还有另一个选择。回想一下我们对两个函数组合的 VJP 规则。
def f(x):
y = g(x)
z = h(y)
return z
def f_vjp(x):
y, g_vjp = jax.vjp(g, x)
z, h_vjp = jax.vjp(h, y)
def f_bwd(z_bar):
y_bar, = h_vjp(z_bar)
x_bar, = g_vjp(y_bar)
return x_bar
return z, f_bwd
一种替代方法是:
def f_vjp_checkpoint(x):
y = g(x)
z, h_vjp = jax.vjp(h, y)
def f_bwd2(z_bar):
y_bar, = h_vjp(z_bar)
_, g_vjp = jax.vjp(g, x)
x_bar, = g_vjp(y_bar)
return x_bar
return z, f_bwd2
换句话说,这种替代实现方式在正向传播中不计算 g_vjp 或其闭包中的残差值。相反,它仅在反向传播 f_bwd2 中计算它们。这意味着 f_vjp_checkpoint 需要更少的内存:如果 g 和 h 为其残差各需要相似的内存量,且都远大于 x,那么 f_vjp_checkpoint(x) 生成的函数所需的内存是 f_vjp(x) 的一半!
我们付出的代价是冗余工作:在 f_bwd2 中,我们必须重新评估 g(x),作为 jax.vjp(g, x) 的一部分,仅仅是为了丢弃其值(在 _, g_vjp = jax.vjp(g, x) 行中的下划线变量)。
我们可以通过使用 jax.checkpoint 替代原始函数 f 的定义来实现自动微分中的这种 VJP 行为——而无需直接编写 VJP 函数。
def f_checkpoint(x):
y = jax.checkpoint(g)(x)
z = h(y)
return z
换句话说,我们将 jax.checkpoint 应用于 f 的第一个阶段 g,而不是 f 本身。这样,当我们计算 jax.grad(f_checkpoint)(x) 时,我们会得到类似以下的计算:
运行
g的正向传播,丢弃残差值;运行
h的正向传播,保存残差;运行
h的反向传播,消耗来自步骤 2 的残差;重新运行
g的正向传播,保存残差;运行
g的反向传播,消耗来自步骤 4 的残差。
也就是说,通过计算 jax.grad(f_checkpoint)(x),我们会得到与以下计算相同的计算:
def f_checkpoint_grad(x):
y = g(x) # step 1
_, h_vjp = jax.vjp(h)(y) # step 2
y_bar, = h_vjp(1.0) # step 3
_, g_vjp = jax.vjp(g, x) # step 4
x_bar, = g_vjp(y_bar) # step 5
return x_bar
总之,jax.checkpoint(foo) 是一个新函数,它具有与 foo 相同的输入-输出行为,但在自动微分下表现不同,尤其是在 jax.linearize 和 jax.vjp(及其包装器,如 jax.grad)下,但在 jax.jvp 下不表现不同。微分时,只有 jax.checkpoint 装饰函数的输入会在正向传播中存储;在反向传播中,残差(即 foo 的中间值及其反向传播所需的雅可比系数)会被重新计算。
请注意,如果我们想要微分的函数是 f = lambda x: h(g(x)),即如果我们想计算 jax.grad(f),那么将 jax.checkpoint 应用于 f 本身并不会节省内存。这是因为计算 jax.grad(jax.checkpoint(f))(x) 将导致类似以下计算:
运行正向传播,丢弃所有残差;
立即重新运行正向传播,保存残差;
运行反向传播,消耗来自步骤 2 的残差。
也就是说,在代码中会有类似这样的内容:
def f_grad_bad1(x):
_ = f(x) # step 1
_, f_vjp = jax.vjp(f, x) # step 2
x_bar, = f_vjp(1.0) # step 3
return x_bar
将 jax.checkpoint 应用于 f 的第二个阶段 h 也不会节省内存。这是因为计算 jax.grad(lambda x: jax.checkpoint(h)(g(x))) 将导致类似以下计算:
运行
g的正向传播,保存残差;运行
h的正向传播,丢弃残差;立即重新运行
h的正向传播,保存残差;运行
h的反向传播,消耗来自步骤 3 的残差;运行
g的反向传播,消耗来自步骤 1 的残差。
也就是说,在代码中会有类似这样的内容:
def f_grad_bad2(x):
y, g_vjp = jax.vjp(g, x) # step 1
z = h(y) # step 2
_, h_vjp = jax.vjp(h, y) # step 3
y_bar, = h_vjp(1.0) # step 3
x_bar, = g_vjp(y_bar) # step 5
return x_bar
更一般地说,如果我们有一个函数链组合,例如 f = lambda x: f3(f2(f1(x))),并且我们有兴趣计算 jax.grad(f),我们可以说:
我们不应该将
jax.checkpoint应用于整个函数f,因为这不会节省任何内存(并且会进行浪费性的重新计算);我们不应该将
jax.checkpoint应用于最后一个子函数f3,因为这不会节省任何内存(并且会进行浪费性的重新计算);我们可以将
jax.checkpoint应用于f1、f2或它们的组合lambda x: f2(f1(x)),因为其中任何一个都可能节省内存,并且会表达不同的内存/重新计算权衡。
可保存内容的自定义策略#
到目前为止,使用 jax.checkpoint 在两个极端之间切换:
不使用
jax.checkpoint时,JAX 的自动微分倾向于在正向传播中计算所有可能的值,并将其存储以备反向传播使用;使用
jax.checkpoint装饰器时,我们反而会在正向传播中计算尽可能少的值,并在反向传播中按需重新计算它们。
要在这两个极端之间操作,保存一些东西而不保存另一些,我们可以小心地将 jax.checkpoint 装饰器应用到子函数上。但这需要编辑要微分的函数,例如模型代码,这可能很不方便。它也可能难以尝试不同的变体。
因此,一种替代方法是使用 jax.checkpoint 的 policy 参数。策略是一个可调用对象(即一个函数),它以一阶基本操作的类型级规范作为输入,并返回一个布尔值,指示是否允许将相应的输出值保存为残差(或者在(共)切线计算中按需重新计算)。为了编写健壮的代码,应从 jax.checkpoint_policies 的属性中选择一个策略,例如 jax.checkpoint_policies.dots_with_no_batch_dims_saveable,因为编写自定义策略可调用对象的 API 被认为是内部的。
例如,考虑这个要微分的函数:
def loss(params, x, y):
return jnp.sum((predict(params, x) - y)**2)
def predict(params, x):
*Ws, Wlast = params
for W in Ws:
x = layer(W, x)
x = jnp.dot(Wlast, x)
return x
def layer(W, x):
return jnp.sin(jnp.dot(W, x))
W1 = W2 = W3 = jnp.ones((4, 4))
params = [W1, W2, W3]
x = jnp.ones(4)
y = jnp.ones(4)
print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] output of sin from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of sin from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of mul from <ipython-input-18-3808b5023c3d>:2 (loss)
与其在正向传播中保存如此多的值,不如我们只希望保存没有批处理维度的矩阵乘法的结果(因为它们可能更受浮点运算限制而不是内存限制)。我们可以使用策略 jax.checkpoint_policies.dots_with_no_batch_dims_saveable 来实现这一点。
loss_checkpoint = jax.checkpoint(loss, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
print_saved_residuals(loss_checkpoint, params, x, y)
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] from the argument 'y'
f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:8 (predict)
另请注意,通过提供策略,我们无需编辑定义 loss、predict 或 layer 的代码。如果要在调用代码(例如训练脚本)中尝试策略,而不更改库代码(例如神经网络库),这一点尤其方便。
一些策略可以引用使用 jax.ad_checkpoint.checkpoint_name 命名的值。
def predict(params, x):
*Ws, Wlast = params
for i, W in enumerate(Ws):
x = layer(W, x)
x = checkpoint_name(x, name=f'layer{i}_output')
x = jnp.dot(Wlast, x)
return x
checkpoint_name 本身只是一个恒等函数。但由于一些策略函数知道要查找它们,我们可以使用这些名称来控制 checkpoint_name 输出的某些值是否被视为可保存。
print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] named 'layer0_output' from <ipython-input-22-e48aedf368ad>:7 (predict)
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] named 'layer1_output' from <ipython-input-22-e48aedf368ad>:7 (predict)
f32[4] output of mul from <ipython-input-18-3808b5023c3d>:2 (loss)
loss_checkpoint2 = jax.checkpoint(loss, policy=jax.checkpoint_policies.save_any_names_but_these('layer1_output'))
print_saved_residuals(loss_checkpoint2, params, x, y)
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] from the argument 'y'
另一个引用名称的策略是 jax.checkpoint_policies.save_only_these_names。
策略列表可以在此处找到。
策略仅指示什么可以保存;只有当反向传播实际需要某个值时,该值才会被保存。
高级:递归 jax.checkpoint#
通过以正确的方式应用 jax.checkpoint,可以表达许多内存使用量与(重新)计算之间的权衡。一个令人惊讶的例子是递归检查点,其中我们将 jax.checkpoint 应用于一个函数,该函数本身调用 jax.checkpoint 装饰的函数,从而使得 \(D\) 个函数的链式组合的内存使用量缩放为 \(\mathcal{O}(\log_2 D)\) 而不是 \(\mathcal{O}(D)\)。
作为示例,考虑多个 jnp.sin 函数的链式组合。
def chain_compose(funs):
def f(x):
for fun in funs:
x = fun(x)
return x
return f
f = chain_compose([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
一般来说,存储的残差数量与链的长度呈线性关系。
f = chain_compose([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
但是,我们可以递归地应用 jax.checkpoint 来改进缩放。
def recursive_checkpoint(funs):
if len(funs) == 1:
return funs[0]
elif len(funs) == 2:
f1, f2 = funs
return lambda x: f1(f2(x))
else:
f1 = recursive_checkpoint(funs[:len(funs)//2])
f2 = recursive_checkpoint(funs[len(funs)//2:])
return lambda x: f1(jax.checkpoint(f2)(x))
f = recursive_checkpoint([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] from the argument 'x'
f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f = recursive_checkpoint([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] from the argument 'x'
f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
这里的成本,和往常一样,是重新计算:特别是,我们最终会执行 \(\mathcal{O}(\log_2 D)\) 倍的浮点运算。
f = chain_compose([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
forward computation: backward computation: { lambda ; a:f32[]. let { lambda ; a:f32[] b:f32[] c:f32[] d:f32[] e:f32[] f:f32[] g:f32[] h:f32[] i:f32[]. let b:f32[] = sin a j:f32[] = mul i a c:f32[] = cos a k:f32[] = mul j b d:f32[] = sin b l:f32[] = mul k c e:f32[] = cos b m:f32[] = mul l d f:f32[] = sin d n:f32[] = mul m e g:f32[] = cos d o:f32[] = mul n f h:f32[] = sin f p:f32[] = mul o g i:f32[] = cos f q:f32[] = mul p h j:f32[] = sin h in (q,) } k:f32[] = cos h l:f32[] = sin j m:f32[] = cos j n:f32[] = sin l o:f32[] = cos l p:f32[] = sin n q:f32[] = cos n in (p, q, o, m, k, i, g, e, c) }
f = recursive_checkpoint([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
forward computation: backward computation: { lambda ; a:f32[]. let { lambda ; a:f32[] b:f32[] c:f32[] d:f32[]. let b:f32[] = remat2[ e:f32[] = mul d a differentiated=False f:f32[] = mul e b jaxpr={ lambda ; c:f32[]. let d:f32[] = sin c; e:f32[] = sin d in (e,) } g:f32[] = remat2[ policy=None differentiated=True prevent_cse=True jaxpr={ lambda ; h:f32[] i:f32[]. let ] a j:f32[] = sin h f:f32[] = sin b k:f32[] = cos h g:f32[] = sin f l:f32[] = cos j h:f32[] = sin g m:f32[] = mul i l i:f32[] = sin h n:f32[] = mul m k j:f32[] = sin i in (n,) } k:f32[] = cos i policy=None l:f32[] = sin j prevent_cse=True m:f32[] = cos j ] c f in (l, m, k, g, a) } o:f32[] = remat2[ differentiated=True jaxpr={ lambda ; p:f32[] q:f32[]. let r:f32[] = sin p s:f32[] = sin r t:f32[] = sin s u:f32[] = cos s v:f32[] = cos t w:f32[] = mul q v x:f32[] = mul w u y:f32[] = remat2[ differentiated=True jaxpr={ lambda ; z:f32[] ba:f32[]. let bb:f32[] = sin z bc:f32[] = cos z bd:f32[] = cos bb be:f32[] = mul ba bd bf:f32[] = mul be bc in (bf,) } policy=None prevent_cse=True ] p x in (y,) } policy=None prevent_cse=True ] 3.0 g in (o,) }
实用说明#
当微分函数被分阶段编译到 XLA 时,例如通过将包含 jax.grad 调用的函数应用 jax.jit,XLA 将自动优化计算,包括何时计算或重新构建值。因此,**对于 jax.jit 下的微分函数,通常不需要 jax.checkpoint**。XLA 会为您优化。
一个例外是使用分阶段的控制流时,例如 jax.lax.scan。跨多个控制流原语的自动编译器优化,例如跨正向传播的 scan 和相应的反向传播 scan,通常不如那么彻底。因此,通常最好将 jax.checkpoint 应用于传递给 jax.lax.scan 的函数体。
例如,大型Transformer模型中的一个常见模式是将架构表示为层上的 jax.lax.scan,以减少编译时间。也就是说,以简单的全连接网络为例,而不是写成这样:
LayerParam = tuple[jnp.ndarray, jnp.ndarray] # weights, bias pair for a layer
ParamsList = list[LayerParam]
def net(params: ParamsList, x: jnp.ndarray):
for W, b in params:
x = jnp.maximum(jnp.dot(x, W) + b, 0.)
return x
我们将改为使用 jax.lax.scan 迭代层应用:
StackedWeights = jnp.ndarray # all weight matrices stacked together
StackedBiases = jnp.ndarray # all bias vectors stacked together
all_weights = jnp.stack([W for W, _ in params])
all_biases = jnp.stack([b for _, b in params])
def layer(x, W_b_pair):
W, b = W_b_pair
out = jnp.maximum(jnp.dot(x, W) + b, 0.)
return out, None
def net(all_weights, all_biases, x):
x, _ = jax.lax.scan(layer, x, (all_weights, all_biases))
return x
这种逐层扫描版本减少了编译时间,但通过阻碍某些编译器优化,它可能会导致梯度计算效率低下。为了缓解这个问题,我们将 jax.checkpoint 应用于扫描的函数:
from functools import partial
@partial(jax.checkpoint,
policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def layer(x, W_b_pair):
W, b = W_b_pair
out = jnp.maximum(jnp.dot(x, W) + b, 0.)
return out, None
通过以这种方式使用 jax.checkpoint,我们手动控制 JAX 的自动微分在正向传播和反向传播之间保存哪些值,因此不依赖 XLA 优化来为我们选择。