高效转置诱导复制的集合通信#
mattjj@, dougalm@
2023 年 8 月
动机#
我们在自动转置包含某些集合通信的 shmap 时遇到了效率问题。该问题出现在 psum 和 all_gather 中,特别是当集合通信的输出作为未映射输出返回给调用者时。而且这不是一个边缘情况:例如,当将 grad 应用于基于 shmap 的批数据并行神经网络损失函数时,该函数使用 psum 来计算总损失时就会出现此问题。
我们已经知道这个问题一段时间了。 pmap 中存在类似的问题,尽管通过将 grad 保留在 pmap 内部而不是外部来解决了。不完整的带名称 avals 工作的主要目标之一是解决此转置效率问题的某个版本。本文档借鉴了这些想法,同时对其进行了扩展和修订,以处理更多情况并使其更容易合并。事实上,这里提出的解决方案仅影响 shmap 的实现。系统的其余部分无需更改(目前)。
本文档的主要目的是定义此转置效率问题并提出一个易于合并的解决方案。
本文档不讨论
数组上的逻辑轴名称(此处唯一的轴名称就像
shmap和原始pmap中的一样);更改自动微分语义(所有数字和(非)错误都保持不变,我们只是使事情更高效);
允许用户代码反映任何新信息,或者实际上对用户代码产生任何影响。
问题:高效转置 psum 或 all_gather 取决于协边是否跨设备不变#
考虑这个半真实的示例,旨在模仿一个复制参数的批数据并行损失函数
devices = jax.devices() # 8 devices
@partial(shmap, mesh=Mesh(devices, ('batch',)),
in_specs=(P(None, None), P('batch', None)),
out_specs=P())
def loss(params, batch):
inputs, targets = batch
predictions = predict(params, inputs)
local_loss = jnp.mean(jnp.sum(predictions - targets, -1))
global_loss = lax.pmean(local_loss, 'batch'))
return global_loss
注意 out_specs=P(),它表示未映射的输出。如果您不熟悉未映射输出的概念,请参阅本文档底部的附录。
loss 示例中的大多数细节并不重要。对我们来说,重要的是我们在最后应用了 psum(更准确地说,是 pmean = lambda x, name: psum(x, name) / psum(1, name))。因此,简化的版本如下所示:
# Example 1: shmap involving psum and unmapped output with inefficient transpose
f1 = shmap(lambda x: psum(g(x), 'i'),
in_specs=P('i'), out_specs=P())
我们甚至通过抑制 mesh 参数来简化了符号。在接下来的示例中,可以从上下文中推断出它。
转置是什么样的?将 t 表示函数转置,我们可以通过应用下面的 ¿f1_transpose? 函数来高效地计算 t(f1)(ybar)
# An efficient "transpose" of Example 1 (but don't transpose this again!)
¿f1_transpose? = shmap(t(g), in_specs=P(), out_specs=P('i'))
但这并不是我们当前通过 t(f1) 获得的转置。
相反,当前的转置方法大致是切换 in_specs 和 out_specs,对未映射的输出进行一些除法缩放,然后转置主体。因为 psum 本身就是其转置(作为 all-reduce sum),我们最终会得到这个转置
# The transpose we currently get for Example 1 (which is fine to transpose again)
t(f1) = shmap(lambda ybar: t(g)(psum(ybar / 8, 'i')),
in_specs=P(), out_specs=P('i'))
此转置结果在数字上是正确的,但效率低下。我们从转置的 in_specs=P() 中静态地知道 ybar 对于每个函数实例都具有相同的值,即它的值在名为 i 的网格轴上是设备不变的,但我们却对它应用了 psum!这使用了昂贵的通信,只是为了将每个设备上的值乘以 8。(这里 8 指的是轴 i 的大小。除以 8 来自原始函数的 out_specs=P();它和平凡的 psum 基本相互抵消。)
我们哪里做错了?我们没有利用与 f1 的未映射输出对应的协边 ybar 保证是设备不变的这一事实;相反,我们防御性地对其进行 psum 操作,就好像它们不是一样,因为 psum 的转置在给定它拥有的本地信息的情况下无法确定。有时 psum 是必要的,就像转置 f2 相对于其第一个参数一样
# Example 2: shmap involving psum and *mapped* output with efficient transpose
f2 = shmap(lambda x, y: psum(g(x), 'i') * y,
in_specs=(P('i'), P('i')), out_specs=P('i'))
# The transpose we currently get for Example 2 is efficient
t(f2, 0) = shmap(lambda y, zbar: t(g)(psum(zbar * y, 'i')),
in_specs=(P('i'), P('i')), out_specs=P('i'))
直观地说,如果我们的转置机制能够区分示例 1 和示例 2,我们就可以通过避免 psum 和除法(如果可能)来做得更好。
低效的示例甚至可以更小。考虑转置这个诅咒的同一性函数
# Example 3: cursed identity
cursed_identity = shmap(lambda x: x, P(), P())
# Currently we get these inefficient transposes
t(cursed_identity) = shmap(lambda x: psum(x / 8, 'i'), P(), P())
t(t(cursed_identity)) = shmap(lambda x: psum(psum(x / 8 / 8, 'i'), 'i')), P(), P())
...
我们每次转置它,它都会变得更大。真丢人!
psum 并不是唯一的罪魁祸首。 all_gather 也存在类似的情况
# Example 4: all_gather to an unmapped output
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())
# Currently we get this inefficient transpose
t(f4) = shmap(lambda ybar: psum_scatter(ybar / 8, 'i'), P(), P('i'))
这个程序有点人工。为什么要做一个 all_gather 并将结果馈送到未映射的输出,而不是在主体中跳过 all_gather 并直接使用 out_specs=P('i') 来收集结果?但是,尽管它是人为设计的,这个示例仍然显示了一个不必要地执行通信的转置(我们可以只执行一个不通信的切片),这与 psum 的示例 1 类似。
与 psum 示例类似,防御性 psum_scatter 在某些情况下是必要的
# Example 5: all_gather to a mapped output
f5 = shmap(lambda x, y: all_gather(x, 'i') * y,
in_specs=(P('i'), P('i')), out_specs=P('i'))
# Currently we get this efficient transpose
t(f5, 0) = shmap(lambda y, zbar: psum_scatter(zbar * y, 'i'),
in_specs=(P('i'), P('i')), out_specs=P('i'))
那么我们如何避免这些低效的转置呢?
解决方案#
这里有两个解决方案。它们并非互斥。但是(剧透)第二个更好,而且是我们需要的一切。
部分解决方案“P-sum”:将 psum 的能力构建到 out_specs 中#
这个解决方案有点像稻草人,因为它只提供了一种写程序的笨拙方式。而且它甚至不能解决所有问题!但值得考虑,即使只是为了激发一个更完整的解决方案。
示例 4 是人为的,因为我们可以直接在主体中使用 out_specs 而不是 all_gather
# Example 4 again
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())
# Why didn't we just write it like this?
f4_better = shmap(lambda x: x, P('i'), P('i'))
f4_better 版本没有任何转置问题,因为转置问题源于主体中的集合通信。
类似地,我们可以通过扩展 out_specs 来使其能够表示求和来修复示例 1
# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
in_specs=P('i'), out_specs=P())
# What if we could write an output sum like this?
f1_better = shmap(g, in_specs=P('i'), out_specs=P(sum='i')) # sum='i' means sum over that axis
# Then it could transpose like this:
t(f1_better) = shmap(t(g), in_specs=P(), out_specs=P('i'))
t(t(f1_better)) = shmap(t(t(g)), in_specs=P('i'), P(sum='i'))
因此,在 out_specs 中内置 psum 解决了示例 1 的转置问题。但它并没有完全解决示例 3 中的诅咒同一性转置
# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())
# How it would transpose with the P-sum partial solution:
t(cursed_identity) = shmap(lambda x: x / 8, P(), P(sum='i'))
t(t(cursed_identity)) = shmap(lambda x: x / 8, P(), P(sum='i'))
这是一个改进,因为程序不会随着我们不断转置而继续变大,但我们仍然在进行低效的通信。
完整解决方案:静态跟踪设备可变与设备不变的中间值,以及新的原始操作#
此解决方案有两个组成部分
跟踪值在特定网格轴上保证是设备不变还是设备可变的,以及
将
psum分解为两步过程,引入新的pbroadcast原始操作,并为all_gather及其转置引入新的原始操作。
从道德上讲,跟踪设备不变性与设备可变性的信息是一种类型级别的考虑。但为了我们首次实现的便利性,我们无需将此信息实际添加到抽象值或 jaxpr 类型中。在讨论实现之前,我们将首先使用类型来介绍这个概念。
接下来还将讨论使 API 对用户友好且向后兼容。但首先要介绍这个概念,我们将忽略便利性,而是编写尽可能明确的代码。
在 avals 中跟踪设备不变性(又名,带名称的 avals,复活)#
我们有时可以仅从静态信息中得知 shmap 主体中某些中间变量的值在网格轴上是保证不变的,即沿着网格轴的函数实例(及其对应的设备)必须使用相同的值进行计算。我们将这些值称为设备不变。对于不是设备不变的值,我们将称它们为设备可变的,尽管我们实际上是从类型系统的角度来表示潜在的设备可变。
为了在类型中编码设备可变性,我们将扩展数组类型的语法。我们将编写诸如 x:f32[3,4]{i} 之类的内容,以表示 x 沿着网格轴 i 是(潜在地)设备可变的(并且对于 shmap 的任何其他网格轴都是设备不变的)。更一般地说,我们将数组类型语法的语法定义为
shaped_array ::= <dtype>[<int_literal>, ...]<device_variance_type>
device_variance_type ::= {<axis_name>, ...}
我们还将更新类型规则以处理设备可变类型
对于集合通信以外的一阶原始操作
对于多参数原始操作,操作数设备可变类型必须在形状必须相等的任何地方相等,例如
mul x:f32[s1]{r1} y:f32[s2][r2]需要r1 == r2加上s1 == s2输出设备可变类型必须与操作数相同
对于高阶原始操作
我们只是实例化任何类型变量,包括设备可变类型(并且检查类型是否相等会检查它们的设备可变类型是否相等)
(在执行类型推断时,例如,对于
cond的分支,我们取设备可变类型轴名称集合的并集)
对于一阶集合通信
集合通信可以接受设备可变或设备不变的输入(沿着与其轴名称参数对应的网格轴);将设备不变的操作数传递给接受设备可变操作数的集合通信,反之亦然,是错误的
集合通信可以产生设备可变或设备不变的输出
请参阅下表。作为一个附带的好处,实现此类型检查的任何逻辑都可以包含
shmap的“静态分析”检查,以确定shmap主体函数是否与任何未映射的out_specs兼容。
这是一个总结集合通信原始操作设备可变类型 的表
名称 |
设备可变类型 |
示例 |
降低到 HLO |
转置 |
|---|---|---|---|---|
|
|
|
|
|
|
|
|
无操作(无通信) |
|
|
|
|
|
|
|
|
|
|
不适用 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
这里有一些令人惊讶的事情!
我们引入了几个新的原始操作,包括
pbroadcast,有趣的是它降低为无操作all_gather_invariant,它降低为与all_gather相同,但具有不同的设备可变类型(本质上all_gather融合了pbroadcast,而all_gather_invariant则没有)pscatter,它是all_gather_invariant的对偶(转置)
all_gather 的结果是设备可变的
直观地说,引入 pbroadcast(除了使类型规则正常工作之外)的原因是让 psum 可以转置为一个物理上的无操作。我们需要 all_gather 具有设备可变结果的原因是为了能够将其转置为 psum_scatter;如果我们而是将其保留为设备不变结果,我们可能需要一个下游的 pbroadcast,而该组合将转置为一个低效的 psum 后跟切片 / pscatter。因此,我们有一个 pbroadcast “融合到” all_gather 中,从而允许高效地转置为 psum_scatter。我们提供 all_gather_invariant 及其转置 pscatter 主要是为了完整性;用户不太可能需要它(它对应于示例 4 中的情况,该情况很容易通过在主体中使用 out_specs 来另写)。
有趣的是,psum 和 pbroadcast 转置对对应于用户在训练使用 pmap 的 LLM 时引入的 psum_idrev 和 id_psumrev。
该系统如何解决低效转置的示例#
再次考虑简化的激励示例
# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
in_specs=P('i'), out_specs=P())
# Example 1 with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1(x: f32[3,4]{i}):
w:f32[]{i} = g(x)
y:f32[]{} = psum(w, 'i')
return y
有了这些新规则,转置是
# Example 1 transpose using device variance types (go ahead and transpose this again!)
t(f1) = shmap(lambda ybar: t(g)(pbroadcast(ybar, 'i')),
in_specs=P(), out_specs=P('i'))
# Example 1 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1_transpose(ybar: f32[]):
wbar:f32[]{i} = pbroadcast(ybar, 'i')
xbar:f32[3,4]{i} = transpose(g)(wbar)
return xbar
其中评估 pbroadcast 应用不涉及任何通信或 FLOPs;它是一个无操作。请注意,如果我们继续转置,主体的大小不会增加;事实上 t(t(f1)) == f1。效率达成!
而且只要我们 pbroadcast 使类型检查在需要的地方通过,我们就不会弄乱其他示例
# Example 2 rewritten with explicit pbroadcast
f2 = shmap(lambda x, y: pbroadcast(psum(g(x), 'i'), 'i') * y,
in_specs=(P('i'), P('i')), out_specs=P('i'))
# Example 2 transpose using device variance types
t(f2, 0) = shmap(lambda y, zbar: t(g)(pbroadcast(psum(zbar * y, 'i'), 'i')),
in_specs=(P('i'), P('i')), out_specs=P('i'))
# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())
# Notice here the body is `f32[...] -> f32[...]`, i.e. no device varying type.
# Example 3 transpose using device variance types
t(cursed_identity) = shmap(lambda x: x, P(), P())
t(t(cursed_identity)) = shmap(lambda x: x, P(), P())
直观地说,在示例 1 中,我们现在只有“原始 psum 的一半”,而在示例 2 中,我们得到“两半”。对于示例 3,我们根本不需要在主体中进行任何操作。
对于 all_gather 示例,示例 4 需要使用 all_reduce_invariant 才能获得高效的转置(尽管最好是在主体中使用 out_specs 而不是集合通信)
# Example 4 rewritten with explicit all_reduce_invariant
f4 = shmap(lambda x: all_gather_invariant(x, 'i'), P('i'), P())
# Example 4 with intermediate device variance types annotated
@partial(shmap, P('i'), P())
def f4(x:f32[1]{i}):
y:f32[8]{} = all_gather_invariant(x, 'i')
return y
# Example 4 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P(), out_specs=P('i'))
def f4_transpose(ybar:f32[8]):
xbar:f32[1]{i} = pscatter(ybar, 'i')
return xbar
对于示例 5,使用设备可变的 all_gather 按我们期望的那样工作
# Example 5 with intermediate device variance types annotated
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5(x:f32[1]{i}, y:f32[8]{i}):
z:f32[8]{i} = all_gather(x, 'i')
w:f32[8]{i} = z * y
return w
# Transpose with respect to first argument
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5_transpose(y:f32[8]{i}, wbar:f32[8]{i}):
zbar:f32[8]{i} = wbar * y
xbar:f32[1]{i} = psum_scatter(zbar, 'i')
return xbar
如何使 API 对用户友好(并向后兼容)#
但是,哪个用户想编写 pbroadcast?哪个开发者想破坏大量涉及 psum 的现有用户代码,而这些 psum 没有被馈送到未映射的输出?反正不是我!
相反,我们可以自动插入 pbroadcast。这有点类似于我们在 jax.numpy 层进行自动秩提升,插入广播以避免二元运算符中的秩不匹配错误。但这要简单得多,因为我们不需要处理形状元组。典型的规则是:每当我们看到一个多参数操作,其操作数的设备可变类型不匹配时,就取操作数设备可变类型轴名称集的并集,并插入 pbroadcast 以将每个操作数提升到结果设备可变类型。
在需要之前自动插入 pbroadcast 可能意味着我们会对同一个操作数应用同一个 pbroadcast 多次,从而创建公共子表达式。转置时,这些可能会变成“psum 的和”而不是“psum 的和”。我们将依赖编译器进行适当的清理。如果这是个问题,我们可以为 pbroadcast 插入过程添加一些简单的记忆化。
all_gather 的用户 API 默认将是 all_gather_p(而不是 all_gather_invariant_p),涵盖了常见情况,并且意味着无需插入 pbroadcast。
我们可以为 shmap 提供一个选项来禁用此自动插入 pbroadcast,在这种情况下,用户需要自行确保类型正确性。这个显式选项可能会吸引一些希望明确反向传播中 psum 的位置的用户。
如何实现解决方案#
使实现轻量级的关键在于**我们不会将这些类型添加到 avals 或 jaxprs**。至少,一开始不会。这可能会很昂贵,因为它需要更新 JAX 的其余部分,例如所有 avals 和 jaxprs 的消费者可能需要处理新类型。我们不会再犯这种错误了!
相反,我们将这些扩展类型保留为 shmap 内部的元数据,就像当前的“ out_specs 的复制检查”机制是 shmap 内部的一样。事实上,这个解决方案只是对现有机制的一个相对较小的扩展:它已经在跟踪相同的信息;现在我们只是添加了 pbroadcast。
我们至少有两个选项可以在哪里执行 pbroadcast 插入
在转置之前,在转置规则中,其中我们有一个要转置的计算的 jaxpr;
在每个
shmap主体中,无论是即时执行还是暂存,就像当前的“out_specs的复制检查”机制一样。前者可能更容易,因为我们只需要处理 jaxpr 情况,并且只处理线性原始操作。但我们将首先尝试后者,以便这里的实现是现有复制检查逻辑的严格修订/扩展。
附录:定义和说明带有未映射输入和输出的映射#
为了具体起见,我们将主要关注 shmap,尽管这些相同的想法也适用于例如 pmap 和可能 xmap。
当 in_specs 的相应条目未提及某个网格轴的名称时,该参数/输入在该网格轴上是*未映射*的。从逻辑上讲,这意味着沿着该网格轴的每个函数实例都为该参数获取相同的值。对于调用者来说,每个操作数都根据操作数映射到的网格轴进行切片,而对于操作数未映射的网格轴则没有切片。
当 out_specs 的相应条目未提及某个网格轴的名称时,该输出在该网格轴上是*未映射*的。从逻辑上讲,这意味着沿着该网格轴的每个函数实例都必须返回相同的值。对于调用者来说,shmap 的每个结果是通过连接输出映射到的网格轴上的所有函数实例的返回值形成的,而对于输出未映射的网格轴,则只使用该值的一个副本。
请参阅 shmap JEP 以获取未映射输入和输出的示例。作为比较,在 vmap 中,未映射输入/输出是通过使用 None 的 in_axes / out_axes 来指示的(而不是 int)。
以下是我们喜欢 shmap 的未映射输入和输出的原因
与
pjit相同的表达能力。pjit可以做的任何事情,shmap逃生舱口也应该能够做到。否则,我们将有一个不完善的逃生舱口!如果我们不在shmap中使用未映射的输出,那么我们就无法像pjit那样表达相同的批处理并行损失函数计算。封闭输入。 封闭输入本质上对应于未映射输入,并且……
转置下的闭包。 一旦有了未映射的输入,自然就可以转置为未映射的输出。
因此,未映射的输出既是标准的又是非常有用的!