复制诱导集体操作的有效转置#

mattjj@, dougalm@

2023 年 8 月

动机#

在自动转置包含某些集体操作的 `shmap` 时,我们遇到了效率问题。这个问题出现在 `psum` 和 `all_gather` 中,特别是当集体操作的输出作为未映射输出返回给调用者时。这并非边缘情况:例如,当对基于 `shmap` 的批数据并行神经网络损失函数应用 `grad`,并且该函数使用 `psum` 计算总损失时,就会出现此问题。

我们早就知道这个问题。`pmap` 也存在类似问题,尽管通过将 `grad` 保留在 `pmap` 内部而不是外部来解决。未完成的 avals-with-names 工作的一个主要目标是解决这个转置效率问题的一个版本。本文档借鉴了这些想法,并对其进行了扩展和修订,以处理更多情况并使其更容易实现。实际上,这里提出的解决方案只影响 `shmap` 的实现。系统的其他部分无需更改(暂不更改)。

本文档的主要目的是定义这个转置效率问题,并提出一个易于实现的解决方案。

本文档不涉及

  • 数组上的逻辑轴名称(这里唯一提及的轴名称与 `shmap` 和原始 `pmap` 中的相同);

  • 更改自动微分语义(所有数值和(非)错误都保持不变,我们只是在提高效率);

  • 允许用户代码反映任何新信息,或者实际影响用户代码。

问题:`psum` 或 `all_gather` 的高效转置取决于余切(cotangent)在设备之间是否不变#

考虑这个半真实示例,它旨在模拟一个复制参数的批数据并行损失函数。

devices = jax.devices()  # 8 devices

@partial(shmap, mesh=Mesh(devices, ('batch',)),
         in_specs=(P(None, None), P('batch', None)),
         out_specs=P())
def loss(params, batch):
  inputs, targets = batch
  predictions = predict(params, inputs)
  local_loss = jnp.mean(jnp.sum(predictions - targets, -1))
  global_loss = lax.pmean(local_loss, 'batch'))
  return global_loss

注意 `out_specs=P()`,它表示一个未映射的输出。如果您不熟悉未映射输出的概念,请参阅本文档底部的附录。

`loss` 示例中的大部分细节并不重要。对我们而言,重要的是我们在最后应用了 `psum`(或者更确切地说,是 `pmean = lambda x, name: psum(x, name) / psum(1, name)`)。因此,简化版本如下所示:

# Example 1: shmap involving psum and unmapped output with inefficient transpose
f1 = shmap(lambda x: psum(g(x), 'i'),
           in_specs=P('i'), out_specs=P())

我们甚至通过省略 `mesh` 参数来简化了表示法。在接下来的示例中,可以从上下文中推断出来。

转置是什么样子?用 `t` 表示函数转置,我们可以通过应用下面的函数 `¿f1_transpose?` 来高效地计算任何 `ybar` 的 `t(f1)(ybar)`:

# An efficient "transpose" of Example 1 (but don't transpose this again!)
¿f1_transpose? = shmap(t(g), in_specs=P(), out_specs=P('i'))

但这并非我们目前作为 t(f1) 得到的转置。

相反,当前的转置方法大致是交换 `in_specs` 和 `out_specs`,对未映射输出进行一些除法重缩放,然后转置函数体。由于 `psum` 是其自身的转置(作为全规约求和),我们最终会得到这个转置:

# The transpose we currently get for Example 1 (which is fine to transpose again)
t(f1) = shmap(lambda ybar: t(g)(psum(ybar / 8, 'i')),
              in_specs=P(), out_specs=P('i'))

这种转置得到的结果是正确的,但它很浪费。我们从转置的 `in_specs=P()` 中静态地知道,`ybar` 对于每个函数实例都具有相同的值,即对于沿着名为 `i` 的网格轴上的设备,其值是设备不变的,但我们仍然对其应用了 `psum`!这会使用昂贵的通信,仅仅是为了将每个设备上的值乘以 8。(这里的 8 指的是轴 i 的大小。除以 8 来自原始函数的 `out_specs=P()`;它和这个简单的 `psum` 基本相互抵消了。)

我们做错了什么?我们没有利用这样一个事实:与 `f1` 的未映射输出对应的余切 `ybar` 保证是设备不变的;相反,我们防御性地对它们进行了 `psum` 求和,就好像它们不是设备不变一样,因为 `psum` 的转置在仅有本地信息的情况下无法确定。有时 `psum` 是必要的,例如在转置 `f2` 相对于其第一个参数时:

