shmap (shard_map) for simple per-device code#

sholto@, sharadmv@, jekbradbury@, zhangqiaorjc@, mattjj@

January 2023

这是提出 shard_map 的设计文档。您可能更希望查看最新用户文档

动机#

JAX 支持两种多设备编程理念

  1. 编译器,接手吧! 让编译器自动将批量数组函数划分到设备上。

  2. 就让我写出我想要表达的,该死! 给我每设备代码和显式通信集合操作。

我们需要两种理念的优秀 API,它们不应相互排斥,而应能够相互组合。

通过 pjit(现在只是 jit),我们为第一种理念提供了下一代 API。但我们尚未完全提升第二种理念的水平。pmap 遵循第二种理念,但随着时间的推移,我们发现它存在致命缺陷xmap 解决了这些缺陷,但它并不能完全提供每设备形状,并且它还包含了其他一些重要概念。与此同时,对每设备显式集合编程的新需求不断涌现,例如高效扩展 Transformer 推理中所示。

我们可以通过 shmap 提升第二种理念的水平。shmap 是:

  • 一个简单的多设备并行 API,它允许我们编写带有显式集合操作的每设备代码,其中逻辑形状与每设备物理缓冲区形状匹配,并且集合操作与跨设备通信精确对应;

  • 一个 xmap 的特化版本,功能有所缩减并进行了一些调整;

  • XLA SPMD 分区器“手动”模式的相当直接的体现;

  • 一个读起来很有趣的苏斯博士风格的名字,它可以代表 shard_mapshpecialized_xmapsholto_mapsharad_map

对于 pjit 用户shmap 是一个补充工具。它可以在 pjit 计算内部使用,暂时进入“手动集合操作”模式,就像从编译器的自动分区中“逃逸”一样。这样,用户可以享受 pjit 带来的大部分代码的便利性和熟悉的 NumPy 编程模型,同时在需要时能够通过 shmap 手动优化集合通信。这是两全其美!

对于 pmap 用户shmap 是一个严格的升级。它更具表达力、性能更优、与其他 JAX API 更具可组合性,同时不会使基本的批数据并行变得更难。

有关实际使用的更多信息,您可以跳转到何时使用 shmap,何时使用 pjit。如果您想知道为什么我们需要一个新事物,或者 pmap 的问题是什么,请跳转到为什么 pmapxmap 尚未解决此问题?。或者继续阅读下一节,查看一些 shmap 示例和 API 规范。

那么,让我们看看 shmap#

简短示例(附详细解释)#

如此巧妙

from functools import partial

import numpy as np

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map

mesh = jax.make_mesh((4, 2), ('i', 'j'))

a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 * 32.).reshape(16, 32)

@partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P('j', None)),
         out_specs=P('i', None))
def matmul_basic(a_block, b_block):
  # a_block: f32[2, 8]
  # b_block: f32[8, 32]
  z_partialsum = jnp.dot(a_block, b_block)
  z_block = jax.lax.psum(z_partialsum, 'j')
  return z_block

c = matmul_basic(a, b)  # c: f32[8, 32]

请注意

  • 对于多轴并行,无需嵌套(或 axis_index_groups),这与 pmap 不同;

  • 调用者无需重塑,这与 pmap 和硬 xmap 不同,并且逻辑形状与每设备物理形状相对应,这与(非硬)xmap 不同;

  • 通过使用 mesh 精确控制设备放置,这与 pmap 不同;

  • 逻辑轴和物理轴只有一套轴名称,这与 xmap 不同;

  • 结果是一个 jax.Array,可以高效地传递给 pjit,这与 pmap 不同;

  • 相同的代码可以在 pjit/jit 内部高效运行,这与 pmap 不同;

  • 此代码以即时模式运行,因此我们可以在中间使用 pdb 并打印值,这与 xmap 的当前实现不同(尽管根据设计,没有顺序调度的 xmap 原则上也可以以即时模式运行)。

