高效转置复制诱导集合运算#
mattjj@, dougalm@
2023 年 8 月
动机#
我们在自动转置包含某些集合运算的 shmap
时存在效率问题。问题出现在 psum
和 all_gather
中,特别是当集合运算的输出作为未映射的输出返回给调用者时。这并非边缘情况:例如,当将 grad
应用于基于 shmap
的批数据并行神经网络损失函数时,就会出现这种情况,该函数使用 psum
来计算总损失。
我们已经知道这个问题一段时间了。 pmap
也存在类似的问题,尽管通过将 grad
保留在 pmap
内部而不是外部来解决。未完成的 avals-with-names 工作的主要目标是解决此转置效率问题的某个版本。本文档借鉴了这些想法,同时扩展和修订了它们,以处理更多情况并更容易实现。实际上,这里提出的解决方案仅影响 shmap
的实现。系统的其余部分无需更改(目前)。
本文档的主要目的是定义此转置效率问题并提出一个易于实现的解决方案。
本文档不涉及
数组上的逻辑轴名称(这里的轴名称与
shmap
和 OGpmap
中的轴名称完全相同);更改自动微分语义(所有的数字和(非)错误都保持不变,我们只是使事情更有效率);
允许用户代码反映任何新信息,或者真正影响用户代码。
问题:psum
或 all_gather
的高效转置取决于余切是否在设备之间是不变的#
考虑这个半真实的示例,旨在类似于复制参数的批数据并行损失函数
devices = jax.devices() # 8 devices
@partial(shmap, mesh=Mesh(devices, ('batch',)),
in_specs=(P(None, None), P('batch', None)),
out_specs=P())
def loss(params, batch):
inputs, targets = batch
predictions = predict(params, inputs)
local_loss = jnp.mean(jnp.sum(predictions - targets, -1))
global_loss = lax.pmean(local_loss, 'batch'))
return global_loss
请注意 out_specs=P()
,这表示未映射的输出。如果您不熟悉未映射输出的概念,请参阅本文档底部的附录。
loss
示例中的大多数细节并不重要。对于我们的目的而言,重要的是我们在最后应用 psum
(或者更确切地说是 pmean = lambda x, name: psum(x, name) / psum(1, name)
)。因此,精简版本如下所示
# Example 1: shmap involving psum and unmapped output with inefficient transpose
f1 = shmap(lambda x: psum(g(x), 'i'),
in_specs=P('i'), out_specs=P())
我们甚至通过抑制 mesh
参数简化了符号。在接下来的示例中,可以从上下文中推断出来。
转置是什么样的?写入 t
表示函数转置,我们可以通过应用下面的 ¿f1_transpose?
函数来有效地评估任何 ybar
的 t(f1)(ybar)
# An efficient "transpose" of Example 1 (but don't transpose this again!)
¿f1_transpose? = shmap(t(g), in_specs=P(), out_specs=P('i'))
但这并不是我们目前作为 t(f1) 获取的转置。
相反,当前的转置方法大致是我们切换 in_specs
和 out_specs
,对未映射的输出进行一些除法重新缩放,并转置主体。由于 psum
是其自身的转置(作为 all-reduce sum),因此我们最终生成了此转置
# The transpose we currently get for Example 1 (which is fine to transpose again)
t(f1) = shmap(lambda ybar: t(g)(psum(ybar / 8, 'i')),
in_specs=P(), out_specs=P('i'))
此转置得到了正确的数字,但很浪费。我们从转置的 in_specs=P()
静态地知道,对于每个函数实例,ybar
具有相同的值,即,对于沿名为 i
的网格轴的设备,其值是设备不变的,但我们仍然对其应用 psum
!这使用了昂贵的通信,仅仅是将每个设备上的值乘以 8。(这里 8 指的是轴 i 的大小。除以 8 来自原始函数的 out_specs=P()
;它和微不足道的 psum
基本上相互抵消。)
我们哪里做错了?我们没有利用与 f1
的未映射输出相对应的余切 ybar
保证是设备不变的这一事实;相反,我们防御性地 psum
它们,就好像它们不是因为 psum
的转置无法确定它所拥有的本地信息。有时 psum
是必要的,就像转置 f2
相对于其第一个参数一样
# Example 2: shmap involving psum and *mapped* output with efficient transpose
f2 = shmap(lambda x, y: psum(g(x), 'i') * y,
in_specs=(P('i'), P('i')), out_specs=P('i'))
# The transpose we currently get for Example 2 is efficient
t(f2, 0) = shmap(lambda y, zbar: t(g)(psum(zbar * y, 'i')),
in_specs=(P('i'), P('i')), out_specs=P('i'))
直观地说,如果我们的转置机制可以区分示例 1 和示例 2,那么我们可以通过尽可能避免 psum 和除法来做得更好。
低效的示例甚至可以更小。考虑转置这个被诅咒的恒等函数
# Example 3: cursed identity
cursed_identity = shmap(lambda x: x, P(), P())
# Currently we get these inefficient transposes
t(cursed_identity) = shmap(lambda x: psum(x / 8, 'i'), P(), P())
t(t(cursed_identity)) = shmap(lambda x: psum(psum(x / 8 / 8, 'i'), 'i')), P(), P())
...
我们转置得越多,它就变得越大。多么尴尬!
并且 psum
不是唯一的罪魁祸首。 all_gather
也存在类似的情况
# Example 4: all_gather to an unmapped output
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())
# Currently we get this inefficient transpose
t(f4) = shmap(lambda ybar: psum_scatter(ybar / 8, 'i'), P(), P('i'))
此程序有点人为。为什么进行 all_gather
并将结果馈送到未映射的输出中,而不是跳过主体中的 all_gather
并仅使用 out_specs=P('i')
来收集结果?但即使它是人为捏造的,此示例仍然展示了一个不必要地执行通信的转置(我们可以只执行非通信切片),类似于 psum
的示例 1。
同样类似于 psum
示例,在某些情况下,防御性 psum_scatter
是必要的
# Example 5: all_gather to a mapped output
f5 = shmap(lambda x, y: all_gather(x, 'i') * y,
in_specs=(P('i'), P('i')), out_specs=P('i'))
# Currently we get this efficient transpose
t(f5, 0) = shmap(lambda y, zbar: psum_scatter(zbar * y, 'i'),
in_specs=(P('i'), P('i')), out_specs=P('i'))
那么我们如何避免这些低效的转置呢?
解决方案#
以下是两个解决方案的想法。它们并非互斥的。但是(剧透)第二个更好,而且它是我们所需要的。
部分解决方案 “P-sum”:构建将 psum
表示为 out_specs
的能力#
此解决方案有点像稻草人,因为它只会提供一种笨拙的程序编写方式。它甚至无法解决所有问题!但值得考虑,即使只是为了激励更完整的解决方案。
上面的示例 4 是人为的,因为我们可以只使用 out_specs
而不是主体中的 all_gather
# Example 4 again
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())
# Why didn't we just write it like this?
f4_better = shmap(lambda x: x, P('i'), P('i'))
f4_better
版本没有任何转置问题,因为转置问题是由主体中的集合运算引起的。
类似地,我们可以通过扩展 out_specs
来修复示例 1,以便它们可以表示求和
# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
in_specs=P('i'), out_specs=P())
# What if we could write an output sum like this?
f1_better = shmap(g, in_specs=P('i'), out_specs=P(sum='i')) # sum='i' means sum over that axis
# Then it could transpose like this:
t(f1_better) = shmap(t(g), in_specs=P(), out_specs=P('i'))
t(t(f1_better)) = shmap(t(t(g)), in_specs=P('i'), P(sum='i'))
因此,在 out_specs
中内置 psum
可以解决示例 1 的转置问题。但这并不能完全解决示例 3 中被诅咒的恒等转置
# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())
# How it would transpose with the P-sum partial solution:
t(cursed_identity) = shmap(lambda x: x / 8, P(), P(sum='i'))
t(t(cursed_identity)) = shmap(lambda x: x / 8, P(), P(sum='i'))
这是一个改进,因为程序不会随着我们不断转置而变得更大,但我们仍然在进行浪费的通信。
完整解决方案:静态跟踪设备可变与设备不变的中间值,加上新的原语#
此解决方案包含两个组件
跟踪值何时保证在特定网格轴上是设备不变的,何时是设备可变的,以及
将
psum
分解为两步过程,引入新的pbroadcast
原语,并为all_gather
及其转置引入新的原语。
从道理上讲,设备不变与设备可变信息的跟踪是类型级别的考虑因素。但为了我们第一次实现的方便,我们不需要真正将信息添加到抽象值或 jaxpr 类型中。在讨论实现之前,我们将首先使用类型来介绍这个想法。
接下来还将讨论使用户 API 方便且向后兼容。但为了首先介绍这个想法,我们将忽略便利性,而是编写尽可能显式的代码。
在 avals 中跟踪设备不变性(又名 avals-with-names,复活)#
有时我们可以仅从静态信息中判断出,shmap
主体中某些中间变量的值保证沿网格轴是不变的,这意味着沿网格轴的函数实例(及其相应的设备)都必须使用相同的值进行计算。我们将此类值称为设备不变值。对于非设备不变的值,我们称它们为设备可变值,尽管实际上我们指的是从类型系统的角度来看可能设备可变的值。
为了在类型中编码设备方差,我们将扩展数组类型的语法。我们将编写诸如 x:f32[3,4]{i}
之类的东西来指示 x
沿网格轴 i
是(可能)设备可变的(并且在 shmap
的任何其他网格轴上是设备不变的)。更一般地,我们将说数组类型语法的语法如下所示
shaped_array ::= <dtype>[<int_literal>, ...]<device_variance_type>
device_variance_type ::= {<axis_name>, ...}
我们还将更新类型规则以处理设备方差类型
对于集合运算以外的一阶原语
对于多arity 原语,操作数设备方差类型必须在形状必须相等的地方相等,例如
mul x:f32[s1]{r1} y:f32[s2][r2]
除了s1 == s2
之外,还需要r1 == r2
输出设备方差类型必须与操作数相同
对于高阶原语
我们只需实例化任何类型变量,包括设备方差类型(并在执行类型相等性检查时检查其设备方差类型是否相等)
(在执行类型推断时,例如对于
cond
的分支,我们取设备方差类型中轴名称集合的并集)
对于一阶集合运算
集合运算可以接受设备可变或设备不变的输入(沿对应于其轴名称参数的网格轴);将设备不变的操作数传递给接受设备可变操作数的集合运算,反之亦然,都是错误
集合运算可以生成设备可变或设备不变的输出
请参阅下表。作为附带好处,无论什么逻辑实现此类型检查,都可以取代
shmap
的“静态分析”检查,以检查shmap
主体函数是否与任何未映射的out_specs
兼容。
下表总结了集合运算原语的设备方差类型
名称 |
设备方差类型 |
示例 |
降低为 HLO |
转置 |
---|---|---|---|---|
|
|
|
|
|
|
|
|
无操作(无通信) |
|
|
|
|
|
|
|
|
|
|
n/a |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
这里有一些令人惊讶的事情!
我们引入了几个新的原语,包括
pbroadcast
,有趣的是,它降低为无操作all_gather_invariant
,它降低为与all_gather
相同的东西,但具有不同的设备方差类型(本质上all_gather
具有融合在其中的pbroadcast
,而all_gather_invariant
则没有)pscatter
,它是all_gather_invariant
的对偶(转置)
all_gather 具有设备可变的结果
直观地说,引入 pbroadcast
的原因(除了使类型规则起作用之外)是为了使 psum
可以转置为物理无操作。我们需要 all_gather
具有设备可变结果的原因是为了我们可以将其转置为 psum_scatter
;如果我们将其保留为设备不变的结果,我们可能需要下游的 pbroadcast
,并且该组合将转置为低效的 psum
,然后是切片 / pscatter
。因此,我们有一个“融合到” all_gather
中的 pbroadcast
,从而允许高效地转置为 psum_scatter
。我们提供 all_gather_invariant
及其转置 pscatter
主要是为了完整性;用户不太可能需要它(它对应于示例 4 中的情况,这种情况很容易使用 out_specs
以不同的方式编写)。
有趣的是,psum
和 pbroadcast
转置对对应于用户在使用 pmap
训练 LLM 时引入的 psum_idrev
和 id_psumrev
。
此系统如何解决低效的转置示例#
再次考虑简化的动机示例
# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
in_specs=P('i'), out_specs=P())
# Example 1 with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1(x: f32[3,4]{i}):
w:f32[]{i} = g(x)
y:f32[]{} = psum(w, 'i')
return y
使用这些新规则,转置为
# Example 1 transpose using device variance types (go ahead and transpose this again!)
t(f1) = shmap(lambda ybar: t(g)(pbroadcast(ybar, 'i')),
in_specs=P(), out_specs=P('i'))
# Example 1 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1_transpose(ybar: f32[]):
wbar:f32[]{i} = pbroadcast(ybar, 'i')
xbar:f32[3,4]{i} = transpose(g)(wbar)
return xbar
其中,评估 pbroadcast
应用根本不涉及通信或 FLOP;它是一个空操作。请注意,如果我们不断转置,主体的大小不会增长;实际上 t(t(f1)) == f1
。效率提升了!
只要我们在需要类型检查的地方使用 pbroadcast
,我们也不会搞砸其他示例
# Example 2 rewritten with explicit pbroadcast
f2 = shmap(lambda x, y: pbroadcast(psum(g(x), 'i'), 'i') * y,
in_specs=(P('i'), P('i')), out_specs=P('i'))
# Example 2 transpose using device variance types
t(f2, 0) = shmap(lambda y, zbar: t(g)(pbroadcast(psum(zbar * y, 'i'), 'i')),
in_specs=(P('i'), P('i')), out_specs=P('i'))
# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())
# Notice here the body is `f32[...] -> f32[...]`, i.e. no device varying type.
# Example 3 transpose using device variance types
t(cursed_identity) = shmap(lambda x: x, P(), P())
t(t(cursed_identity)) = shmap(lambda x: x, P(), P())
直观地说,在示例 1 中,我们现在只有“原始 psum 的一半”,而在示例 2 中,我们得到了“两半”。对于示例 3,我们根本不需要主体中的任何操作。
对于 all_gather
示例,示例 4 将需要使用 all_reduce_invariant
来进行高效的转置(尽管最好使用 out_specs
而不是主体中的集合通信)
# Example 4 rewritten with explicit all_reduce_invariant
f4 = shmap(lambda x: all_gather_invariant(x, 'i'), P('i'), P())
# Example 4 with intermediate device variance types annotated
@partial(shmap, P('i'), P())
def f4(x:f32[1]{i}):
y:f32[8]{} = all_gather_invariant(x, 'i')
return y
# Example 4 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P(), out_specs=P('i'))
def f4_transpose(ybar:f32[8]):
xbar:f32[1]{i} = pscatter(ybar, 'i')
return xbar
对于示例 5,使用设备相关的 all_gather
可以按我们期望的方式工作
# Example 5 with intermediate device variance types annotated
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5(x:f32[1]{i}, y:f32[8]{i}):
z:f32[8]{i} = all_gather(x, 'i')
w:f32[8]{i} = z * y
return w
# Transpose with respect to first argument
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5_transpose(y:f32[8]{i}, wbar:f32[8]{i}):
zbar:f32[8]{i} = wbar * y
xbar:f32[1]{i} = psum_scatter(zbar, 'i')
return xbar
如何使 API 对用户友好(并向后兼容)#
但是哪个用户想编写 pbroadcast
?哪个开发者想破坏大量现有的用户代码,这些代码涉及未馈送到未映射输出的 psum
?反正不是我!
相反,我们可以自动插入 pbroadcast
。这有点类似于我们在 jax.numpy
层进行自动秩提升的方式,插入广播以避免二元运算符中的秩不匹配错误。但这要简单得多,因为我们不需要处理形状元组。典型的规则是:每当我们看到一个多参数操作,其中操作数的设备方差类型不一致时,取操作数设备方差类型的轴名称集的并集,并插入 pbroadcast
以将每个操作数提升到结果设备方差类型。
在需要 pbroadcast
之前自动插入它们可能意味着我们对同一个操作数多次应用相同的 pbroadcast
,从而创建公共子表达式。当我们转置时,这些可能会变成 psum
的和,而不是和的 psum
。我们将依靠编译器来酌情清理它。如果这是一个问题,那么我们可以向 pbroadcast
插入过程添加一些简单的 memoization。
对于 all_gather
的用户 API 将默认意味着 all_gather_p
(而不是 all_gather_invariant_p
),涵盖常见情况,意味着不必插入 pbroadcast
。
我们可以在 shmap
上提供一个选项来禁用自动插入 pbroadcast
,在这种情况下,将由用户来确保类型正确性。对于一些想要明确 psum
在反向传播中出现位置的人来说,这个显式选项可能很有吸引力。
如何实现解决方案#
使实现轻量级的关键是,**我们不会将这些类型添加到 avals 或 jaxprs 中**。至少,一开始不会。这可能会很昂贵,因为它需要更新 JAX 的其余部分,例如,avals 和 jaxprs 的所有使用者可能都需要处理新类型。我们不会再犯同样的错误了!
相反,我们将把这些扩展类型保留为 shmap
内部的元数据,就像当前“out_specs
的复制检查”机制是 shmap
内部的一样。实际上,这个解决方案相当于对现有机制的一个相对小的扩展:它已经在跟踪相同的信息;现在我们只是添加了 pbroadcast
。
我们至少有两种选择来执行 pbroadcast
插入
就在转置之前,在转置规则中,我们有一个要转置的计算的 jaxpr;
在每个
shmap
主体中,无论是急切执行还是分段输出,都像当前的“out_specs
的复制检查”机制一样。前者可能最终更容易,因为我们只需要处理 jaxpr 情况,并且只需要线性原语。但我们将首先尝试后者,因此这里的实现是对现有复制检查逻辑的严格修订/扩展。
附录:定义和激励具有未映射输入和输出的映射#
为了具体起见,我们将主要关注 shmap
,尽管这些相同的想法也适用于例如 pmap
,可能也适用于 xmap
。
当 in_specs
的对应条目未提及该网格轴的名称时,参数/输入沿着网格轴是未映射的。逻辑上,这意味着沿着该网格轴的每个函数实例都获得相同的参数值。对于调用者,每个操作数都根据操作数映射到的网格轴进行切片,而对于操作数未映射到的网格轴则不进行切片。
当 out_specs
的对应条目未提及该网格轴的名称时,输出沿着网格轴是未映射的。逻辑上,这意味着沿着该网格轴的每个函数实例都必须返回相同的值。对于调用者,shmap
的每个结果都是通过连接每个函数实例的返回值形成的,这些函数实例沿着输出映射到的网格轴,而对于输出未映射到的网格轴,仅使用该值的一个副本。
有关未映射输入和输出的示例,请参阅 shmap
JEP。为了进行比较,在 vmap
中,未映射的输入/输出通过使用 in_axes
/ out_axes
的 None
(而不是 int
)来指示。
以下是我们喜欢 shmap
的未映射输入和输出的原因
与
pjit
相同的表达能力。pjit
可以做的任何事情,shmap
的逃生舱口也应该能够做到。否则我们将有一个不足的逃生舱口!如果我们在shmap
中没有未映射的输出,那么我们就无法表达与pjit
相同的批处理并行损失函数计算。闭包输入。 闭包输入本质上对应于未映射的输入,并且…
转置下的闭包。 一旦我们有了未映射的输入,那么能够转置到未映射的输出就很自然了。
因此,未映射的输出既是规范的,也是有用的!