Pallas TPU 分布式计算#

在本教程中,我们将介绍 Pallas 在 TPU 上进行分布式计算的基础知识。我们将了解 TPU 拓扑结构、使用远程 DMA 原始操作进行通信,以及如何使用 jax.shard_map 从 JAX 调用分布式内核。我们还将介绍一些更高级的内核编写技术,例如双缓冲、双向带宽优化和嵌套流水线。作为教学示例,我们将学习如何实现 JAX 的各种集体操作,例如 lax.ppermutelax.all_gatherlax.psumlax.psum_scatter

一些推荐的预读材料

import functools
import jax
from jax import lax
from jax import numpy as jnp
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu

P = jax.sharding.PartitionSpec

num_devices = jax.local_device_count()
assert num_devices > 1, "Please run this notebook with more than one device."
assert "TPU" in jax.devices()[0].device_kind, "Please run this notebook with TPU devices."
print(f"Running with {num_devices} {jax.devices()[0].device_kind} devices.")
Running with 4 TPU v4 devices.

TPU 拓扑结构#

TPU 通常部署在由高速芯片间互连 (ICI) 连接的多个设备组成的集群中,用于集群内部的通信,其速度远超普通网络连接。例如,TPU v5p 的规格表显示,每颗芯片的 ICI 带宽为 4.8Tb/s(作为参考,TPU v5p 还拥有 21Tb/s 的*本地* HBM 带宽)。ICI 使我们能够实现需要集群内部高带宽通信的快速高效的分布式内核,并使用数据中心网络进行带宽要求较低的操作的并行化,例如批处理维上的数据并行。

TPU 集群通常排列成 ND 环形拓扑结构。下图展示了不同大小配置的几个示例。

tpu_topologies

将环形拓扑结构展平为图,可以可视化如下。每条边(橙色或黑色)是两个设备之间的双向连接。在讨论设备拓扑结构时,您经常会听到“环形”这个词——环形的一个关键特征是,当沿着集群的某个轴进行切片时,例如节点 [(0,1), (1, 1), (2, 1), (3, 1)][(0, 1), (1, 1)],我们得到一个设备环。这是我们可以用来简化集群内通信模式的一个特性。

tpu_torus

远程直接内存访问 (RDMA) 模型#

TPU 通过一种称为远程直接内存访问 (RDMA) 的仅推送模型进行通信。TPU 可以发出复制指令,将数据从本地缓冲区推送到同一集群中另一个设备上的任何缓冲区,此操作与主程序线程异步执行。但是,TPU 只能读取存储在本地的数据。这与更传统的*多核编程*不同,后者可以同时读写共享内存中的值。

异步远程复制操作#

函数 pltpu.make_async_remote_copy 用于创建远程 DMA 描述符对象,该对象参数化了“发送”操作和“接收”操作。以下是其签名:

 def make_async_remote_copy(
     src_ref: Ref,
     dst_ref: Ref,
     send_sem: Ref[SemaphoreType],
     recv_sem: Ref[SemaphoreType],
     device_id: int | tuple[int, ...],
     device_id_type: DeviceIdType
 ) -> AsyncCopyDescriptor:
  • src_ref 是包含您要发送到 dst_ref(位于另一设备上)的数据的本地 Ref(在任何内存空间中)。

  • dst_ref 是目标设备上数据将被复制到的远程 Ref(在任何内存空间中)。

  • send_sem 是一个 DMA 信号量,用于阻止程序执行,直到所有数据已从 src_ref 发送完毕。

  • recv_sem 是一个 DMA 信号量,用于阻止程序执行,直到在 dst_ref 处接收到预期数量的字节。DMA 发送方将写入接收方的 recv_sem

  • device_id 是要发送到的目标设备的设备 ID。

  • device_id_type 指定 device_id 的格式,可以是 LOGICAL 格式(整数设备 ID)或 MESH 格式(逻辑设备网格的 ND 元组索引)。默认模式为 MESH。

make_async_remote_copy 返回一个描述符对象,您可以使用其 .start() 方法来启动 DMA,使用 .wait_send() 来阻塞 send_sem,使用 .wait_recv() 来阻塞 recv_sem(或使用 .wait() 来同时阻塞两者)。如果设备只需要发送数据,只需调用 .start().wait_send() 即可;同样,如果设备只需要接收数据,只需调用 .wait_recv() 即可。如果使用 SPMD 模式,所有设备都执行 DMA,那么每个设备通常会同时调用 .start().wait()

dma_descriptor = make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id)
dma_descriptor.start() # Initiate the DMA (non-blocking).
# ... do other work
dma_descriptor.wait_send() # Block until all data has been sent.
dma_descriptor.wait_recv() # Block until all data has been received.

例如,让我们可视化一个 DMA 操作,考虑 4 个设备(索引为 0、1、2、3)。我们考虑一个方案,其中设备 0 复制到设备 1,设备 2 和 3 相互复制。实际上,我们可以通过使用 @pl.when 根据设备 ID 进行分支来创建这种不对称通信模式。

(1) 每个设备创建 DMA 描述符。设备 0、2 和 3 调用 .start() 来启动从 src_ref 的 DMA。设备 1 跳过 .start() 并且什么也不做,例如通过使用 pl.when

rdma_start

(2) 由于 .start() 是非阻塞的,每个设备都可以在 DMA 进行期间自由地执行其他计算。设备 0、2 和 3 调用 .wait_send() 来等待 send_sem,该信号量会阻塞直到所有数据都已发送。

rdma_send

(3) 最后,设备 1、2 和 3 将调用 .wait_recv() 来等待 recv_sem,直到所有数据都已到达 dst_ref

rdma_recv

上述通信模式可以这样编写:

def example_kernel(input_ref, output_ref, send_sem, recv_sem):
    device_id = lax.axis_index('x')
    copy_0_to_1 = pltpu.make_async_remote_copy(
        src_ref=input_ref,
        dst_ref=output_ref,
        send_sem=send_sem,
        recv_sem=recv_sem,
        device_id=1,
    )
    copy_2_to_3 = pltpu.make_async_remote_copy(
        src_ref=input_ref,
        dst_ref=output_ref,
        send_sem=send_sem,
        recv_sem=recv_sem,
        device_id=3,
    )
    copy_3_to_2 = pltpu.make_async_remote_copy(
        src_ref=input_ref,
        dst_ref=output_ref,
        send_sem=send_sem,
        recv_sem=recv_sem,
        device_id=2,
    )
    @pl.when(device_id == 0)
    def _():
      copy_0_to_1.start()
      copy_0_to_1.wait_send()
    @pl.when(device_id == 1)
    def _():
      copy_0_to_1.wait_recv()
    @pl.when(device_id == 2)
    def _():
      copy_2_to_3.start()
      copy_2_to_3.wait_send()
      copy_3_to_2.wait_recv()
    @pl.when(device_id == 3)
    def _():
      copy_3_to_2.start()
      copy_3_to_2.wait_send()
      copy_2_to_3.wait_recv()

DMA 信号量#

send_semrecv_sem 是特种信号量实例,专供 DMA 使用。在为 pallas_call 指定输入规范时,必须使用 tpu.SemaphoreType.DMA 类型进行分配。

内部而言,DMA 信号量可以被看作是整数值的进度跟踪器。在 DMA 开始时,本地设备将开始异步地增加 send_sem 和接收方 recv_sem 的值。等待信号量将阻塞,直到信号量的值达到已发送/接收数据的总字节数;当达到该值时,等待的线程将被释放,信号量的值将减少相同数量。这意味着要么所有数据都已发送(对于 send_sem),要么所有数据都已接收(对于 recv_sem)。信号量的值可以通过 pl.semaphore_read 读取,但请注意,值的底层语义可能在不同硬件代系之间发生变化(例如,该值可能不精确表示发送的字节数,尽管这是理解信号量行为的有用心智模型)。

