高效转置复制诱导的集合算子#

mattjj@, dougalm@

2023年8月

动机#

在自动转置包含特定集合算子(collectives)的 shmap 时,我们遇到了效率问题。该问题出现在 psumall_gather 中,特别是在集合算子的输出作为未映射输出返回给调用者时。这并非边缘情况:例如,当对基于 shmap 的批处理数据并行神经网络损失函数应用 grad 时,该损失函数使用 psum 来计算总损失,就会出现此问题。

我们早就知道这个问题。在 pmap 中也存在类似的问题,尽管可以通过将 grad 保留在 pmap 内部而不是外部来绕过它。之前未完成的“带名称的 aval”工作的首要目标就是解决这个转置效率问题的一个版本。本文档借鉴了这些想法,同时进行了扩展和修订,以处理更多情况并使其更容易落地。事实上,本文提出的解决方案仅影响 shmap 的实现,系统的其余部分(目前)无需更改。

本文档的主要目的是定义这一转置效率问题,并提出一个易于落地的解决方案。

本文档不涉及

  • 数组上的逻辑轴名称(此处的轴名称仅与 shmap 和原版 pmap 中相同);

  • 改变自动微分语义(所有数值和(非)错误保持不变,我们只是在提高效率);

  • 允许用户代码反映任何新信息,或者实际上影响用户代码。

问题:psumall_gather 的高效转置取决于余切在设备间是否保持不变#

考虑这个半现实的例子,旨在模拟一个复制参数的批处理数据并行损失函数

devices = jax.devices()  # 8 devices

@partial(shmap, mesh=Mesh(devices, ('batch',)),
         in_specs=(P(None, None), P('batch', None)),
         out_specs=P())
def loss(params, batch):
  inputs, targets = batch
  predictions = predict(params, inputs)
  local_loss = jnp.mean(jnp.sum(predictions - targets, -1))
  global_loss = lax.pmean(local_loss, 'batch'))
  return global_loss

注意 out_specs=P(),这表示一个未映射的输出。如果您不熟悉未映射输出的概念,请参阅本文档底部的附录。

loss 示例中的大部分细节并不重要。对我们而言,重要的是我们在最后应用了 psum(或者更准确地说是 pmean = lambda x, name: psum(x, name) / psum(1, name))。因此,一个精简版本如下所示

# Example 1: shmap involving psum and unmapped output with inefficient transpose
f1 = shmap(lambda x: psum(g(x), 'i'),
           in_specs=P('i'), out_specs=P())

我们甚至通过省略 mesh 参数简化了符号表示。在随后的示例中,它可以根据上下文推断出来。

转置看起来是什么样的?写 t 表示函数转置,我们可以通过应用下方的 ¿f1_transpose? 函数来高效地计算任意 ybart(f1)(ybar)

# An efficient "transpose" of Example 1 (but don't transpose this again!)
¿f1_transpose? = shmap(t(g), in_specs=P(), out_specs=P('i'))

但那不是我们目前作为 t(f1) 得到的转置。

相反,目前的转置方案大致是交换 in_specsout_specs,对未映射输出进行一些除法重缩放,并转置主体。由于 psum 是它自身的转置(作为一种全归约求和),我们最终产生了这样的转置

# The transpose we currently get for Example 1 (which is fine to transpose again)
t(f1) = shmap(lambda ybar: t(g)(psum(ybar / 8, 'i')),
              in_specs=P(), out_specs=P('i'))

这个转置计算出了正确的数值,但很浪费。我们从转置的 in_specs=P() 中静态地知道 ybar 对于每个函数实例具有相同的值,即其值对于名为 i 的 mesh 轴上的设备是设备不变的,然而我们却对它应用了 psum!这使用了昂贵的通信,仅仅是为了将每个设备上的值乘以 8。(这里 8 指的是轴 i 的大小。除以 8 来自原始函数的 out_specs=P();它和平凡的 psum 基本相互抵消了。)

我们做错了什么?我们没有利用 f1 未映射输出对应的余切 ybar 保证是设备不变的这一事实;相反,我们防御性地对它们进行了 psum,就好像它们不是不变的一样,因为 psum 的转置无法根据其拥有的局部信息确定这一点。有时 psum 是必要的,例如在相对于其第一个参数转置 f2

