Pallas 异步操作#

背景 + 动机#

我们希望在 Pallas 中公开 API,以显式地重叠多个内核之间的计算和通信。

XLA 异步分解#

作为动机,请考虑以下 JAX 伪代码

def f(x):
  y = ppermute(x)
  z = x + 1
  return y, z

在此函数中,我们可以同时执行 ppermutex + 1。这是 XLA 通过以下方式自动执行的优化:

  1. ppermute 分解为 ppermute_startppermute_done 操作,这两个操作通过 future 连接。

  2. ppermute_startppermute_done 之间调度 x + 1

生成以下程序

def f(x):
  fut = ppermute_start(x)
  z = x + 1  # happens at the same time as ppermute
  y = ppermute_done(fut)
  return y, z

内核内的异步操作#

现在假设我们没有使用 XLA 的 ppermute,而是使用我们自己的自定义 Pallas ppermute

def ppermute_kernel(x_ref, y_ref, send_sem, recv_sem):
  right_neighbor = ...
  descriptor = pltpu.make_async_remote_copy(x_ref, y_ref, send_sem, recv_sem, device_id=right_neighbor)
  descriptor.start()
  descriptor.wait_send()
  descriptor.wait_recv()

def ppermute(x):
  return pl.pallas_call(ppermute_kernel, out_shape=x, ...)(x)

目前,我们无法像 XLA 那样将 ppermute 分解为 start/done 对,因此我们显式地将 x + 1 融合到内核中。

def add_one(x_ref, z_ref):
  z_ref[...] = x_ref[...] + 1

def ppermute_add_one_kernel(x_ref, y_ref, z_ref, send_sem, recv_sem):
  right_neighbor = ...
  descriptor = pltpu.make_async_remote_copy(x_ref, y_ref, send_sem, recv_sem, device_id=right_neighbor)
  descriptor.start()

  # Explicitly schedule inner kernel between start/wait
  pltpu.emit_pipeline(add_one)(x_ref, z_ref)

  descriptor.wait_send()
  descriptor.wait_recv()

def ppermute_and_add_one(x):
  return pl.pallas_call(ppermute_add_one_kernel, out_shape=(x, x), ...)(x)

我们的目标是能够为启动 ppermute 和等待其完成编写单独的内核,这样我们就可以在两者之间使用普通的 x + 1(或我们想要的任何计算)。这使代码更具可读性、可维护性且更不易出错。

我们如何在(TPU 上)实现分解的 Pallas 异步操作?#

在 Pallas 中实现分解的异步操作时,主要要弄清楚的是在它们之间传递的 future 包含什么。具体来说,它必须包含有关在后台发生的运算的一些重要状态。

如果我们查看 Pallas 代码,我们可以看到我们需要一个“描述符”来启动和等待远程复制。我们是否可以将这个描述符从 Pallas 内核中提取出来,然后将其传递到另一个内核中?嗯,有点。底层的 TPU 硬件通过一对信号量跟踪异步操作的进度:send_sem 使我们能够等待设备何时完成将数据发送到其邻居,而 recv_sem 跟踪从邻居发送到设备的数据传输。如果我们想象编写一个启动内核和一个完成内核,那么我们需要从启动传递到完成的只是信号量和一些关于等待这些信号量的多少的信息。

我们可以通过扩展 Pallas 来支持从内核返回信号量来实现这一点。

def ppermute_start_kernel(
    in_ref, send_sem, recv_sem, out_ref, *, axis_name,
):
  axis_size = jax.lax.psum(1, axis_name)
  left_neighbor = jax.lax.rem(
      jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size
  )
  right_neighbor = jax.lax.rem(jax.lax.axis_index(axis_name) + 1, axis_size)
  barrier_sem = pltpu.get_barrier_semaphore()
  pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor)
  pltpu.semaphore_wait(barrier_sem, 1)
  pltpu.make_async_remote_copy(
      in_ref, out_ref, send_sem, recv_sem, device_id=right_neighbor
  ).start()