路由#

发送方可以将数据发送到同一集群内的任何接收方,即使它们没有直接连接(此规则的例外是 TPU v5e,其中设备只能路由到与其自身有 2 的幂次偏移的设备)。TPU 具有内部路由机制,可以将数据传递到通往目标路径上的下一个设备。然而,不建议通过这种方式进行通信,因为作为内核编写者,您无法控制网络争用。本教程中我们将涵盖的示例通过仅将数据传输到相邻设备来最小化低效通信。

故障模式#

如果错误地使用远程 DMA,您可能会遇到几种难以调试的故障模式。DMA 使用错误的通用症状是崩溃、挂起或静默数据损坏。

  • 如果信号量以无效的非零值退出程序,Pallas 将崩溃并退出程序。

  • 如果等待信号量但接收的字节数不足(即没有发送方,或者发送的数据小于接收设备上 dst_ref 的大小),程序可能会无限期地挂起,等待从未发送过的字节。在这种情况下,程序需要重新启动。

  • 如果遇到竞态条件,由于同时发生的两个写入或一个同时的读写操作,可能会导致静默数据损坏。

上述情况的一些常见原因包括:

  • 如果某个设备调用 .wait_recv() 但没有其他设备发送数据给它,内核可能会挂起。

  • 如果发送给某个设备的字节数超过了它预期的接收量,它可能会因非零信号量状态而崩溃。如果发送的字节数较少,它可能会无限期挂起。

  • 如果启动了 DMA 但没有等待信号量,程序可能会因非零信号量状态而崩溃。

  • 如果两个设备复制到同一个目标,您可能会遇到非确定性结果(由于竞态条件),或者因非零信号量状态而崩溃。

示例:右置换 (lax.ppermute)#

让我们深入到一个非常基础的例子。我们将实现一个执行右置换的内核,其中每个设备将其数据切片发送给右侧的邻居。

假设我们有一个包含 512 个元素的数组,我们将其分成 4 个设备上大小为 128 的切片。每个设备将其切片传递给下一个设备,输出将是相同的数据,但切片旋转了 1。这与 lax.ppermute 操作相同,其中置换设置为 (n, (n+1) % 4)

为了在分布式模式下调用内核,我们将 pallas_call 包装在 shard_map 转换中。从那里,我们可以像编写普通单设备 Pallas 内核一样编写内核,只是我们现在可以访问远程 DMA 指令。JAX 的集体操作(如 lax.axis_index)可用于获取 device_id,该 ID 可用于计算要复制到的目标设备,方法是引用传递给 shard_map 的相同的命名轴名称。

partition = P(None, 'x')
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)

# Create an input array that shards the last dimension across
# all devices.
input_arr = jax.random.uniform(jax.random.key(0), (8, 128 * num_devices))
input_arr = jax.device_put(input_arr, sharding)


def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem):
  my_id = lax.axis_index('x')
  right_neighbor = lax.rem(my_id + 1, num_devices)
  remote_copy_op = pltpu.make_async_remote_copy(
      src_ref=input_ref,
      dst_ref=output_ref,
      send_sem=send_sem,
      recv_sem=recv_sem,
      device_id=(right_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )
  remote_copy_op.start()
  remote_copy_op.wait()


out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32)
grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=0,
    # MemorySpace.ANY will (usually) place the tensor in HBM.
    in_specs=[
        pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
    ],
    out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
    scratch_shapes=(
        # We allocate DMA semaphores in scratch memory.
        [pltpu.SemaphoreType.DMA] * 2
    ),
)
right_permute = pl.pallas_call(
    right_permute_kernel,
    out_shape=out_shape,
    grid_spec=grid_spec,
)
# Wrap the kernel within a shard_map to call.
pallas_result = jax.jit(
    jax.shard_map(
        right_permute,
        mesh=mesh,
        in_specs=partition,
        out_specs=partition,
        check_vma=False,
    )
)(input_arr)

# Compare Pallas result to XLA shard_map result.
perm = tuple((src, (src + 1) % num_devices) for src in range(num_devices))

xla_result = jax.jit(
    jax.shard_map(
        lambda x: lax.ppermute(x, 'x', perm),
        mesh=mesh, in_specs=partition, out_specs=partition)
)(input_arr)

print('Input = ', input_arr[0, ::128])
print('Pallas Result = ', pallas_result[0, ::128])
print('lax.ppermute Result = ', xla_result[0, ::128])
print(
    'Difference |Pallas - lax.ppermute| = ',
    jnp.mean(jnp.abs(pallas_result - xla_result)),
)
Input =  [0.9858954  0.11763906 0.9955574  0.775211  ]
Pallas Result =  [0.775211   0.9858954  0.11763906 0.9955574 ]
lax.ppermute Result =  [0.775211   0.9858954  0.11763906 0.9955574 ]
Difference |Pallas - lax.ppermute| =  0.0

示例:All-gather (lax.all_gather)#

在下一个示例中,我们将实现 all-gather 集体操作,它在 JAX 中具有等效操作 lax.all_gather。与上面仅涉及一对源和目标邻居的右置换示例相比,all-gather 操作需要所有设备之间的通信,因此我们必须考虑数据如何在它们之间路由。我们如何实现这一点取决于设备拓扑结构,我们假设它是环形的。

环形通信模式#

我们将编写我们的内核,假设采用环形拓扑结构。环形结构非常适合 TPU,因为沿着环形拓扑结构的任何维度进行切片都会产生一个环形。在编写集体操作时,我们通常只需要一次考虑环形拓扑结构的 1D 切片,因为环形拓扑结构的不同维度保留用于不同类型的并行化(例如,数据并行化与模型并行化)。

我们将使用的策略是编写一个循环内核,在每次迭代中,设备从其左侧邻居接收分片数组的一个切片,并将之前接收的切片复制到右侧邻居。在 num_devices 次迭代后,每个设备都将在其本地 HBM 中拥有整个数组的副本。

all_gather

我们可以重新利用 Pallas 的 grid 参数来实现循环。与我们之前教程中迭代数组的瓦片不同,我们现在将 grid 设置为 (num_devices,),以指示我们要循环设备的数量,并在 Pallas 内核内部使用 pl.program_id 来获取循环迭代。以下代码片段演示了如何实现此目的:

partition = P('x', None)
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)

# Create an input array that shards the first dimension across
# all devices.
input_arr = jax.random.uniform(jax.random.key(0), (8 * num_devices, 128))
input_arr = jax.device_put(input_arr, sharding)