# Example 2: shmap involving psum and *mapped* output with efficient transpose
f2 = shmap(lambda x, y: psum(g(x), 'i') * y,
          in_specs=(P('i'), P('i')), out_specs=P('i'))

# The transpose we currently get for Example 2 is efficient
t(f2, 0) = shmap(lambda y, zbar: t(g)(psum(zbar * y, 'i')),
                in_specs=(P('i'), P('i')), out_specs=P('i'))

直观地讲,如果我们的转置机制能够区分示例 1 和示例 2,我们就可以通过尽可能避免 psum 和除法来做得更好。

低效示例甚至可以更小。考虑转置这个“被诅咒的”恒等函数

# Example 3: cursed identity
cursed_identity = shmap(lambda x: x, P(), P())

# Currently we get these inefficient transposes
t(cursed_identity) = shmap(lambda x: psum(x / 8, 'i'), P(), P())
t(t(cursed_identity)) = shmap(lambda x: psum(psum(x / 8 / 8, 'i'), 'i')), P(), P())
...

我们转置得越多,它就变得越大。真尴尬!

而且 psum 并不是唯一的罪魁祸首。all_gather 也存在类似的情况

# Example 4: all_gather to an unmapped output
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())

# Currently we get this inefficient transpose
t(f4) = shmap(lambda ybar: psum_scatter(ybar / 8, 'i'), P(), P('i'))

这个程序有点做作。为什么要执行 all_gather 并将结果输入到未映射输出中,而不是跳过主体中的 all_gather 并直接使用 out_specs=P('i') 来收集结果?尽管它是人为设计的,但这个示例展示了一个不必要地执行通信的转置(我们本可以只执行非通信的切片),类似于 psum 的示例 1。

同样类似于 psum 的示例,防御性的 psum_scatter 在某些情况下是必要的

# Example 5: all_gather to a mapped output
f5 = shmap(lambda x, y: all_gather(x, 'i') * y,
           in_specs=(P('i'), P('i')), out_specs=P('i'))

# Currently we get this efficient transpose
t(f5, 0) = shmap(lambda y, zbar: psum_scatter(zbar * y, 'i'),
                 in_specs=(P('i'), P('i')), out_specs=P('i'))

那么我们如何避免这些低效的转置呢?

解决方案#

这里有两个解决方案思路。它们并非互斥。但(剧透一下)第二个方案更好,而且这就是我们所需要的。

部分解决方案“P-sum”:构建将 psum 表示到 out_specs 中的能力#

这个方案有点像稻草人,因为它只会提供一种笨拙的编程方式。而且它甚至不能解决所有问题!但值得考虑,即使只是为了激励一个更完整的解决方案。

上面的示例 4 是人为设计的,因为我们本可以在主体中使用 out_specs 而不是 all_gather

# Example 4 again
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())

# Why didn't we just write it like this?
f4_better = shmap(lambda x: x, P('i'), P('i'))

f4_better 版本没有任何转置问题,因为转置问题源于主体中的集合算子。

类似地,我们可以通过扩展 out_specs 使其能够表达求和来修复示例 1

# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
           in_specs=P('i'), out_specs=P())

# What if we could write an output sum like this?
f1_better = shmap(g, in_specs=P('i'), out_specs=P(sum='i'))  # sum='i' means sum over that axis

# Then it could transpose like this:
t(f1_better) = shmap(t(g), in_specs=P(), out_specs=P('i'))
t(t(f1_better)) = shmap(t(t(g)), in_specs=P('i'), P(sum='i'))

因此,提供内置于 out_specs 中的 psum 修复了示例 1 的转置问题。但它并没有完全修复示例 3 中被诅咒的恒等转置

# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())

# How it would transpose with the P-sum partial solution:
t(cursed_identity) = shmap(lambda x: x / 8, P(), P(sum='i'))
t(t(cursed_identity)) = shmap(lambda x: x / 8, P(), P(sum='i'))

这是一种改进,因为程序不会随着我们不断转置而继续变大,但我们仍然在进行浪费的通信。

完整解决方案:静态追踪设备间变化与设备间不变的中间量,并引入新的原语#

该解决方案有两个组成部分

  1. 追踪数值在特定 mesh 轴上何时保证是设备不变的与设备变化的,以及

  2. psum 分解为两步过程,引入一个新的 pbroadcast 原语,并为 all_gather 及其转置引入新的原语。