def ppermute_start(x, *, axis_name) -> tuple[Semaphore, Semaphore, Array]:
  send_sem, recv_sem, out = pl.pallas_call(
      functools.partial(ppermute_start_kernel, axis_name=axis_name),
      out_shape=(
          pltpu.SemaphoreType.DMA(()),
          pltpu.SemaphoreType.DMA(()),
          jax.ShapeDtypeStruct(
              x.shape,
              dtype=x.dtype,
          ),
      ),
      in_specs=[
          pl.BlockSpec(memory_space=pltpu.ANY),
      ],
      out_specs=(
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
          pl.BlockSpec(memory_space=pltpu.ANY),
      ),
  )(x)
  return send_sem, recv_sem, out

请注意,这里发生了一些微妙的事情。Pallas 正在告诉 XLA 它希望一些输出是信号量(也称为同步标志),XLA 会将它们视为“保留”(例如,当它们在 XLA 程序中处于活动状态时,其他内核无法分配这些同步标志)。它们的行为类似于屏障信号量,屏障信号量是 XLA 管理的保留信号量。

另一个需要注意的是,我们从启动内核返回输出缓冲区 out同时它正在被主动复制到其中

现在我们编写执行阻塞操作的 done 内核。我们将 out 传递到内核中,以计算阻塞信号量所需的形状。

def ppermute_done_kernel(ref, send_sem, recv_sem, _):
  pltpu.make_async_copy(ref, ref, send_sem).wait()
  pltpu.make_async_copy(ref, ref, recv_sem).wait()

def ppermute_done(send_sem, recv_sem, out) ->Array:
  out = pl.pallas_call(
      ppermute_done_kernel,
      out_shape=(
          jax.ShapeDtypeStruct(
              out.shape,
              dtype=out.dtype,
          ),
      ),
      in_specs=[
          pl.BlockSpec(memory_space=pltpu.ANY),
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
      ],
      out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
      input_output_aliases={0:0}
  )(out, send_sem, recv_sem)
  return out

注意:我们在此处对输出缓冲区进行 i/o 别名,以保证使用者在 ppermute_done 的下游。

现在我们可以实现分解的集合排列。

def f(x):
  fut = ppermute_start(x)
  z = x + 1  # happens at the same time as ppermute
  y = ppermute_done(fut)
  return y, z

或者我们可以吗?

为什么它不起作用#

这里还有三个问题,每个问题都在 Pallas 之外存在一定程度。以下是它们的高级概述。

  1. 调度 - 仅仅因为我们编写了 ppermute_start,然后是 x + 1,然后是 ppermute_done,并不能保证它们会按该顺序发生。XLA 负责调度,因此当我们编写 JAX 程序时,我们正在建立 XLA 将遵守的数据依赖关系,但 XLA 不会遵守 JAX 中编写的特定操作顺序。

  2. 生命周期 - XLA 假设一旦值超出依赖关系图中的范围,就可以释放其内存供其他值使用。如果我们有一个异步复制 x -> y 的操作,我们需要确保 x 在复制完成之前处于活动状态,否则我们将从垃圾内存中复制。

  3. 防御性复制 - XLA 保留创建值副本的权利。我们需要确保我们不会引入不必要的副本,以 a) 避免不必要的运行时开销,以及 b) 确保正确性。

我们将逐一讨论这些问题并提出修复方法。

调度#

我们如何在 JAX 中显式强制操作以特定顺序发生?请注意,这不是 Pallas 特有的问题,如果我们使用其他方法实现了异步操作,我们仍然会遇到这个问题。

一种方法是在 XLA 程序中引入优化屏障。优化屏障将阻止 XLA 移动其周围的操作。

这是我们的原始代码

def f(x):
  fut = ppermute_start(x)
  z = x + 1
  y = ppermute_done(fut)
  return y, z

XLA 可以选择在以下三个位置中的任何一个位置执行 x + 1