def all_gather_kernel(input_ref,
                      output_ref,
                      local_copy_sem,
                      send_sem,
                      recv_sems):
  outer_step = pl.program_id(0)
  my_id = lax.axis_index('x')
  right_neighbor = lax.rem(my_id + 1, num_devices)
  copy_slot = my_id - outer_step
  copy_slot = lax.rem(copy_slot + num_devices, num_devices)

  @pl.when(outer_step == 0)
  def _():
    local_copy_op = pltpu.make_async_copy(
      src_ref=input_ref,
      dst_ref=output_ref.at[my_id],
      sem=local_copy_sem,
    )
    local_copy_op.start()
    local_copy_op.wait()

  # Copy to our right neighbor.
  # Note that we will also be receiving data from our left neighbor,
  # but at `copy_slot-1` rather than `copy_slot`! This makes use of the fact
  # that the indices do not need to be symmetric between remote DMAs.
  remote_copy_op = pltpu.make_async_remote_copy(
      src_ref=output_ref.at[copy_slot],
      dst_ref=output_ref.at[copy_slot],
      send_sem=send_sem,
      recv_sem=recv_sems.at[outer_step],
      device_id=(right_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )
  remote_copy_op.start()
  remote_copy_op.wait()

out_shape = jax.ShapeDtypeStruct((num_devices, 8, 128), jnp.float32)
grid_spec = pltpu.PrefetchScalarGridSpec(
            num_scalar_prefetch=0,
            in_specs=[
                # MemorySpace.ANY will (usually) place the tensor in HBM.
                pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
            ],
            out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
            scratch_shapes=(
              # DMA semaphores are allocated in scratch memory.
              # We allocated one semaphore for a local HBM-VMEM copy,
              # and one for the remote send semaphore.
              [pltpu.SemaphoreType.DMA] * 2
              # We additionally allocate one receive semaphore per device.
              # This is to avoid situations where we have multiple
              # DMAs in flight, as we do not want to share a receive
              # semaphore between the DMAs.
              + [pltpu.SemaphoreType.DMA((num_devices-1,))]

            ),
            grid=(num_devices-1,)
        )

all_gather = pl.pallas_call(
      all_gather_kernel,
      out_shape=out_shape,
      grid_spec=grid_spec,
  )

# Wrap the kernel within a shard_map to call.
pallas_result = jax.jit(
      jax.shard_map(
          all_gather,
          mesh=mesh,
          in_specs=partition,
          out_specs=partition,
          check_vma=False
      )
)(input_arr)

# Compare Pallas result to XLA shard_map result.
xla_result = jax.jit(
    jax.shard_map(
        lambda x: lax.all_gather(x, 'x'),
        mesh=mesh, in_specs=partition, out_specs=partition
    )
)(input_arr)

print('Input: ', input_arr.shape, input_arr[::8, 0])
print('Pallas Result: ', pallas_result.shape, pallas_result[:, 0, 0])
print('lax.all_gather Result: ', xla_result.shape, xla_result[:, 0, 0])
print('Difference |Pallas - lax.all_gather| = ',
      jnp.mean(jnp.abs(pallas_result - xla_result)))
Input:  (32, 128) [0.9858954  0.54248166 0.9547038  0.954962  ]
Pallas Result:  (16, 8, 128) [0.9858954  0.54248166 0.9547038  0.954962   0.9858954  0.54248166
 0.9547038  0.954962   0.9858954  0.54248166 0.9547038  0.954962
 0.9858954  0.54248166 0.9547038  0.954962  ]
lax.all_gather Result:  (16, 8, 128) [0.9858954  0.54248166 0.9547038  0.954962   0.9858954  0.54248166
 0.9547038  0.954962   0.9858954  0.54248166 0.9547038  0.954962
 0.9858954  0.54248166 0.9547038  0.954962  ]
Difference |Pallas - lax.all_gather| =  0.0

这里值得一提的一个细节是使用多个接收信号量。因为我们只在接收设备上阻塞,发送方仍然可能在接收方完成处理第一个 DMA 之前有多个 DMA 处于飞行状态(请参阅下一节和 reduce-sum 示例,其中更详细地讨论了竞态条件)。在这种情况下,我们可能会遇到同一个信号量用于多个同时发生的 DMA 的情况。为避免此问题,我们分配 num_devices-1 个信号量,以避免重用风险。虽然在如此小的内核上不太可能发生此竞态条件,但在较大的内核中,设备可能更容易不同步并可能导致静默失败。

高级技术#

现在我们已经了解了如何使用远程 DMA 操作编写几个基本内核,我们将介绍用于同步和编写高效内核的更高级技术。

同步:常规信号量和屏障信号量#

我们在基础教程中实现的示例不需要特殊同步处理,因为所有必要的通信都写入不相交的缓冲区。但是,其他操作可能需要更复杂的通信模式,需要额外的同步原语来避免竞态条件。Pallas 提供了两种额外的原语来帮助解决此问题:常规信号量和屏障信号量。

常规信号量#

常规信号量是用于跨多个设备进行同步的标准工具。信号量本质上是计数器——任何设备都可以对其进行递增,之后设备可以阻塞直到信号量的值达到特定值(然后递减该值)。

可以对常规信号量使用的三个主要操作是:信号、等待和读取。

def semaphore_signal(
    sem: Ref[SemaphoreType],
    inc: int,
    device_id: int | tuple[int, ...],
    device_id_type: DeviceIdType
) -> None:
  ... # Increments the semaphore `sem` on the target device `device_id` by `inc`.
  
def semaphore_wait(
    semaphore: Ref[SemaphoreType],
    value: int,
) -> None:
  ... # Blocks until the locally allocated copy of `sem` reaches `value`, then decrement by `value` and proceed.
    
def semaphore_read(
    sem: Ref[SemaphoreType],
) -> jax.Array:
  ...  # Returns the current value of `sem` as an `int32[]`.

为了使用常规信号量,它们可以像 DMA 信号量一样分配,但通过指定 pltpu.SemaphoreType.REGULAR 而不是 pltpu.SemaphoreType.DMA

信号量在 Pallas 程序结束时必须为零才能成功完成。有两个错误情况可能发生:

  • 如果信号量被过度信号,程序将在结束时出现非零(>0)信号量。在这种情况下,程序将在完成时崩溃。这对于调试很有用,因为非零信号量通常意味着程序内部存在错误。

  • 如果信号量被过度等待,程序将在阻塞的 semaphore_wait 调用中挂起,直到信号量被递增。在这种情况下,设备或程序需要重新启动。

屏障信号量#

屏障信号量是全局分配的信号量,用于同步整个程序中的设备,并确保所有设备都已进入 Pallas 内核。

如果 Pallas 内核在更大的 XLA 程序上下文中执行,我们需要确保所有通信的设备都已进入内核。但是,DMA 和常规信号量都是本地范围的——它们仅被已进入内核的其他设备理解。屏障信号量充当全局理解的信号量,可用于同步,无论设备当前在 XLA 程序中的哪个位置执行。

默认情况下,如果您不指定屏障信号量,Pallas 将在您的程序开始时自动插入一个屏障信号量。但是,编写自己的屏障信号量可能更有效。屏障信号量类似于常规信号量,因为它们是可以通过 semaphore_signal 递增,并通过 semaphore_wait 递减的计数器。它们是通过在内核中调用 get_barrier_semaphore() 创建的。通常,我们会在内核开始时使用一次屏障来与所有我们要通信的设备进行同步。

from jax.experimental.pallas import tpu as pltpu

def example_kernel(...):
  # Use barrier semaphores at the beginning of a kernel.
  # is_start_of_kernel = ...
  # right_neighbor = ...
  # ...
  @pl.when(is_start_of_kernel)
  def _():
    barrier_sem = pltpu.get_barrier_semaphore()
    # Increment the semaphore of your right neighbor.
    pltpu.semaphore_signal(
          barrier_sem,
          device_id=right_neighbor,
          device_id_type=pltpu.DeviceIdType.LOGICAL,
    )
    # Wait until your left neighbor has incremented your semaphore
    pltpu.semaphore_wait(barrier_sem, 1)
  # ...

使用屏障信号量时,必须将 collective_id 编译器参数传递给 pallas_call,以指定正在使用哪个屏障信号量。TPU 具有少量固定的屏障信号量(通常在 20-30 左右),因此应谨慎使用它们。为了确保正确性,只有具有相同通信模式的内核才应使用相同的 collective_id。例如,如果两个内核仅与同一网格轴上的邻居同步,则允许它们共享同一个 collective_id。但是,如果两个内核沿不同轴同步,则它们必须具有不同的 collective_id。否则可能导致难以调试的竞态条件。

kernel = pl.pallas_call(
      example_kernel,
      ...,
      compiler_params=pltpu.CompilerParams(collective_id=0),
)

双缓冲#

为了避免读取另一个设备正在写入的本地 Ref 并导致竞态条件,一种有用的技术是“双缓冲”策略,即我们为每个目标值分配两个 Ref。在每次迭代中,一个 Ref 将被指定为“工作”槽,另一个将被指定为“接收”槽。设备可以自由地使用工作槽进行计算,但只会将数据复制到其邻居的接收槽。工作槽和接收槽每迭代一次都会交替,因此一旦复制完成,旧的接收槽就成为新的工作槽,反之亦然。通过正确使用此方案,数据永远不会在同一个缓冲区中读取和写入。

以下代码骨架演示了如何使用双缓冲。我们将迭代计数器保留在变量 iteration 中,并且 working_slotreceiving_slot 在每次迭代时都在 0 和 1 之间交替。 dst_ref 被分配为双缓冲,大小为 [2, ...]。在每次迭代中,我们从工作槽读取 dst_ref.at[working_slot, ...] 并使用该值进行计算。同时,我们将其复制到我们邻居的 dst_ref.at[receiving_slot],以避免覆盖其 working_slot 值。通过以这种方式构建我们的通信,可以重叠远程 DMA 的通信延迟与本地计算,同时最小化竞态条件的风险。

def kernel(...):
  # ...
  iteration = pl.program_id(0)
  working_slot = lax.rem(iteration, 2)
  receiving_slot = 1 - working_slot
  # ...

  local_copy_op = pltpu.make_async_copy(
    src_ref=dst_ref.at[working_slot, ...],
    dst_ref=local_scratch_ref,
    sem=local_copy_sem,
  )
  local_copy_op.start()
  remote_copy_op = pltpu.make_async_remote_copy(
    src_ref=src_ref,
    dst_ref=dst_ref.at[receiving_slot, ...],
    send_sem=send_sem,
    recv_sem=recv_sem,
    device_id=target_device,
    device_id_type=pltpu.DeviceIdType.MESH,
  )
  remote_copy_op.start()
  
  local_copy_op.wait()
  # ... do work on local_scratch while waiting for async_copy_op to finish.
  remote_copy_op.wait()

在同步方面,双缓冲结构在所有设备都执行相同迭代时有效。如果发送方比接收方提前一个迭代,那么它的 working_slotreceiving_slot 索引将与接收方相反,这意味着它可能正在写入 working_slot 的同时接收方正在从中读取。为避免这种情况,可能需要使用信号量来同步发送方和接收方,或者添加额外的缓冲槽(“三缓冲”、“四缓冲”或 N 缓冲)来允许额外的超前执行,但会增加内存消耗。在我们之前的 all_gather 示例中,请注意该内核包含一个具有 N 个槽的接收缓冲区,这完全避免了竞态条件。在我们下一个内核中,我们将改为通过一个使用双缓冲并带有显式同步的示例。

示例:All-Reduce Sum (lax.psum)#

我们将使用双缓冲和信号量进行同步来实现一个 all-reduce sum 内核。对于熟悉 JAX 中集体操作的人来说,等效操作是 lax.psum。All-reduce 是一种标准的集体操作,其目标是沿数组的某个轴进行规约,但数组被分片到多个设备上。

reduce_sum_1

在上面的示例中,我们有一个数组 [5, 2, 1, 3] 分片到 4 个设备上。all-reduce sum 操作将对所有值求和,并将结果复制到每个设备上,最终结果是 [11, 11, 11, 11] 分片到所有 4 个设备上。

All-reduce 的朴素实现是将所有必需的值收集到每个设备上,然后进行规约。但是,我们可以通过将通信与计算交织在一起来提高此实现的性能。可以这样可视化交织的、单向的 all-reduce。在每次迭代中,我们从左侧邻居接收输入值,同时将输入传递给下一个邻居,并将其与本地累加器相加。经过 N-1 次迭代后,每个设备将在其内存中拥有完整总和的副本。

reduce_sum_2

综合应用#

以下内核演示了如何将这些原理组合成一个功能性内核。

序言(在 outer_step==0 时执行)首先与两个邻居发起屏障,以确保它们也已进入内核。它还处理所有 Ref 的初始化,并处理到右侧邻居“工作”槽的第一次远程复制。

主体假定一个值已复制到我们的本地工作槽中,该值要么来自前一次迭代,要么来自序言。一个复杂因素是我们的目标缓冲区位于 HBM 中,但在执行算术之前,我们需要将数据加载到 VMEM。因此,我们将工作槽值同时复制到我们的 VMEM(receive_scratch)并将其传递给邻居的接收槽。一旦值被复制到我们的 VMEM,我们就可以将其累加到我们的结果中(包含在 o_ref 中)。

如果一个设备比其右侧邻居提前一个循环,则可能会发生细微的竞态条件。在这种情况下,它可能在接收方从其 working_slot 读取的同时写入接收方的 working_slot。为避免此问题,每个设备在复制到右侧邻居的 dst_ref 之前,都会阻塞 REGULAR 信号量,直到它发出信号表明已完成对其 working_slot 的读取。对于像这个例子这样小的内核,很少会触发此竞态条件,但如果使用 pltpu.delay 指令人为地挂起设备,则可以明确触发它。

请注意,这不是一个最优或完全通用的内核,因为块大小必须完全适合 VMEM,而且我们可以更好地交织通信和累加。我们将在后面的部分讨论这些优化。

partition = P(None, 'x')
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)