# Example 2: shmap involving psum and *mapped* output with efficient transpose
f2 = shmap(lambda x, y: psum(g(x), 'i') * y,
          in_specs=(P('i'), P('i')), out_specs=P('i'))

# The transpose we currently get for Example 2 is efficient
t(f2, 0) = shmap(lambda y, zbar: t(g)(psum(zbar * y, 'i')),
                in_specs=(P('i'), P('i')), out_specs=P('i'))

直观上,如果我们的转置机制能够区分示例 1 和示例 2,我们就可以通过尽可能避免 `psum` 和除法来做得更好。

低效的例子甚至可以更小。考虑转置这个“被诅咒”的恒等函数:

# Example 3: cursed identity
cursed_identity = shmap(lambda x: x, P(), P())

# Currently we get these inefficient transposes
t(cursed_identity) = shmap(lambda x: psum(x / 8, 'i'), P(), P())
t(t(cursed_identity)) = shmap(lambda x: psum(psum(x / 8 / 8, 'i'), 'i')), P(), P())
...

转置越多,它就变得越大。多么尴尬!

`psum` 并非唯一的罪魁祸首。类似的情况也适用于 `all_gather`:

# Example 4: all_gather to an unmapped output
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())

# Currently we get this inefficient transpose
t(f4) = shmap(lambda ybar: psum_scatter(ybar / 8, 'i'), P(), P('i'))

这个程序有点人为。为什么要做一个 `all_gather` 并将结果输入到未映射的输出中,而不是跳过函数体中的 `all_gather` 并直接使用 `out_specs=P('i')` 来收集结果呢?但即使它是人为设计的,这个示例仍然展示了一种不必要地执行通信的转置(我们本可以只执行一个非通信的切片),类似于 `psum` 的示例 1。

与 `psum` 示例类似,在某些情况下,防御性的 `psum_scatter` 是必要的:

# Example 5: all_gather to a mapped output
f5 = shmap(lambda x, y: all_gather(x, 'i') * y,
           in_specs=(P('i'), P('i')), out_specs=P('i'))

# Currently we get this efficient transpose
t(f5, 0) = shmap(lambda y, zbar: psum_scatter(zbar * y, 'i'),
                 in_specs=(P('i'), P('i')), out_specs=P('i'))

那么我们如何避免这些低效的转置呢?

解决方案#

这里有两个解决方案的想法。它们并非相互排斥。但(剧透)第二个更好,而且它就是我们所需要的全部。

部分解决方案“P-sum”:构建在 `out_specs` 中表达 `psum` 的能力#

这个解决方案有点“稻草人”,因为它只提供了一种笨拙的编写程序的方式。而且它甚至无法解决所有问题!但它值得考虑,哪怕只是为了激发一个更完整的解决方案。

上述示例 4 是人为的,因为我们本可以在函数体中使用 `out_specs` 而不是 `all_gather`。

# Example 4 again
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())

# Why didn't we just write it like this?
f4_better = shmap(lambda x: x, P('i'), P('i'))

`f4_better` 版本没有任何转置问题,因为转置问题源于函数体中的集体操作。

类似地,我们可以通过扩展 `out_specs` 来修复示例 1,使其能够表达求和操作:

# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
           in_specs=P('i'), out_specs=P())

# What if we could write an output sum like this?
f1_better = shmap(g, in_specs=P('i'), out_specs=P(sum='i'))  # sum='i' means sum over that axis

# Then it could transpose like this:
t(f1_better) = shmap(t(g), in_specs=P(), out_specs=P('i'))
t(t(f1_better)) = shmap(t(t(g)), in_specs=P('i'), P(sum='i'))

因此,在 `out_specs` 中内置 `psum` 解决了示例 1 的转置问题。但它不能完全修复示例 3 中“被诅咒”的恒等转置:

# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())

# How it would transpose with the P-sum partial solution:
t(cursed_identity) = shmap(lambda x: x / 8, P(), P(sum='i'))
t(t(cursed_identity)) = shmap(lambda x: x / 8, P(), P(sum='i'))

这是一个改进,因为程序在持续转置时不会继续变大,但我们仍然在进行浪费的通信。

完整解决方案:静态跟踪设备可变与设备不变的中间值,并引入新的原语#

这个解决方案包含两个组成部分:

  1. 跟踪值在特定网格轴上何时保证是设备不变的,何时是设备可变的,以及

  2. 将 `psum` 分解为两步过程,引入新的 `pbroadcast` 原语,并为 `all_gather` 及其转置引入新的原语。

