Pallas 异步操作#

背景 + 动机#

我们希望在 Pallas 中公开 API,以显式地重叠跨多个内核的计算和通信。

XLA 异步分解#

作为动机,请考虑以下 JAX 伪代码

def f(x):
  y = ppermute(x)
  z = x + 1
  return y, z

在这个函数中,我们可以并行执行 ppermutex + 1。这是 XLA 自动执行的优化,通过

  1. ppermute 分解为 ppermute_startppermute_done 操作,它们通过 future 连接。

  2. ppermute_startppermute_done 之间调度 x + 1

从而产生以下程序

def f(x):
  fut = ppermute_start(x)
  z = x + 1  # happens at the same time as ppermute
  y = ppermute_done(fut)
  return y, z

内核内的异步操作#

现在想象一下,我们没有使用 XLA 的 ppermute,而是有我们自己的自定义 Pallas ppermute

def ppermute_kernel(x_ref, y_ref, send_sem, recv_sem):
  right_neighbor = ...
  descriptor = pltpu.make_async_remote_copy(x_ref, y_ref, send_sem, recv_sem, device_id=right_neighbor)
  descriptor.start()
  descriptor.wait_send()
  descriptor.wait_recv()

def ppermute(x):
  return pl.pallas_call(ppermute_kernel, out_shape=x, ...)(x)

目前,我们无法像 XLA 那样将 ppermute 分解为 start/done 对,因此我们显式地融合 x + 1 到内核中。

def add_one(x_ref, z_ref):
  z_ref[...] = x_ref[...] + 1

def ppermute_add_one_kernel(x_ref, y_ref, z_ref, send_sem, recv_sem):
  right_neighbor = ...
  descriptor = pltpu.make_async_remote_copy(x_ref, y_ref, send_sem, recv_sem, device_id=right_neighbor)
  descriptor.start()

  # Explicitly schedule inner kernel between start/wait
  pltpu.emit_pipeline(add_one)(x_ref, z_ref)

  descriptor.wait_send()
  descriptor.wait_recv()

def ppermute_and_add_one(x):
  return pl.pallas_call(ppermute_add_one_kernel, out_shape=(x, x), ...)(x)

目标是能够为启动 ppermute 和等待其完成编写单独的内核,以便我们可以在两者之间使用常规的 x + 1(或任何我们想要的计算)。这使代码更具可读性、可维护性,并且不易出错。

我们如何在 TPU 上实现分解的 Pallas 异步操作?#

在 Pallas 中实现分解的异步操作时,主要需要弄清楚的是在它们之间传递的 future 包含什么。具体来说,它必须包含有关后台操作的一些重要状态。

如果我们查看 Pallas 代码,我们可以看到我们需要一个“描述符”来启动和等待远程复制。我们可以将这个描述符从 Pallas 内核中导出,然后将其传递到另一个内核中吗?嗯,有点可以。底层 TPU 硬件通过一对信号量跟踪异步操作进度:send_sem 使我们能够等待设备何时完成向其邻居发送数据,而 recv_sem 跟踪从其邻居发送到设备的数据传输。如果我们想象编写一个 start 内核和一个 done 内核,那么我们需要从 start 传递到 done 的只是信号量以及有关在这些信号量上等待多少的信息。

我们可以通过扩展 Pallas 以支持从内核返回信号量来实现这一点。

def ppermute_start_kernel(
    in_ref, send_sem, recv_sem, out_ref, *, axis_name,
):
  axis_size = jax.lax.psum(1, axis_name)
  left_neighbor = jax.lax.rem(
      jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size
  )
  right_neighbor = jax.lax.rem(jax.lax.axis_index(axis_name) + 1, axis_size)
  barrier_sem = pltpu.get_barrier_semaphore()
  pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor)
  pltpu.semaphore_wait(barrier_sem, 1)
  pltpu.make_async_remote_copy(
      in_ref, out_ref, send_sem, recv_sem, device_id=right_neighbor
  ).start()

