Pallas 异步操作#

背景 + 动机#

我们希望在 Pallas 中公开 API,以实现在多个内核之间明确地重叠计算和通信。

XLA 异步分解#

作为动机,考虑以下 JAX 伪代码

def f(x):
  y = ppermute(x)
  z = x + 1
  return y, z

在这个函数中,我们可以同时执行 ppermutex + 1。这是 XLA 通过以下方式自动进行的优化:

  1. ppermute 分解为 ppermute_startppermute_done 操作,它们通过一个 Future 连接。

  2. ppermute_startppermute_done 之间调度 x + 1

得到以下程序:

def f(x):
  fut = ppermute_start(x)
  z = x + 1  # happens at the same time as ppermute
  y = ppermute_done(fut)
  return y, z

内核中的异步操作#

现在想象我们不使用 XLA 的 ppermute,而是使用我们自己的自定义 Pallas ppermute

def ppermute_kernel(x_ref, y_ref, send_sem, recv_sem):
  right_neighbor = ...
  descriptor = pltpu.make_async_remote_copy(x_ref, y_ref, send_sem, recv_sem, device_id=right_neighbor)
  descriptor.start()
  descriptor.wait_send()
  descriptor.wait_recv()

def ppermute(x):
  return pl.pallas_call(ppermute_kernel, out_shape=x, ...)(x)

目前,我们无法像 XLA 那样将 ppermute 分解为 start/done 对,因此我们明确地将 x + 1 融合到内核中。

def add_one(x_ref, z_ref):
  z_ref[...] = x_ref[...] + 1

def ppermute_add_one_kernel(x_ref, y_ref, z_ref, send_sem, recv_sem):
  right_neighbor = ...
  descriptor = pltpu.make_async_remote_copy(x_ref, y_ref, send_sem, recv_sem, device_id=right_neighbor)
  descriptor.start()

  # Explicitly schedule inner kernel between start/wait
  pltpu.emit_pipeline(add_one)(x_ref, z_ref)

  descriptor.wait_send()
  descriptor.wait_recv()

def ppermute_and_add_one(x):
  return pl.pallas_call(ppermute_add_one_kernel, out_shape=(x, x), ...)(x)

目标是能够为启动 ppermute 和等待其完成编写单独的内核,这样我们就可以在两者之间使用常规的 x + 1(或我们想要的任何计算)。这使得代码更具可读性、可维护性,并且减少了出错的可能性。

我们如何在 TPU 上实现分解的 Pallas 异步操作?#

在 Pallas 中实现分解的异步操作时,主要需要弄清楚的是它们之间传递的 future 包含什么。具体来说,它必须包含关于后台操作的一些重要状态。

如果我们查看 Pallas 代码,可以看到我们需要一个“描述符”来启动并等待远程复制。我们能否将此描述符从 Pallas 内核中引出,然后将其传递给另一个内核?某种程度上可以。底层 TPU 硬件通过一对信号量跟踪异步操作的进度:send_sem 让我们能够等待设备完成向其邻居发送数据,而 recv_sem 则跟踪从邻居发送到设备的数据传输。如果我们设想编写一个启动内核和一个完成内核,那么从启动到完成所需传递的就只有信号量以及关于需要等待这些信号量多久的一些信息。

我们可以通过扩展 Pallas 来支持从内核返回信号量来实现这一点。

def ppermute_start_kernel(
    in_ref, send_sem, recv_sem, out_ref, *, axis_name,
):
  axis_size = jax.lax.psum(1, axis_name)
  left_neighbor = jax.lax.rem(
      jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size
  )
  right_neighbor = jax.lax.rem(jax.lax.axis_index(axis_name) + 1, axis_size)
  barrier_sem = pltpu.get_barrier_semaphore()
  pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor)
  pltpu.semaphore_wait(barrier_sem, 1)
  pltpu.make_async_remote_copy(
      in_ref, out_ref, send_sem, recv_sem, device_id=right_neighbor
  ).start()