从概念上讲,跟踪设备不变与设备可变信息是一个类型层面的考量。但为了我们首次实现的速度,我们不需要字面上将这些信息添加到抽象值或 jaxpr 类型中。在我们开始实现之前,我们将首先使用类型来介绍这个想法。

接下来还将讨论如何让用户 API 方便易用并向后兼容。但为了首先介绍这个想法,我们将忽略便利性,转而编写尽可能明确的代码。

在 avals 中跟踪设备不变性(又称 avlas-with-names,已复活)#

有时我们仅从静态信息就可以判断,`shmap` 函数体中某些中间变量的值在沿特定网格轴上是保证不变的,这意味着沿该网格轴的函数实例(及其相应的设备)都必须使用相同的值进行计算。我们将这些值称为设备不变的。对于非设备不变的值,我们称它们为设备可变的,尽管从类型系统的角度来看,我们实际指的是潜在的设备可变。

为了在类型中编码设备方差,我们将扩展数组的类型语法。我们将编写诸如 `x:f32[3,4]{i}` 之类的形式,以指示 `x` 沿网格轴 `i` (潜在地)是设备可变的(而在 `shmap` 的任何其他网格轴上是设备不变的)。更一般地说,我们将数组类型语法的语法定义为类似:

shaped_array ::= <dtype>[<int_literal>, ...]<device_variance_type>
device_variance_type ::= {<axis_name>, ...}

我们还将更新类型规则以处理设备方差类型:

  • 对于非集体操作的一阶原语

    • 对于多参数原语,操作数设备方差类型必须在形状相等的地方相等,例如 `mul x:f32[s1]{r1} y:f32[s2][r2]` 要求除了 `s1 == s2` 之外,还要求 `r1 == r2`。

    • 输出设备方差类型必须与操作数相同

  • 对于高阶原语

    • 我们只需实例化任何类型变量,包括设备方差类型(并检查类型相等性以确认其设备方差类型是否相等)

    • (在执行类型推断时,例如对于 `cond` 的分支,我们取设备方差类型中轴名称集合的并集)

  • 对于一阶集体操作

    • 一个集体操作可以接受设备可变或设备不变的输入(沿与其轴名称参数对应的网格轴);将设备不变的操作数传递给接受设备可变操作数的集体操作,反之亦然,都是错误的。

    • 一个集体操作可以产生设备可变或设备不变的输出

    • 见下表。作为附带好处,实现此类型检查的任何逻辑都可以包含 `shmap` 的“静态分析”检查,即检查 `shmap` 函数体是否与任何未映射的 `out_specs` 兼容。

下表总结了集体原语的设备方差类型:

名称

设备方差类型

示例

转换为 HLO

转置

psum2

可变 -> 不变

y:f32[3]{j} = psum(x:f32[3]{i,j}, axis='i')

AllReduceSum (通信)

pbroadcast

pbroadcast

不变 -> 可变

y:f32[3]{i} = pbroadcast(x:f32[3], 'i')

无操作(无通信)

psum

all_to_all

可变 -> 可变

y:f32[16]{i} = all_to_all(x:f32[16]{i}, 'i', 0, 0) AllToAll (通信)

all_to_all

axis_index

() -> 可变

idx:i32[]{i} = axis_index('i')

ReplicaId 和一些算术运算(无通信)

不适用

psum_scatter

可变 -> 可变

y:f32[2]{i} = psum_scatter(x:f32[16]{i}, 'i')

ReduceScatterSum (通信)

all_gather

all_gather

可变 -> 可变

y:f32[16]{i} = all_gather(x:f32[2]{i}, 'i')

AllGather (通信)

psum_scatter

pscatter

不变 -> 可变

y:f32[2]{i} = pscatter(x:f32[16], 'i')

lambda x: x[axis_index('i'), None] (无通信)

all_gather_invariant

all_gather_invariant

可变 -> 不变

y:f32[16] = all_gather_invariant(x:f32[2]{i}, 'i')

AllGather (通信)

pscatter

这里有一些令人惊讶的地方!

  • 我们引入了几个新的原语,包括:

    • `pbroadcast`,有趣的是,它会转换为一个无操作

    • `all_gather_invariant`,它会转换为与 `all_gather` 相同的东西,但具有不同的设备方差类型(本质上,`all_gather` 融合了 `pbroadcast`,而 `all_gather_invariant` 则没有)

    • `pscatter`,它是 `all_gather_invariant` 的对偶(转置)

  • `all_gather` 产生一个设备可变的结果