def ppermute_start(x, *, axis_name) -> tuple[Semaphore, Semaphore, Array]:
  send_sem, recv_sem, out = pl.pallas_call(
      functools.partial(ppermute_start_kernel, axis_name=axis_name),
      out_shape=(
          pltpu.SemaphoreType.DMA(()),
          pltpu.SemaphoreType.DMA(()),
          jax.ShapeDtypeStruct(
              x.shape,
              dtype=x.dtype,
          ),
      ),
      in_specs=[
          pl.BlockSpec(memory_space=pltpu.ANY),
      ],
      out_specs=(
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
          pl.BlockSpec(memory_space=pltpu.ANY),
      ),
  )(x)
  return send_sem, recv_sem, out

请注意,这里正在发生一些微妙的事情。Pallas 正在告诉 XLA 它希望某些输出是信号量(又名同步标志),XLA 会将它们视为“保留”(例如,当它们在 XLA 程序中存活时,这些同步标志不能被其他内核分配)。它们的行为类似于屏障信号量,屏障信号量是由 XLA 管理的保留信号量。

另一个需要注意的事情是,我们从 start 内核返回输出缓冲区 out同时它正在被主动复制到其中

现在我们编写执行阻塞操作的 done 内核。我们将 out 传递到内核中,以计算阻塞信号量所需的形状。

def ppermute_done_kernel(ref, send_sem, recv_sem, _):
  pltpu.make_async_copy(ref, ref, send_sem).wait()
  pltpu.make_async_copy(ref, ref, recv_sem).wait()

def ppermute_done(send_sem, recv_sem, out) ->Array:
  out = pl.pallas_call(
      ppermute_done_kernel,
      out_shape=(
          jax.ShapeDtypeStruct(
              out.shape,
              dtype=out.dtype,
          ),
      ),
      in_specs=[
          pl.BlockSpec(memory_space=pltpu.ANY),
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
      ],
      out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
      input_output_aliases={0:0}
  )(out, send_sem, recv_sem)
  return out

注意:我们在此处 i/o 别名输出缓冲区,以保证消费者是 ppermute_done 的下游。

我们现在可以实现分解的集体 permute。

def f(x):
  fut = ppermute_start(x)
  z = x + 1  # happens at the same time as ppermute
  y = ppermute_done(fut)
  return y, z

或者我们可以吗?

为什么起作用?#

这仍然存在三个问题,每个问题都在某种程度上存在于 Pallas 之外。以下是它们的高级概述。

  1. 调度 - 仅仅因为我们编写了 ppermute_start,然后 x + 1,然后 ppermute_done,并不能保证它们会按该顺序发生。XLA 负责调度,因此当我们编写 JAX 程序时,我们正在设置 XLA 将尊重的数据依赖关系,但 XLA 不会尊重 JAX 中编写的特定操作顺序。

  2. 生命周期 - XLA 假设一旦值超出依赖关系图的范围,它的内存就可以被释放以供其他值使用。如果我们有一个异步复制 x -> y 的操作,我们需要确保 x 在复制完成之前是存活的,否则我们将从垃圾内存中复制。

  3. 防御性复制 - XLA 保留创建值副本的权利。我们需要确保我们不会引入不必要的副本,a) 避免不必要的运行时开销,以及 b) 确保正确性。

我们将逐一讨论这些问题并提出修复方案。

调度#

我们如何在 JAX 中显式地强制操作以特定顺序发生?请注意,这不是 Pallas 特有的问题,如果我们使用替代方法实现了异步操作,我们仍然会遇到这个问题。

一种方法是在 XLA 程序中引入优化屏障。优化屏障将阻止 XLA 移动其周围的操作。

这是我们的原始代码

def f(x):
  fut = ppermute_start(x)
  z = x + 1
  y = ppermute_done(fut)
  return y, z