def ppermute_start(x, *, axis_name) -> tuple[Semaphore, Semaphore, Array]:
  send_sem, recv_sem, out = pl.pallas_call(
      functools.partial(ppermute_start_kernel, axis_name=axis_name),
      out_shape=(
          pltpu.SemaphoreType.DMA(()),
          pltpu.SemaphoreType.DMA(()),
          jax.ShapeDtypeStruct(
              x.shape,
              dtype=x.dtype,
          ),
      ),
      in_specs=[
          pl.BlockSpec(memory_space=pltpu.ANY),
      ],
      out_specs=(
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
          pl.BlockSpec(memory_space=pltpu.ANY),
      ),
  )(x)
  return send_sem, recv_sem, out

请注意,这里发生了一些微妙的事情。Pallas 正在告诉 XLA,它希望某些输出是信号量(也称为同步标志),XLA 会将它们视为“保留的”(例如,当它们在 XLA 程序中存活时,这些同步标志不能被其他内核分配)。它们的行为类似于屏障信号量,后者是 XLA 管理的保留信号量。

另一个需要注意的地方是,我们从启动内核返回输出缓冲区 out 同时它正在被主动复制

现在我们编写执行阻塞操作的 done 内核。我们将 out 传递到内核中,以计算阻塞信号量所需的形状。

def ppermute_done_kernel(ref, send_sem, recv_sem, _):
  pltpu.make_async_copy(ref, ref, send_sem).wait()
  pltpu.make_async_copy(ref, ref, recv_sem).wait()

def ppermute_done(send_sem, recv_sem, out) ->Array:
  out = pl.pallas_call(
      ppermute_done_kernel,
      out_shape=(
          jax.ShapeDtypeStruct(
              out.shape,
              dtype=out.dtype,
          ),
      ),
      in_specs=[
          pl.BlockSpec(memory_space=pltpu.ANY),
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
      ],
      out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
      input_output_aliases={0:0}
  )(out, send_sem, recv_sem)
  return out

注意:我们在这里对输出缓冲区进行了 I/O 别名,以确保消费者位于 ppermute_done 的下游。

我们现在可以实现分解的集体置换。

def f(x):
  fut = ppermute_start(x)
  z = x + 1  # happens at the same time as ppermute
  y = ppermute_done(fut)
  return y, z

真的可以吗?

为什么这不起作用#

这个问题还有三个方面,每个方面在某种程度上都存在于 Pallas 之外。以下是它们的高层概述。

  1. 调度 - 仅仅因为我们编写了 ppermute_start,然后是 x + 1,再然后是 ppermute_done,并不能保证它们会按照这个顺序发生。XLA 负责调度,所以当我们编写 JAX 程序时,我们建立的是 XLA 会尊重的数据依赖关系,但 XLA 不会尊重在 JAX 中编写的特定操作顺序

  2. 生命周期 - XLA 假定一旦一个值在依赖图中超出作用域,它的内存就可以被释放供其他值使用。如果一个操作异步地将 x 复制到 y,我们需要确保 x 在复制完成之前一直处于活跃状态,否则我们将从垃圾内存中进行复制。

  3. 防御性复制 - XLA 保留创建值副本的权利。我们需要确保不引入不必要的副本,以 (a) 避免不必要的运行时开销,以及 (b) 确保正确性。

我们将逐一介绍这些问题并提出解决方案。

调度#

我们如何在 JAX 中明确强制操作按照特定顺序发生?请注意,这不是 Pallas 特有的问题,即使我们使用替代方法实现了异步操作,仍然会遇到这个问题。

一种方法是在 XLA 程序中引入优化屏障。优化屏障将阻止 XLA 移动其周围的操作。

这是我们最初的代码:

def f(x):
  fut = ppermute_start(x)
  z = x + 1
  y = ppermute_done(fut)
  return y, z

