Pallas 中用于 TPU 的分布式计算#

在本教程中,我们将介绍 Pallas 在 TPU 上进行分布式计算的基础知识。我们将学习 TPU 拓扑、使用远程 DMA 原语进行通信,以及如何使用 jax.shard_map 从 JAX 调用分布式内核。我们还将介绍一些更高级的内核编写技术,例如双缓冲、双向带宽优化和嵌套流水线。作为教学示例,我们将学习如何实现 JAX 中的各种集合操作原语,例如 lax.ppermutelax.all_gatherlax.psumlax.psum_scatter

建议提前阅读

import functools
import jax
from jax import lax
from jax import numpy as jnp
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu

P = jax.sharding.PartitionSpec

num_devices = jax.local_device_count()
assert num_devices > 1, "Please run this notebook with more than one device."
assert "TPU" in jax.devices()[0].device_kind, "Please run this notebook with TPU devices."
print(f"Running with {num_devices} {jax.devices()[0].device_kind} devices.")
Running with 4 TPU v4 devices.

TPU 拓扑#

TPU 通常以包含多个设备的 Pod 形式部署,这些设备通过高带宽的片间互连 (ICI) 连接,以实现 Pod 内的通信,其速度远快于典型的网络连接。例如,TPU v5p 的规格表显示,每个芯片的 ICI 带宽为 4.8Tb/s(作为参考,TPU v5p 还拥有 21Tb/s 的本地 HBM 带宽)。ICI 使我们能够实现需要 Pod 内高带宽通信的快速高效的分布式内核,并利用数据中心网络进行并行化,以处理带宽要求较低的操作,例如批处理维度上的数据并行。

TPU Pod 通常以 ND 环面拓扑结构排列。下图展示了几种不同大小的配置示例。

tpu_topologies

展平为图时,环面可以如下可视化。每条边(橙色或黑色)是两个设备之间的双向连接。您会经常听到与设备拓扑讨论相关的环——环面的一个关键特征是,当沿 Pod 的某个轴切片时,例如节点 [(0,1), (1, 1), (2, 1), (3, 1)][(0, 1), (1, 1)],我们有一个设备环。这是我们可以用来简化 Pod 内通信模式的特性。

tpu_torus

远程直接内存访问 (RDMA) 模型#

TPU 通过一种称为远程直接内存访问 (RDMA) 的推入式模型进行通信。TPU 可以发出复制指令,将数据从本地缓冲区推送到同一 Pod 内另一个设备上的任何缓冲区,该操作与主程序线程异步执行。然而,TPU 只能读取本地存储的数据。这与更传统的多核编程不同,后者可以读写共享内存中的值。

异步远程复制操作#

pltpu.make_async_remote_copy 函数用于创建远程 DMA 描述符对象,该对象参数化“发送”操作和“接收”操作。其签名如下:

 def make_async_remote_copy(
     src_ref: Ref,
     dst_ref: Ref,
     send_sem: Ref[SemaphoreType],
     recv_sem: Ref[SemaphoreType],
     device_id: int | tuple[int, ...],
     device_id_type: DeviceIdType
 ) -> AsyncCopyDescriptor:
  • src_ref 是包含您希望发送到另一个设备上的 dst_ref 的数据的本地 Ref(在任何内存空间中)。

  • dst_ref 是目标设备上将复制数据的远程 Ref(在任何内存空间中)。

  • send_sem 是一个 DMA 信号量,用于阻塞直到所有数据都已从 src_ref 发送完毕。

  • recv_sem 是一个 DMA 信号量,用于阻塞直到在 dst_ref 接收到预期数量的字节。DMA 的发送方将写入接收方的 recv_sem

  • device_id 是目标发送设备的设备 ID。

  • device_id_type 指定 device_id 的格式,可以是 LOGICAL 格式(整数设备 ID),也可以是 MESH 格式(逻辑设备网格的 ND 元组索引)。默认模式为 MESH。

make_async_remote_copy 返回一个描述符对象,您可以使用其 .start() 方法启动 DMA,并使用 .wait_send() 阻塞 send_sem,以及使用 .wait_recv() 阻塞 recv_sem(或使用 .wait() 阻塞两者)。如果设备仅预期发送数据,则仅调用 .start().wait_send() 即可;同样,如果设备仅接收数据,则仅调用 .wait_recv() 即可。如果使用 SPMD 模式,其中所有设备执行 DMA,则每个设备通常会同时调用 .start().wait()

dma_descriptor = make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id)
dma_descriptor.start() # Initiate the DMA (non-blocking).
# ... do other work
dma_descriptor.wait_send() # Block until all data has been sent.
dma_descriptor.wait_recv() # Block until all data has been received.

例如,让我们可视化一个 DMA,其中我们考虑 4 个设备(索引为 0、1、2、3)。我们考虑一种方案,设备 0 复制到设备 1,设备 2 和 3 相互复制。实际上,我们可以通过使用 @pl.when 根据设备 ID 进行分支来创建这种不对称的通信模式。

(1) 每个设备创建 DMA 描述符。设备 0、2 和 3 调用 .start() 以从 src_ref 启动 DMA。设备 1 跳过 .start() 并且不执行任何操作,例如通过使用 pl.when

rdma_start

(2) 由于 .start() 是非阻塞的,因此每个设备都可以在 DMA 传输过程中自由地执行其他计算。设备 0、2 和 3 调用 .wait_send() 来等待 send_sem,该调用将阻塞直到所有数据都已发送完毕。

rdma_send

(3) 最后,设备 1、2 和 3 将调用 .wait_recv() 来等待 recv_sem,直到所有数据都已到达 dst_ref

rdma_recv

上述通信模式可以写成如下形式:

def example_kernel(input_ref, output_ref, send_sem, recv_sem):
    device_id = lax.axis_index('x')
    copy_0_to_1 = pltpu.make_async_remote_copy(
        src_ref=input_ref,
        dst_ref=output_ref,
        send_sem=send_sem,
        recv_sem=recv_sem,
        device_id=1,
    )
    copy_2_to_3 = pltpu.make_async_remote_copy(
        src_ref=input_ref,
        dst_ref=output_ref,
        send_sem=send_sem,
        recv_sem=recv_sem,
        device_id=3,
    )
    copy_3_to_2 = pltpu.make_async_remote_copy(
        src_ref=input_ref,
        dst_ref=output_ref,
        send_sem=send_sem,
        recv_sem=recv_sem,
        device_id=2,
    )
    @pl.when(device_id == 0)
    def _():
      copy_0_to_1.start()
      copy_0_to_1.wait_send()
    @pl.when(device_id == 1)
    def _():
      copy_0_to_1.wait_recv()
    @pl.when(device_id == 2)
    def _():
      copy_2_to_3.start()
      copy_2_to_3.wait_send()
      copy_3_to_2.wait_recv()
    @pl.when(device_id == 3)
    def _():
      copy_3_to_2.start()
      copy_3_to_2.wait_send()
      copy_2_to_3.wait_recv()

DMA 信号量#

send_semrecv_sem 是专用于 DMA 的特殊类型信号量的实例。在将输入规格指定给 pallas_call 时,它们必须使用 tpu.SemaphoreType.DMA 类型进行分配。

在内部,DMA 信号量可以被视为整数值的进度跟踪器。在 DMA 启动时,本地设备将开始异步递增 send_sem 的值和接收方的 recv_sem。等待信号量将阻塞,直到信号量的值达到发送/接收的数据总字节数;当达到该值时,等待线程被释放,并且信号量的值将按相同数量递减。这意味着所有数据都已发送(对于 send_sem)或所有数据都已接收(对于 recv_sem)。信号量的值可以通过 pl.semaphore_read 读取,但请注意,值的底层语义可能会在不同硬件代之间发生变化(例如,该值可能不完全代表发送的字节数,尽管这是推理信号量行为时一个有用的心智模型)。