XLA 可以选择在以下三个位置中的任何一个执行 x + 1

def f(x):
  z = x + 1
  fut = ppermute_start(x)
  y = ppermute_done(fut)
  return y, z

# OR

def f(x):
  fut = ppermute_start(x)
  z = x + 1
  y = ppermute_done(fut)
  return y, z

# OR

def f(x):
  fut = ppermute_start(x)
  y = ppermute_done(fut)
  z = x + 1
  return y, z

为了强制 x + 1ppermute 操作之间发生,我们可以使用 optimization_barrier,它在语义上是恒等函数(即 lambda x: x),但在值之间引入了显式的数据依赖关系。具体来说,如果我们使 x + 1 中使用的 x 依赖于 ppermute_start 返回的 fut,则它必须在 ppermute_start 之后发生。

我们还引入了一个依赖关系,强制输出值 y 依赖于 z

def f(x):
  fut = ppermute_start(x)
  x, fut = optimization_barrier((x, fut))  # x now depends on fut
  z = x + 1
  z, fut = optimization_barrier((z, fut)) # fut now depends on z
  y = ppermute_done(fut)
  return y, z

optimization_barrier 对于我们显式写出调度来说是一个足够好的工具。

生命周期#

让我们再次查看我们的原始代码,并假设操作以正确的顺序发生。

def f(x):
  fut = ppermute_start(x)
  z = x + 1
  y = ppermute_done(fut)
  return y, z

让我们看看 XLA 认为可以释放 x 的缓冲区的程序中的哪个点。它将是 x 不再使用之后的点,具体来说是在 z = x + 1 之后。

def f(x):
  fut = ppermute_start(x)
  z = x + 1
  # XLA can free x here!
  y = ppermute_done(fut)
  return y, z

如果 XLA 在 z = x + 1 完成后释放 x,我们将遇到一个非常糟糕的问题。ppermute 可能仍然在 z = x + 1 之后主动将 x 复制到邻居,这意味着如果 x 被释放,ppermute 将从垃圾内存中读取!

我们如何将 x 的生命周期延长到 ppermute_done?嗯,我们可以引入数据依赖关系!我们需要稍微修改我们的内核才能实现这一点。

首先,我们重写 ppermute_start 以返回 x,通过内核对其进行别名。

def ppermute_start_kernel(
    in_ref, send_sem, recv_sem, out_ref, _, *, axis_name,
):
  axis_size = jax.lax.psum(1, axis_name)
  left_neighbor = jax.lax.rem(
      jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size
  )
  right_neighbor = jax.lax.rem(jax.lax.axis_index(axis_name) + 1, axis_size)
  barrier_sem = pltpu.get_barrier_semaphore()
  pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor)
  pltpu.semaphore_wait(barrier_sem, 1)
  pltpu.make_async_remote_copy(
      in_ref, out_ref, send_sem, recv_sem, device_id=right_neighbor
  ).start()

def ppermute_start(x, *, axis_name) -> tuple[Semaphore, Semaphore, Array, Array]:
  send_sem, recv_sem, x, out = pl.pallas_call(
      functools.partial(ppermute_start_kernel, axis_name=axis_name),
      out_shape=(
          pltpu.SemaphoreType.DMA(()),
          pltpu.SemaphoreType.DMA(()),
          jax.ShapeDtypeStruct(
              x.shape,
              dtype=x.dtype,
          ),
	   jax.ShapeDtypeStruct(
              x.shape,
              dtype=x.dtype,
          ),
      ),
      in_specs=[
          pl.BlockSpec(memory_space=pltpu.ANY),
      ],
      out_specs=(
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
          pl.BlockSpec(memory_space=pltpu.ANY),
          pl.BlockSpec(memory_space=pltpu.ANY),
      ),
      input_output_aliases={0:2}
  )(x)
  return send_sem, recv_sem, x, out

然后我们让 ppermute_done 接收 x 并且不对其执行任何操作。