XLA 可以选择在以下三个位置中的任何一个执行 x + 1

def f(x):
  z = x + 1
  fut = ppermute_start(x)
  y = ppermute_done(fut)
  return y, z

# OR

def f(x):
  fut = ppermute_start(x)
  z = x + 1
  y = ppermute_done(fut)
  return y, z

# OR

def f(x):
  fut = ppermute_start(x)
  y = ppermute_done(fut)
  z = x + 1
  return y, z

为了强制 x + 1ppermute 操作之间发生,我们可以使用 optimization_barrier,它在语义上是恒等函数(即 lambda x: x),但会在值之间引入显式数据依赖。具体来说,如果我们让 x + 1 中使用的 x 依赖于 ppermute_start 返回的 fut,那么它必须在 ppermute_start 之后发生。

我们还引入了一个依赖关系,强制输出值 y 依赖于 z

def f(x):
  fut = ppermute_start(x)
  x, fut = optimization_barrier((x, fut))  # x now depends on fut
  z = x + 1
  z, fut = optimization_barrier((z, fut)) # fut now depends on z
  y = ppermute_done(fut)
  return y, z

optimization_barrier 对我们来说是一个足够好的工具,可以显式地编写调度。

生命周期#

让我们再次看看我们最初的代码,并假设操作以正确的顺序发生。

def f(x):
  fut = ppermute_start(x)
  z = x + 1
  y = ppermute_done(fut)
  return y, z

让我们看看在程序的哪个点,XLA 认为可以释放 x 的缓冲区。这将是 x 不再使用的点之后,特别是 z = x + 1 之后。

def f(x):
  fut = ppermute_start(x)
  z = x + 1
  # XLA can free x here!
  y = ppermute_done(fut)
  return y, z

如果 XLA 在 z = x + 1 完成后释放 x,我们就会遇到一个非常严重的问题。ppermute 可能仍在 z = x + 1 之后主动将 x 复制到邻居。这意味着如果 x 被释放,ppermute 将会从垃圾内存中读取!

我们如何将 x 的生命周期扩展到 ppermute_done?我们可以引入一个数据依赖!我们需要稍微修改一下我们的内核来实现这一点。

首先,我们重写 ppermute_start 以返回 x,通过内核对其进行别名。

def ppermute_start_kernel(
    in_ref, send_sem, recv_sem, out_ref, _, *, axis_name,
):
  axis_size = jax.lax.psum(1, axis_name)
  left_neighbor = jax.lax.rem(
      jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size
  )
  right_neighbor = jax.lax.rem(jax.lax.axis_index(axis_name) + 1, axis_size)
  barrier_sem = pltpu.get_barrier_semaphore()
  pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor)
  pltpu.semaphore_wait(barrier_sem, 1)
  pltpu.make_async_remote_copy(
      in_ref, out_ref, send_sem, recv_sem, device_id=right_neighbor
  ).start()

def ppermute_start(x, *, axis_name) -> tuple[Semaphore, Semaphore, Array, Array]:
  send_sem, recv_sem, x, out = pl.pallas_call(
      functools.partial(ppermute_start_kernel, axis_name=axis_name),
      out_shape=(
          pltpu.SemaphoreType.DMA(()),
          pltpu.SemaphoreType.DMA(()),
          jax.ShapeDtypeStruct(
              x.shape,
              dtype=x.dtype,
          ),
	   jax.ShapeDtypeStruct(
              x.shape,
              dtype=x.dtype,
          ),
      ),
      in_specs=[
          pl.BlockSpec(memory_space=pltpu.ANY),
      ],
      out_specs=(
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
          pl.BlockSpec(memory_space=pltpu.ANY),
          pl.BlockSpec(memory_space=pltpu.ANY),
      ),
      input_output_aliases={0:2}
  )(x)
  return send_sem, recv_sem, x, out

然后我们让 ppermute_done 接收 x 但不对其进行任何操作。