这里是另一个具有完全分片结果的矩阵乘法变体

@partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P('j', None)),
         out_specs=P('i', 'j'))
def matmul_reduce_scatter(a_block, b_block):
  # c_partialsum: f32[8/X, 32]
  c_partialsum = jnp.matmul(a_block, b_block)
  # c_block: f32[8/X, 32/Y]
  c_block = jax.lax.psum_scatter(c_partialsum, 'j', scatter_dimension=1, tiled=True)
  return c_block

c = matmul_reduce_scatter(a, b)

慢下来,从基础开始!#

数组轴上的降维映射与保维映射#

我们可以将 pmap(以及 vmapxmap)视为沿着某个轴解堆叠每个数组输入(例如,将 2D 矩阵解包为 1D 行),将其主体函数应用于每个部分,然后将结果重新堆叠在一起,至少在不涉及集合操作时是这样

pmap(f, in_axes=[0], out_axes=0)(xs) == jnp.stack([f(x) for x in xs])

例如,如果 xs 的形状为 f32[8,5],那么每个 x 的形状为 f32[5],并且如果每个 f(x) 的形状为 f32[3,7],那么最终堆叠结果 pmap(f)(xs) 的形状为 f32[8,3,7]。也就是说,主体函数 f 的每次应用都接受比 pmap(f) 相应参数少一个轴的输入作为参数。我们可以说这些是 *降维映射*,伴随着输入/输出的解堆叠/堆叠。

`f` 的逻辑应用数量由被映射的输入轴大小决定:例如,如果我们映射一个大小为 8 的输入轴,语义上我们会得到 8 个函数的逻辑应用,对于 pmap 来说,这始终对应于 8 个设备物理执行它们。

相比之下,shmap 没有这种降维行为。相反,我们可以将其视为沿着输入轴切片(或“非拼接”)成块,应用主体函数,然后将结果重新拼接在一起(同样,在不涉及集合操作时)

devices = np.array(jax.devices()[:4])
m = Mesh(devices, ('i',))  # mesh.shape['i'] = 4

shard_map(f, m, in_specs=P('i'), out_specs=P('i'))(y)
==
jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, 4)])

回想一下,jnp.split 将其输入切片成相同大小且具有相同秩的块,因此如果上述示例中 y 的形状为 f32[8,5],那么每个 y_blk 的形状为 f32[2,5],并且如果每个 f(y_blk) 的形状为 f32[3,7],那么最终拼接结果 shard_map(f, ...)(y) 的形状为 f32[12,7]。因此 shmapshard_map)映射其输入的分片或块。我们可以说它是一个 *保维映射*,伴随着输入/输出的非拼接/拼接。

`f` 的逻辑应用数量由网格大小决定,而不是由任何输入轴大小决定:例如,如果我们有一个总大小为 4 的网格(即在 4 个设备上),那么语义上我们会得到 4 个函数的逻辑应用,对应于 4 个设备物理执行它们。

使用 in_specs 控制每个输入如何拆分(非拼接)和平铺#

每个 in_specs 通过使用 PartitionSpecs 按名称将相应输入数组的某些轴与网格轴进行标识,表示如何将该输入拆分(或非拼接)成将应用主体函数的块。该标识决定了分片大小;当输入轴与网格轴标识时,输入会沿该逻辑轴拆分(非拼接)成与相应网格轴大小相等的片段数量。(如果相应网格轴大小不能均匀地除尽输入数组轴大小,则会报错。)如果输入的 pspec 未提及网格轴名称,则不会在该网格轴上进行拆分。例如

devices = np.array(jax.devices())
m = Mesh(devices.reshape(4, 2), ('i', 'j'))

@partial(shard_map, mesh=m, in_specs=P('i', None), out_specs=P('i', 'j'))
def f1(x_block):
  print(x_block.shape)
  return x_block

x1 = np.arange(12 * 12).reshape(12, 12)
y = f1(x1)  # prints (3,12)