def ppermute_done_kernel(_, ref, send_sem, recv_sem, _):
  pltpu.make_async_copy(ref, ref, send_sem).wait()
  pltpu.make_async_copy(ref, ref, recv_sem).wait()

def ppermute_done(send_sem, recv_sem, x, out) ->Array:
  out = pl.pallas_call(
      ppermute_done_kernel,
      out_shape=(
          jax.ShapeDtypeStruct(
              out.shape,
              dtype=out.dtype,
          ),
      ),
      in_specs=[
          pl.BlockSpec(memory_space=pltpu.ANY),
          pl.BlockSpec(memory_space=pltpu.ANY),
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
      ],
      out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
      input_output_aliases={1:0}
  )(x, out, send_sem, recv_sem)
  return out

现在当我们写

def f(x):
  *sems, x ,out = ppermute_start(x)
  z = x + 1
  y = ppermute_done(*sems, x, out)
  return y, z

XLA 不能再释放 x,因为它是 ppermute_done 的输入!这意味着 x 的生命周期与 ppermute 绑定,并且这段代码现在是正确的。

防御性复制#

XLA 在其缓冲区分配过程中,分析哪些缓冲区相互别名,并在操作别名其输入之一但不是该输入的最终消费者时插入副本。

背景#

这是一个简单的示例。假设我们有一个操作 add_one_inplace,它接收一个数组并加一,但承诺就地执行。

以下代码是合法的。

def f():
  x = jnp.arange(...)
  y = add_one_inplace(x)  return y

但是,如果 x 也有一个单独的消费者,则程序可能无法正确执行。

def f():
  x = jnp.arange(...)
  y = add_one_inplace(x)
  return y, x * 2 # another x consumer!

这是因为 x * 2 对原始 x 进行操作,但 add_one_inplace 会覆盖 x 中的值。x * 2 需要确保读取 x 的原始值,而不是在我们将其递增 1 之后的值。XLA 注意到这一点并插入一个 copy 操作(它在语义上是恒等,但输入和输出缓冲区将不同)。

def f(x):
  x2 = copy(x)
  y = add_one_inplace(x2)
  return y, x * 2

XLA 中的此过程通过强制它们有效地与 copy 操作一起成为异地操作,从而确保了在存在执行就地更新的操作时程序的正确性。

具有下游操作的复制#

让我们重新审视我们在 ppermute 时加 1 的示例。

def f(x):
  fut = ppermute_start(x)
  z = x + 1
  y = ppermute_done(fut)
  return y, z

如果我们解包 future 到其组件,我们将看到别名模式

def f(x):
  *sems, x2, y = ppermute_start(x)
  z = x + 1
  y = ppermute_done((*sems, x2, y))
  return y, z

我们知道 xppermute_start 中保持不变(也就是说,xx2 相同),但 XLA 并不知道。实际上,它看起来像我们的 add_one_inplace 示例对于 XLA,XLA 保守地假设 ppermute_start 突变了 x,而 x2 是新的别名结果。因此,当我们执行 z = x + 1 时,我们遇到了原始缓冲区的消费者。因此,XLA 引入了副本!

def f(x):
  x2 = copy(x)
  *sems, x2, y = ppermute_start(x2)
  z = x + 1
  y = ppermute_done((*sems, x2, y))
  return y, z

这个副本是不必要的,因为我们知道 x2x 相比没有改变。为了删除此副本,我们需要一些机制来告知 XLA 我们只是转发一个值。但是,在没有该机制的情况下,我们可以稍微重写我们的程序以显式使用 x2 而不是 x

def f(x):
  *sems, x2, y = ppermute_start(x)
  z = x2 + 1
  y = ppermute_done((*sems, x2, y))
  return y, z

现在,XLA 没有看到 x 的单独消费者,因此不再引入副本。但是,这有一个主要的缺点,因为它迫使我们解包来自 ppermute_start 的 future。它将生命周期问题与复制问题联系起来。