def ppermute_done_kernel(_, ref, send_sem, recv_sem, _):
  pltpu.make_async_copy(ref, ref, send_sem).wait()
  pltpu.make_async_copy(ref, ref, recv_sem).wait()

def ppermute_done(send_sem, recv_sem, x, out) ->Array:
  out = pl.pallas_call(
      ppermute_done_kernel,
      out_shape=(
          jax.ShapeDtypeStruct(
              out.shape,
              dtype=out.dtype,
          ),
      ),
      in_specs=[
          pl.BlockSpec(memory_space=pltpu.ANY),
          pl.BlockSpec(memory_space=pltpu.ANY),
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
      ],
      out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
      input_output_aliases={1:0}
  )(x, out, send_sem, recv_sem)
  return out

现在当我们编写时:

def f(x):
  *sems, x ,out = ppermute_start(x)
  z = x + 1
  y = ppermute_done(*sems, x, out)
  return y, z

XLA 不能再释放 x,因为它是 ppermute_done 的输入!这意味着 x 的生命周期与 ppermute 绑定,这段代码现在是正确的。

防御性复制#

XLA 在其缓冲区分配过程中,分析哪些缓冲区相互别名,并在别名其某个输入的某个操作不是该输入的最终消费者时插入复制。

背景#

这是一个简单的例子。假设我们有一个操作 add_one_inplace,它接收一个数组并加一,但承诺原地执行。

以下代码是合法的。

def f():
  x = jnp.arange(...)
  y = add_one_inplace(x)  return y

然而,如果 x 也有一个单独的消费者,程序可能无法正确执行。

def f():
  x = jnp.arange(...)
  y = add_one_inplace(x)
  return y, x * 2 # another x consumer!

这是因为 x * 2 对原始的 x 进行操作,而 add_one_inplace 会覆盖 x 中的值。x * 2 需要确保读取 x 的原始值,而不是我们将其加 1 之后的值。XLA 注意到这一点并插入了一个 copy 操作(它在语义上是恒等操作,但输入和输出缓冲区将不同)。

def f(x):
  x2 = copy(x)
  y = add_one_inplace(x2)
  return y, x * 2

XLA 中的这一过程通过强制执行原地更新的操作实际上是异地(out-of-place)的 copy 操作来确保正确性。

带有下游操作的复制#

让我们重新审视在 ppermute 时加 1 的例子。

def f(x):
  fut = ppermute_start(x)
  z = x + 1
  y = ppermute_done(fut)
  return y, z

如果我们将 Future 分解为其组成部分,我们将看到别名模式:

def f(x):
  *sems, x2, y = ppermute_start(x)
  z = x + 1
  y = ppermute_done((*sems, x2, y))
  return y, z

我们知道 ppermute_start 不会改变 x(也就是说,xx2 相同),但 XLA 不知道。事实上,在 XLA 看来,它就像我们的 add_one_inplace 示例,它保守地假定 ppermute_start 改变了 x,而 x2 是新的别名结果。因此,当我们执行 z = x + 1 时,我们遇到了原始缓冲区的消费者。XLA 因此引入了一个复制!

def f(x):
  x2 = copy(x)
  *sems, x2, y = ppermute_start(x2)
  z = x + 1
  y = ppermute_done((*sems, x2, y))
  return y, z

这个复制是不必要的,因为我们知道 x2x 相比没有改变。为了消除这个复制,我们需要某种机制来通知 XLA 我们只是在转发一个值。然而,在没有这种机制的情况下,我们可以稍微重写我们的程序,明确使用 x2 而不是 x

def f(x):
  *sems, x2, y = ppermute_start(x)
  z = x2 + 1
  y = ppermute_done((*sems, x2, y))
  return y, z

现在,XLA 看不到 x 的独立消费者,因此不再引入复制。然而,这带来了一个主要缺点,因为它迫使我们解包来自 ppermute_start 的 Future。它将生命周期问题与复制问题耦合在一起。