路由#

发送方可以将数据发送给同一 Pod 内的任何接收方,即使它们不共享直接连接(TPU v5e 是此规则的例外,其中设备只能路由到与其自身偏移量为 2 的幂次的设备)。TPU 具有内部路由机制,可以将数据沿路径传递到目标设备上的下一个设备。但是,不建议以这种方式进行通信,因为作为内核编写者,您无法控制网络争用。本教程中我们将介绍的示例通过仅将数据传输到相邻设备来最大程度地减少低效通信。

故障模式#

如果错误使用远程 DMA,您可能会遇到几种难以调试的故障模式。有缺陷的 DMA 使用的普遍症状是崩溃、挂起或静默数据损坏。

  • 如果信号量以无效的非零值退出程序,Pallas 将崩溃并退出程序。

  • 如果信号量被等待但接收到的字节数不足(即没有发送方,或者发送的数据小于接收设备上 dst_ref 的大小),程序可能会无限期地挂起,等待永远不会发送的字节。在这种情况下,程序需要重新启动。

  • 如果遇到竞态条件,如果发生两次同时写入或同时读写,可能会导致静默数据损坏。

上述问题的一些常见原因包括:

  • 如果设备调用 .wait_recv() 但没有其他设备向其发送数据,则内核可能会挂起。

  • 如果设备收到的字节数多于预期,也可能由于非零信号量状态而崩溃。如果收到的字节数少于预期,则可能会无限期挂起。

  • 如果 DMA 已启动但未等待信号量,则程序可能会因非零信号量状态而崩溃。

  • 如果两个设备复制到同一目的地,您可能会因竞态条件而遇到不确定的结果,或者因非零信号量状态而崩溃。

示例:右置换 (lax.ppermute)#

让我们深入了解一个非常基本的示例。我们将实现一个执行右置换的内核,其中每个设备将其数据切片发送给它的右侧邻居。

假设我们有一个包含 512 个元素的数组,我们将其分片为 4 个设备上每个 128 个元素的切片。每个设备将把它的切片传递给下一个设备,输出将包含相同的数据,但切片会旋转 1。这与 lax.ppermute 操作相同,其中置换设置为 (n, (n+1) % 4)

为了以分布式模式调用内核,我们将 pallas_call 包装在 shard_map 转换中。从那里,我们可以像编写普通单设备 Pallas 内核一样编写内核,只是现在我们可以访问远程 DMA 指令。JAX 集合操作原语,如 lax.axis_index,可以用于获取 device_id,该 ID 可用于通过引用传递给 shard_map 的相同命名轴名称来计算要复制到的目标设备。

partition = P(None, 'x')
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)

# Create an input array that shards the last dimension across
# all devices.
input_arr = jax.random.uniform(jax.random.key(0), (8, 128 * num_devices))
input_arr = jax.device_put(input_arr, sharding)


def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem):
  my_id = lax.axis_index('x')
  right_neighbor = lax.rem(my_id + 1, num_devices)
  remote_copy_op = pltpu.make_async_remote_copy(
      src_ref=input_ref,
      dst_ref=output_ref,
      send_sem=send_sem,
      recv_sem=recv_sem,
      device_id=(right_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )
  remote_copy_op.start()
  remote_copy_op.wait()


out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32)
grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=0,
    # MemorySpace.ANY will (usually) place the tensor in HBM.
    in_specs=[
        pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
    ],
    out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
    scratch_shapes=(
        # We allocate DMA semaphores in scratch memory.
        [pltpu.SemaphoreType.DMA] * 2
    ),
)
right_permute = pl.pallas_call(
    right_permute_kernel,
    out_shape=out_shape,
    grid_spec=grid_spec,
)
# Wrap the kernel within a shard_map to call.
pallas_result = jax.jit(
    jax.shard_map(
        right_permute,
        mesh=mesh,
        in_specs=partition,
        out_specs=partition,
        check_vma=False,
    )
)(input_arr)

# Compare Pallas result to XLA shard_map result.
perm = tuple((src, (src + 1) % num_devices) for src in range(num_devices))

xla_result = jax.jit(
    jax.shard_map(
        lambda x: lax.ppermute(x, 'x', perm),
        mesh=mesh, in_specs=partition, out_specs=partition)
)(input_arr)

print('Input = ', input_arr[0, ::128])
print('Pallas Result = ', pallas_result[0, ::128])
print('lax.ppermute Result = ', xla_result[0, ::128])
print(
    'Difference |Pallas - lax.ppermute| = ',
    jnp.mean(jnp.abs(pallas_result - xla_result)),
)
Input =  [0.9858954  0.11763906 0.9955574  0.775211  ]
Pallas Result =  [0.775211   0.9858954  0.11763906 0.9955574 ]
lax.ppermute Result =  [0.775211   0.9858954  0.11763906 0.9955574 ]
Difference |Pallas - lax.ppermute| =  0.0

示例:全收集 (lax.all_gather)#

在下一个示例中,我们将实现全收集集合操作,其在 JAX 中有等效的 lax.all_gather。与上面仅涉及一对源和目标邻居的右置换示例不同,全收集操作需要所有设备之间的通信,因此我们必须考虑数据如何在它们之间路由。我们如何实现这一点的具体细节由设备拓扑决定,我们假设它是一个环形。

环形通信模式#

我们将假设环形拓扑来编写内核。环形非常适合 TPU,因为沿着环面的任何维度切片都会生成一个环。在编写集合操作时,我们通常只需要一次考虑环面的一维切片,因为环面的不同维度保留用于不同类型的并行(例如,数据并行与模型并行)。

我们将采用的策略是编写一个循环内核,在每次迭代中,一个设备从其左侧邻居接收分片数组的一个切片,并将先前接收到的切片复制到其右侧邻居。经过 num_devices 次迭代后,每个设备将在其本地 HBM 中拥有整个数组的副本。

all_gather

我们可以重新利用 Pallas 的 grid 参数来实现循环。我们不再像之前的教程那样遍历数组的瓦片,而是将网格设置为 (num_devices,),表示我们要循环遍历设备数量,并使用 pl.program_id 在 Pallas 内核内部获取循环迭代。以下代码片段演示了如何实现这一点:

partition = P('x', None)
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)

# Create an input array that shards the first dimension across
# all devices.
input_arr = jax.random.uniform(jax.random.key(0), (8 * num_devices, 128))
input_arr = jax.device_put(input_arr, sharding)