input_arr = jax.random.uniform(jax.random.key(0), shape=(8, 128 * num_devices))
input_arr = jax.device_put(input_arr, sharding)


def local_barrier(left_neighbor, right_neighbor, double_barrier=True):
  """Performs a barrier with neighbors on the global barrier semaphore.

  Optionally performs a second barrier, which prevents a potential race
  when reusing the same collective_id across kernel invocations.
  """
  barrier_sem = pltpu.get_barrier_semaphore()
  for neighbor in [left_neighbor, right_neighbor]:
    pltpu.semaphore_signal(
      barrier_sem,
      inc=1,
      device_id=(neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
    )
  pltpu.semaphore_wait(barrier_sem, 2)
  if double_barrier:
    # The double-barrier prevents a race condition where one neighbor can
    # re-enter the kernel again on a subsequent call and increment the
    # barrier semaphore a second time. This would unblock the current device
    # even if the other neighbor is not ready yet.
    # To implement a double-barrier, we stack-allocate a second REGULAR
    # semaphore using run_scoped.
    @functools.partial(pl.run_scoped,
                       second_barrier=pltpu.SemaphoreType.REGULAR)
    def _(second_barrier):
      for neighbor in [left_neighbor, right_neighbor]:
        pltpu.semaphore_signal(
          second_barrier,
          inc=1,
          device_id=(neighbor,),
          device_id_type=pltpu.DeviceIdType.MESH,
        )
      pltpu.semaphore_wait(second_barrier, 2)


def all_reduce_kernel(
    x_ref,
    o_ref,
    hbm_scratch,
    copy_sem,
    remote_recv_sem,
    remote_send_sem,
    capacity_sem,
    receive_scratch,
):
  outer_step = pl.program_id(0)
  working_slot = lax.rem(outer_step, 2)
  receiving_slot = 1 - working_slot

  my_id = lax.axis_index('x')
  right_neighbor = lax.rem(my_id + 1, num_devices)
  left_neighbor = lax.rem(my_id - 1 + num_devices, num_devices)

  @pl.when(outer_step == 0)
  def _():
    # Barrier with both neighbors at the start, since we will be
    # communicating with both.
    local_barrier(left_neighbor, right_neighbor)

    # Initialize o_ref, acc_scratch, and hbm_scratch.
    o_ref[...] = jnp.zeros_like(o_ref)
    receive_scratch[...] = jnp.zeros_like(receive_scratch)
    initial_copy = pltpu.make_async_remote_copy(
        src_ref=x_ref,
        dst_ref=hbm_scratch.at[working_slot],
        send_sem=remote_send_sem,
        recv_sem=remote_recv_sem,
        device_id=(right_neighbor,),
        device_id_type=pltpu.DeviceIdType.MESH,
    )
    initial_copy.start()
    initial_copy.wait()

  # Signal to our left neighbor that we are ready to receive.
  # Without this signal, our left neighbor can be >=1 iteration ahead,
  # meaning it could write into our working slot.
  pltpu.semaphore_signal(
      capacity_sem,
      inc=1,
      device_id=(left_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  # Copy the partial result our left neighbor sent to us into VMEM for
  # computation.
  local_copy = pltpu.make_async_copy(
      src_ref=hbm_scratch.at[working_slot],
      dst_ref=receive_scratch,
      sem=copy_sem,
  )
  local_copy.start()

  # Block until our right neighbor is ready to receive.
  pltpu.semaphore_wait(capacity_sem, 1)
  # Pass the value to our right neighbor.
  remote_copy = pltpu.make_async_remote_copy(
      src_ref=hbm_scratch.at[working_slot],
      dst_ref=hbm_scratch.at[receiving_slot],
      send_sem=remote_send_sem,
      recv_sem=remote_recv_sem,
      device_id=(right_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )
  remote_copy.start()
  # Finish local copy and accumulate while remote_copy is happening.
  local_copy.wait()
  o_ref[...] += receive_scratch[...]
  # Block until remote copy finishes.
  remote_copy.wait()


out_shape = (
    jax.ShapeDtypeStruct((8, 128), jnp.float32),
    # We allocate the double-buffer as a Pallas output so that it is
    # resident in HBM.
    jax.ShapeDtypeStruct((2, 8, 128), jnp.float32),  # hbm_scratch
)

grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=0,
    in_specs=[
        # Our input lives in VMEM
        pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),
    ],
    out_specs=[
        # Our output lives in VMEM
        pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),
        # Our double-buffer lives in HBM
        pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
    ],
    grid=(num_devices,),
    scratch_shapes=(
        [pltpu.SemaphoreType.DMA] * 3
        + [pltpu.SemaphoreType.REGULAR]  # capacity_sem
        + [pltpu.VMEM((8, 128), jnp.float32)]  # receive_scratch
    ),
)

kernel = pl.pallas_call(
    all_reduce_kernel,
    out_shape=out_shape,
    grid_spec=grid_spec,
    compiler_params=pltpu.CompilerParams(collective_id=0),
)

pallas_result = jax.jit(
    jax.shard_map(
        kernel,
        mesh=mesh,
        in_specs=partition,
        out_specs=partition,
        check_vma=False,
    )
)(input_arr)
pallas_result = jax.block_until_ready(pallas_result)[0]


def lax_sum(x):
  return lax.psum(x, 'x')


xla_result = jax.jit(
    jax.shard_map(
        lax_sum, mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x')
    )
)(input_arr)

print('Input = ', input_arr[0, ::128])
print('Pallas result = ', pallas_result[0, ::128])
print('lax.psum result = ', xla_result[0, ::128])
difference = jnp.mean(jnp.abs(pallas_result - xla_result))
print('Difference |Pallas - lax.psum| = ', difference)
Input =  [0.9858954  0.11763906 0.9955574  0.775211  ]
Pallas result =  [2.8743029 2.8743029 2.8743029 2.8743029]
lax.psum result =  [2.8743029 2.8743029 2.8743029 2.8743029]
Difference |Pallas - lax.psum| =  1.0535587e-08

超前执行与竞态条件#

作为一项经验法则,为了最大限度地提高性能,我们希望允许设备在不牺牲程序正确性的前提下,尽可能多地在没有同步的情况下超前于其他设备运行。虽然我们可以在每次迭代开始时强制所有设备进行屏障同步,但这会将程序的性能瓶颈限制在每个循环中最慢的设备上。通过放宽同步并允许适度的超前执行,我们可以更好地适应每次迭代和设备之间延迟的差异,因为在一个迭代中速度较慢的设备可以在下一个迭代中赶上来。

在我们之前编写的 all-reduce 内核中,我们允许设备超前,但比其邻居少一个迭代(但是,非邻居设备可能相差一个多迭代)。要了解为什么需要信号量同步,请考虑一个设备(例如设备 2)挂起并落后于其他设备的情况。RDMA 没有“握手”——只有接收方在等待数据到达时被阻塞。因此,每个设备可以超前最多一个迭代,然后再阻塞等待下一个 RDMA 到达。如果我们有 N 个设备,这意味着最后一个设备可以比第一个设备超前 N 个迭代。

race_condition

如果不向另一个方向添加同步(强制发送方阻塞),设备 1 可能会超前设备 2 最多 N 个迭代(N = num_devices),从而在过程中写入多个值并覆盖现有数据。为解决此问题,在我们之前编写的 all_reduce 内核中,我们实现了一个“握手”协议,其中接收方向发送方发出信号表明它已准备好接收,然后发送方才开始发出下一个 RDMA。

双向通信#

在我们之前的内核中,我们沿着环形从左到右单向通信。然而,由于 ICI 连接是双向的,我们通过不沿相反方向(从右到左)发送值,实际上浪费了一半的总带宽。在下一个内核中,我们将演示一个双向通信以最大化 ICI 带宽的示例。

示例:双向 Reduce-Scatter (lax.psum_scatter)#

reduce-scatter 操作是 all-reduce 和 scatter 的组合。或者,all-reduce 是 reduce-scatter 和 all-gather 的组合。

下图描绘了此操作的语义。我们假设每个设备都从一个部分和集合开始(用字母+数字表示,例如 A0)。目标是沿一个轴(数字)进行规约,同时沿另一个轴(字母)进行分片。

reduce_scatter_1

为了实现双向通信策略,我们将每个输入块分成两半,并为每一半指定一个方向。每个块的上半部分将从右到左传递,下半部分将从左到右传递。与我们之前的 all-reduce 和 all-gather 内核通信模式的第二个不同之处在于,我们还将传递累加器或部分和,并将输入保留在每个设备的本地。这与之前的示例不同,在那些示例中我们传递输入但将累加器保留在本地设备上。传递累加器更适合此问题,因为与 all-reduce 不同,输入中的大部分数据不是设备本地将要存储的输出的一部分。(例如,上面图中的 B0C0D0 将不会存储在持有 A 的设备上)。

下图说明了这种通信模式,其中彩色框代表累加器(不是输入!)。最初,累加器只是输入中包含的值。在算法的每次迭代中,我们将从左右邻居接收一个部分和。然后,我们计算输入中需要累加到部分缓冲区中的正确切片,然后将新的部分和传递给下一个邻居。经过 N 次迭代后,累加器将经过每个设备,这意味着它最终将包含完整的总和。

reduce_scatter_2

在内核构建方面,我们引入了一个额外的 phase 维度到 Pallas 网格中,该维度表示我们当前正在计算哪个累加器(左或右)。我们将 phase=0 表示累加器向左移动,phase=1 表示累加器向右移动。然后我们对这两个阶段进行流水线化,这样在计算一个阶段的结果时,我们正在将先前计算的值在相反方向上传输,为下一个阶段做准备。例如,当我们在 phase=0(左)时,我们首先开始一个 DMA 将我们在前一次迭代中计算的结果传输到我们的右邻居(右 DMA)。然后,我们累加到左缓冲区并将结果保存到 HBM。然后我们等待右 DMA 完成,以便为 phase=1(右)做好准备。

partition = P(None, 'x')
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)