循环别名#

让我们考虑一个稍微高级一点的例子。我们将实现一个使用 while_loopppermute 在环中发送值的函数。

def f(x):
  def body(i, x):
    fut = ppermute_start(x)
    y = ppermute_done(fut)
    return y
  return fori_loop(0, 8, body, x)

fori_loop 的一个实现细节是输入和输出缓冲区会自动相互别名。请注意,我们在 ppermute_startppermute_done 操作中设置了一些额外的别名。让我们通过对程序中的每个值进行着色来运行我们自己的“缓冲区分配”,以确定我们需要多少个唯一的缓冲区。

首先,我们将解包包含别名 xout 缓冲区的 fut 元组。

def f(x):
  def body(i, x):
    *sems, x, y = ppermute_start(x)
    y = ppermute_done(*sems, x, y)
    return y
  return fori_loop(0, 8, body, x)

现在让我们根据每个值被分配的唯一缓冲区进行着色。我们有来自 fori_loop 的输入/输出别名,来自 ppermute_startx 别名,以及来自 ppermute_doney 别名。

def f(x):
  def body(i, x):
    *sems, x, y = ppermute_start(x)
    y = ppermute_done((*sems, x, y))
    return y
  return fori_loop(0, 8, body, x)

如果你运行别名分析,你会发现所有的缓冲区都被着色为相同的颜色!直观地说,这是有问题的,因为如果我们正在进行 ppermute 循环,我们不能写入我们正在发送到的同一个缓冲区。我们通常需要一个额外的(即“双重”)缓冲区来接收,然后通常会在下一次迭代中切换发送/接收缓冲区。XLA 在实践中会做的是,它会观察到缓冲区的重用,并防御性地插入一个复制。

def f(x):
  def body(i, x):
    x = copy(x)
    *sems, x, y = ppermute_start(x)
    y = ppermute_done((*sems, x, y))
    return y
  return fori_loop(0, 8, body, x)

这个复制意味着 xy 不再相互别名,程序将是正确的。然而,我们需要这个复制吗?我们如何引入双缓冲区以避免每次迭代的昂贵复制?答案是展开!

我们将手动展开我们的代码。

def f(x):
  def body(i, x):
    *sems, x, x2 = ppermute_start(x)
    x2 = ppermute_done((*sems, x, x2))
    
    *sems, x2, y = ppermute_start(x2)
    y = ppermute_done((*sems, x2, y))
    return y
  return fori_loop(0, 4, body, x)

现在,如果我们运行相同的别名分析,我们会发现所有缓冲区不再相互别名,并且我们不需要插入防御性复制也能保持正确。

因此,消除这些复制的简单解决方案是使用 fori_loop,并设置 unroll >= 2

def f(x):
  def body(i, x):
    fut = ppermute_start(x)
    y = ppermute_done(fut)
    return y
  return fori_loop(0, 8, body, x, unroll=2)

这足以实现在没有额外复制的情况下实现这个循环!

跨循环边界传递 Future#

现在我们来看一个更高级的例子。我们将实现与之前相同的程序,但错开循环,在循环之前的前言中开始 ppermute,并在循环开始时等待 ppermute 完成。

def f(x):
  fut = ppermute_start(x)
  def body(i, fut):
    x = ppermute_done(fut)
    fut = ppermute_start(x)
    return fut
  fut = fori_loop(0, 7, body, fut)
  return ppermute_done(fut)

在这个例子中,我们不是将值 x 从一个循环传递到另一个循环,而是传递一个 Future 值。

让我们再次解包 Future,看看发生了什么。

def f(x):
  fut = ppermute_start(x)
  def body(i, fut):
    *sems, x, out = fut
    x = ppermute_done((*sems, x, out))
    (*sems, x, out) = ppermute_start(x)
    return (*sems, x, out)
  (*sems, x, out) = fori_loop(0, 7, body, x)
  return ppermute_done((*sems, x, out))