def f(x):
  z = x + 1
  fut = ppermute_start(x)
  y = ppermute_done(fut)
  return y, z

# OR

def f(x):
  fut = ppermute_start(x)
  z = x + 1
  y = ppermute_done(fut)
  return y, z

# OR

def f(x):
  fut = ppermute_start(x)
  y = ppermute_done(fut)
  z = x + 1
  return y, z

为了强制 x + 1ppermute 操作之间发生,我们可以使用 optimization_barrier,它在语义上是恒等函数(即 lambda x: x),但在值之间引入了显式的数据依赖关系。具体来说,如果我们使 x + 1 中使用的 x 依赖于 ppermute_start 返回的 fut,则它必须在 ppermute_start 之后发生。

我们还引入了一个依赖关系,该依赖关系强制输出值 y 依赖于 z

def f(x):
  fut = ppermute_start(x)
  x, fut = optimization_barrier((x, fut))  # x now depends on fut
  z = x + 1
  z, fut = optimization_barrier((z, fut)) # fut now depends on z
  y = ppermute_done(fut)
  return y, z

optimization_barrier 对我们来说是一个足够好的工具,可以显式地编写出调度。

生命周期#

让我们再次查看我们的原始代码,并假设操作以正确的顺序发生。

def f(x):
  fut = ppermute_start(x)
  z = x + 1
  y = ppermute_done(fut)
  return y, z

让我们看看 XLA 认为可以释放 x 缓冲区的程序中的哪个点。它将在 x 不再使用之后,具体来说是在 z = x + 1 之后。

def f(x):
  fut = ppermute_start(x)
  z = x + 1
  # XLA can free x here!
  y = ppermute_done(fut)
  return y, z

如果 XLA 在 z = x + 1 完成后释放 x,我们会遇到一个非常糟糕的问题。ppermute 可能仍在 z = x + 1 之后主动将 x 复制到邻居,这意味着如果释放 x,则 ppermute 将从垃圾内存中读取!

我们如何将 x 的生命周期延长到 ppermute_done?好吧,我们可以引入数据依赖关系!我们需要稍微修改一下我们的内核来实现这一点。

首先,我们重写 ppermute_start 以返回 x,通过内核对其进行别名。

def ppermute_start_kernel(
    in_ref, send_sem, recv_sem, out_ref, _, *, axis_name,
):
  axis_size = jax.lax.psum(1, axis_name)
  left_neighbor = jax.lax.rem(
      jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size
  )
  right_neighbor = jax.lax.rem(jax.lax.axis_index(axis_name) + 1, axis_size)
  barrier_sem = pltpu.get_barrier_semaphore()
  pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor)
  pltpu.semaphore_wait(barrier_sem, 1)
  pltpu.make_async_remote_copy(
      in_ref, out_ref, send_sem, recv_sem, device_id=right_neighbor
  ).start()

def ppermute_start(x, *, axis_name) -> tuple[Semaphore, Semaphore, Array, Array]:
  send_sem, recv_sem, x, out = pl.pallas_call(
      functools.partial(ppermute_start_kernel, axis_name=axis_name),
      out_shape=(
          pltpu.SemaphoreType.DMA(()),
          pltpu.SemaphoreType.DMA(()),
          jax.ShapeDtypeStruct(
              x.shape,
              dtype=x.dtype,
          ),
	   jax.ShapeDtypeStruct(
              x.shape,
              dtype=x.dtype,
          ),
      ),
      in_specs=[
          pl.BlockSpec(memory_space=pltpu.ANY),
      ],
      out_specs=(
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
          pl.BlockSpec(memory_space=pltpu.ANY),
          pl.BlockSpec(memory_space=pltpu.ANY),
      ),
      input_output_aliases={0:2}
  )(x)
  return send_sem, recv_sem, x, out

然后,我们让 ppermute_done 接收 x 并且不执行任何操作。