def all_gather_kernel(input_ref,
                      output_ref,
                      local_copy_sem,
                      send_sem,
                      recv_sems):
  outer_step = pl.program_id(0)
  my_id = lax.axis_index('x')
  right_neighbor = lax.rem(my_id + 1, num_devices)
  copy_slot = my_id - outer_step
  copy_slot = lax.rem(copy_slot + num_devices, num_devices)

  @pl.when(outer_step == 0)
  def _():
    local_copy_op = pltpu.make_async_copy(
      src_ref=input_ref,
      dst_ref=output_ref.at[my_id],
      sem=local_copy_sem,
    )
    local_copy_op.start()
    local_copy_op.wait()

  # Copy to our right neighbor.
  # Note that we will also be receiving data from our left neighbor,
  # but at `copy_slot-1` rather than `copy_slot`! This makes use of the fact
  # that the indices do not need to be symmetric between remote DMAs.
  remote_copy_op = pltpu.make_async_remote_copy(
      src_ref=output_ref.at[copy_slot],
      dst_ref=output_ref.at[copy_slot],
      send_sem=send_sem,
      recv_sem=recv_sems.at[outer_step],
      device_id=(right_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )
  remote_copy_op.start()
  remote_copy_op.wait()

out_shape = jax.ShapeDtypeStruct((num_devices, 8, 128), jnp.float32)
grid_spec = pltpu.PrefetchScalarGridSpec(
            num_scalar_prefetch=0,
            in_specs=[
                # MemorySpace.ANY will (usually) place the tensor in HBM.
                pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
            ],
            out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
            scratch_shapes=(
              # DMA semaphores are allocated in scratch memory.
              # We allocated one semaphore for a local HBM-VMEM copy,
              # and one for the remote send semaphore.
              [pltpu.SemaphoreType.DMA] * 2
              # We additionally allocate one receive semaphore per device.
              # This is to avoid situations where we have multiple
              # DMAs in flight, as we do not want to share a receive
              # semaphore between the DMAs.
              + [pltpu.SemaphoreType.DMA((num_devices-1,))]

            ),
            grid=(num_devices-1,)
        )

all_gather = pl.pallas_call(
      all_gather_kernel,
      out_shape=out_shape,
      grid_spec=grid_spec,
  )

# Wrap the kernel within a shard_map to call.
pallas_result = jax.jit(
      jax.shard_map(
          all_gather,
          mesh=mesh,
          in_specs=partition,
          out_specs=partition,
          check_vma=False
      )
)(input_arr)

# Compare Pallas result to XLA shard_map result.
xla_result = jax.jit(
    jax.shard_map(
        lambda x: lax.all_gather(x, 'x'),
        mesh=mesh, in_specs=partition, out_specs=partition
    )
)(input_arr)

print('Input: ', input_arr.shape, input_arr[::8, 0])
print('Pallas Result: ', pallas_result.shape, pallas_result[:, 0, 0])
print('lax.all_gather Result: ', xla_result.shape, xla_result[:, 0, 0])
print('Difference |Pallas - lax.all_gather| = ',
      jnp.mean(jnp.abs(pallas_result - xla_result)))
Input:  (32, 128) [0.9858954  0.54248166 0.9547038  0.954962  ]
Pallas Result:  (16, 8, 128) [0.9858954  0.54248166 0.9547038  0.954962   0.9858954  0.54248166
 0.9547038  0.954962   0.9858954  0.54248166 0.9547038  0.954962
 0.9858954  0.54248166 0.9547038  0.954962  ]
lax.all_gather Result:  (16, 8, 128) [0.9858954  0.54248166 0.9547038  0.954962   0.9858954  0.54248166
 0.9547038  0.954962   0.9858954  0.54248166 0.9547038  0.954962
 0.9858954  0.54248166 0.9547038  0.954962  ]
Difference |Pallas - lax.all_gather| =  0.0

这里值得一提的一个细节是使用了多个接收信号量。因为我们只在接收设备上阻塞,发送方仍然可能在接收方完成处理第一个 DMA 之前发送了多个正在进行的 DMA(有关竞态条件的详细讨论,请参阅下一节和规约求和示例)。在这种情况下,我们可能会遇到相同信号量同时用于多个 DMA 的情况。为了避免这种情况,我们分配了 num_devices-1 个信号量,因此没有重复使用的风险。虽然这种竞态条件在这种小型内核上不太可能发生,但在大型内核上,设备更容易不同步并可能导致静默故障。

高级技术#

现在我们已经了解了如何使用远程 DMA 操作编写几个基本内核,我们将介绍更高级的同步和高效内核编写技术。

同步:常规信号量和屏障信号量#

我们在基本教程中实现的示例不需要特殊处理同步,因为所有必要的通信都写入不相交的缓冲区。然而,其他操作可能需要更复杂的通信模式,需要额外的同步原语来避免竞态条件。Pallas 提供了两个额外的原语来帮助解决这个问题:常规信号量和屏障信号量。

常规信号量#

常规信号量是用于跨多个设备同步的标准工具。信号量本质上是计数器——任何设备都可以递增它们,之后设备可以阻塞直到信号量的值达到特定值(然后递减该值)。

常规信号量可用的三个主要操作是 signal、wait 和 read。

def semaphore_signal(
    sem: Ref[SemaphoreType],
    inc: int,
    device_id: int | tuple[int, ...],
    device_id_type: DeviceIdType
) -> None:
  ... # Increments the semaphore `sem` on the target device `device_id` by `inc`.
  
def semaphore_wait(
    semaphore: Ref[SemaphoreType],
    value: int,
) -> None:
  ... # Blocks until the locally allocated copy of `sem` reaches `value`, then decrement by `value` and proceed.
    
def semaphore_read(
    sem: Ref[SemaphoreType],
) -> jax.Array:
  ...  # Returns the current value of `sem` as an `int32[]`.

要使用常规信号量,它们的分配方式与 DMA 信号量相同,但需要指定 pltpu.SemaphoreType.REGULAR 而不是 pltpu.SemaphoreType.DMA

信号量在 Pallas 程序结束时必须为零才能成功完成。在以下两种错误情况下可能会发生这种情况:

  • 如果信号量过量发出信号,程序将以非零(>0)信号量结束。在这种情况下,程序将在完成时崩溃。这对于调试很有用,因为非零信号量通常意味着程序内部存在错误。

  • 如果信号量过度等待,程序将挂在阻塞的 semaphore_wait 调用上,因为它在等待信号量递增。在这种情况下,设备或程序将需要重新启动。

屏障信号量#

屏障信号量是全局分配的信号量,用于在整个程序中同步设备,并确保所有设备都已进入 Pallas 内核。

如果在更大的 XLA 程序上下文中执行 Pallas 内核,我们需要确保所有通信设备都已进入内核。然而,DMA 和常规信号量都是本地范围的——它们只被进入内核的其他设备理解。屏障信号量作为一种全局理解的信号量,无论设备当前在 XLA 程序的何处执行,都可以用于同步。

默认情况下,如果您不指定屏障信号量,Pallas 将在程序开始时自动插入一个屏障信号量。然而,编写自己的屏障信号量可能更有效率。屏障信号量与常规信号量类似,它们都是可以通过 semaphore_signal 递增并通过 semaphore_wait 递减的计数器。它们通过在内核中调用 get_barrier_semaphore() 来创建。通常,我们在内核开始时使用屏障来与所有通信设备同步。

from jax.experimental.pallas import tpu as pltpu

def example_kernel(...):
  # Use barrier semaphores at the beginning of a kernel.
  # is_start_of_kernel = ...
  # right_neighbor = ...
  # ...
  @pl.when(is_start_of_kernel)
  def _():
    barrier_sem = pltpu.get_barrier_semaphore()
    # Increment the semaphore of your right neighbor.
    pltpu.semaphore_signal(
          barrier_sem,
          device_id=right_neighbor,
          device_id_type=pltpu.DeviceIdType.LOGICAL,
    )
    # Wait until your left neighbor has incremented your semaphore
    pltpu.semaphore_wait(barrier_sem, 1)
  # ...

使用屏障信号量时,必须将 collective_id 编译器参数传递给 pallas_call 以指定正在使用的屏障信号量。TPU 具有少量固定的可用屏障信号量(通常在 20-30 个之间),因此应谨慎使用。为确保正确性,只有共享相同通信模式的内核才应使用相同的 collective_id。例如,如果两个内核仅在相同网格轴上的邻居之间同步,它们允许共享相同的 collective_id。但是,如果两个内核沿着不同的轴同步,它们必须具有不同的 collective_id。否则可能导致难以调试的竞态条件。

kernel = pl.pallas_call(
      example_kernel,
      ...,
      compiler_params=pltpu.CompilerParams(collective_id=0),
)

双缓冲#

为了避免从本地 Ref 读取时,另一个设备也正在写入该 Ref 从而产生竞态条件,一个有用的技术是“双缓冲”策略,即我们为每个目标值分配两个 Ref。在每次迭代中,一个 Ref 将被指定为“工作槽”,另一个将被指定为“接收槽”。设备可以自由使用工作槽进行计算,但只会将数据复制到其邻居的接收槽中。工作槽和接收槽每迭代交替一次,以便一旦复制完成,旧的接收槽成为新的工作槽,反之亦然。正确使用此方案,数据永远不会从同一缓冲区读取和写入。

以下代码骨架演示了如何使用双缓冲。我们保留变量 iteration 中的运行迭代计数器,并且 working_slotreceiving_slot 每迭代在 0 和 1 之间交替。dst_ref 被分配为双缓冲区,其大小为 [2, ...]。在每次迭代中,我们使用 dst_ref.at[working_slot, ...] 从工作槽读取数据,并使用该值执行计算。同时,我们复制到邻居的 dst_ref.at[receiving_slot] 以避免覆盖其 working_slot 值。通过以这种方式构建通信,可以重叠远程 DMA 的通信延迟与本地计算,同时最大程度地降低竞态条件的风险。

def kernel(...):
  # ...
  iteration = pl.program_id(0)
  working_slot = lax.rem(iteration, 2)
  receiving_slot = 1 - working_slot
  # ...

  local_copy_op = pltpu.make_async_copy(
    src_ref=dst_ref.at[working_slot, ...],
    dst_ref=local_scratch_ref,
    sem=local_copy_sem,
  )
  local_copy_op.start()
  remote_copy_op = pltpu.make_async_remote_copy(
    src_ref=src_ref,
    dst_ref=dst_ref.at[receiving_slot, ...],
    send_sem=send_sem,
    recv_sem=recv_sem,
    device_id=target_device,
    device_id_type=pltpu.DeviceIdType.MESH,
  )
  remote_copy_op.start()
  
  local_copy_op.wait()
  # ... do work on local_scratch while waiting for async_copy_op to finish.
  remote_copy_op.wait()

在同步方面,如果所有设备都在同一迭代上执行,则双缓冲结构可以工作。如果发送方设法比其接收方提前一个迭代,则其 working_slotreceiving_slot 索引将与接收方相比翻转,这意味着它可能在接收方读取 working_slot 值的同时写入它。为了避免这种情况,可能需要使用信号量来同步发送方与接收方,或添加额外的缓冲槽(“三缓冲”、“四缓冲”或 N 缓冲)以允许额外的超前执行,但代价是消耗更多内存。在我们之前的 all_gather 示例中,请注意内核包含一个带有 N 个槽的接收缓冲区,这完全避免了竞态条件。在我们的下一个内核中,我们将改为通过一个使用双缓冲和显式同步的示例。

示例:全规约求和 (lax.psum)#

我们现在将使用双缓冲和信号量进行同步来实现一个全规约求和内核。对于熟悉 JAX 中集合操作的人来说,等效操作是 lax.psum。全规约是一种标准集合操作,其目标是沿着数组的一个轴进行规约,但数组是跨多个设备分片的。

reduce_sum_1

在上面的示例中,我们将数组 [5, 2, 1, 3] 分片到 4 个设备上。全规约求和操作将对所有值求和并将结果复制到每个设备上,从而在所有 4 个设备上得到结果 [11, 11, 11, 11] 的分片。

全规约的朴素实现是将所有所需值收集到每个设备上,然后进行规约。然而,我们可以通过交错通信和计算来提高此实现的性能。单向交错的全规约可以可视化如下。在每次迭代中,我们从左邻居接收一个输入值,同时将输入传递给下一个邻居,并用我们的本地累加器递增它。经过 N-1 次迭代后,每个设备将在其内存中拥有完整和的副本。

reduce_sum_2

综合应用#

以下内核演示了如何将这些原理组合成一个功能性内核。

序言(当 outer_step==0 时执行)首先与两个邻居启动一个屏障,以确保它们也已进入内核。它还处理所有 Ref 的初始化,并处理到右邻居“工作”槽的第一次远程复制。

主体假定一个值已复制到我们的本地工作槽中,无论是来自上一次迭代还是来自序言。一个复杂因素是我们的目标缓冲区位于 HBM 中,但我们需要在执行算术运算之前将值加载到 VMEM 中。因此,我们同时将工作槽值复制到我们的 VMEM (receive_scratch) 中,并将其传递到我们右邻居的接收槽中。一旦该值已复制到我们的 VMEM 中,我们就可以将其累加到我们的结果(包含在 o_ref 中)。

如果一个设备比其右侧邻居超前一个循环,则可能会发生一个微妙的竞态条件。在这种情况下,它可能会在接收方读取其 working_slot 的同时复制到接收方的 working_slot 中。为了避免这种情况,每个设备在复制到右侧邻居的 dst_ref 之前,将阻塞在 REGULAR 信号量上,直到它已发出信号表明已完成从其 working_slot 读取。对于像本示例这样的小型内核,这种竞态条件很少被触发,但如果例如使用 pltpu.delay 指令来人为地挂起设备,则可以显式触发它。

请注意,这并非最优或完全通用的内核,因为块大小必须完全适合 VMEM,并且我们可以更好地交错通信和累积。我们将在后面的章节中讨论这些优化。

partition = P(None, 'x')
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)