在逻辑上,设备不变与设备变化信息的追踪属于类型层面的考虑。但为了我们首次实现的权宜之计,我们不需要字面上将这些信息添加到抽象值或 jaxpr 类型中。在我们进入实现之前,我们将首先使用类型引入这个想法。

接下来还将讨论如何使用户 API 方便且向后兼容。但为了首先引入这个想法,我们将忽略方便性,转而编写尽可能显式的代码。

在 aval(抽象值)中追踪设备不变性(即“带名称的 aval”,已复活)#

我们有时仅凭静态信息就能判断 shmap 主体中某些中间变量的值被保证沿某个 mesh 轴是不变的,即该 mesh 轴上的函数实例(及其对应的设备)必须都在使用相同的值进行计算。我们将此类值称为设备不变的。对于非设备不变的值,我们称其为设备变化的,尽管实际上在类型系统的观点看来,我们指的是潜在地设备变化的。

为了在类型中编码设备方差,我们将扩展数组类型的语法。我们将编写类似于 x:f32[3,4]{i} 的内容,以表示 x 沿 mesh 轴 i 是(潜在)设备变化的(并且在 shmap 的任何其他 mesh 轴上是设备不变的)。更一般地,我们将数组类型语法的语法规则描述为类似

shaped_array ::= <dtype>[<int_literal>, ...]<device_variance_type>
device_variance_type ::= {<axis_name>, ...}

我们还将更新类型规则以处理设备方差类型

  • 对于除集合算子之外的一阶原语

    • 对于多参数原语,操作数的设备方差类型必须在形状必须相等的地方相等,例如 mul x:f32[s1]{r1} y:f32[s2][r2] 除了 s1 == s2 外还要求 r1 == r2

    • 输出的设备方差类型必须与操作数相同

  • 对于高阶原语

    • 我们只是实例化任何类型变量,包括设备方差类型(检查类型相等性时会检查其设备方差类型是否相等)

    • (执行类型推断时,例如对于 cond 的分支,我们取设备方差类型中轴名称集的并集)

  • 对于一阶集合算子

    • 集合算子可以接受设备变化的或设备不变的输入(沿其轴名称参数对应的 mesh 轴);将设备不变的操作数传递给接受设备变化操作数的集合算子(反之亦然)是错误的

    • 集合算子可以产生设备变化的或设备不变的输出

    • 参见下表。作为一个附带的好处,实现此类型检查的任何逻辑都可以包含 shmap 关于 shmap 主体函数是否与任何未映射 out_specs 兼容的“静态分析”检查。

下表总结了集合算子原语的设备方差类型

名称

设备方差类型

示例

Lower 至 HLO

转置

psum2

Varying -> Invariant

y:f32[3]{j} = psum(x:f32[3]{i,j}, axis='i')

AllReduceSum (通信)

pbroadcast

pbroadcast

Invariant -> Varying

y:f32[3]{i} = pbroadcast(x:f32[3], 'i')

no-op (无通信)

psum

all_to_all

Varying -> Varying

y:f32[16]{i} = all_to_all(x:f32[16]{i}, 'i', 0, 0) AllToAll (通信)

all_to_all

axis_index

() -> Varying

idx:i32[]{i} = axis_index('i')

ReplicaId 和一些算术 (无通信)

不适用

psum_scatter

Varying -> Varying

y:f32[2]{i} = psum_scatter(x:f32[16]{i}, 'i')

ReduceScatterSum (通信)

all_gather

all_gather

Varying -> Varying

y:f32[16]{i} = all_gather(x:f32[2]{i}, 'i')

AllGather (通信)

psum_scatter

pscatter

Invariant -> Varying

y:f32[2]{i} = pscatter(x:f32[16], 'i')

lambda x: x[axis_index('i'), None] (无通信)

all_gather_invariant

all_gather_invariant

Varying -> Invariant

y:f32[16] = all_gather_invariant(x:f32[2]{i}, 'i')

AllGather (通信)

pscatter

这里有一些令人惊讶的事情!

  • 我们引入了几个新的原语,包括

    • pbroadcast,有趣的是它 lower 到一个 no-op

    • all_gather_invariant,它 lower 到与 all_gather 相同的东西,但具有不同的设备方差类型(本质上 all_gather 有一个融合进其中的 pbroadcast,而 all_gather_invariant 没有)

    • pscatter,它是 all_gather_invariant 的对偶(转置)

  • all_gather 有一个设备变化的结果