这里,因为输入 pspec 没有提及网格轴名称 `'j'`,所以没有输入数组轴在该网格轴上拆分;同样,因为输入数组的第二个轴未与任何网格轴标识(因此也未在其上拆分),所以 `f1` 的应用会沿该轴获得输入的完整视图。

当输入 pspec 中未提及网格轴时,我们总是可以重写为效率较低的程序,其中提及所有网格轴,但调用者执行 jnp.tile 操作,例如

@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P('i', 'j'))
def f2(x_block):
  print(x_block.shape)
  return x_block

x = np.arange(12 * 12).reshape(12, 12)
x_ = jnp.tile(x, (1, mesh.axis_size['j']))  # x_ has shape (12, 24)
y = f2(x_)  # prints (3,12), and f1(x) == f2(x_)

换句话说,由于每个输入 pspec 可以提及每个网格轴名称零次或一次,而不是必须精确提及每个名称一次,我们可以说,除了内置到其输入中的 jnp.split 之外,shard_map 的输入中也内置了 jnp.tile,至少在逻辑上是这样(尽管平铺可能不需要物理执行,具体取决于参数的物理分片布局)。要使用的平铺方式并非唯一;我们也可以沿着第一个轴进行平铺,并使用 pspec P(('j', 'i'), None)

输入上可能发生物理数据移动,因为每个设备都需要一份相应数据的副本。

使用 out_specs 控制每个输出如何通过拼接、块转置和去平铺来组装#

类似于输入端,每个 out_specs 通过名称将相应输出数组的某些轴与网格轴进行标识,表示输出块(主体函数每次应用一个,或等效于每个物理设备一个)应如何组装在一起以形成最终输出值。例如,在上述 f1f2 示例中,out_specs 都表示我们应该通过沿两个轴拼接块结果来形成最终输出,这两种情况下都得到形状为 (12,24) 的数组 y。(如果主体函数的输出形状,即输出块形状,其秩太小,无法满足相应输出 pspec 所描述的拼接,则会报错。)

当输出 pspec 中未提及网格轴名称时,它表示 *去平铺*:当用户编写的输出 pspec 未提及其中一个网格轴名称时,他们承诺输出块在该网格轴上是相等的,因此输出中只使用该轴上的一个块(而不是将所有块沿该网格轴拼接在一起)。例如,使用与上述相同的网格

x = jnp.array([[3.]])

z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P('i', 'j'))()
print(z)  # prints the same as jnp.tile(x, (4, 2))

z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P('i', None))()
print(z)  # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))

z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P(None, None))()
print(z)  # prints the same as jnp.tile(x, (1, 1)), or just x

请注意,主体函数关闭数组值等效于将其作为参数传递,其相应的输入 pspec 为 P(None, None)。作为另一个示例,更接近于上述其他示例

@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P('i', None))
def f3(x_block):
  return jax.lax.psum(x_block, 'j')

x = np.arange(12 * 12).reshape(12, 12)
y3 = f3(x)
print(y3.shape)  # (12,6)

请注意,结果的第二个轴大小为 6,是输入第二个轴大小的一半。在这种情况下,输出 pspec 中未提及网格轴名称 `'j'` 所表达的去平铺是安全的,因为集合操作 psum 确保了每个输出块沿相应的网格轴是相等的。以下是另外两个示例,我们改变了输出 pspec 中提及的网格轴

@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f4(x_block):
  return jax.lax.psum(x_block, 'i')

x = np.arange(12 * 12).reshape(12, 12)
y4 = f4(x)
print(y4.shape)  # (3,12)


@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P(None, None))
def f5(x_block):
  return jax.lax.psum(x_block, ('i', 'j'))

y5 = f5(x)
print(y5.shape)  # (3,6)

在物理方面,在输出 pspec 中不提及网格轴名称,会从输出设备缓冲区组装一个 Array,该 Array 沿该网格轴具有复制布局。