input_arr = jax.random.uniform(jax.random.key(0), shape=(8, 128 * num_devices))
input_arr = jax.device_put(input_arr, sharding)


def local_barrier(left_neighbor, right_neighbor, double_barrier=True):
  """Performs a barrier with neighbors on the global barrier semaphore.

  Optionally performs a second barrier, which prevents a potential race
  when reusing the same collective_id across kernel invocations.
  """
  barrier_sem = pltpu.get_barrier_semaphore()
  for neighbor in [left_neighbor, right_neighbor]:
    pltpu.semaphore_signal(
      barrier_sem,
      inc=1,
      device_id=(neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
    )
  pltpu.semaphore_wait(barrier_sem, 2)
  if double_barrier:
    # The double-barrier prevents a race condition where one neighbor can
    # re-enter the kernel again on a subsequent call and increment the
    # barrier semaphore a second time. This would unblock the current device
    # even if the other neighbor is not ready yet.
    # To implement a double-barrier, we stack-allocate a second REGULAR
    # semaphore using run_scoped.
    @functools.partial(pl.run_scoped,
                       second_barrier=pltpu.SemaphoreType.REGULAR)
    def _(second_barrier):
      for neighbor in [left_neighbor, right_neighbor]:
        pltpu.semaphore_signal(
          second_barrier,
          inc=1,
          device_id=(neighbor,),
          device_id_type=pltpu.DeviceIdType.MESH,
        )
      pltpu.semaphore_wait(second_barrier, 2)


def all_reduce_kernel(
    x_ref,
    o_ref,
    hbm_scratch,
    copy_sem,
    remote_recv_sem,
    remote_send_sem,
    capacity_sem,
    receive_scratch,
):
  outer_step = pl.program_id(0)
  working_slot = lax.rem(outer_step, 2)
  receiving_slot = 1 - working_slot

  my_id = lax.axis_index('x')
  right_neighbor = lax.rem(my_id + 1, num_devices)
  left_neighbor = lax.rem(my_id - 1 + num_devices, num_devices)

  @pl.when(outer_step == 0)
  def _():
    # Barrier with both neighbors at the start, since we will be
    # communicating with both.
    local_barrier(left_neighbor, right_neighbor)

    # Initialize o_ref, acc_scratch, and hbm_scratch.
    o_ref[...] = jnp.zeros_like(o_ref)
    receive_scratch[...] = jnp.zeros_like(receive_scratch)
    initial_copy = pltpu.make_async_remote_copy(
        src_ref=x_ref,
        dst_ref=hbm_scratch.at[working_slot],
        send_sem=remote_send_sem,
        recv_sem=remote_recv_sem,
        device_id=(right_neighbor,),
        device_id_type=pltpu.DeviceIdType.MESH,
    )
    initial_copy.start()
    initial_copy.wait()

  # Signal to our left neighbor that we are ready to receive.
  # Without this signal, our left neighbor can be >=1 iteration ahead,
  # meaning it could write into our working slot.
  pltpu.semaphore_signal(
      capacity_sem,
      inc=1,
      device_id=(left_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  # Copy the partial result our left neighbor sent to us into VMEM for
  # computation.
  local_copy = pltpu.make_async_copy(
      src_ref=hbm_scratch.at[working_slot],
      dst_ref=receive_scratch,
      sem=copy_sem,
  )
  local_copy.start()

  # Block until our right neighbor is ready to receive.
  pltpu.semaphore_wait(capacity_sem, 1)
  # Pass the value to our right neighbor.
  remote_copy = pltpu.make_async_remote_copy(
      src_ref=hbm_scratch.at[working_slot],
      dst_ref=hbm_scratch.at[receiving_slot],
      send_sem=remote_send_sem,
      recv_sem=remote_recv_sem,
      device_id=(right_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )
  remote_copy.start()
  # Finish local copy and accumulate while remote_copy is happening.
  local_copy.wait()
  o_ref[...] += receive_scratch[...]
  # Block until remote copy finishes.
  remote_copy.wait()


out_shape = (
    jax.ShapeDtypeStruct((8, 128), jnp.float32),
    # We allocate the double-buffer as a Pallas output so that it is
    # resident in HBM.
    jax.ShapeDtypeStruct((2, 8, 128), jnp.float32),  # hbm_scratch
)

grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=0,
    in_specs=[
        # Our input lives in VMEM
        pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),
    ],
    out_specs=[
        # Our output lives in VMEM
        pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),
        # Our double-buffer lives in HBM
        pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
    ],
    grid=(num_devices,),
    scratch_shapes=(
        [pltpu.SemaphoreType.DMA] * 3
        + [pltpu.SemaphoreType.REGULAR]  # capacity_sem
        + [pltpu.VMEM((8, 128), jnp.float32)]  # receive_scratch
    ),
)

kernel = pl.pallas_call(
    all_reduce_kernel,
    out_shape=out_shape,
    grid_spec=grid_spec,
    compiler_params=pltpu.CompilerParams(collective_id=0),
)

pallas_result = jax.jit(
    jax.shard_map(
        kernel,
        mesh=mesh,
        in_specs=partition,
        out_specs=partition,
        check_vma=False,
    )
)(input_arr)
pallas_result = jax.block_until_ready(pallas_result)[0]


def lax_sum(x):
  return lax.psum(x, 'x')


xla_result = jax.jit(
    jax.shard_map(
        lax_sum, mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x')
    )
)(input_arr)

print('Input = ', input_arr[0, ::128])
print('Pallas result = ', pallas_result[0, ::128])
print('lax.psum result = ', xla_result[0, ::128])
difference = jnp.mean(jnp.abs(pallas_result - xla_result))
print('Difference |Pallas - lax.psum| = ', difference)
Input =  [0.9858954  0.11763906 0.9955574  0.775211  ]
Pallas result =  [2.8743029 2.8743029 2.8743029 2.8743029]
lax.psum result =  [2.8743029 2.8743029 2.8743029 2.8743029]
Difference |Pallas - lax.psum| =  1.0535587e-08

超前执行和竞态条件#

通常来说,为了最大化性能,我们希望设备在不牺牲程序正确性的前提下,尽可能多地在没有同步的情况下超前于其他设备执行。虽然我们可以在每次迭代开始时强制所有设备进行屏障同步,但这会使程序的性能受限于每次循环中最慢的设备。通过放松同步并允许适度的超前执行,我们可以更好地适应迭代和设备之间的延迟差异,因为在一个迭代中慢的设备可以在下一个迭代中赶上来。

在我们之前编写的全规约内核中,我们允许设备超前执行,但与其邻居相比,超前量小于一个迭代(然而,非邻居设备之间的迭代数可能相距超过 1)。为了理解为什么信号量同步是必要的,考虑一个设备(比如设备 2)挂起并落后于其他设备的情况。RDMA 没有“握手”——只有接收方在等待数据到达时才会被阻塞。因此,每个设备在被阻塞等待下一个 RDMA 到达之前,可以超前执行最多一个迭代。如果设备数量为 N,这意味着最后一个设备可以比第一个设备超前最多 N 个迭代。

race_condition

如果不添加另一个方向的同步(强制发送方阻塞),设备 1 可能会超前设备 2 最多 N 次迭代(N = num_devices),在此过程中发送多个写入并覆盖值。为了解决我们之前编写的 all_reduce 内核中的这个问题,我们实现了一个“握手”协议,其中接收方向发送方发出信号表明它已准备好接收,然后发送方才开始发出下一个 RDMA。

双向通信#

在我们之前的内核中,我们沿着环形结构以从左到右的单一方向进行通信。然而,由于 ICI 连接是双向的,通过不以从右到左的相反方向发送值,我们实际上浪费了一半的总带宽。在下一个内核中,我们将演示一个示例,该示例以双向通信来最大化 ICI 带宽。

示例:双向规约分散 (lax.psum_scatter)#

规约分散操作是全规约和分散的组合。或者,全规约是规约分散和全收集的组合。

下图描述了此操作的语义。我们假设每个设备都从一组部分和(由字母+数字表示,例如 A0)开始。目标是沿着一个轴(数字)进行规约,同时沿着另一个轴(字母)进行分片。

reduce_scatter_1

为了实现双向通信策略,我们将每个输入块一分为二,并为每一半指定一个方向。每个块的上半部分将从右向左传递,下半部分将从左向右传递。与我们之前所有规约和全收集内核的通信模式的第二个不同之处在于,我们还将传递累加器或部分和,并保持输入对每个设备来说是本地的。这与之前的示例形成了对比,在那些示例中我们传递输入但将累加器保留在设备本地。传递累加器更适合这个问题,因为与全规约不同,输入中的大部分数据不属于将存储在设备本地的输出。(例如,图中的 B0C0D0 在最后不会存储在持有 A 的设备上)。

下图展示了这种通信模式,其中彩色框代表累加器(而非输入!)。最初,累加器就是输入中包含的值。在算法的每次迭代中,我们将从每个方向的邻居那里收到一个部分和。然后,我们计算输入中正确的切片以累加到部分缓冲区中,然后将新的部分和传递给下一个邻居。经过 N 次迭代后,累加器将通过每个设备,这意味着它最终将持有完整的和。

reduce_scatter_2

在内核的构建方面,我们在 Pallas 网格中引入了一个额外的 phase 维度,表示我们当前正在计算哪个累加器(左或右)。我们让 phase=0 表示累加器向左移动,phase=1 表示累加器向右移动。然后我们对这两个阶段进行流水线处理,以便在计算一个阶段的结果时,我们将之前计算的值沿相反方向传输,为下一个阶段做准备。例如,当我们在 phase=0(左)时,我们首先开始一个 DMA 将我们在上一次迭代中计算的结果传输到我们的右邻居(右 DMA)。然后,我们累加到左缓冲区并将结果保存到 HBM。然后我们等待右 DMA 完成,以便它为 phase=1(右)做好准备。

partition = P(None, 'x')
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)