因此,我们明确地将信号量、输入缓冲区和目标输出缓冲区作为循环传递。如果现在运行别名分析会怎样?嗯,我们会遇到与前一节中相同的别名问题,其中 xout 将相互别名。XLA 将引入一个复制。

def f(x):
  fut = ppermute_start(x)
  def body(i, fut):
    *sems, x, out = fut
    out = copy(out)
    x = ppermute_done((*sems, x, out))
    (*sems, x, out) = ppermute_start(x)
    return (*sems, x, out)
  (*sems, x, out) = fori_loop(0, 7, body, x)
  return ppermute_done((*sems, x, out))

在这种情况下,我们在 out 上插入了一个复制。然而,这是一个非常糟糕的场景,因为 out 正在被主动复制!即使我们在 x 上插入一个复制,我们也会遇到问题,因为那时 x 的生命周期不会延伸到 ppermute_done。这非常非常糟糕!我们不仅会得到复制,还会得到不正确的结果!

解决方案,如我们之前所观察到的,是通过展开避免所有缓冲区的别名来避免复制。因此,如果我们这样做:

def f(x):
  fut = ppermute_start(x)
  def body(i, fut):
    x = ppermute_done(fut)
    fut = ppermute_start(x)
    return fut
  fut = fori_loop(0, 7, body, x, unroll=2)
  return ppermute_done(fut)

我们的程序现在应该正确。

综合起来#

所以我们总结了一些经验法则:

  1. 如果我们有依赖于 ppermute 输入值的操作,解包 Future 以使用别名值而不是原始值。

  2. 在循环体中进行 ppermute 操作时,使用 unroll >= 2

让我们将所有内容组合到一个函数中,该函数在循环中执行 ppermute 并累积结果。

def f(x):
  out = jnp.zeros_like(x)
  fut = (*sems, x, out) = ppermute_start(x)
  out = out + x
  def body(i, carry):
    out, fut = carry
    x = ppermute_done(fut)
    fut = (*sems, x, out) = ppermute_start(x)
    out = out + x
    return out, fut
  out, fut = fori_loop(0, 7, body, (out, fut), unroll=2)
  return out, ppermute_done(fut)

请注意,在此示例中,我们不需要 optimization_barrier,因为循环边界充当调度屏障,将 startdone 分开。

就这样,我们完成了!这将是 Pallas 中进行异步操作的官方 API。谢谢大家!任务完成!

真的吗?

状态的复仇#

尽管我们似乎通过一些巧妙的技巧解决了复制和不正确的问题,但我们仍然处于一个尴尬的境地。这个 API 功能强大,但有许多许多陷阱和注意事项。可能还有更多需要处理的边缘情况,甚至需要对 XLA 有深入的了解才能预测或理解。我们应该发布这样的 API 吗?或者有没有替代方案?

嗯,答案可能一直就在我们眼前。

让我们再完整地演练一遍,除了,这次我们编写有状态的版本。这意味着我们的每个自定义异步操作现在都作用于 Ref 而不是值。

def ppermute_start_stateful(x_ref, y_ref) -> tuple[Semaphore, Semaphore]:
  ...

def ppermute_done_stateful(send_sem, recv_sem, x_ref, y_ref) -> None:
  ...

让我们假设我们可以在 Pallas 中实现这些,看看我们的新程序会是什么样子。我们从一个基本的集体置换开始:

def f(x):
  x_ref = make_ref(x)
  y_ref = make_ref(zeros_like(x))
  fut = ppermute_start_stateful(x_ref, y_ref)
  ppermute_done_stateful(*fut, x_ref, y_ref)
  return y_ref[...]

它比我们最初基于值的版本更冗长一些,但它有一些关键区别。首先,我们创建一个“空” Ref 来接收 ppermute 的结果,这与基于值的版本不同,后者会为我们创建一个值。一个巧妙之处在于 x_ref 的生命周期在这里很明确:它会一直存在直到 ppermute_done_stateful。我们不再需要像以前那样将 x 值“偷偷”地传入操作中。