直观地讲,引入 pbroadcast 的原因(除了让类型规则起作用之外)是使得 psum 可以转置为一个物理上的 no-op。我们需要 all_gather 具有设备变化结果的原因是这样我们可以将其转置为 psum_scatter;如果我们改为让它具有设备不变的结果,我们可能需要一个下游的 pbroadcast,而该组合会转置为一个低效的 psum 后跟切片 / pscatter。因此,我们让一个 pbroadcast “融合进” all_gather 中,从而允许高效转置为 psum_scatter。我们主要为了完整性提供了 all_gather_invariant 及其转置 pscatter;用户不太可能需要它(它对应于示例 4 中的情况,使用 out_specs 可以更容易地以不同方式编写)。

有趣的是,psumpbroadcast 转置对对应于用户在使用 pmap 训练 LLM 时引入的 psum_idrevid_psumrev

该系统如何解决低效转置示例#

再次考虑简化的动机示例

# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
           in_specs=P('i'), out_specs=P())

# Example 1 with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1(x: f32[3,4]{i}):
  w:f32[]{i} = g(x)
  y:f32[]{} = psum(w, 'i')
  return y

有了这些新规则,转置是

# Example 1 transpose using device variance types (go ahead and transpose this again!)
t(f1) = shmap(lambda ybar: t(g)(pbroadcast(ybar, 'i')),
              in_specs=P(), out_specs=P('i'))

# Example 1 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1_transpose(ybar: f32[]):
  wbar:f32[]{i} = pbroadcast(ybar, 'i')
  xbar:f32[3,4]{i} = transpose(g)(wbar)
  return xbar

其中评估 pbroadcast 应用不涉及任何通信或 FLOPs;它是一个 no-op。请注意,如果我们继续转置,主体大小不会增长;事实上 t(t(f1)) == f1。实现了效率!

而且我们也不会弄乱其他示例,只要我们通过 pbroadcast 使类型在需要的地方检查通过即可

# Example 2 rewritten with explicit pbroadcast
f2 = shmap(lambda x, y: pbroadcast(psum(g(x), 'i'), 'i') * y,
           in_specs=(P('i'), P('i')), out_specs=P('i'))

# Example 2 transpose using device variance types
t(f2, 0) = shmap(lambda y, zbar: t(g)(pbroadcast(psum(zbar * y, 'i'), 'i')),
                 in_specs=(P('i'), P('i')), out_specs=P('i'))


# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())
# Notice here the body is `f32[...] -> f32[...]`, i.e. no device varying type.

# Example 3 transpose using device variance types
t(cursed_identity) = shmap(lambda x: x, P(), P())
t(t(cursed_identity)) = shmap(lambda x: x, P(), P())

直观地讲,在示例 1 中,我们现在只有“原始 psum 的一半”,而在示例 2 中,我们得到了两个“一半”。对于示例 3,我们根本不需要主体中的任何操作。

对于 all_gather 示例,示例 4 需要使用 all_reduce_invariant 才能具有高效的转置(尽管最好使用 out_specs 代替主体中的集合算子)

# Example 4 rewritten with explicit all_reduce_invariant
f4 = shmap(lambda x: all_gather_invariant(x, 'i'), P('i'), P())

# Example 4 with intermediate device variance types annotated
@partial(shmap, P('i'), P())
def f4(x:f32[1]{i}):
  y:f32[8]{} = all_gather_invariant(x, 'i')
  return y

# Example 4 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P(), out_specs=P('i'))
def f4_transpose(ybar:f32[8]):
  xbar:f32[1]{i} = pscatter(ybar, 'i')
  return xbar

对于示例 5,使用设备变化的 all_gather 可以按我们期望的方式工作

# Example 5 with intermediate device variance types annotated
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5(x:f32[1]{i}, y:f32[8]{i}):
  z:f32[8]{i} = all_gather(x, 'i')
  w:f32[8]{i} = z * y
  return w

# Transpose with respect to first argument
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5_transpose(y:f32[8]{i}, wbar:f32[8]{i}):
  zbar:f32[8]{i} = wbar * y
  xbar:f32[1]{i} = psum_scatter(zbar, 'i')
  return xbar