没有运行时检查以验证输出块在要进行去平铺的网格轴上是否实际相等,或者等效地,相应物理缓冲区是否具有相同的值,从而可以被解释为单个逻辑数组的复制布局。但我们可以提供一个静态检查机制,对所有潜在不正确的程序引发错误。

因为 out_specs 可以提及网格轴名称零次或一次,并且它们可以以任意顺序提及,所以我们可以说,除了内置到其输出中的 jnp.concatenate 之外,shard_map 的输出中还内置了去平铺和块转置。

输出上不可能发生物理数据移动,无论输出 pspec 是什么。相反,out_specs 只是编码如何将块输出组装成 Array,或者物理上如何将跨设备的缓冲区解释为单个逻辑 Array 的物理布局。

API 规范#

from jax.sharding import Mesh
Specs = PyTree[PartitionSpec]

def shard_map(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs
          ) -> Callable:
  ...

其中

  • mesh 编码按数组排列并具有关联轴名称的设备,就像 xmapsharding.NamedSharding 那样;

  • in_specsout_specsPartitionSpecs,它们可以仿射地提及来自 mesh 的轴名称(不像 xmap 中那样使用单独的逻辑名称),分别表达输入和输出的切片/非拼接和拼接(不像 pmapxmap 那样进行解堆叠和堆叠),其中未提及的名称分别对应于复制和去平铺(断言已复制,所以给我一份副本);

  • 传递给 f 的参数形状与其传递给 shard_map-of-f 的参数具有相同的秩(与 pmapxmap 秩被降低不同),并且 f 的参数形状是根据 shard_map-of-f 相应参数的形状 shape 和相应的 PartitionSpec 规范大致计算得出,形式为 tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))

  • `f` 的主体可以使用 mesh 中的名称应用集合操作。

shmap 默认是即时模式的,这意味着我们逐个原始操作调度计算,以便用户可以在完全复制的值上使用 Python 控制流,并进行交互式 pdb 调试以打印任何值。要对 shmap 函数进行阶段性优化和端到端编译,只需在其外部添加一个 jit。结果是 shmap 没有像 xmappmap 当前那样拥有自己的调度和编译路径;它只使用 jit 路径。

当它被例如外部的 jit 阶段性优化时,shmap 到 StableHLO 的降低是微不足道的:它只涉及在输入上切换到“手动 SPMD 模式”,然后在输出上切换回来。(我们目前不计划支持部分手动部分自动模式。)

与效果的交互与 pmap 相同。

与自动微分的交互也与 pmap 类似(而不是尝试 xmap 所做的新语义,即拥有未映射的中间值,从而 gradreduce_axes 以及将 psum 转置为 pbroadcast 而不是 psum)。但它因此继承了 pmap 的一个未解决问题:在某些情况下,与其将 psum 转置为 psum,从而执行与前向传播 psum 对应的反向传播 psum,不如将反向传播 psum 移动到反向传播中的其他位置,从而利用线性性。许多高级 pmap 用户通过使用 custom_vjp 实现 psum_idrevid_psumrev 函数来解决此挑战,但由于很容易意外地使它们不平衡,该技术存在缺陷。我们有一些关于如何以更安全的方式提供此功能的想法。

何时使用 shmap,何时使用 pjit#

一种理念是:用 jit==pjit 编写程序几乎总是更简单——但如果程序的某个部分编译器优化得不如预期,则可以转而使用 shmap

一个实际示例#

下面是 shmap 在具有 2D 权重聚集模式的 Transformer 层传递中的示例(论文,第 5 页第 3.2.3 节)

def matmul_2D_wg_manual(xnorm, q_wi, layer):
  '''Calls a custom manual implementation of matmul_reducescatter'''
  # [batch, maxlen, embed.X] @ [heads.YZ, embed.X, q_wi_per_head]
  # -> (matmul)
  # -> [batch, maxlen, heads.YZ, q_wi_per_head]{x unreduced}
  # -> (reducescatter over x into X heads, B batches)
  # -> [batch, maxlen, heads.YZX, q_wi_per_head]
  with jax.named_scope('q_wi'):
    xnorm = intermediate_dtype(xnorm)
    q_wi = matmul_reducescatter(
        'bte,hed->bthd',
        xnorm,
        params.q_wi,
        scatter_dimension=(0, 2),
        axis_name='i',
        layer=layer)
   return q_wi