循环别名#

让我们考虑一个稍微高级的示例。让我们实现一个使用 while_loopppermute 在环中发送值的函数。

def f(x):
  def body(i, x):
    fut = ppermute_start(x)
    y = ppermute_done(fut)
    return y
  return fori_loop(0, 8, body, x)

fori_loop 的一个实现细节是输入和输出缓冲区自动相互别名。请注意,我们在 ppermute_startppermute_done 操作中设置了一些额外的别名。让我们通过为程序中的每个值着色来运行我们自己的“缓冲区分配”,以确定我们需要多少个唯一的缓冲区。

首先,我们将解包具有别名 xout 缓冲区的 fut 元组。

def f(x):
  def body(i, x):
    *sems, x, y = ppermute_start(x)
    y = ppermute_done(*sems, x, y)
    return y
  return fori_loop(0, 8, body, x)

现在让我们根据分配给它们的唯一缓冲区为每个值着色。我们有来自 fori_loop 的输入/输出别名、来自 ppermute_startx 别名和来自 ppermute_doney 别名。

def f(x):
  def body(i, x):
    *sems, x, y = ppermute_start(x)
    y = ppermute_done((*sems, x, y))
    return y
  return fori_loop(0, 8, body, x)

如果您运行别名分析,您会发现所有缓冲区都已着色相同!直观地说,这是有问题的,因为如果我们正在执行 ppermute 的循环,我们不能写入我们正在发送到的同一缓冲区。我们通常需要一个额外的(即“双”)缓冲区来接收,然后通常我们将在下一次迭代中切换发送/接收缓冲区。XLA 在实践中会做的是,它会观察到缓冲区重用并防御性地插入副本。

def f(x):
  def body(i, x):
    x = copy(x)
    *sems, x, y = ppermute_start(x)
    y = ppermute_done((*sems, x, y))
    return y
  return fori_loop(0, 8, body, x)

此副本意味着 xy 不再相互别名,并且程序将是正确的。但是,我们需要这个副本吗?我们如何引入双缓冲区以避免每次迭代都进行昂贵的复制?答案是展开!

我们将手动展开我们的代码。

def f(x):
  def body(i, x):
    *sems, x, x2 = ppermute_start(x)
    x2 = ppermute_done((*sems, x, x2))
    
    *sems, x2, y = ppermute_start(x2)
    y = ppermute_done((*sems, x2, y))
    return y
  return fori_loop(0, 4, body, x)

现在,如果我们运行相同的别名分析,我们将发现缓冲区都不再相互别名,并且我们不需要插入防御性副本来保证正确性。

因此,删除这些副本的简单解决方案是使用 fori_loopunroll >= 2

def f(x):
  def body(i, x):
    fut = ppermute_start(x)
    y = ppermute_done(fut)
    return y
  return fori_loop(0, 8, body, x, unroll=2)

这足以实现此循环而无需额外的副本!

跨循环边界传递 futures#

现在让我们看一个更高级的示例。我们将实现与之前相同的程序,但会交错循环,我们在循环之前的序言中开始 ppermute,并在循环开始时等待 ppermute

def f(x):
  fut = ppermute_start(x)
  def body(i, fut):
    x = ppermute_done(fut)
    fut = ppermute_start(x)
    return fut
  fut = fori_loop(0, 7, body, fut)
  return ppermute_done(fut)

在此示例中,我们不是将值 x 从一个循环传递到另一个循环,而是传递一个 future 值。

让我们再次解包 future 以查看发生了什么。

def f(x):
  fut = ppermute_start(x)
  def body(i, fut):
    *sems, x, out = fut
    x = ppermute_done((*sems, x, out))
    (*sems, x, out) = ppermute_start(x)
    return (*sems, x, out)
  (*sems, x, out) = fori_loop(0, 7, body, x)
  return ppermute_done((*sems, x, out))