当我们尝试在 start/done 之间添加一个操作时,另一个区别变得更加明显。

def f(x):
  x_ref = make_ref(x)
  y_ref = make_ref(zeros_like(x))
  fut = ppermute_start_stateful(x_ref, y_ref)
  x_ref[...] += 1
  ppermute_done_stateful(*fut, x_ref, y_ref)
  return y_ref[...]

之前,我们遇到了调度歧义,XLA 可能会相对于 ppermute 重新排序加法操作。而有了有状态的语义,我们实际上增加了排序约束!x_ref[...] += 1 会修改 x_ref,因此它不能相对于 ppermute_done_stateful 移动。JAX 可以将这些调度约束作为降低到 HLO 的一部分注入。

最后一个关键区别在尝试我们的循环示例时很明显。

def f(x):
  x_ref = make_ref(x)
  y_ref = make_ref(zeros_like(x))
  def body(i, _):
    fut = ppermute_start_stateful(x_ref, y_ref)
    ppermute_done_stateful(*fut, x_ref, y_ref)
    # Now switch to y_ref -> x_ref
    fut = ppermute_start_stateful(y_ref, x_ref)
    ppermute_done_stateful(*fut, y_ref, x_ref)
  fori_loop(0, 8 // 2, body, None)
  return x_ref[...]

由于我们需要一个单独的缓冲区来接收 ppermute 的结果,我们被迫以展开的方式编写代码!无法在 XLA 中编写需要复制的版本,因为那将涉及一个 ppermute,它从 Ref 发送到自身,这并没有真正的意义。

要在不手动展开的情况下处理这个问题,我们将创建一个带有前导 2 维度的临时缓冲区,该缓冲区在迭代中充当发送/接收目标,并在每次迭代中进行切换。这与我们在 Pallas 内核内部编写手动重叠内核时使用的模式相同。

这里的认识是,有状态迫使我们更早地处理许多值语义中出现的问题。我们通过定义来消除它们!

  1. 调度 - 以 Ref 作为输入的有状态操作强制程序排序。请注意,这将安排相同 Ref 上的操作相互关联。我们可能还需要一个 opt_barrier_stateful 来强制执行更多排序约束。

  2. 生命周期 - Ref 的生命周期可以通过 run_state 进行作用域,或者作为有状态操作的输入。

  3. 防御性复制 - 使用 Ref 迫使我们“手动”处理缓冲区分配,并且降低过程可以确保别名正确,从而避免任何复制。

另一个重要的基本限制是,我们最终会生成一个 HLO 程序,其中活跃缓冲区和信号量被表示为数组值类型。XLA 不保证这些中间值的缓冲区生命周期或它们所在的内存空间。因此,即使 Pallas 内核正在主动复制数组值,XLA 也有可能复制这些数组值。这在 HLO 中很容易验证,但它是使用自定义调用在 HLO 中表示异步操作的一个尖锐问题。

结论#

我们已经探讨了 Pallas 和 JAX 中异步操作的一些棘手挑战。Ref 似乎是一种很有前途的方式来表示这些操作,它规避了值语义带来的一些问题。然而,一个缺点是它将有状态的 JAX 推到了核心位置,而我们在 Pallas 之外尚未这样做。值得思考我们是否应该向用户介绍有状态操作,或者提供一个更危险的 API。我们也不知道我们想做的一切是否都可以通过 Ref 来表达。我们还应该集思广益,寻找状态的替代方案,以充实设计空间。例如,如果 XLA 提供一个尊重生命周期的一等 Future API,并且可以自动处理像带有 Future 的双缓冲循环这样的事情会怎样?这可能是一个可行的替代方案,但权衡是给予编译器更多控制权,而不是用户显式控制。