def ppermute_done_kernel(_, ref, send_sem, recv_sem, _):
  pltpu.make_async_copy(ref, ref, send_sem).wait()
  pltpu.make_async_copy(ref, ref, recv_sem).wait()

def ppermute_done(send_sem, recv_sem, x, out) ->Array:
  out = pl.pallas_call(
      ppermute_done_kernel,
      out_shape=(
          jax.ShapeDtypeStruct(
              out.shape,
              dtype=out.dtype,
          ),
      ),
      in_specs=[
          pl.BlockSpec(memory_space=pltpu.ANY),
          pl.BlockSpec(memory_space=pltpu.ANY),
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
      ],
      out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
      input_output_aliases={1:0}
  )(x, out, send_sem, recv_sem)
  return out

现在当我们编写

def f(x):
  *sems, x ,out = ppermute_start(x)
  z = x + 1
  y = ppermute_done(*sems, x, out)
  return y, z

XLA 不能再释放 x,因为它 是 ppermute_done 的输入!这意味着 x 的生命周期与 ppermute 相关联,并且此代码现在是正确的。

防御性复制#

XLA 在其缓冲区分配过程中,分析哪些缓冲区彼此别名,并在别名其输入之一的操作不是该输入的最终使用者时插入副本。

背景#

这是一个简单的示例。假设我们有一个操作 add_one_inplace,它接收一个数组并加一,但承诺就地执行。

以下代码将是合法的。

def f():
  x = jnp.arange(...)
  y = add_one_inplace(x)  return y

但是,如果 x 也有单独的消费者,则程序可能无法正确执行。

def f():
  x = jnp.arange(...)
  y = add_one_inplace(x)
  return y, x * 2 # another x consumer!

这是因为 x * 2 操作的是原始的 x,而 add_one_inplace 则会覆盖 x 中的值。 x * 2 需要确保读取的是 x 的原始值,而不是我们将其递增 1 后的值。XLA 注意到了这一点,并插入了一个 copy 操作(在语义上是恒等式,但输入和输出缓冲区将不同)。

def f(x):
  x2 = copy(x)
  y = add_one_inplace(x2)
  return y, x * 2

XLA 中的这一遍处理确保了在存在执行原地更新操作的情况下,通过强制它们使用 copy 操作有效地变成非原地更新来保证正确性。

带有下游操作的复制#

让我们重新审视在执行 ppermute 操作的同时加 1 的示例。

def f(x):
  fut = ppermute_start(x)
  z = x + 1
  y = ppermute_done(fut)
  return y, z

如果我们把 future 拆分成其组件,我们将看到别名模式

def f(x):
  *sems, x2, y = ppermute_start(x)
  z = x + 1
  y = ppermute_done((*sems, x2, y))
  return y, z

我们知道 xppermute_start 操作后保持不变(也就是说,xx2 相同),但 XLA 并不知道。实际上,它看起来就像是我们的 add_one_inplace 示例在 XLA 中的情况,它保守地假设 ppermute_start 修改了 x,并且 x2 是新的别名结果。因此,当我们执行 z = x + 1 时,我们遇到了原始缓冲区的消费者。因此,XLA 引入了一个复制操作!

def f(x):
  x2 = copy(x)
  *sems, x2, y = ppermute_start(x2)
  z = x + 1
  y = ppermute_done((*sems, x2, y))
  return y, z

这个复制操作是不必要的,因为我们知道 x2x 相比没有改变。为了移除这个复制操作,我们需要某种机制来通知 XLA 我们只是转发一个值。然而,在没有这种机制的情况下,我们可以稍微重写我们的程序,显式地使用 x2 而不是 x

def f(x):
  *sems, x2, y = ppermute_start(x)
  z = x2 + 1
  y = ppermute_done((*sems, x2, y))
  return y, z

现在,XLA 没有看到 x 的单独消费者,因此不再引入复制操作。然而,这样做的一个主要缺点是,它迫使我们解包从 ppermute_start 返回的 future。它将生命周期问题与复制问题联系起来。