所以我们显式地将信号量、输入缓冲区和目标输出缓冲区作为循环携带传递。如果我们现在运行别名分析会发生什么?嗯,我们将遇到与上一节中相同的别名问题,其中 xout 将相互别名。XLA 将引入一个副本。

def f(x):
  fut = ppermute_start(x)
  def body(i, fut):
    *sems, x, out = fut
    out = copy(out)
    x = ppermute_done((*sems, x, out))
    (*sems, x, out) = ppermute_start(x)
    return (*sems, x, out)
  (*sems, x, out) = fori_loop(0, 7, body, x)
  return ppermute_done((*sems, x, out))

在这种情况下,我们在 out 上插入了一个副本。但是,这是一个非常糟糕的情况,因为 out 正在被主动复制到其中!即使我们在 x 上插入副本,我们也会遇到问题,因为那时 x 的生命周期不会扩展到 ppermute_done。这非常非常糟糕!我们不仅会得到副本,而且还会得到不正确的结果!

正如我们之前观察到的,解决方案是通过展开来避免别名所有缓冲区,从而避免副本。所以,如果我们这样做

def f(x):
  fut = ppermute_start(x)
  def body(i, fut):
    x = ppermute_done(fut)
    fut = ppermute_start(x)
    return fut
  fut = fori_loop(0, 7, body, x, unroll=2)
  return ppermute_done(fut)

我们的程序现在应该是正确的。

整合在一起#

所以我们已经总结出一些经验法则

  1. 如果我们的操作依赖于 ppermute 的输入值,请解包 future 以使用别名值而不是原始值。

  2. 使用 unroll >= 2 在循环体中执行 ppermute 时。

让我们将所有内容组合到一个函数中,该函数在循环中执行 ppermute 并累积结果。

def f(x):
  out = jnp.zeros_like(x)
  fut = (*sems, x, out) = ppermute_start(x)
  out = out + x
  def body(i, carry):
    out, fut = carry
    x = ppermute_done(fut)
    fut = (*sems, x, out) = ppermute_start(x)
    out = out + x
    return out, fut
  out, fut = fori_loop(0, 7, body, (out, fut), unroll=2)
  return out, ppermute_done(fut)

请注意,在此示例中,我们不需要 optimization_barrier,因为循环边界充当调度屏障,分隔了 startdone

就这样,我们完成了!这将是 Pallas 中执行异步操作的官方 API。感谢大家!任务完成!

真的是这样吗?

状态的反击#

虽然看起来我们通过使用一些巧妙的技巧解决了复制和不正确的问题,但我们仍然处于尴尬的境地。这个 API 功能强大,但有很多陷阱和注意事项。可能还有更多我们需要处理的边缘情况,这些情况甚至需要深入了解 XLA 才能预测或理解。我们应该发布这样的 API 吗?还是有其他替代方案?

好吧,答案可能一直就在我们眼前。

让我们再运行一遍整个练习,*除了*,让我们编写有状态版本。这意味着我们每个自定义异步操作现在都在 Ref 而不是值上运行。

def ppermute_start_stateful(x_ref, y_ref) -> tuple[Semaphore, Semaphore]:
  ...

def ppermute_done_stateful(send_sem, recv_sem, x_ref, y_ref) -> None:
  ...

让我们假设我们可以在 Pallas 中实现这些,看看我们的新程序会是什么样子。让我们从基本的集体 permute 开始

def f(x):
  x_ref = make_ref(x)
  y_ref = make_ref(zeros_like(x))
  fut = ppermute_start_stateful(x_ref, y_ref)
  ppermute_done_stateful(*fut, x_ref, y_ref)
  return y_ref[...]

它比我们最初的基于值的版本稍微冗长一些,但它有一些关键区别。第一个区别是我们创建了一个“空的” Ref 来接收 ppermute 的结果,这与基于值的版本不同,后者为我们创建一个值。一个巧妙之处是 x_ref 的生命周期在这里很明确:它一直存在到 ppermute_done_stateful。我们不需要像以前那样“偷偷地”将 x 值放入操作中。