直观上,引入 `pbroadcast` 的原因(除了让类型规则生效)是为了让 `psum` 能够转置为一个物理上的无操作。我们需要 `all_gather` 产生设备可变结果的原因是,这样我们就可以将其转置为 `psum_scatter`;如果我们改为让它产生设备不变的结果,我们可能需要一个下游的 `pbroadcast`,而这种组合会转置为一个低效的 `psum`,然后是切片 / `pscatter`。因此,我们不是这样做,而是将 `pbroadcast` “融合”到 `all_gather` 中,从而允许高效地转置为 `psum_scatter`。我们提供 `all_gather_invariant` 及其转置 `pscatter` 主要是为了完整性;用户不太可能需要它(它对应于示例 4 中的情况,该情况很容易使用 `out_specs` 以不同方式编写)。

有趣的是,`psum` 和 `pbroadcast` 转置对与用户在使用 `pmap` 训练 LLM 时引入的 `psum_idrev` 和 `id_psumrev` 相对应。

该系统如何解决低效转置示例#

再次考虑简化的动机示例:

# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
           in_specs=P('i'), out_specs=P())

# Example 1 with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1(x: f32[3,4]{i}):
  w:f32[]{i} = g(x)
  y:f32[]{} = psum(w, 'i')
  return y

有了这些新规则,转置是:

# Example 1 transpose using device variance types (go ahead and transpose this again!)
t(f1) = shmap(lambda ybar: t(g)(pbroadcast(ybar, 'i')),
              in_specs=P(), out_specs=P('i'))

# Example 1 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1_transpose(ybar: f32[]):
  wbar:f32[]{i} = pbroadcast(ybar, 'i')
  xbar:f32[3,4]{i} = transpose(g)(wbar)
  return xbar

其中,评估 `pbroadcast` 应用不涉及任何通信或浮点运算;它是一个无操作。请注意,如果我们持续转置,函数体的大小不会增加;事实上,`t(t(f1)) == f1`。效率已实现!

而且我们也不会搞砸其他示例,只要我们在需要的地方使用 `pbroadcast` 来进行类型检查。

# Example 2 rewritten with explicit pbroadcast
f2 = shmap(lambda x, y: pbroadcast(psum(g(x), 'i'), 'i') * y,
           in_specs=(P('i'), P('i')), out_specs=P('i'))

# Example 2 transpose using device variance types
t(f2, 0) = shmap(lambda y, zbar: t(g)(pbroadcast(psum(zbar * y, 'i'), 'i')),
                 in_specs=(P('i'), P('i')), out_specs=P('i'))


# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())
# Notice here the body is `f32[...] -> f32[...]`, i.e. no device varying type.

# Example 3 transpose using device variance types
t(cursed_identity) = shmap(lambda x: x, P(), P())
t(t(cursed_identity)) = shmap(lambda x: x, P(), P())

直观上,在示例 1 中,我们现在只有“原始 `psum` 的一半”,而在示例 2 中,我们得到了两个“一半”。对于示例 3,函数体中根本不需要任何操作。

对于 `all_gather` 示例,示例 4 将需要使用 `all_reduce_invariant` 才能进行高效转置(尽管最好是使用 `out_specs` 而不是函数体中的集体操作)

# Example 4 rewritten with explicit all_reduce_invariant
f4 = shmap(lambda x: all_gather_invariant(x, 'i'), P('i'), P())

# Example 4 with intermediate device variance types annotated
@partial(shmap, P('i'), P())
def f4(x:f32[1]{i}):
  y:f32[8]{} = all_gather_invariant(x, 'i')
  return y

# Example 4 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P(), out_specs=P('i'))
def f4_transpose(ybar:f32[8]):
  xbar:f32[1]{i} = pscatter(ybar, 'i')
  return xbar

对于示例 5,使用设备可变的 `all_gather` 达到了我们想要的效果:

# Example 5 with intermediate device variance types annotated
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5(x:f32[1]{i}, y:f32[8]{i}):
  z:f32[8]{i} = all_gather(x, 'i')
  w:f32[8]{i} = z * y
  return w

# Transpose with respect to first argument
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5_transpose(y:f32[8]{i}, wbar:f32[8]{i}):
  zbar:f32[8]{i} = wbar * y
  xbar:f32[1]{i} = psum_scatter(zbar, 'i')
  return xbar