import partitioning.logical_to_physical as l2phys

def pjit_transformer_layer(
    hparams: HParams, layer: int, params: weights.Layer, sin: jnp.ndarray,
    cos: jnp.ndarray, kv_caches: Sequence[attention.KVCache],
    x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
  """Forward pass through a single layer, returning output, K, V."""

  def my_layer(t, axis=0):
    """Gets the parameters corresponding to a given layer."""
    return lax.dynamic_index_in_dim(t, layer, axis=axis, keepdims=False)

  # 2D: [batch.Z, time, embed.XY]
  x = _with_sharding_constraint(
      x, ('residual_batch', 'residual_time', 'residual_embed'))
  xnorm = _layernorm(x)
  # 2D: [batch, time, embed.X]
  xnorm = _with_sharding_constraint(
      xnorm, ('post_norm_batch', 'time', 'post_norm_embed'))
  # jump into manual mode where you want to optimise
  if manual:
    q_wi = shard_map(matmul_2D_wg_manual, mesh
                in_specs=(l2phys('post_norm_batch', 'time', 'post_norm_embed'),
                          l2phys('layers', 'heads', 'embed', 'q_wi_per_head')),
                out_specs=l2phys('post_norm_batch', 'time', 'heads', 'q_wi_per_head'))(xnorm, q_wi, layer)
  else:
    q_wi = jnp.einsum('bte,hed->bthd', xnorm, my_layer(params.q_wi))
    # 2D: [batch, time, heads.YZX, None]
    q_wi = _with_sharding_constraint(q_wi,
                                   ('post_norm_batch', 'time', 'heads', 'qkv'))
  q = q_wi[:, :, :, :hparams.qkv]
  q = _rope(sin, cos, q)
  # unlike in https://arxiv.org/pdf/2002.05202.pdf, PaLM implements
  # swiGLU with full d_ff dimension, rather than 2/3 scaled
  wi0 = q_wi[:, :, :, hparams.qkv:hparams.qkv + (hparams.ff // hparams.heads)]
  wi1 = q_wi[:, :, :, hparams.qkv + (hparams.ff // hparams.heads):]
  kv = jnp.einsum('bte,ezd->btzd', xnorm, my_layer(params.kv))
  k = kv[:, :, 0, :hparams.qkv]
  v = kv[:, :, 0, hparams.qkv:]
  k = _rope(sin, cos, k)

  y_att = jnp.bfloat16(attention.attend(q, k, v, kv_caches, layer))

  y_mlp = special2.swish2(wi0) * wi1
  # 2D: [batch, time, heads.YZX, None]
  y_mlp = _with_sharding_constraint(y_mlp,
                                    ('post_norm_batch', 'time', 'heads', None))

  y_fused = jnp.concatenate([y_att, y_mlp], axis=-1)
  # do the second half of the mlp and the self-attn projection in parallel
  y_out = jnp.einsum('bthd,hde->bte', y_fused, my_layer(params.o_wo))
  # 2D: [batch.Z, time, embed.XY]
  y_out = _with_sharding_constraint(
      y_out, ('residual_batch', 'residual_time', 'residual_embed'))
  z = y_out + x
  z = _with_sharding_constraint(
      z, ('residual_batch', 'residual_time', 'residual_embed'))
  return z, k, v

在下面的性能分析中,第一个和第二个矩阵乘法都被手动降低的版本替换,其中计算(融合)与通信(ppermute)完全重叠!一个有趣的线索表明我们正在使用延迟优化的变体是 ppmerute 像素是抖动的——因为同时有两个重叠的 ppermute 使用相反的 ICI 轴!

全互连(All-to-all)更难重叠,因此被搁置了。

image

为什么 pmapxmap 尚未解决此问题?#

pmap 是我们第一个多设备并行 API。它遵循“每设备代码和显式集合操作”的理念。但它存在主要缺点,使其不适用于当今的程序

  • 映射多个轴需要嵌套的 pmap 嵌套的 pmap 不仅编写繁琐,而且难以控制(甚至预测)数据和计算的设备放置,也难以保留数据分片(参见接下来的两点)。当今的程序需要多轴并行。

  • 控制设备放置是不可能的。 特别是在多轴并行的情况下,程序员需要控制这些轴如何与硬件资源及其通信拓扑对齐。但是(嵌套的)pmap 不提供对映射程序实例如何在硬件上放置的控制;只有一个用户无法控制的自动设备顺序。(Gopher 使用 axis_index_groups 和单个非嵌套 pmap 本质上是一种通过将多个并行轴扁平化为一个来解决此问题的“黑科技”。)

  • jit/pjit 可组合性。 jit 包裹 pmap 是一个性能陷阱,嵌套 pmap 也是如此,例如 scan 包裹 pmap 也是,因为从内部 pmap 返回时分片不会被保留。为了保留分片,我们需要对 jaxprs 进行模式匹配,以确保我们正在使用完美嵌套的 `pmap`,或者 `pmap` 仅在 jit 内部。此外,pjit 在这里也无济于事,因为 pmap 针对 XLA 副本,而 pjit 针对 XLA SPMD 分区器,将这两者组合起来很困难。

  • jax.Array 兼容性(以及 pjit 兼容性)。 由于 pmap 的输出分片无法表示为 Shardings / OpShardings,这是因为 pmap 采用堆叠而非拼接语义,因此 pmap 计算的输出目前无法直接传递给 pjit 计算,除非通过主机中转(或调度一个重塑计算)。

  • 多控制器语义(以及 pjit 兼容性)。 多控制器 pmap 会在控制器之间拼接值,这种方式运行良好,但与单控制器 pmap 的堆叠语义不同。更实际的是,它排除了使用非完全可寻址的 jax.Array 输入和输出,而这在多控制器 pjit 中是允许的。

  • 即时模式。 我们没有将 pmap 设计为优先支持即时模式,尽管我们最终(4 年多后!)通过 disable_jit() 添加了即时操作,但 pmap 内部融合了 jit,这意味着它有自己的编译和调度路径(实际上是两个调度路径:Python 中用于处理 Tracers,C++ 中用于原始 Array 输入的性能!),这带来了沉重的实现负担。

  • 调用者需要进行重塑。 在 8 个设备上使用 pmap 的典型用例可能包括:从大小为 128 的批处理轴开始,将其重塑为两个大小分别为 (8, 16) 的轴,然后对第一个轴进行 pmap 操作。这些重塑操作很笨拙,并且编译器通常将其解释为复制而不是视图——这会增加内存和时间的使用。

这些缺点在仅进行批数据并行时并不算太糟。但当涉及更多并行时,pmap 就力不从心了!

xmap 作为 pmap 的下一代演进,铺平了道路,并解决了(几乎)所有这些问题。shmap 沿着 xmap 的足迹,以基本相同的方式解决了这些问题;实际上,shmap 就像是 xmap 的一个特殊子集(有些人称之为“硬 xmap”子集),并进行了一些调整。

对于初始原型,我们选择将 shmap 作为独立于 xmap 的原语来实现,因为限制其支持的功能集更容易专注于核心功能。例如,shmap 不允许未映射的中间值,这使得无需担心命名轴和自动微分之间的交互变得更容易。此外,无需考虑所有功能对之间的交互,使得添加超出 xmap 当前实现的能力变得更容易,例如对即时模式的支持。

shmapxmap 共享大量的降低代码。我们将来可以考虑合并两者,甚至完全专注于 shmap,具体取决于使用情况的演变。