如何使 API 对用户更方便(且向后兼容)#

但是,哪个用户想编写 pbroadcast?又有哪个开发者想破坏大量现有涉及 psum(且没有输入到未映射输出中)的用户代码?我可不想!

相反,我们可以自动插入 pbroadcast。这有点类似于我们在 jax.numpy 层执行自动秩提升的方式,插入广播以避免二元运算符中的秩不匹配错误。但它简单得多,因为我们不需要处理形状元组。典型的规则是:每当我们看到多参数操作且操作数的设备方差类型不一致时,取操作数设备方差类型轴名称集的并集,并插入 pbroadcast 以将每个操作数提升到最终的设备方差类型。

在需要的地方之前自动插入 pbroadcast 可能意味着我们将相同的 pbroadcast 应用于相同的操作数多次,从而创建公共子表达式。当我们转置时,这些可能会变成 psum 之和,而不是求和之 psum。我们将依靠编译器适当地清理它。如果这是一个问题,我们可以向 pbroadcast 插入过程中添加一些简单的记忆化。

all_gather 的用户 API 默认将意味着 all_gather_p(而非 all_gather_invariant_p),覆盖常见情况,这意味着无需插入 pbroadcast

我们可以在 shmap 上提供一个选项来禁用这种自动插入 pbroadcast 的行为,在这种情况下,用户需自行确保类型正确。对于那些希望明确 psum 在反向传播中发生位置的用户,此显式选项可能很有吸引力。

如何实现该方案#

使实现轻量级的关键是我们不会将这些类型添加到 aval 或 jaxpr 中。至少最初不会。这可能会很昂贵,因为它需要更新 JAX 的其余部分,例如所有 aval 和 jaxpr 的消费者可能都需要处理新类型。我们不会再上当了!

相反,我们将保持这些扩展类型作为 shmap 内部的元数据,就像当前“out_specs 的复制检查”机制是 shmap 内部的一样。事实上,此解决方案是对现有机制的相对较小的扩展:它已经在追踪相同的信息;现在我们只是添加了 pbroadcast

关于在哪里执行 pbroadcast 插入,我们至少有两个选择

  1. 在转置规则中,就在转置之前,此时我们拥有要转置的计算的 jaxpr;

  2. 在每个 shmap 主体中,无论是急切执行还是分阶段执行,就像当前的“out_specs 的复制检查”机制一样。前者可能更容易,因为我们只需要处理 jaxpr 情况,并且只需要线性原语。但我们将从尝试后者开始,以便此处的实现是对现有复制检查逻辑的严格修订/扩展。

附录:定义并论证带有未映射输入和输出的映射#

为了具体起见,我们主要关注 shmap,尽管这些相同的想法也适用于例如 pmap 和可能的 xmap

in_specs 的相应条目没有提及该 mesh 轴的名称时,参数/输入沿该 mesh 轴是未映射的。从逻辑上讲,这意味着沿该 mesh 轴的每个函数实例都会获得该参数的相同值。对于调用者而言,每个操作数都根据其映射到的 mesh 轴进行切片,而对于操作数未映射的 mesh 轴,则没有切片。

out_specs 的相应条目没有提及该 mesh 轴的名称时,输出沿该 mesh 轴是未映射的。从逻辑上讲,这意味着沿该 mesh 轴的每个函数实例必须返回相同的值。对于调用者而言,shmap 的每个结果都是通过拼接输出映射到的每个函数实例的返回值形成的,而对于输出未映射的 mesh 轴,只使用该值的一个副本。

请参阅 shmap JEP 获取未映射输入和输出的示例。作为比较,在 vmap 中,未映射的输入/输出通过使用 None(而不是 int)的 in_axes / out_axes 来指示。

以下是我们喜欢 shmap 的未映射输入和输出的原因

  • pjit 具有相同的表达能力。 pjit 能做的任何事情,shmap 逃生舱(escape hatch)也应该能够做到。否则我们就会有一个功能不足的逃生舱!如果 shmap 中没有未映射输出,我们就无法表达与 pjit 相同的批处理并行损失函数计算。

  • 闭包捕获的输入。 闭包捕获的输入本质上对应于未映射的输入,并且……

  • 转置下的闭包。 一旦我们有了未映射的输入,能够转置为未映射的输出就很自然了。

所以未映射的输出既是规范的又是有用的!