如何让 API 对用户友好(并向后兼容)#

但有哪个用户愿意编写 `pbroadcast` 呢?又有哪个开发者愿意破坏大量涉及未输入到未映射输出的 `psum` 的现有用户代码呢?反正我不想!

相反,我们可以自动插入 `pbroadcast`。这有点类似于我们在 `jax.numpy` 层进行自动秩提升的方式,即插入广播以避免二元运算符中的秩不匹配错误。但这要简单得多,因为我们不需要处理形状元组。典型规则是:每当我们看到一个多元操作,其操作数在设备方差类型上不一致时,就取操作数设备方差类型轴名称集合的并集,并插入 `pbroadcast` 将每个操作数提升到结果设备方差类型。

在需要 `pbroadcast` 之前自动插入它们,可能意味着我们多次对相同的操作数应用相同的 `pbroadcast`,从而创建了公共子表达式。当我们进行转置时,这些可能会变成“`psum` 之和”,而不是“和之 `psum`”。我们将依赖编译器进行适当清理。如果这是一个问题,那么我们可以在 `pbroadcast` 插入阶段添加一些简单的记忆化功能。

`all_gather` 的用户 API 默认将意味着 `all_gather_p`(而非 `all_gather_invariant_p`),涵盖了常见情况,意味着无需插入 `pbroadcast`。

我们可以在 `shmap` 上提供一个选项来禁用 `pbroadcast` 的这种自动插入,在这种情况下,将由用户来确保类型正确性。对于那些希望明确 `psum` 在反向传播中出现位置的用户来说,这个明确的选项可能很有吸引力。

如何实现该解决方案#

实现轻量化的关键在于,我们不会将这些类型添加到 avals 或 jaxprs 中。至少,一开始不会。这可能代价高昂,因为它需要更新 JAX 的其他部分,例如所有 avals 和 jaxprs 的消费者可能需要处理新类型。我们不会再上当了!

相反,我们将把这些扩展类型作为 `shmap` 内部的元数据保留,就像当前 `shmap` 内部的“`out_specs` 的复制检查”机制一样。事实上,这个解决方案只是对现有机制的一个相对较小的扩展:它已经在跟踪相同的信息;现在我们只是添加了 `pbroadcast`。

我们至少有两种选择来执行 `pbroadcast` 插入:

  1. 就在转置之前,在转置规则中,我们有一个要转置的计算的 jaxpr;

  2. 在每个 `shmap` 函数体中,无论是即时执行还是阶段性执行,都像当前“`out_specs` 的复制检查”机制一样。前者可能最终会更容易,因为我们只需要处理 jaxpr 的情况,并且只处理线性原语。但我们将首先尝试后者,这样这里的实现是对现有复制检查逻辑的严格修订/扩展。

附录:定义和阐述具有未映射输入和输出的映射#

为了具体起见,我们主要关注 `shmap`,尽管这些相同的思想也适用于 `pmap` 和 `xmap` 等。

当 `in_specs` 的相应条目未提及网格轴名称时,沿着该网格轴的参数/输入是未映射的。逻辑上,这意味着沿该网格轴的每个函数实例都获取参数的相同值。对于调用者来说,每个操作数都根据其映射的网格轴进行切片,而对于未映射的网格轴则不进行切片。

当 `out_specs` 的相应条目未提及网格轴名称时,沿着该网格轴的输出是未映射的。逻辑上,这意味着沿该网格轴的每个函数实例都必须返回相同的值。对于调用者来说,`shmap` 的每个结果都是通过连接所有输出映射的函数实例的返回值形成的,而对于输出未映射的网格轴,则只使用值的一个副本。

有关未映射输入和输出的示例,请参阅 `shmap` JEP。作为比较,在 `vmap` 中,未映射的输入/输出通过使用 `in_axes` / `out_axes` 为 `None`(而不是整数)来指示。

以下是我们喜欢 `shmap` 的未映射输入和输出的原因:

  • 与 `pjit` 具有相同的表达能力。 任何 `pjit` 能做的事情,`shmap` 的“紧急出口”也应该能做到。否则我们的“紧急出口”就不完善了!如果 `shmap` 中没有未映射的输出,那么我们就无法表达与 `pjit` 相同的批并行损失函数计算。

  • 闭包输入。 闭包输入本质上对应于未映射输入,并且…

  • 转置下的闭包。 一旦我们有了未映射输入,就自然而然地能够转置为未映射输出。

因此,未映射输出既规范又实用!