使用 jax.checkpoint
(jax.remat
) 进行梯度检查点#
在本教程中,您将学习如何使用 jax.checkpoint()
(也称为 jax.remat()
)来控制 JAX 自动微分的保存值,这在机器学习中特别有用。
如果您不熟悉自动微分 (autodiff) 或者需要温习一下,JAX 提供了自动微分和高级自动微分教程。
总结 结合jax.grad()
使用jax.checkpoint()
装饰器(别名为jax.remat()
),可以控制在正向传播中保存哪些中间值,以及在反向传播中重新计算哪些中间值,从而在内存和浮点运算之间进行权衡。
如果您不使用 jax.checkpoint()
,则 jax.grad(f)(x)
的正向传播会存储雅可比系数和其他中间值,以便在反向传播中使用。这些保存的值称为残差。
注意: 不要错过实用注意事项中关于jax.checkpoint()
如何与jax.jit()
交互的讨论。
import jax
import jax.numpy as jnp
def g(W, x):
y = jnp.dot(W, x)
return jnp.sin(y)
def f(W1, W2, W3, x):
x = g(W1, x)
x = g(W2, x)
x = g(W3, x)
return x
W1 = jnp.ones((5, 4))
W2 = jnp.ones((6, 5))
W3 = jnp.ones((7, 6))
x = jnp.ones(4)
# Inspect the 'residual' values to be saved on the forward pass
# if you were to evaluate `jax.grad(f)(W1, W2, W3, x)`
from jax.ad_checkpoint import print_saved_residuals
jax.ad_checkpoint.print_saved_residuals(f, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of sin from /tmp/ipykernel_1778/1801108376.py:6:9 (g)
f32[5] output of cos from /tmp/ipykernel_1778/1801108376.py:6:9 (g)
f32[6] output of sin from /tmp/ipykernel_1778/1801108376.py:6:9 (g)
f32[6] output of cos from /tmp/ipykernel_1778/1801108376.py:6:9 (g)
f32[7] output of cos from /tmp/ipykernel_1778/1801108376.py:6:9 (g)
通过将 jax.checkpoint()
应用于子函数,作为装饰器或在特定应用点,您可以强制 JAX 不保存该子函数的任何残差。相反,只保存 jax.checkpoint()
装饰函数的输入,并且在反向传播中消耗的任何残差都会根据需要从这些输入中重新计算。
def f2(W1, W2, W3, x):
x = jax.checkpoint(g)(W1, x)
x = jax.checkpoint(g)(W2, x)
x = jax.checkpoint(g)(W3, x)
return x
jax.ad_checkpoint.print_saved_residuals(f2, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of sin from /tmp/ipykernel_1778/1801108376.py:6:9 (g)
f32[6] output of sin from /tmp/ipykernel_1778/1801108376.py:6:9 (g)
在这里,两个 sin
应用的值被保存,因为它们是后续应用于 jax.checkpoint()
装饰的 g
函数的参数,并且 jax.checkpoint()
装饰函数的输入可以被保存。但是 cos
应用的值都没有被保存。
为了控制哪些值可以被保存,而无需编辑要微分的函数的定义,您可以使用一个重计算(rematerialization)策略。这里有一个示例,它只保存没有批量维度的 dot
运算的结果(因为它们通常受 FLOP 限制,因此值得保存而不是重新计算)
f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
jax.ad_checkpoint.print_saved_residuals(f3, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of reduce_precision from /tmp/ipykernel_1778/1801108376.py:5:6 (g)
f32[6] output of reduce_precision from /tmp/ipykernel_1778/1801108376.py:5:6 (g)
f32[7] output of reduce_precision from /tmp/ipykernel_1778/1801108376.py:5:6 (g)
您还可以使用策略来引用您使用 jax.ad_checkpoint.checkpoint_name()
命名的中间值
from jax.ad_checkpoint import checkpoint_name
def f4(W1, W2, W3, x):
x = checkpoint_name(g(W1, x), name='a')
x = checkpoint_name(g(W2, x), name='b')
x = checkpoint_name(g(W3, x), name='c')
return x
f4 = jax.checkpoint(f4, policy=jax.checkpoint_policies.save_only_these_names('a'))
jax.ad_checkpoint.print_saved_residuals(f4, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of reduce_precision from /tmp/ipykernel_1778/2296542172.py:4:6 (f4)
在尝试这些玩具示例时,您可以使用此笔记本中定义的自定义 print_fwd_bwd
实用工具,更详细地了解正在发生的事情
from jax.tree_util import tree_flatten, tree_unflatten
from rich.console import Console
from rich.table import Table
import rich.text
def print_fwd_bwd(f, *args, **kwargs) -> None:
args, in_tree = tree_flatten((args, kwargs))
def f_(*args):
args, kwargs = tree_unflatten(in_tree, args)
return f(*args, **kwargs)
fwd = jax.make_jaxpr(lambda *args: jax.vjp(f_, *args))(*args).jaxpr
y, f_vjp = jax.vjp(f_, *args)
res, in_tree = tree_flatten(f_vjp)
def g_(*args):
*res, y = args
f_vjp = tree_unflatten(in_tree, res)
return f_vjp(y)
bwd = jax.make_jaxpr(g_)(*res, y).jaxpr
table = Table(show_header=False, show_lines=True, padding=(1, 2, 0, 2), box=None)
table.add_row("[bold green]forward computation:",
"[bold green]backward computation:")
table.add_row(rich.text.Text.from_ansi(str(fwd)),
rich.text.Text.from_ansi(str(bwd)))
console = Console(width=240, force_jupyter=True)
console.print(table)
def _renderable_repr(self):
return self.html
rich.jupyter.JupyterRenderable._repr_html_ = _renderable_repr
# Without using `jax.checkpoint`:
print_fwd_bwd(f, W1, W2, W3, x)
forward computation: backward computation: { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let { lambda ; a:f32[4] b:f32[5,4] c:f32[5] d:f32[5] e:f32[6,5] f:f32[6] g:f32[6] h:f32[7,6] e:f32[5] = dot_general[ i:f32[7] j:f32[7]. let dimension_numbers=(([1], [0]), ([], [])) k:f32[7] = mul j i preferred_element_type=float32 l:f32[6] = dot_general[ ] a d dimension_numbers=(([0], [0]), ([], [])) f:f32[5] = sin e preferred_element_type=float32 g:f32[5] = cos e ] k h h:f32[6] = dot_general[ m:f32[7,6] = dot_general[ dimension_numbers=(([1], [0]), ([], [])) dimension_numbers=(([], []), ([], [])) preferred_element_type=float32 preferred_element_type=float32 ] b f ] k g i:f32[6] = sin h n:f32[6] = mul l f j:f32[6] = cos h o:f32[5] = dot_general[ k:f32[7] = dot_general[ dimension_numbers=(([0], [0]), ([], [])) dimension_numbers=(([1], [0]), ([], [])) preferred_element_type=float32 preferred_element_type=float32 ] n e ] c i p:f32[6,5] = dot_general[ l:f32[7] = sin k dimension_numbers=(([], []), ([], [])) m:f32[7] = cos k preferred_element_type=float32 in (l, d, a, g, f, b, j, i, c, m) } ] n d q:f32[5] = mul o c r:f32[4] = dot_general[ dimension_numbers=(([0], [0]), ([], [])) preferred_element_type=float32 ] q b s:f32[5,4] = dot_general[ dimension_numbers=(([], []), ([], [])) preferred_element_type=float32 ] q a in (s, p, m, r) }
# Using `jax.checkpoint` with policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable:
print_fwd_bwd(f3, W1, W2, W3, x)
forward computation: backward computation: { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let { lambda ; a:f32[5] b:f32[6] c:f32[7] d:f32[5,4] e:f32[6,5] f:f32[7,6] g:f32[4] h:f32[7]. let e:f32[5] = dot_general[ i:f32[5,4] j:f32[6,5] k:f32[7,6] l:f32[4] = remat2[ dimension_numbers=(([1], [0]), ([], [])) differentiated=True preferred_element_type=float32 jaxpr={ lambda ; m:f32[5] n:f32[6] o:f32[7] p:f32[5,4] q:f32[6,5] r:f32[7,6] ] a d s:f32[4] t:f32[7]. let f:f32[5] = reduce_precision[exponent_bits=8 mantissa_bits=23] e u:f32[5] = sin m g:f32[5] = sin f v:f32[5] = cos m h:f32[6] = dot_general[ w:f32[6] = sin n dimension_numbers=(([1], [0]), ([], [])) x:f32[6] = cos n preferred_element_type=float32 y:f32[7] = cos o ] b g z:f32[7] = mul t y i:f32[6] = reduce_precision[exponent_bits=8 mantissa_bits=23] h ba:f32[6] = dot_general[ j:f32[6] = sin i dimension_numbers=(([0], [0]), ([], [])) k:f32[7] = dot_general[ preferred_element_type=float32 dimension_numbers=(([1], [0]), ([], [])) ] z r preferred_element_type=float32 bb:f32[7,6] = dot_general[ ] c j dimension_numbers=(([], []), ([], [])) l:f32[7] = reduce_precision[exponent_bits=8 mantissa_bits=23] k preferred_element_type=float32 m:f32[7] = sin l ] z w in (m, f, i, l, a, b, c, d) } bc:f32[6] = mul ba x bd:f32[5] = dot_general[ dimension_numbers=(([0], [0]), ([], [])) preferred_element_type=float32 ] bc q be:f32[6,5] = dot_general[ dimension_numbers=(([], []), ([], [])) preferred_element_type=float32 ] bc u bf:f32[5] = mul bd v bg:f32[4] = dot_general[ dimension_numbers=(([0], [0]), ([], [])) preferred_element_type=float32 ] bf p bh:f32[5,4] = dot_general[ dimension_numbers=(([], []), ([], [])) preferred_element_type=float32 ] bf s in (bh, be, bb, bg) } policy=<function dots_with_no_batch_dims_saveable at 0x790d9ca26f20> prevent_cse=True ] a b c d e f g h in (i, j, k, l) }
让我们逐步思考#
注意: 在此处继续之前,查阅高级自动微分教程可能会有帮助。
jax.checkpoint
基本原理#
在jax.linearize()
和jax.vjp()
中,某些值的计算方式和时间上存在灵活性。不同的选择可以在内存使用和浮点运算之间进行权衡。JAX 通过 jax.checkpoint()
提供对这些选择的控制。
一个这样的选择是,是在正向传播中(一旦输入可用)执行雅可比系数计算,还是在反向传播中(就在需要系数之前)执行。考虑 sin_vjp
的例子
def sin_vjp(x):
y = jnp.sin(x)
cos_x = jnp.cos(x)
return y, lambda y_bar: cos_x * y_bar
另一种有效的实现方式是在反向传播中而不是在正向传播中计算 jnp.cos(x)
的值
def sin_vjp2(x):
y = jnp.sin(x)
return y, lambda y_bar: jnp.cos(x) * y_bar
对于这个特定的函数,两个版本使用的内存量是相同的,尽管您减少了原始计算(正向传播)的浮点运算,并增加了余切计算(反向传播)的浮点运算。
当涉及到函数组合时,还有另一种选择。回想一下两个函数组合的 VJP 规则
def f(x):
y = g(x)
z = h(y)
return z
def f_vjp(x):
y, g_vjp = jax.vjp(g, x)
z, h_vjp = jax.vjp(h, y)
def f_bwd(z_bar):
y_bar, = h_vjp(z_bar)
x_bar, = g_vjp(y_bar)
return x_bar
return z, f_bwd
另一种选择是
def f_vjp_checkpoint(x):
y = g(x)
z, h_vjp = jax.vjp(h, y)
def f_bwd2(z_bar):
y_bar, = h_vjp(z_bar)
_, g_vjp = jax.vjp(g, x)
x_bar, = g_vjp(y_bar)
return x_bar
return z, f_bwd2
换句话说,这种替代实现方式在正向传播中不计算 g_vjp
或其闭包中的残差值。相反,它只在反向传播 f_bwd2
中计算它们。这意味着 f_vjp_checkpoint
需要更少的内存:如果 g
和 h
各自需要相似的残差内存量,且都远大于 x
,那么 f_vjp_checkpoint(x)
产生的函数所需的内存是 f_vjp(x)
的一半!
您付出的代价是冗余工作:在 f_bwd2
中,您必须重新评估 g(x)
作为 jax.vjp(g, x)
的一部分,仅仅是为了丢弃其值(在行 _, g_vjp = jax.vjp(g, x)
中的下划线变量中)。
您可以在自动微分中获得这种 VJP 行为 — 无需直接编写 VJP 函数 — 而是通过在原始函数 f
的替代定义中使用 jax.checkpoint()
。
def f_checkpoint(x):
y = jax.checkpoint(g)(x)
z = h(y)
return z
换句话说,您将 jax.checkpoint()
应用于 g
— f
的第一阶段 — 而不是 f
本身。这样,当您评估 jax.grad(f_checkpoint)(x)
时,您将得到一个类似于以下的计算:
运行
g
的正向传播,丢弃残差值。运行
h
的正向传播,保存残差。运行
h
的反向传播,消耗步骤 2 中的残差。重新运行
g
的正向传播,保存残差。运行
g
的反向传播,消耗步骤 4 中的残差。
也就是说,通过评估 jax.grad(f_checkpoint)(x)
,我们将得到与以下相同的计算:
def f_checkpoint_grad(x):
y = g(x) # step 1
_, h_vjp = jax.vjp(h)(y) # step 2
y_bar, = h_vjp(1.0) # step 3
_, g_vjp = jax.vjp(g, x) # step 4
x_bar, = g_vjp(y_bar) # step 5
return x_bar
通常,jax.checkpoint(foo)
是一个新函数,它与 foo
具有相同的输入-输出行为,但它在自动微分下表现不同,特别是在 jax.linearize()
和 jax.vjp()
(以及它们的包装器,例如 jax.grad()
)下,但不包括 jax.jvp()
。当进行微分时,只有 jax.checkpoint()
微分函数的输入在正向传播中被存储。在反向传播中,残差(来自 foo
的中间值及其反向传播所需的雅可比系数)会重新计算。
请注意,如果 f = lambda x: h(g(x))
是您想要微分的函数(换句话说,如果您想应用 jax.grad(f)
),那么通过将 jax.checkpoint()
应用于 f
本身,您不会获得任何内存节省。那是因为评估 jax.grad(jax.checkpoint(f))(x)
将导致如下计算:
运行正向传播,丢弃所有残差。
立即重新运行正向传播,保存残差。
运行反向传播,消耗步骤 2 中的残差。
在代码中,您会看到类似以下的代码:
def f_grad_bad(x):
_ = f(x) # step 1
_, f_vjp = jax.vjp(f, x) # step 2
x_bar, = f_vjp(1.0) # step 3
return x_bar
通过将 jax.checkpoint()
应用于 h
(f
的第二阶段),您也不会获得任何内存节省。那是因为评估 jax.grad(lambda x: jax.checkpoint(h)(g(x)))
将导致如下计算:
运行
g
的正向传播,保存残差。运行
h
的正向传播,丢弃残差。立即重新运行
h
的正向传播,保存残差。运行
h
的反向传播,消耗步骤 3 中的残差。运行
g
的反向传播,消耗步骤 1 中的残差。
在代码中,您会看到类似以下的代码:
def f_grad_bad2(x):
y, g_vjp = jax.vjp(g, x) # step 1
z = h(y) # step 2
_, h_vjp = jax.vjp(h, y) # step 3
y_bar, = h_vjp(1.0) # step 3
x_bar, = g_vjp(y_bar) # step 5
return x_bar
更一般地,如果您有一个函数链式组合,例如 f = lambda x: f3(f2(f1(x)))
,并且对评估 jax.grad(f)
感兴趣,您可以说:
不应该将
jax.checkpoint()
应用于整个函数f
,因为那不会节省任何内存(并且会执行浪费的重新计算)。不应该将
jax.checkpoint()
应用于最后一个子函数f3
,因为那不会节省任何内存(并且会执行浪费的重新计算)。可以将
jax.checkpoint()
应用于f1
、f2
,或它们的组合lambda x: f2(f1(x))
,因为其中任何一个都可能节省内存并表达不同的内存/重新计算权衡。
可保存项的自定义策略#
如上所示,使用 jax.checkpoint()
会从一个极端转向另一个极端:
没有
jax.checkpoint()
,JAX 的自动微分倾向于在正向传播中计算所有可能的值并将其存储以供反向传播使用。使用
jax.checkpoint()
装饰器,您则在正向传播中尽可能少地计算,并在反向传播中根据需要重新计算值。
为了在这两个极端之间操作,即保存一些东西而不保存另一些,您可以仔细地将 jax.checkpoint()
装饰器放置在子函数上。但这需要编辑要微分的函数,例如模型代码,这可能不方便。试验不同的变体也可能很困难。
因此,另一种选择是使用 jax.checkpoint()
的 policy
参数。策略是一个可调用对象(即一个函数),它接受一阶原始应用类型级别的规范作为输入,并返回一个布尔值,指示相应的输出值是否允许作为残差保存(或者是否需要根据需要通过(共)切计算重新计算)。为了编写健壮的代码,策略应该从 jax.checkpoint_policies()
的属性中选择,例如 jax.checkpoint_policies.dots_with_no_batch_dims_saveable()
,因为编写自定义策略可调用对象的 API 被认为是内部的。
例如,考虑这个要微分的函数:
def loss(params, x, y):
return jnp.sum((predict(params, x) - y)**2)
def predict(params, x):
*Ws, Wlast = params
for W in Ws:
x = layer(W, x)
x = jnp.dot(Wlast, x)
return x
def layer(W, x):
return jnp.sin(jnp.dot(W, x))
W1 = W2 = W3 = jnp.ones((4, 4))
params = [W1, W2, W3]
x = jnp.ones(4)
y = jnp.ones(4)
print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] output of sin from /tmp/ipykernel_1778/4230705069.py:12:9 (layer)
f32[4] output of cos from /tmp/ipykernel_1778/4230705069.py:12:9 (layer)
f32[4] output of sin from /tmp/ipykernel_1778/4230705069.py:12:9 (layer)
f32[4] output of cos from /tmp/ipykernel_1778/4230705069.py:12:9 (layer)
f32[4] output of mul from /tmp/ipykernel_1778/4230705069.py:2:17 (loss)
与其在正向传播中保存这么多值,也许您只想保存没有批量维度的矩阵乘法的结果(因为它们可能受浮点运算限制而非内存限制)。您可以使用策略 jax.checkpoint_policies.dots_with_no_batch_dims_saveable()
来实现这一点。
loss_checkpoint = jax.checkpoint(loss, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
print_saved_residuals(loss_checkpoint, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] from the argument y
f32[4] output of reduce_precision from /tmp/ipykernel_1778/4230705069.py:12:17 (layer)
f32[4] output of reduce_precision from /tmp/ipykernel_1778/4230705069.py:12:17 (layer)
f32[4] output of reduce_precision from /tmp/ipykernel_1778/4230705069.py:8:6 (predict)
另请注意,通过提供策略,您无需编辑定义 loss
、predict
或 layer
的代码。如果您想在调用代码(例如训练脚本)中试验策略而无需更改库代码(例如神经网络库),这尤其方便。
一些策略可以引用使用 jax.ad_checkpoint.checkpoint_name()
命名的值
from jax.ad_checkpoint import checkpoint_name
def predict(params, x):
*Ws, Wlast = params
for i, W in enumerate(Ws):
x = layer(W, x)
x = checkpoint_name(x, name=f'layer{i}_output')
x = jnp.dot(Wlast, x)
return x
就其本身而言,jax.ad_checkpoint.checkpoint_name()
只是一个恒等函数。但由于某些策略函数知道如何查找它们,您可以使用这些名称来控制 jax.ad_checkpoint.checkpoint_name()
输出的某些值是否被认为是可保存的。
print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] output of cos from /tmp/ipykernel_1778/4230705069.py:12:9 (layer)
f32[4] named 'layer0_output' from /tmp/ipykernel_1778/178264713.py:7:8 (predict)
f32[4] output of cos from /tmp/ipykernel_1778/4230705069.py:12:9 (layer)
f32[4] named 'layer1_output' from /tmp/ipykernel_1778/178264713.py:7:8 (predict)
f32[4] output of mul from /tmp/ipykernel_1778/4230705069.py:2:17 (loss)
loss_checkpoint2 = jax.checkpoint(loss, policy=jax.checkpoint_policies.save_any_names_but_these('layer1_output'))
print_saved_residuals(loss_checkpoint2, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] from the argument y
另一个引用名称的策略是 jax.checkpoint_policies.save_only_these_names
。
卸载的自定义策略#
在检查点设置时,您可以考虑将数据卸载到 CPU 内存而不是重新计算,以节省加速器内存。jax.checkpoint_policies.offload_dot_with_no_batch_dims
可以将没有批量维度的矩阵乘法结果卸载到 CPU。
from jax.ad_checkpoint import checkpoint
def checkpoint_offload_dot_with_no_batch_dims(self):
policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(
"device", "pinned_host")
@functools.partial(checkpoint, policy=policy)
def f(x):
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
x = jnp.sin(x)
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
x = jnp.sin(x)
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
x = jnp.sin(x)
x = jnp.sum(x)
return x
JAX 的检查点策略之一允许将指定的检查点名称卸载到 CPU。此策略通过 jax.checkpoint_policies.save_and_offload_only_these_names
实现,它有四个参数:names_which_can_be_saved
、names_which_can_be_offloaded
、卸载源和目标。列在 names_which_can_be_saved
中的名称保留在设备上,列在 names_which_can_be_offloaded
中的名称移动到 CPU 内存,其他没有名称的名称或操作则重新计算。例如,如果我们有检查点名称 y
、z
和 w
,则 y
可以保存在设备上,z
可以卸载到 CPU 内存,而 w
可以重新计算。
from jax.ad_checkpoint import checkpoint, checkpoint_name
from jax._src import test_util as jtu
def checkpoint_names_saved_offloaded_recomputed(self):
mesh = jtu.create_mesh((2,), ("x",))
shape = (256, 128)
np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
s = NamedSharding(mesh, P("x"))
inp = jax.device_put(np_inp, s)
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z"],
offload_src='device', offload_dst='pinned_host')
@functools.partial(checkpoint, policy=policy)
def f(x):
def g(ys, _):
y, _ = ys
y = checkpoint_name(jnp.sin(y), "y")
z = checkpoint_name(jnp.sin(y), "z")
z = z.T
w = checkpoint_name(jnp.sin(z), "w")
return (w.T, jnp.sum(w)), None
_, scan_out = jax.lax.scan(g, (x, np.array(1, dtype=np.float32)), [np_inp])[0]
return scan_out
该代码定义了一个函数 f
,它使用自定义策略应用检查点。该策略决定了哪些计算可以在执行期间保存或卸载。在 f
内部,有一个嵌套函数 g
执行核心计算。jax.lax.scan
函数用于在输入数据上重复应用 g
。
策略列表#
策略包括:
everything_saveable
(默认策略,就好像根本没有使用jax.checkpoint
一样)nothing_saveable
(即重新计算所有内容,就好像根本没有使用自定义策略一样)dots_saveable
或其别名checkpoint_dots
dots_with_no_batch_dims_saveable
或其别名checkpoint_dots_with_no_batch_dims
save_anything_but_these_names
(保存除checkpoint_name
输出的给定名称之外的任何值)save_any_names_but_these
(只保存命名值,即checkpoint_name
的任何输出,但排除给定名称的值)save_only_these_names
(只保存命名值,且仅限于给定名称中的值)offload_dot_with_no_batch_dims
与dots_with_no_batch_dims_saveable
相同,但卸载到 CPU 内存而不是重新计算。save_and_offload_only_these_names
与save_only_these_names
相同,但卸载到 CPU 内存而不是重新计算。save_from_both_policies(policy_1, policy_2)
(类似于逻辑or
,因此如果残差根据policy_1
或policy_2
可保存,则该残差可保存)
策略只指示什么是可保存的;值只有在反向传播实际需要时才会被保存。
高级:递归 jax.checkpoint
#
通过正确应用 jax.checkpoint()
,可以在内存使用和(重新)计算之间表达出许多权衡。一个令人惊讶的例子是递归检查点,您将 jax.checkpoint()
应用于一个函数,该函数本身以一种方式调用 jax.checkpoint()
装饰的函数,使得 \(D\) 个函数的链式组合的内存使用量以 \(\mathcal{O}(\log_2 D)\) 而非 \(\mathcal{O}(D)\) 的方式扩展。
作为一个玩具示例,考虑多个 jax.numpy.sin()
函数的链式组合:
def chain_compose(funs):
def f(x):
for fun in funs:
x = fun(x)
return x
return f
f = chain_compose([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] output of cos from /tmp/ipykernel_1778/410288286.py:4:10 (chain_compose.<locals>.f)
f32[] output of cos from /tmp/ipykernel_1778/410288286.py:4:10 (chain_compose.<locals>.f)
f32[] output of cos from /tmp/ipykernel_1778/410288286.py:4:10 (chain_compose.<locals>.f)
f32[] output of cos from /tmp/ipykernel_1778/410288286.py:4:10 (chain_compose.<locals>.f)
f32[] output of cos from /tmp/ipykernel_1778/410288286.py:4:10 (chain_compose.<locals>.f)
f32[] output of cos from /tmp/ipykernel_1778/410288286.py:4:10 (chain_compose.<locals>.f)
f32[] output of cos from /tmp/ipykernel_1778/410288286.py:4:10 (chain_compose.<locals>.f)
f32[] output of cos from /tmp/ipykernel_1778/410288286.py:4:10 (chain_compose.<locals>.f)
通常,存储的残差数量与链的长度呈线性关系:
f = chain_compose([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] output of cos from /tmp/ipykernel_1778/410288286.py:4:10 (chain_compose.<locals>.f)
f32[] output of cos from /tmp/ipykernel_1778/410288286.py:4:10 (chain_compose.<locals>.f)
f32[] output of cos from /tmp/ipykernel_1778/410288286.py:4:10 (chain_compose.<locals>.f)
f32[] output of cos from /tmp/ipykernel_1778/410288286.py:4:10 (chain_compose.<locals>.f)
f32[] output of cos from /tmp/ipykernel_1778/410288286.py:4:10 (chain_compose.<locals>.f)
f32[] output of cos from /tmp/ipykernel_1778/410288286.py:4:10 (chain_compose.<locals>.f)
f32[] output of cos from /tmp/ipykernel_1778/410288286.py:4:10 (chain_compose.<locals>.f)
f32[] output of cos from /tmp/ipykernel_1778/410288286.py:4:10 (chain_compose.<locals>.f)
f32[] output of cos from /tmp/ipykernel_1778/410288286.py:4:10 (chain_compose.<locals>.f)
f32[] output of cos from /tmp/ipykernel_1778/410288286.py:4:10 (chain_compose.<locals>.f)
f32[] output of cos from /tmp/ipykernel_1778/410288286.py:4:10 (chain_compose.<locals>.f)
f32[] output of cos from /tmp/ipykernel_1778/410288286.py:4:10 (chain_compose.<locals>.f)
f32[] output of cos from /tmp/ipykernel_1778/410288286.py:4:10 (chain_compose.<locals>.f)
f32[] output of cos from /tmp/ipykernel_1778/410288286.py:4:10 (chain_compose.<locals>.f)
f32[] output of cos from /tmp/ipykernel_1778/410288286.py:4:10 (chain_compose.<locals>.f)
f32[] output of cos from /tmp/ipykernel_1778/410288286.py:4:10 (chain_compose.<locals>.f)
但您可以通过递归应用 jax.checkpoint()
来改善扩展性:
def recursive_checkpoint(funs):
if len(funs) == 1:
return funs[0]
elif len(funs) == 2:
f1, f2 = funs
return lambda x: f1(f2(x))
else:
f1 = recursive_checkpoint(funs[:len(funs)//2])
f2 = recursive_checkpoint(funs[len(funs)//2:])
return lambda x: f1(jax.checkpoint(f2)(x))
f = recursive_checkpoint([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] from the argument x
f32[] output of sin from /tmp/ipykernel_1778/1943107544.py:6:21 (recursive_checkpoint.<locals>.<lambda>)
f32[] output of cos from /tmp/ipykernel_1778/1943107544.py:6:24 (recursive_checkpoint.<locals>.<lambda>)
f32[] output of cos from /tmp/ipykernel_1778/1943107544.py:6:21 (recursive_checkpoint.<locals>.<lambda>)
f = recursive_checkpoint([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] from the argument x
f32[] output of sin from /tmp/ipykernel_1778/1943107544.py:6:21 (recursive_checkpoint.<locals>.<lambda>)
f32[] output of sin from /tmp/ipykernel_1778/1943107544.py:6:21 (recursive_checkpoint.<locals>.<lambda>)
f32[] output of cos from /tmp/ipykernel_1778/1943107544.py:6:24 (recursive_checkpoint.<locals>.<lambda>)
f32[] output of cos from /tmp/ipykernel_1778/1943107544.py:6:21 (recursive_checkpoint.<locals>.<lambda>)
这里的代价,一如既往,是重新计算:特别是,您最终会执行 \(\mathcal{O}(\log_2 D)\) 倍的浮点运算。
f = chain_compose([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
forward computation: backward computation: { lambda ; a:f32[]. let { lambda ; a:f32[] b:f32[] c:f32[] d:f32[] e:f32[] f:f32[] g:f32[] h:f32[] i:f32[]. let b:f32[] = sin a j:f32[] = mul i h c:f32[] = cos a k:f32[] = mul j g d:f32[] = sin b l:f32[] = mul k f e:f32[] = cos b m:f32[] = mul l e f:f32[] = sin d n:f32[] = mul m d g:f32[] = cos d o:f32[] = mul n c h:f32[] = sin f p:f32[] = mul o b i:f32[] = cos f q:f32[] = mul p a j:f32[] = sin h in (q,) } k:f32[] = cos h l:f32[] = sin j m:f32[] = cos j n:f32[] = sin l o:f32[] = cos l p:f32[] = sin n q:f32[] = cos n in (p, c, e, g, i, k, m, o, q) }
f = recursive_checkpoint([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
forward computation: backward computation: { lambda ; a:f32[]. let { lambda ; a:f32[] b:f32[] c:f32[] d:f32[]. let b:f32[] = remat2[ e:f32[] = mul d c differentiated=False f:f32[] = mul e b jaxpr={ lambda ; c:f32[]. let d:f32[] = sin c; e:f32[] = sin d in (e,) } g:f32[] = remat2[ policy=None differentiated=True prevent_cse=True jaxpr={ lambda ; h:f32[] i:f32[]. let ] a j:f32[] = sin h f:f32[] = sin b k:f32[] = cos h g:f32[] = sin f l:f32[] = cos j h:f32[] = sin g m:f32[] = mul i l i:f32[] = sin h n:f32[] = mul m k j:f32[] = sin i in (n,) } k:f32[] = cos i policy=None l:f32[] = sin j prevent_cse=True m:f32[] = cos j ] a f in (l, a, g, k, m) } o:f32[] = remat2[ differentiated=True jaxpr={ lambda ; p:f32[] q:f32[]. let r:f32[] = sin p s:f32[] = sin r t:f32[] = sin s u:f32[] = cos s v:f32[] = cos t w:f32[] = mul q v x:f32[] = mul w u y:f32[] = remat2[ differentiated=True jaxpr={ lambda ; z:f32[] ba:f32[]. let bb:f32[] = sin z bc:f32[] = cos z bd:f32[] = cos bb be:f32[] = mul ba bd bf:f32[] = mul be bc in (bf,) } policy=None prevent_cse=True ] p x in (y,) } policy=None prevent_cse=True ] 3.0:f32[] g in (o,) }
实用注意事项#
当微分函数被分阶段提交到 XLA 进行编译时 — 例如通过将 jax.jit()
应用于包含 jax.grad()
调用的函数时 — XLA 将自动优化计算,包括何时计算或重新具体化值的决策。因此,在 jax.jit()
下的微分函数通常不需要 jax.checkpoint()
。XLA 会为您优化。
一个例外是使用分阶段控制流时,例如 jax.lax.scan()
。跨多个控制流原语的自动编译器优化(例如,跨正向传播 scan
和相应的反向传播 scan
)通常不够彻底。因此,通常最好在传递给 jax.lax.scan()
的主体函数上使用 jax.checkpoint()
。
例如,大型Transformer 模型中一个常见模式是将架构表示为在层上进行jax.lax.scan()
,以减少编译时间。也就是说,以一个简单的全连接网络为例,而不是像这样编写代码:
LayerParam = tuple[jnp.ndarray, jnp.ndarray] # Weights-bias pair for a layer.
ParamsList = list[LayerParam]
def net(params: ParamsList, x: jnp.ndarray):
for W, b in params:
x = jnp.maximum(jnp.dot(x, W) + b, 0.)
return x
相反,使用 jax.lax.scan()
迭代层应用:
params = [(jnp.array([[0.5, 0.5], [1., 1.]]), jnp.array([0.5, 0.5])),
(jnp.array([[0.5, 0.5], [1., 1.]]), jnp.array([0.5, 0.5]))]
all_weights = jnp.stack([W for W, _ in params])
all_biases = jnp.stack([b for _, b in params])
def layer(x, W_b_pair):
W, b = W_b_pair
out = jnp.maximum(jnp.dot(x, W) + b, 0.)
return out, None
def net(all_weights, all_biases, x):
x, _ = jax.lax.scan(layer, x, (all_weights, all_biases))
return x
这种分层扫描的版本减少了编译时间,但通过阻止某些编译器优化,可能导致梯度计算效率低下。为了缓解这个问题,您可以在扫描函数上使用 jax.checkpoint()
:
from functools import partial
@partial(jax.checkpoint,
policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def layer(x, W_b_pair):
W, b = W_b_pair
out = jnp.maximum(jnp.dot(x, W) + b, 0.)
return out, None
通过以这种方式使用 jax.checkpoint()
,您可以手动控制 JAX 自动微分在正向和反向传播之间保存哪些值,因此不需要依赖 XLA 优化来替您选择。