循环别名#

让我们考虑一个稍微高级的例子。让我们实现一个使用 while_loopppermute 来在环中传递值的函数。

def f(x):
  def body(i, x):
    fut = ppermute_start(x)
    y = ppermute_done(fut)
    return y
  return fori_loop(0, 8, body, x)

fori_loop 的一个实现细节是,输入和输出缓冲区会自动相互别名。请注意,我们在 ppermute_startppermute_done 操作中设置了一些额外的别名。让我们通过为程序中的每个值着色来运行我们自己的“缓冲区分配”,以确定我们需要多少个唯一的缓冲区。

首先,我们将解包具有别名 xout 缓冲区的 fut 元组。

def f(x):
  def body(i, x):
    *sems, x, y = ppermute_start(x)
    y = ppermute_done(*sems, x, y)
    return y
  return fori_loop(0, 8, body, x)

现在,让我们根据分配给它们的唯一缓冲区为每个值着色。我们有来自 fori_loop 的输入/输出别名,来自 ppermute_startx 别名和来自 ppermute_doney 别名。

def f(x):
  def body(i, x):
    *sems, x, y = ppermute_start(x)
    y = ppermute_done((*sems, x, y))
    return y
  return fori_loop(0, 8, body, x)

如果你运行别名分析,你会发现所有缓冲区都被着色为相同颜色!直观地说,这是有问题的,因为如果我们正在执行 ppermute 循环,我们不能写入我们正在发送到的同一个缓冲区。我们通常需要一个额外的(即“双”)缓冲区来接收,然后通常我们会在下一次迭代中切换发送/接收缓冲区。XLA 实际上会做的是,它会观察缓冲区重用并防御性地插入一个副本。

def f(x):
  def body(i, x):
    x = copy(x)
    *sems, x, y = ppermute_start(x)
    y = ppermute_done((*sems, x, y))
    return y
  return fori_loop(0, 8, body, x)

这个副本意味着 xy 不再相互别名,程序将是正确的。但是,我们需要这个副本吗?我们如何引入双缓冲区以避免每次迭代都进行昂贵的复制?答案是展开循环!

我们将手动展开我们的代码。

def f(x):
  def body(i, x):
    *sems, x, x2 = ppermute_start(x)
    x2 = ppermute_done((*sems, x, x2))
    
    *sems, x2, y = ppermute_start(x2)
    y = ppermute_done((*sems, x2, y))
    return y
  return fori_loop(0, 4, body, x)

现在,如果我们运行相同的别名分析,我们会发现缓冲区不再相互别名,并且我们不需要插入防御性副本来保证正确性。

因此,删除这些副本的简单解决方案是使用 fori_loop 并设置 unroll >= 2

def f(x):
  def body(i, x):
    fut = ppermute_start(x)
    y = ppermute_done(fut)
    return y
  return fori_loop(0, 8, body, x, unroll=2)

这足以在不进行额外复制的情况下实现此循环!

在循环边界之间传递 future#

现在让我们看一个更高级的例子。我们将实现与之前相同的程序,但错开循环,我们在循环之前的序言中开始 ppermute,并在循环开始时等待 ppermute

def f(x):
  fut = ppermute_start(x)
  def body(i, fut):
    x = ppermute_done(fut)
    fut = ppermute_start(x)
    return fut
  fut = fori_loop(0, 7, body, fut)
  return ppermute_done(fut)

在这个例子中,我们不是将值 x 从一个循环传递到另一个循环,而是传递一个 future 值。

让我们再次解包 future 以查看发生了什么。

def f(x):
  fut = ppermute_start(x)
  def body(i, fut):
    *sems, x, out = fut
    x = ppermute_done((*sems, x, out))
    (*sems, x, out) = ppermute_start(x)
    return (*sems, x, out)
  (*sems, x, out) = fori_loop(0, 7, body, x)
  return ppermute_done((*sems, x, out))