# We need a block size of (16, 128) to ensure that a half-slice is at least
# of size (8, 128), which is the size of a VREG. This makes tiling easier
# for the compiler.
block_size = (16, 128)
input_arr = jax.random.uniform(
    jax.random.key(0),
    shape=(block_size[0] * num_devices, block_size[1] * num_devices),
)
input_arr = jax.device_put(input_arr, sharding)

LEFT = 0
RIGHT = 1


def mod(x, n):
  return lax.rem(x + n, n)


def signal(left_or_right, semaphore):
  my_id = lax.axis_index('x')
  if left_or_right == LEFT:
    neighbor = mod(my_id - 1, num_devices)
  else:
    neighbor = mod(my_id + 1, num_devices)
  pltpu.semaphore_signal(
      semaphore,
      inc=1,
      device_id=(neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )


def reduce_scatter_kernel(
    x_ref,
    o_ref,
    hbm_scratch,
    local_copy_sem,
    left_recv_sem,
    left_send_sem,
    right_recv_sem,
    right_send_sem,
    left_capacity_sem,
    right_capacity_sem,
    accum_scratch,
):
  outer_step = pl.program_id(0)
  phase = pl.program_id(1)
  is_start = jnp.logical_and(outer_step == 0, phase == 0)
  last_iteration = outer_step == pl.num_programs(0) - 1

  working_slot = lax.rem(outer_step, 2)
  receiving_slot = 1 - working_slot
  my_id = lax.axis_index('x')
  right_neighbor = mod(my_id + 1, num_devices)
  left_neighbor = mod(my_id - 1, num_devices)

  left_copy_device = mod(my_id + outer_step + 1, num_devices)
  right_copy_device = mod(my_id - outer_step - 1, num_devices)
  # Slices can be specified using pl.ds(start, size)
  left_copy_slice = pl.ds(0, block_size[0] // 2)
  right_copy_slice = pl.ds(block_size[0] // 2, block_size[0] // 2)
  current_phase_slice = pl.ds(phase * (block_size[0] // 2), block_size[0] // 2)

  initial_left_copy = pltpu.make_async_remote_copy(
      src_ref=x_ref.at[my_id, left_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, left_copy_slice],
      send_sem=left_send_sem,
      recv_sem=left_recv_sem,
      device_id=(left_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  initial_right_copy = pltpu.make_async_remote_copy(
      src_ref=x_ref.at[my_id, right_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
      send_sem=right_send_sem,
      recv_sem=right_recv_sem,
      device_id=(right_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  left_copy = pltpu.make_async_remote_copy(
      src_ref=hbm_scratch.at[working_slot, left_copy_slice],
      dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],
      send_sem=left_send_sem,
      recv_sem=left_recv_sem,
      device_id=(left_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )
  right_copy = pltpu.make_async_remote_copy(
      # Note: Right copy is flipped with regards to slots since we are copying
      # to the next outer_step iteration.
      src_ref=hbm_scratch.at[receiving_slot, right_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
      send_sem=right_send_sem,
      recv_sem=right_recv_sem,
      device_id=(right_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  # --- Prologue ---
  @pl.when(is_start)
  def _():
    # Barrier with both neighbors at the start, since we will be
    # communicating with both.
    local_barrier(left_neighbor, right_neighbor)

    # Initialize o_ref, acc_scratch, and hbm_scratch with initial copies.
    o_ref[...] = jnp.zeros_like(o_ref[...])
    accum_scratch[...] = jnp.zeros_like(accum_scratch[...])

    initial_left_copy.start()
    initial_left_copy.wait()
    initial_right_copy.start()

    # We tell our left neighbor that it is allowed to send to the right.
    # (and vice versa for right neighbor)
    signal(LEFT, right_capacity_sem)
    signal(RIGHT, left_capacity_sem)

  # --- Body ---
  # At the beginning of our kernel body, we start a DMA which copies
  # the result we computed in the previous phase to our neighbor.
  # This allows us to overlap the communication of sending our previous phase
  # with the computation for the current phase.
  @pl.when(~is_start)
  def _():
    @pl.when(phase == LEFT)
    def _():
      # We block here until our right neighbor tells use we can send to
      # the right.
      pltpu.semaphore_wait(right_capacity_sem, 1)
      right_copy.start()

    @pl.when(phase == RIGHT)
    def _():
      # We block here until our left neighbor tells use we can send to
      # the left.
      pltpu.semaphore_wait(left_capacity_sem, 1)
      left_copy.start()

  local_copy = pltpu.make_async_copy(
      src_ref=hbm_scratch.at[working_slot, current_phase_slice],
      dst_ref=accum_scratch,
      sem=local_copy_sem,
  )
  local_copy.start()
  local_copy.wait()

  @pl.when(~last_iteration)
  def _():
    @pl.when(phase == LEFT)
    def _():
      accum_scratch[...] += x_ref[left_copy_device, left_copy_slice]

    @pl.when(phase == RIGHT)
    def _():
      accum_scratch[...] += x_ref[right_copy_device, right_copy_slice]

  local_copy = pltpu.make_async_copy(
      src_ref=accum_scratch,
      dst_ref=hbm_scratch.at[working_slot, current_phase_slice],
      sem=local_copy_sem,
  )
  local_copy.start()
  local_copy.wait()

  @pl.when(is_start)
  def _():
    initial_right_copy.wait()

  # At the end of our kernel body, we wait on the DMA of the previous phase
  # to make sure the results are ready for the next phase.
  @pl.when(~is_start)
  def _():
    @pl.when(phase == LEFT)
    def _():
      right_copy.wait()
      signal(LEFT, right_capacity_sem)

    @pl.when(phase == RIGHT)
    def _():
      left_copy.wait()
      signal(RIGHT, left_capacity_sem)

  # --- Epilogue ---
  # Store result on last iteration.
  @pl.when(last_iteration)
  def _():
    # Clean up semaphores so that they exit with a value of 0.
    @pl.when(phase == LEFT)
    def _():
      o_ref[left_copy_slice, ...] = accum_scratch[...]
      pltpu.semaphore_wait(right_capacity_sem, 1)

    @pl.when(phase == RIGHT)
    def _():
      o_ref[right_copy_slice, ...] = accum_scratch[...]
      pltpu.semaphore_wait(left_capacity_sem, 1)


out_shape = (
    jax.ShapeDtypeStruct((block_size[0], block_size[1]), jnp.float32),  # output
    # Shape: [working/recv, block[0], block[1]]
    jax.ShapeDtypeStruct(
        (2, block_size[0], block_size[1]), jnp.float32
    ),  # hbm_scratch
)

grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=0,
    in_specs=[
        pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),
    ],
    out_specs=[
        pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),
        pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
    ],
    grid=(num_devices, 2),
    scratch_shapes=(
        [pltpu.SemaphoreType.DMA] * 5
        + [pltpu.SemaphoreType.REGULAR] * 2  # Capacity semaphores
        + [
            pltpu.VMEM((block_size[0] // 2, block_size[1]), jnp.float32)
        ]  # accum_scratch
    ),
)


def pallas_reduce_scatter(input_arr):
  input_arr = input_arr.reshape(num_devices, block_size[0], block_size[1])
  return pl.pallas_call(
      reduce_scatter_kernel,
      out_shape=out_shape,
      grid_spec=grid_spec,
      compiler_params=pltpu.CompilerParams(collective_id=0),
  )(input_arr)[0]


pallas_result = jax.jit(
    jax.shard_map(
        pallas_reduce_scatter,
        mesh=mesh,
        in_specs=P(None, 'x'),
        out_specs=P('x', None),
        check_vma=False,
    )
)(input_arr)

pallas_result = jax.block_until_ready(pallas_result)
# Compare our result to XLA.
def lax_reduce_sum_scatter(x):
  x = x.reshape(num_devices, block_size[0], block_size[1])
  return lax.psum_scatter(x, 'x')


xla_result = jax.jit(
    jax.shard_map(
        lax_reduce_sum_scatter,
        mesh=mesh,
        in_specs=P(None, 'x'),
        out_specs=P('x', None),
    )
)(input_arr)

print('Input:', input_arr.shape, input_arr[::4, 0])
print('Pallas Result:', pallas_result.shape, pallas_result[::4, 0])
print('lax.psum_scatter Result:', xla_result.shape, xla_result[::4, 0])
print(
    'Difference |Pallas - lax.psum_scatter|:',
    jnp.max(jnp.abs(pallas_result - xla_result)),
)
Input: (64, 512) [0.78051674 0.3524047  0.59993696 0.9714314  0.24692321 0.01347649
 0.01857424 0.24841607 0.86097646 0.8261659  0.9753758  0.6902338
 0.4431417  0.963323   0.3158517  0.535548  ]
Pallas Result: (64, 128) [1.3593563 1.6274805 1.0979297 3.082869  1.4194957 1.4163033 1.2401303
 1.1892898 2.6545286 2.221559  2.7995253 2.08431   2.2509837 3.0726733
 2.4662397 1.9542246]
lax.psum_scatter Result: (64, 128) [1.3593563 1.6274805 1.0979297 3.082869  1.4194957 1.4163033 1.2401303
 1.1892898 2.6545286 2.221559  2.7995253 2.08431   2.2509837 3.0726733
 2.4662397 1.9542246]
Difference |Pallas - lax.psum_scatter|: 2.3841858e-07

嵌套的远程和本地 DMA 流水线#

我们之前编写的全规约和规约分散内核的一个限制是,我们通过远程 DMA 复制的块必须足够小,才能容纳在我们用于累积的工作 VMEM 中。对于某些内核,使用更大的块大小可能更有利于更好地利用 TPU。例如,矩阵乘法需要大约 \(O(N^3)\) 次计算操作,但只需 \(O(N^2)\) 次内存传输。因此,我们希望在设备之间传输的每个工作块都足够大,以便操作变为计算密集型,并且我们可以使用流水线隐藏通信成本。作为参考,TPU(v4/v5 代)的 VMEM 通常在 10-100MB 范围内,而 HBM 范围从 10-100GB。

为了解决这个问题,我们需要能够编写一个“内部内核”来处理本地 HBM-VMEM 流水线,而“外部内核”处理设备之间更大的 HBM-HBM 传输。Pallas 提供了一个 API,用于使用 emit_pipeline 函数构建嵌套流水线。emit_pipeline 的基本调用签名遵循标准 pallas_call 的形式,通过为输入和输出指定 gridBlockSpec

def emit_pipeline(
    kernel: Callable,
    grid: tuple[int],
    in_specs: PyTree[BlockSpec] = None,
    out_specs: PyTree[BlockSpec] = None,
    should_accumulate_out: bool = False,
    dimension_semantics: tuple[GridDimensionSemantics] = None,
) -> Callable:
  ... # Returns a custom pipeline given an inner kernel and BlockSpecs.

事实上,可以将 pallas_call 本身视为 emit_pipeline 的简单封装。由于我们的外部内核仅涉及远程 HBM-HBM 传输,我们没有使用 pallas_call 为 HBM-VMEM 传输提供的任何内置流水线。以下代码骨架演示了使用此模式的典型程序结构:


def outer_kernel(...):
  # ... do work to pipeline remote HBM-HBM transfers (outer kernel)

  def inner_kernel(...):
    # ... do work (inner kernel)
  pltpu.emit_pipeline(
          inner_kernel,
          grid=inner_grid,
          in_specs=...,
          out_specs=...,
  )(inner_kernel_args)
  # ... do more work (outer kernel)

pl.pallas_call(
  outer_kernel,
  grid=outer_grid,
  in_specs=...
  out_specs=...
  scratch=inner_kernel_allocs
)

示例:带有大 HBM 块的规约分散#

在下一个示例中,我们将修改之前的规约分散示例,以利用嵌套的内部流水线。请注意,reduce_scatter 的通信和计算成本都与输入大小线性相关,因此我们不一定期望随着块大小的增加而使操作变为计算密集型。本示例纯粹用于演示如何使用流水线发射器。

我们将增加外部内核的块大小,使其不适合放置在 VMEM 中,并将所有输入和输出分配在 HBM 中 (memory_space=MemorySpace.ANY)。与我们之前的内核相比,唯一的主要变化是内核的主体,即累积完成的地方。我们不再手动从 HBM 复制到 VMEM,累积,然后复制回 HBM,而是使用 emit_pipeline 为我们处理内存传输。累积在内部内核中完成,其块大小更小,更适合 VMEM。

在之前的内核中,我们有以下内核主体,用于将数据从 HBM 复制到 VMEM 累加器,递增,然后将结果复制回 HBM:

local_copy = pltpu.make_async_copy(
    src_ref=hbm_scratch.at[working_slot, current_phase_slice],
    dst_ref=accum_scratch,
    sem=local_copy_sem,
)
local_copy.start()
local_copy.wait()
@pl.when(~last_iteration)
def _():
  @pl.when(phase == LEFT)
  def _():
    accum_scratch[...] += x_ref[left_copy_device, left_copy_slice]
  @pl.when(phase == RIGHT)
  def _():
    accum_scratch[...] += x_ref[right_copy_device, right_copy_slice]
local_copy = pltpu.make_async_copy(
    src_ref=accum_scratch,
    dst_ref=hbm_scratch.at[working_slot, current_phase_slice],
    sem=local_copy_sem,
)
local_copy.start()
local_copy.wait()

我们的新内核将其替换为以下 emit_pipeline 调用:

def inner_kernel(input_ref, accum_ref):
  accum_ref[...] = input_ref[...]
accum_pipeline = pltpu.emit_pipeline(inner_kernel,
                                     in_specs=[inner_block_spec],
                                     out_specs=inner_block_spec,
                                     should_accumulate_out=True,
                                     grid=inner_grid)
@pl.when(~last_iteration)
def _():
  @pl.when(phase == LEFT)
  def _():
    accum_pipeline(x_ref.at[left_copy_device, left_copy_slice],
                   hbm_scratch.at[working_slot, left_copy_slice],
    )
  @pl.when(phase == RIGHT)
  def _():
    accum_pipeline(x_ref.at[right_copy_device, right_copy_slice],
                   hbm_scratch.at[working_slot, right_copy_slice],
    )

完整的内核如下:

partition = P(None, 'x')
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)

# We pick a large outer kernel block size that we do not want to place
# in VMEM. For pedagogical purposes we use (4096, 4096), although in
# principle this can be much larger.
outer_block_size = (4096, 4096)
# We pick a smaller VMEM block size for the inner kernel.
inner_block_size = (128, 128)
input_arr = jax.random.uniform(
    jax.random.key(0),
    shape=(
        outer_block_size[0] * num_devices,
        outer_block_size[1] * num_devices,
    ),
)
input_arr = jax.device_put(input_arr, sharding)


inner_grid = (
    outer_block_size[0] // inner_block_size[0] // 2,
    outer_block_size[1] // inner_block_size[1],
)
inner_block_spec = pl.BlockSpec(
    index_map=lambda i, j: (i, j),
    block_shape=inner_block_size,
    memory_space=pltpu.MemorySpace.ANY,
)


def reduce_scatter_kernel(
    x_ref,
    o_ref,
    hbm_scratch,
    left_recv_sem,
    left_send_sem,
    copy_sem,
    right_recv_sem,
    right_send_sem,
    left_capacity_sem,
    right_capacity_sem,
):
  outer_step = pl.program_id(0)
  phase = pl.program_id(1)
  is_start = jnp.logical_and(outer_step == 0, phase == 0)
  last_iteration = outer_step == pl.num_programs(0) - 1

  working_slot = lax.rem(outer_step, 2)
  receiving_slot = 1 - working_slot
  my_id = lax.axis_index('x')
  right_neighbor = mod(my_id + 1, num_devices)
  left_neighbor = mod(my_id - 1, num_devices)

  left_copy_device = mod(my_id + outer_step + 1, num_devices)
  right_copy_device = mod(my_id - outer_step - 1, num_devices)
  left_copy_slice = pl.ds(0, outer_block_size[0] // 2)
  right_copy_slice = pl.ds(outer_block_size[0] // 2, outer_block_size[0] // 2)
  current_phase_slice = pl.ds(
      phase * (outer_block_size[0] // 2), outer_block_size[0] // 2
  )

  initial_left_copy = pltpu.make_async_remote_copy(
      src_ref=x_ref.at[my_id, left_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, left_copy_slice],
      send_sem=left_send_sem,
      recv_sem=left_recv_sem,
      device_id=(left_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  initial_right_copy = pltpu.make_async_remote_copy(
      src_ref=x_ref.at[my_id, right_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
      send_sem=right_send_sem,
      recv_sem=right_recv_sem,
      device_id=(right_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  left_copy = pltpu.make_async_remote_copy(
      src_ref=hbm_scratch.at[working_slot, left_copy_slice],
      dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],
      send_sem=left_send_sem,
      recv_sem=left_recv_sem,
      device_id=(left_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )
  right_copy = pltpu.make_async_remote_copy(
      src_ref=hbm_scratch.at[receiving_slot, right_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
      send_sem=right_send_sem,
      recv_sem=right_recv_sem,
      device_id=(right_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  # --- Prologue ---
  @pl.when(is_start)
  def _():
    # Barrier with both neighbors at the start, since we will be
    # communicating with both.
    local_barrier(left_neighbor, right_neighbor)

    initial_left_copy.start()
    initial_left_copy.wait()
    initial_right_copy.start()

    # We tell our left neighbor that it is allowed to send to the right.
    # (and vice versa for right neighbor)
    signal(LEFT, right_capacity_sem)
    signal(RIGHT, left_capacity_sem)

  @pl.when(~is_start)
  def _():
    @pl.when(phase == LEFT)
    def _():
      # We block here until our right neighbor tells use we can send to
      # the right.
      pltpu.semaphore_wait(right_capacity_sem, 1)
      right_copy.start()

    @pl.when(phase == RIGHT)
    def _():
      # We block here until our left neighbor tells use we can send to
      # the left.
      pltpu.semaphore_wait(left_capacity_sem, 1)
      left_copy.start()

  # --- Body ---
  def inner_kernel(input_ref, accum_ref):
    # We do not explicitly use += because we set should_accumulate_out=True.
    accum_ref[...] = input_ref[...]

  accum_pipeline = pltpu.emit_pipeline(
      inner_kernel,
      in_specs=[inner_block_spec],
      out_specs=inner_block_spec,
      should_accumulate_out=True,
      grid=inner_grid,
  )

  @pl.when(~last_iteration)
  def _():
    @pl.when(phase == LEFT)
    def _():
      accum_pipeline(
          x_ref.at[left_copy_device, left_copy_slice],
          hbm_scratch.at[working_slot, left_copy_slice],
      )

    @pl.when(phase == RIGHT)
    def _():
      accum_pipeline(
          x_ref.at[right_copy_device, right_copy_slice],
          hbm_scratch.at[working_slot, right_copy_slice],
      )

  # --- Epilogue ---
  @pl.when(is_start)
  def _():
    initial_right_copy.wait()

  @pl.when(~is_start)
  def _():
    @pl.when(phase == LEFT)
    def _():
      right_copy.wait()
      signal(LEFT, right_capacity_sem)

    @pl.when(phase == RIGHT)
    def _():
      left_copy.wait()
      signal(RIGHT, left_capacity_sem)

  # Store result on last iteration.
  @pl.when(last_iteration)
  def _():
    output_copy = pltpu.make_async_copy(
        src_ref=hbm_scratch.at[working_slot, current_phase_slice],
        dst_ref=o_ref.at[current_phase_slice],
        sem=copy_sem,
    )
    output_copy.start()
    output_copy.wait()

    # Clean up semaphores so that they exit with a value of 0.
    @pl.when(phase == LEFT)
    def _():
      pltpu.semaphore_wait(right_capacity_sem, 1)

    @pl.when(phase == RIGHT)
    def _():
      pltpu.semaphore_wait(left_capacity_sem, 1)


out_shape = (
    jax.ShapeDtypeStruct(
        (outer_block_size[0], outer_block_size[1]), jnp.float32
    ),
    # Shape: [working/recv, block[0], block[1]]
    jax.ShapeDtypeStruct(
        (2, outer_block_size[0], outer_block_size[1]), jnp.float32
    ),  # hbm_scratch
)

grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=0,
    in_specs=[
        pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
    ],
    out_specs=[
        pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
        pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
    ],
    grid=(num_devices, 2),
    scratch_shapes=(
        [pltpu.SemaphoreType.DMA] * 5
        + [pltpu.SemaphoreType.REGULAR] * 2  # Capacity semaphores
    ),
)


def pallas_reduce_scatter(input_arr):
  input_arr = input_arr.reshape(
      num_devices, outer_block_size[0], outer_block_size[1]
  )
  return pl.pallas_call(
      reduce_scatter_kernel,
      out_shape=out_shape,
      grid_spec=grid_spec,
      compiler_params=pltpu.CompilerParams(collective_id=0),
  )(input_arr)[0]


pallas_result = jax.jit(
    jax.shard_map(
        pallas_reduce_scatter,
        mesh=mesh,
        in_specs=P(None, 'x'),
        out_specs=P('x', None),
        check_vma=False,
    )
)(input_arr)

pallas_result = jax.block_until_ready(pallas_result)
# Now we compare our result to XLA.
def lax_reduce_sum_scatter(x):
  x = x.reshape(num_devices, outer_block_size[0], outer_block_size[1])
  return lax.psum_scatter(x, 'x')


xla_result = jax.jit(
    jax.shard_map(
        lax_reduce_sum_scatter,
        mesh=mesh,
        in_specs=P(None, 'x'),
        out_specs=P('x', None),
    )
)(input_arr)

print('Input:', input_arr.shape, input_arr[::4, 0])
print('Pallas Result:', pallas_result.shape, pallas_result[::4, 0])
print('lax.psum_scatter Result:', xla_result.shape, xla_result[::4, 0])
print(
    'Difference |Pallas - lax.psum_scatter|:',
    jnp.max(jnp.abs(pallas_result - xla_result)),
)
Input: (16384, 16384) [0.74162567 0.0242182  0.27751946 ... 0.05213022 0.36088037 0.04494429]
Pallas Result: (16384, 4096) [2.0648427 1.674587  1.9148926 ... 1.3371865 1.3296283 1.2887063]
lax.psum_scatter Result: (16384, 4096) [2.0648427 1.674587  1.9148926 ... 1.3371865 1.3296283 1.2887063]
Difference |Pallas - lax.psum_scatter|: 2.3841858e-07

最终说明#

Megacore#

某些 TPU 包含多个核心,采用 Megacore 配置。在此配置中,我们的一般建议是仅从单个核心启动 DMA,并且只执行 HBM-HBM 传输。为此,将其中一个网格轴设置为核心数(可通过 jax.devices()[0].num_cores 获取),并将 dimension_semantics 设置为 "parallel"。然后,您可以使用 core_index = pl.program_id(axis) 获取沿该轴的核心索引,并使用 @pl.when(core_index==i) 执行特定于该核心的代码。

与 XLA 的交互#

在本教程中,我们介绍了几个内核示例,它们复制了 JAX 中集合操作的功能,例如 lax.all_gatherlax.psumlax.psum_scatter。需要注意的一个重要告诫是,Pallas 内核对 XLA 编译器来说是某种程度上不透明的,这可能导致它错过一些通常会执行的优化。例如,XLA 可以异步调度集合操作,以便在不编写自定义内核的情况下交错通信和计算。当涉及 Pallas 内核时,不能保证会发生这种情况,因此对程序进行性能分析以查看是否存在此问题非常重要。另一个示例是,本教程中我们用于生成嵌套流水线的 emit_pipeline 函数对 XLA 编译器不可见,因此无法与相邻操作融合。

后续步骤#

对读者而言,优秀的后续练习可能包括实现分布式矩阵乘法、实现 lax.all_to_all 以及放宽同步以允许额外的超前执行。