# We need a block size of (16, 128) to ensure that a half-slice is at least
# of size (8, 128), which is the size of a VREG. This makes tiling easier
# for the compiler.
block_size = (16, 128)
input_arr = jax.random.uniform(
    jax.random.key(0),
    shape=(block_size[0] * num_devices, block_size[1] * num_devices),
)
input_arr = jax.device_put(input_arr, sharding)

LEFT = 0
RIGHT = 1


def mod(x, n):
  return lax.rem(x + n, n)


def signal(left_or_right, semaphore):
  my_id = lax.axis_index('x')
  if left_or_right == LEFT:
    neighbor = mod(my_id - 1, num_devices)
  else:
    neighbor = mod(my_id + 1, num_devices)
  pltpu.semaphore_signal(
      semaphore,
      inc=1,
      device_id=(neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )


def reduce_scatter_kernel(
    x_ref,
    o_ref,
    hbm_scratch,
    local_copy_sem,
    left_recv_sem,
    left_send_sem,
    right_recv_sem,
    right_send_sem,
    left_capacity_sem,
    right_capacity_sem,
    accum_scratch,
):
  outer_step = pl.program_id(0)
  phase = pl.program_id(1)
  is_start = jnp.logical_and(outer_step == 0, phase == 0)
  last_iteration = outer_step == pl.num_programs(0) - 1

  working_slot = lax.rem(outer_step, 2)
  receiving_slot = 1 - working_slot
  my_id = lax.axis_index('x')
  right_neighbor = mod(my_id + 1, num_devices)
  left_neighbor = mod(my_id - 1, num_devices)

  left_copy_device = mod(my_id + outer_step + 1, num_devices)
  right_copy_device = mod(my_id - outer_step - 1, num_devices)
  # Slices can be specified using pl.ds(start, size)
  left_copy_slice = pl.ds(0, block_size[0] // 2)
  right_copy_slice = pl.ds(block_size[0] // 2, block_size[0] // 2)
  current_phase_slice = pl.ds(phase * (block_size[0] // 2), block_size[0] // 2)

  initial_left_copy = pltpu.make_async_remote_copy(
      src_ref=x_ref.at[my_id, left_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, left_copy_slice],
      send_sem=left_send_sem,
      recv_sem=left_recv_sem,
      device_id=(left_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  initial_right_copy = pltpu.make_async_remote_copy(
      src_ref=x_ref.at[my_id, right_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
      send_sem=right_send_sem,
      recv_sem=right_recv_sem,
      device_id=(right_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  left_copy = pltpu.make_async_remote_copy(
      src_ref=hbm_scratch.at[working_slot, left_copy_slice],
      dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],
      send_sem=left_send_sem,
      recv_sem=left_recv_sem,
      device_id=(left_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )
  right_copy = pltpu.make_async_remote_copy(
      # Note: Right copy is flipped with regards to slots since we are copying
      # to the next outer_step iteration.
      src_ref=hbm_scratch.at[receiving_slot, right_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
      send_sem=right_send_sem,
      recv_sem=right_recv_sem,
      device_id=(right_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  # --- Prologue ---
  @pl.when(is_start)
  def _():
    # Barrier with both neighbors at the start, since we will be
    # communicating with both.
    local_barrier(left_neighbor, right_neighbor)

    # Initialize o_ref, acc_scratch, and hbm_scratch with initial copies.
    o_ref[...] = jnp.zeros_like(o_ref[...])
    accum_scratch[...] = jnp.zeros_like(accum_scratch[...])

    initial_left_copy.start()
    initial_left_copy.wait()
    initial_right_copy.start()

    # We tell our left neighbor that it is allowed to send to the right.
    # (and vice versa for right neighbor)
    signal(LEFT, right_capacity_sem)
    signal(RIGHT, left_capacity_sem)

  # --- Body ---
  # At the beginning of our kernel body, we start a DMA which copies
  # the result we computed in the previous phase to our neighbor.
  # This allows us to overlap the communication of sending our previous phase
  # with the computation for the current phase.
  @pl.when(~is_start)
  def _():
    @pl.when(phase == LEFT)
    def _():
      # We block here until our right neighbor tells use we can send to
      # the right.
      pltpu.semaphore_wait(right_capacity_sem, 1)
      right_copy.start()

    @pl.when(phase == RIGHT)
    def _():
      # We block here until our left neighbor tells use we can send to
      # the left.
      pltpu.semaphore_wait(left_capacity_sem, 1)
      left_copy.start()

  local_copy = pltpu.make_async_copy(
      src_ref=hbm_scratch.at[working_slot, current_phase_slice],
      dst_ref=accum_scratch,
      sem=local_copy_sem,
  )
  local_copy.start()
  local_copy.wait()

  @pl.when(~last_iteration)
  def _():
    @pl.when(phase == LEFT)
    def _():
      accum_scratch[...] += x_ref[left_copy_device, left_copy_slice]

    @pl.when(phase == RIGHT)
    def _():
      accum_scratch[...] += x_ref[right_copy_device, right_copy_slice]

  local_copy = pltpu.make_async_copy(
      src_ref=accum_scratch,
      dst_ref=hbm_scratch.at[working_slot, current_phase_slice],
      sem=local_copy_sem,
  )
  local_copy.start()
  local_copy.wait()

  @pl.when(is_start)
  def _():
    initial_right_copy.wait()

  # At the end of our kernel body, we wait on the DMA of the previous phase
  # to make sure the results are ready for the next phase.
  @pl.when(~is_start)
  def _():
    @pl.when(phase == LEFT)
    def _():
      right_copy.wait()
      signal(LEFT, right_capacity_sem)

    @pl.when(phase == RIGHT)
    def _():
      left_copy.wait()
      signal(RIGHT, left_capacity_sem)

  # --- Epilogue ---
  # Store result on last iteration.
  @pl.when(last_iteration)
  def _():
    # Clean up semaphores so that they exit with a value of 0.
    @pl.when(phase == LEFT)
    def _():
      o_ref[left_copy_slice, ...] = accum_scratch[...]
      pltpu.semaphore_wait(right_capacity_sem, 1)

    @pl.when(phase == RIGHT)
    def _():
      o_ref[right_copy_slice, ...] = accum_scratch[...]
      pltpu.semaphore_wait(left_capacity_sem, 1)


out_shape = (
    jax.ShapeDtypeStruct((block_size[0], block_size[1]), jnp.float32),  # output
    # Shape: [working/recv, block[0], block[1]]
    jax.ShapeDtypeStruct(
        (2, block_size[0], block_size[1]), jnp.float32
    ),  # hbm_scratch
)

grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=0,
    in_specs=[
        pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),
    ],
    out_specs=[
        pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),
        pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
    ],
    grid=(num_devices, 2),
    scratch_shapes=(
        [pltpu.SemaphoreType.DMA] * 5
        + [pltpu.SemaphoreType.REGULAR] * 2  # Capacity semaphores
        + [
            pltpu.VMEM((block_size[0] // 2, block_size[1]), jnp.float32)
        ]  # accum_scratch
    ),
)


def pallas_reduce_scatter(input_arr):
  input_arr = input_arr.reshape(num_devices, block_size[0], block_size[1])
  return pl.pallas_call(
      reduce_scatter_kernel,
      out_shape=out_shape,
      grid_spec=grid_spec,
      compiler_params=pltpu.CompilerParams(collective_id=0),
  )(input_arr)[0]


pallas_result = jax.jit(
    jax.shard_map(
        pallas_reduce_scatter,
        mesh=mesh,
        in_specs=P(None, 'x'),
        out_specs=P('x', None),
        check_vma=False,
    )
)(input_arr)

pallas_result = jax.block_until_ready(pallas_result)
# Compare our result to XLA.
def lax_reduce_sum_scatter(x):
  x = x.reshape(num_devices, block_size[0], block_size[1])
  return lax.psum_scatter(x, 'x')


xla_result = jax.jit(
    jax.shard_map(
        lax_reduce_sum_scatter,
        mesh=mesh,
        in_specs=P(None, 'x'),
        out_specs=P('x', None),
    )
)(input_arr)

print('Input:', input_arr.shape, input_arr[::4, 0])
print('Pallas Result:', pallas_result.shape, pallas_result[::4, 0])
print('lax.psum_scatter Result:', xla_result.shape, xla_result[::4, 0])
print(
    'Difference |Pallas - lax.psum_scatter|:',
    jnp.max(jnp.abs(pallas_result - xla_result)),
)
Input: (64, 512) [0.78051674 0.3524047  0.59993696 0.9714314  0.24692321 0.01347649
 0.01857424 0.24841607 0.86097646 0.8261659  0.9753758  0.6902338
 0.4431417  0.963323   0.3158517  0.535548  ]
Pallas Result: (64, 128) [1.3593563 1.6274805 1.0979297 3.082869  1.4194957 1.4163033 1.2401303
 1.1892898 2.6545286 2.221559  2.7995253 2.08431   2.2509837 3.0726733
 2.4662397 1.9542246]
lax.psum_scatter Result: (64, 128) [1.3593563 1.6274805 1.0979297 3.082869  1.4194957 1.4163033 1.2401303
 1.1892898 2.6545286 2.221559  2.7995253 2.08431   2.2509837 3.0726733
 2.4662397 1.9542246]
Difference |Pallas - lax.psum_scatter|: 2.3841858e-07

嵌套远程和本地 DMA 流水线#

我们之前编写的 all-reduce 和 reduce-scatter 内核的一个限制是,通过远程 DMA 复制的块必须足够小,能够放入我们用于累加的工作 VMEM 中。对于某些内核,使用更大的块大小以更好地利用 TPU 可能是有利的。例如,矩阵乘法需要约 \(O(N^3)\) 的计算操作,但仅 \(O(N^2)\) 的内存传输。因此,我们希望设备之间传输的每个工作块都足够大,以便操作成为计算密集型,并且我们可以通过流水线隐藏通信成本。作为参考,TPU 的 VMEM(对于 v4/v5 代)通常在 10-100MB 范围内,而 HBM 范围在 10-100GB 范围内。

为了解决这个问题,我们需要能够编写一个“内部内核”来处理远程 HBM-HBM 传输的“外部内核”中的本地 HBM-VMEM 流水线。Pallas 提供了用于使用 emit_pipeline 函数构建嵌套流水线的 API。有关 emit_pipeline 的通用概述,请参阅 TPU 流水线 指南。由于我们的外部内核仅涉及远程 HBM-HBM 传输,因此我们没有使用 pallas_call 为 HBM-VMEM 传输提供的任何内置流水线。以下代码骨架演示了使用此模式的典型程序结构:


def outer_kernel(...):
  # ... do work to pipeline remote HBM-HBM transfers (outer kernel)

  def inner_kernel(...):
    # ... do work (inner kernel)
  pltpu.emit_pipeline(
          inner_kernel,
          grid=inner_grid,
          in_specs=...,
          out_specs=...,
  )(inner_kernel_args)
  # ... do more work (outer kernel)

pl.pallas_call(
  outer_kernel,
  grid=outer_grid,
  in_specs=...
  out_specs=...
  scratch=inner_kernel_allocs
)

示例:大 HBM 块的 Reduce-Scatter#

在下一个示例中,我们将修改我们之前的 reduce-scatter 示例以利用嵌套的内部流水线。请注意,reduce_scatter 的通信和计算成本都随输入大小线性增长,因此我们不一定期望操作在较大的块大小下变得计算密集。此示例纯粹是为了演示如何使用流水线发射器。

我们将增加外部内核的块大小,使其不适合放入 VMEM,并将所有输入和输出分配到 HBM(memory_space=MemorySpace.ANY)。与我们之前的内核相比,唯一的重大变化是内核的主体,其中进行了累加。我们不手动从 HBM 复制到 VMEM、累加和复制回 HBM,而是使用 emit_pipeline 来处理内存传输。累加是在一个具有更小、适合 VMEM 的块大小的内部内核中完成的。

在我们之前的内核中,我们有以下内核主体,用于将数据从 HBM 复制到 VMEM 累加器、增加值,然后将结果复制回 HBM:

local_copy = pltpu.make_async_copy(
    src_ref=hbm_scratch.at[working_slot, current_phase_slice],
    dst_ref=accum_scratch,
    sem=local_copy_sem,
)
local_copy.start()
local_copy.wait()
@pl.when(~last_iteration)
def _():
  @pl.when(phase == LEFT)
  def _():
    accum_scratch[...] += x_ref[left_copy_device, left_copy_slice]
  @pl.when(phase == RIGHT)
  def _():
    accum_scratch[...] += x_ref[right_copy_device, right_copy_slice]
local_copy = pltpu.make_async_copy(
    src_ref=accum_scratch,
    dst_ref=hbm_scratch.at[working_slot, current_phase_slice],
    sem=local_copy_sem,
)
local_copy.start()
local_copy.wait()

我们的新内核用以下 emit_pipeline 调用替换了它:

def inner_kernel(input_ref, accum_ref):
  accum_ref[...] = input_ref[...]
accum_pipeline = pltpu.emit_pipeline(inner_kernel,
                                     in_specs=[inner_block_spec],
                                     out_specs=inner_block_spec,
                                     should_accumulate_out=True,
                                     grid=inner_grid)
@pl.when(~last_iteration)
def _():
  @pl.when(phase == LEFT)
  def _():
    accum_pipeline(x_ref.at[left_copy_device, left_copy_slice],
                   hbm_scratch.at[working_slot, left_copy_slice],
    )
  @pl.when(phase == RIGHT)
  def _():
    accum_pipeline(x_ref.at[right_copy_device, right_copy_slice],
                   hbm_scratch.at[working_slot, right_copy_slice],
    )

完整的内核如下:

partition = P(None, 'x')
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)

# We pick a large outer kernel block size that we do not want to place
# in VMEM. For pedagogical purposes we use (4096, 4096), although in
# principle this can be much larger.
outer_block_size = (4096, 4096)
# We pick a smaller VMEM block size for the inner kernel.
inner_block_size = (128, 128)
input_arr = jax.random.uniform(
    jax.random.key(0),
    shape=(
        outer_block_size[0] * num_devices,
        outer_block_size[1] * num_devices,
    ),
)
input_arr = jax.device_put(input_arr, sharding)


inner_grid = (
    outer_block_size[0] // inner_block_size[0] // 2,
    outer_block_size[1] // inner_block_size[1],
)
inner_block_spec = pl.BlockSpec(
    index_map=lambda i, j: (i, j),
    block_shape=inner_block_size,
    memory_space=pltpu.MemorySpace.ANY,
)


def reduce_scatter_kernel(
    x_ref,
    o_ref,
    hbm_scratch,
    left_recv_sem,
    left_send_sem,
    copy_sem,
    right_recv_sem,
    right_send_sem,
    left_capacity_sem,
    right_capacity_sem,
):
  outer_step = pl.program_id(0)
  phase = pl.program_id(1)
  is_start = jnp.logical_and(outer_step == 0, phase == 0)
  last_iteration = outer_step == pl.num_programs(0) - 1

  working_slot = lax.rem(outer_step, 2)
  receiving_slot = 1 - working_slot
  my_id = lax.axis_index('x')
  right_neighbor = mod(my_id + 1, num_devices)
  left_neighbor = mod(my_id - 1, num_devices)

  left_copy_device = mod(my_id + outer_step + 1, num_devices)
  right_copy_device = mod(my_id - outer_step - 1, num_devices)
  left_copy_slice = pl.ds(0, outer_block_size[0] // 2)
  right_copy_slice = pl.ds(outer_block_size[0] // 2, outer_block_size[0] // 2)
  current_phase_slice = pl.ds(
      phase * (outer_block_size[0] // 2), outer_block_size[0] // 2
  )

  initial_left_copy = pltpu.make_async_remote_copy(
      src_ref=x_ref.at[my_id, left_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, left_copy_slice],
      send_sem=left_send_sem,
      recv_sem=left_recv_sem,
      device_id=(left_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  initial_right_copy = pltpu.make_async_remote_copy(
      src_ref=x_ref.at[my_id, right_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
      send_sem=right_send_sem,
      recv_sem=right_recv_sem,
      device_id=(right_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  left_copy = pltpu.make_async_remote_copy(
      src_ref=hbm_scratch.at[working_slot, left_copy_slice],
      dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],
      send_sem=left_send_sem,
      recv_sem=left_recv_sem,
      device_id=(left_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )
  right_copy = pltpu.make_async_remote_copy(
      src_ref=hbm_scratch.at[receiving_slot, right_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
      send_sem=right_send_sem,
      recv_sem=right_recv_sem,
      device_id=(right_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  # --- Prologue ---
  @pl.when(is_start)
  def _():
    # Barrier with both neighbors at the start, since we will be
    # communicating with both.
    local_barrier(left_neighbor, right_neighbor)

    initial_left_copy.start()
    initial_left_copy.wait()
    initial_right_copy.start()

    # We tell our left neighbor that it is allowed to send to the right.
    # (and vice versa for right neighbor)
    signal(LEFT, right_capacity_sem)
    signal(RIGHT, left_capacity_sem)

  @pl.when(~is_start)
  def _():
    @pl.when(phase == LEFT)
    def _():
      # We block here until our right neighbor tells use we can send to
      # the right.
      pltpu.semaphore_wait(right_capacity_sem, 1)
      right_copy.start()

    @pl.when(phase == RIGHT)
    def _():
      # We block here until our left neighbor tells use we can send to
      # the left.
      pltpu.semaphore_wait(left_capacity_sem, 1)
      left_copy.start()

  # --- Body ---
  def inner_kernel(input_ref, accum_ref):
    # We do not explicitly use += because we set should_accumulate_out=True.
    accum_ref[...] = input_ref[...]

  accum_pipeline = pltpu.emit_pipeline(
      inner_kernel,
      in_specs=[inner_block_spec],
      out_specs=inner_block_spec,
      should_accumulate_out=True,
      grid=inner_grid,
  )

  @pl.when(~last_iteration)
  def _():
    @pl.when(phase == LEFT)
    def _():
      accum_pipeline(
          x_ref.at[left_copy_device, left_copy_slice],
          hbm_scratch.at[working_slot, left_copy_slice],
      )

    @pl.when(phase == RIGHT)
    def _():
      accum_pipeline(
          x_ref.at[right_copy_device, right_copy_slice],
          hbm_scratch.at[working_slot, right_copy_slice],
      )

  # --- Epilogue ---
  @pl.when(is_start)
  def _():
    initial_right_copy.wait()

  @pl.when(~is_start)
  def _():
    @pl.when(phase == LEFT)
    def _():
      right_copy.wait()
      signal(LEFT, right_capacity_sem)

    @pl.when(phase == RIGHT)
    def _():
      left_copy.wait()
      signal(RIGHT, left_capacity_sem)

  # Store result on last iteration.
  @pl.when(last_iteration)
  def _():
    output_copy = pltpu.make_async_copy(
        src_ref=hbm_scratch.at[working_slot, current_phase_slice],
        dst_ref=o_ref.at[current_phase_slice],
        sem=copy_sem,
    )
    output_copy.start()
    output_copy.wait()

    # Clean up semaphores so that they exit with a value of 0.
    @pl.when(phase == LEFT)
    def _():
      pltpu.semaphore_wait(right_capacity_sem, 1)

    @pl.when(phase == RIGHT)
    def _():
      pltpu.semaphore_wait(left_capacity_sem, 1)


out_shape = (
    jax.ShapeDtypeStruct(
        (outer_block_size[0], outer_block_size[1]), jnp.float32
    ),
    # Shape: [working/recv, block[0], block[1]]
    jax.ShapeDtypeStruct(
        (2, outer_block_size[0], outer_block_size[1]), jnp.float32
    ),  # hbm_scratch
)

grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=0,
    in_specs=[
        pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
    ],
    out_specs=[
        pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
        pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
    ],
    grid=(num_devices, 2),
    scratch_shapes=(
        [pltpu.SemaphoreType.DMA] * 5
        + [pltpu.SemaphoreType.REGULAR] * 2  # Capacity semaphores
    ),
)


def pallas_reduce_scatter(input_arr):
  input_arr = input_arr.reshape(
      num_devices, outer_block_size[0], outer_block_size[1]
  )
  return pl.pallas_call(
      reduce_scatter_kernel,
      out_shape=out_shape,
      grid_spec=grid_spec,
      compiler_params=pltpu.CompilerParams(collective_id=0),
  )(input_arr)[0]


pallas_result = jax.jit(
    jax.shard_map(
        pallas_reduce_scatter,
        mesh=mesh,
        in_specs=P(None, 'x'),
        out_specs=P('x', None),
        check_vma=False,
    )
)(input_arr)

pallas_result = jax.block_until_ready(pallas_result)
# Now we compare our result to XLA.
def lax_reduce_sum_scatter(x):
  x = x.reshape(num_devices, outer_block_size[0], outer_block_size[1])
  return lax.psum_scatter(x, 'x')


xla_result = jax.jit(
    jax.shard_map(
        lax_reduce_sum_scatter,
        mesh=mesh,
        in_specs=P(None, 'x'),
        out_specs=P('x', None),
    )
)(input_arr)

print('Input:', input_arr.shape, input_arr[::4, 0])
print('Pallas Result:', pallas_result.shape, pallas_result[::4, 0])
print('lax.psum_scatter Result:', xla_result.shape, xla_result[::4, 0])
print(
    'Difference |Pallas - lax.psum_scatter|:',
    jnp.max(jnp.abs(pallas_result - xla_result)),
)
Input: (16384, 16384) [0.74162567 0.0242182  0.27751946 ... 0.05213022 0.36088037 0.04494429]
Pallas Result: (16384, 4096) [2.0648427 1.674587  1.9148926 ... 1.3371865 1.3296283 1.2887063]
lax.psum_scatter Result: (16384, 4096) [2.0648427 1.674587  1.9148926 ... 1.3371865 1.3296283 1.2887063]
Difference |Pallas - lax.psum_scatter|: 2.3841858e-07

最终说明#

Megacore#

某些 TPU 包含多核的 Megacore 配置。在此配置中,我们的通用建议是仅从一个核心启动 DMA,并且仅执行 HBM-HBM 传输。为此,请将一个网格轴设置为核心数量(可以通过 jax.devices()[0].num_cores 获取),并将 dimension_semantics 设置为 "parallel"。然后,您可以使用 core_index = pl.program_id(axis) 获取沿该轴的核心索引,并使用 @pl.when(core_index==i) 执行特定于该核心的代码。

与 XLA 的交互#

在本教程中,我们介绍了几个内核示例,它们复制了 JAX 中的集体操作的功能,例如 lax.all_gatherlax.psumlax.psum_scatter。一个重要的注意事项是,Pallas 内核对 XLA 编译器来说是有些不透明的,可能会导致它错过一些通常会执行的优化。例如,XLA 可以异步分派集体操作,以便在不编写自定义内核的情况下交织通信和计算。当 Pallas 内核涉及时,这不一定会被执行,因此对程序进行性能分析以确定这是否是一个问题很重要。另一个例子是,在本教程中用于生成嵌套流水线的 emit_pipeline 函数对 XLA 编译器是不可见的,因此无法与相邻的操作进行融合。

后续步骤#

对读者来说,优秀的后续练习可能包括实现分布式矩阵乘法、实现 lax.all_to_all,以及放宽同步以允许额外的超前执行。