因此,我们显式地将信号量、输入缓冲区和目标输出缓冲区作为循环携带线程化。如果我们现在运行别名分析会发生什么?好吧,我们将遇到与上一节中相同的别名问题,其中 xout 将相互别名。XLA 将引入一个副本。

def f(x):
  fut = ppermute_start(x)
  def body(i, fut):
    *sems, x, out = fut
    out = copy(out)
    x = ppermute_done((*sems, x, out))
    (*sems, x, out) = ppermute_start(x)
    return (*sems, x, out)
  (*sems, x, out) = fori_loop(0, 7, body, x)
  return ppermute_done((*sems, x, out))

在这种情况下,我们在 out 上插入了一个副本。然而,这是一个非常糟糕的情况,因为 out 正在被主动复制到其中!即使我们在 x 上插入一个副本,我们也会遇到问题,因为 x 的生命周期不会扩展到 ppermute_done。这非常非常糟糕!我们不仅会得到副本,还会得到不正确的结果!

正如我们之前观察到的,解决方案是通过展开来避免对所有缓冲区进行别名来避免副本。所以,如果我们这样做

def f(x):
  fut = ppermute_start(x)
  def body(i, fut):
    x = ppermute_done(fut)
    fut = ppermute_start(x)
    return fut
  fut = fori_loop(0, 7, body, x, unroll=2)
  return ppermute_done(fut)

我们的程序现在应该是正确的。

将它们放在一起#

所以我们总结了一些经验法则

  1. 如果我们有依赖于 ppermute 的输入值的操作,请解包 future 以使用别名值而不是原始值。

  2. 在循环体中执行 ppermute 时,请使用 unroll >= 2

让我们将所有内容组合到一个函数中,该函数在循环中执行 ppermute 并累积结果。

def f(x):
  out = jnp.zeros_like(x)
  fut = (*sems, x, out) = ppermute_start(x)
  out = out + x
  def body(i, carry):
    out, fut = carry
    x = ppermute_done(fut)
    fut = (*sems, x, out) = ppermute_start(x)
    out = out + x
    return out, fut
  out, fut = fori_loop(0, 7, body, (out, fut), unroll=2)
  return out, ppermute_done(fut)

请注意,在此示例中,我们不需要 optimization_barrier,因为循环边界充当调度屏障,分割 startdone

就是这样,我们完成了!这将是 Pallas 中进行异步操作的官方 API。谢谢大家!任务完成!

或者真的是这样吗?

状态的反击#

虽然看起来我们通过使用一些巧妙的技巧解决了副本和不正确的问题,但我们仍然处于尴尬的境地。此 API 功能强大,但存在许多陷阱和注意事项。可能还有更多我们需要处理的边缘情况,甚至需要深入了解 XLA 才能预测或理解。我们应该发布这样的 API 吗?还是有其他选择?

好吧,答案可能一直摆在我们面前。

让我们再次完成整个练习,除了,让我们编写有状态版本。这意味着我们每个自定义的异步操作现在都对 Ref 而不是值进行操作。

def ppermute_start_stateful(x_ref, y_ref) -> tuple[Semaphore, Semaphore]:
  ...

def ppermute_done_stateful(send_sem, recv_sem, x_ref, y_ref) -> None:
  ...

让我们假设我们可以在 Pallas 中实现这些操作,并查看我们的新程序是什么样的。让我们从一个基本的集体置换开始

def f(x):
  x_ref = make_ref(x)
  y_ref = make_ref(zeros_like(x))
  fut = ppermute_start_stateful(x_ref, y_ref)
  ppermute_done_stateful(*fut, x_ref, y_ref)
  return y_ref[...]

它比我们原来的基于值的版本稍微冗长一些,但它有一些关键的区别。第一个区别是,我们创建了一个“空”Ref 来接收 ppermute 的结果,这与基于值的版本不同,后者会为我们创建一个值。一个巧妙的事情是 x_ref 的生命周期在这里很清楚:它一直存在到 ppermute_done_stateful。我们不需要像之前那样将 x 值“偷偷地”放到操作中。