当我们尝试在 start/done 之间添加操作时,另一个区别变得更加明显。

def f(x):
  x_ref = make_ref(x)
  y_ref = make_ref(zeros_like(x))
  fut = ppermute_start_stateful(x_ref, y_ref)
  x_ref[...] += 1
  ppermute_done_stateful(*fut, x_ref, y_ref)
  return y_ref[...]

之前,我们遇到了调度歧义,其中 XLA 可能会相对于 ppermute 重新排序 add。使用有状态语义,我们实际上添加了排序约束!x_ref[...] += 1 会改变 x_ref,因此它不能相对于 ppermute_done_stateful 移动。JAX 可以将这些调度约束作为降低到 HLO 的一部分注入。

当我们尝试循环示例时,最终的关键区别显而易见。

def f(x):
  x_ref = make_ref(x)
  y_ref = make_ref(zeros_like(x))
  def body(i, _):
    fut = ppermute_start_stateful(x_ref, y_ref)
    ppermute_done_stateful(*fut, x_ref, y_ref)
    # Now switch to y_ref -> x_ref
    fut = ppermute_start_stateful(y_ref, x_ref)
    ppermute_done_stateful(*fut, y_ref, x_ref)
  fori_loop(0, 8 // 2, body, None)
  return x_ref[...]

由于我们需要准备一个单独的缓冲区来接收 ppermute,因此我们被迫以一种展开循环的方式编写代码!在 XLA 中无法编写需要复制的版本,因为这将涉及从 Ref 发送到自身的 ppermute,这实际上没有意义。

为了在没有手动展开的情况下处理这个问题,我们将创建一个前导 2 维度的暂存缓冲区,该缓冲区充当跨迭代的发送/接收目标,每次迭代都切换。这与我们在编写手动重叠内核时在 Pallas 内核内部使用的模式相同。

这里的认识是,有状态迫使我们更早地处理许多在使用值语义时出现的问题。我们定义它们消失了!

  1. 调度 - 具有 Ref 作为输入的有状态操作强制我们程序的排序。请注意,这将调度相对于彼此的同一 Ref 上的操作。我们可能还需要一个 opt_barrier_stateful 来强制执行更多排序约束。

  2. 生命周期 - Ref 生命周期可以通过 run_state 作用域限定,也可以作为有状态操作的输入。

  3. 防御性复制 - 使用 Ref 迫使我们“手动”处理缓冲区分配,并且降低可以确保别名工作正常,以避免任何复制。

另一个重要的基本限制是,我们最终会分阶段输出一个 HLO 程序,其中实时缓冲区和信号量表示为数组值类型。XLA 不保证这些中间值的缓冲区生命周期或它们所在的内存空间。因此,即使 Pallas 内核正在主动复制数组值,XLA 也可能复制它们。 这在 HLO 中很容易验证,但这是使用自定义调用来表示 HLO 中的异步操作的一个尖锐边缘。

结论#

我们已经讨论了在 Pallas 和 JAX 中使用异步操作时遇到的一些棘手挑战。Ref 似乎是表示这些操作的一种有希望的方式,它可以规避使用值语义时出现的一些问题。但是,一个缺点是它将有状态的 JAX 放在了首要位置,这是我们在 Pallas 之外尚未做过的。值得思考我们是否应该教育用户了解有状态操作,还是提供更危险的 API。我们也不知道我们想要做的一切是否都可以通过 Ref 来表达。我们还应该集思广益,寻找状态的替代方案,以充实设计空间。例如,如果 XLA 提供了一个尊重生命周期的头等 futures API,并且它可以自动执行诸如使用 futures 对循环进行双缓冲之类的操作,会怎么样?这可能是一个可行的替代方案,但权衡将是给予编译器更多控制权,而不是用户明确的控制权。