当我们尝试在 start/done 之间添加一个操作时,另一个区别变得更加明显。

def f(x):
  x_ref = make_ref(x)
  y_ref = make_ref(zeros_like(x))
  fut = ppermute_start_stateful(x_ref, y_ref)
  x_ref[...] += 1
  ppermute_done_stateful(*fut, x_ref, y_ref)
  return y_ref[...]

之前,我们遇到了调度歧义,其中 XLA 可以相对于 ppermute 重新排序加法操作。使用有状态的语义,我们实际上添加了一个排序约束! x_ref[...] += 1 修改了 x_ref,因此它不能相对于 ppermute_done_stateful 移动。JAX 可以将这些调度约束作为降低到 HLO 的一部分注入。

最后一个关键区别在我们尝试循环示例时显而易见。

def f(x):
  x_ref = make_ref(x)
  y_ref = make_ref(zeros_like(x))
  def body(i, _):
    fut = ppermute_start_stateful(x_ref, y_ref)
    ppermute_done_stateful(*fut, x_ref, y_ref)
    # Now switch to y_ref -> x_ref
    fut = ppermute_start_stateful(y_ref, x_ref)
    ppermute_done_stateful(*fut, y_ref, x_ref)
  fori_loop(0, 8 // 2, body, None)
  return x_ref[...]

由于我们需要准备一个单独的缓冲区来接收 ppermute,我们被迫以展开的方式编写代码!在 XLA 中,无法编写需要复制的版本,因为这将涉及一个从 Ref 发送到自身的 ppermute,这实际上没有意义。

为了在不进行手动展开的情况下处理这种情况,我们将创建一个前导维度为 2 的临时缓冲区,该缓冲区在迭代中充当发送/接收目标,每次都切换。这与我们在编写手动重叠内核时在 Pallas 内核内部使用的模式相同。

这里的认识是,有状态迫使我们更早地处理值语义出现的问题。我们通过定义来消除它们!

  1. 调度 - 将 Ref 作为输入的有状态操作强制执行我们程序的顺序。请注意,这将调度同一 Ref 上的操作彼此相关。我们可能还需要一个 opt_barrier_stateful 来强制执行更多的排序约束。

  2. 生命周期 - Ref 的生命周期可以通过 run_state 来限定范围,也可以作为有状态操作的输入。

  3. 防御性复制 - 使用 Ref 迫使我们“手动”处理缓冲区分配,并且降级可以确保别名工作正常以避免任何复制。

另一个重要的基本限制是,我们最终会阶段性地输出一个 HLO 程序,其中活动缓冲区和信号量表示为数组值类型。XLA 不保证这些中间值的缓冲区生命周期或它们所在的内存空间。因此,即使 Pallas 内核正在主动将数据复制到数组值中,XLA 仍然可以复制它们。 这在 HLO 中很容易验证,但这是使用自定义调用来表示 HLO 中异步操作的一个尖锐边缘。

结论#

我们已经讨论了在 Pallas 和 JAX 中处理异步操作时遇到的一些棘手挑战。Ref 似乎是一种很有前景的方式来表示这些操作,它可以规避值语义带来的一些问题。但是,缺点是它将有状态的 JAX 置于核心位置,而我们之前在 Pallas 之外还没有这样做过。值得思考的是,我们是否应该教育用户了解有状态的操作,或者提供一个更危险的 API。我们也不知道我们想要做的一切是否都可以通过 Ref 来表达。我们还应该集思广益,寻找状态的替代方案,以充实设计空间。例如,如果 XLA 提供了一个尊重生命周期的一流期货 API,并且它可以自动执行诸如在其中带有期货的循环的双缓冲之类的操作,该怎么办?这可能是一个可行的替代方案,但权衡将是给予编译器更多的控制权,而不是用户明确的控制权。