Pallas 异步操作#
背景 + 动机#
我们希望在 Pallas 中公开 API,以便在多个内核之间显式地重叠计算和通信。
XLA 异步分解#
作为动机,考虑以下 JAX 伪代码
def f(x):
y = ppermute(x)
z = x + 1
return y, z
在此函数中,我们可以同时执行 ppermute
和 x + 1
。这是 XLA 通过
将
ppermute
分解为ppermute_start
和ppermute_done
操作来实现的,这两个操作通过一个 Future 连接。将
x + 1
调度在ppermute_start
和ppermute_done
之间,
得到如下程序
def f(x):
fut = ppermute_start(x)
z = x + 1 # happens at the same time as ppermute
y = ppermute_done(fut)
return y, z
内核内的异步操作#
现在想象一下,我们没有使用 XLA 的 ppermute
,而是拥有自己的自定义 Pallas ppermute
。
def ppermute_kernel(x_ref, y_ref, send_sem, recv_sem):
right_neighbor = ...
descriptor = pltpu.make_async_remote_copy(x_ref, y_ref, send_sem, recv_sem, device_id=right_neighbor)
descriptor.start()
descriptor.wait_send()
descriptor.wait_recv()
def ppermute(x):
return pl.pallas_call(ppermute_kernel, out_shape=x, ...)(x)
目前,我们无法像 XLA 那样将 ppermute
分解为 start/done
对,因此我们明确地将 x + 1
**融合**到内核中。
def add_one(x_ref, z_ref):
z_ref[...] = x_ref[...] + 1
def ppermute_add_one_kernel(x_ref, y_ref, z_ref, send_sem, recv_sem):
right_neighbor = ...
descriptor = pltpu.make_async_remote_copy(x_ref, y_ref, send_sem, recv_sem, device_id=right_neighbor)
descriptor.start()
# Explicitly schedule inner kernel between start/wait
pltpu.emit_pipeline(add_one)(x_ref, z_ref)
descriptor.wait_send()
descriptor.wait_recv()
def ppermute_and_add_one(x):
return pl.pallas_call(ppermute_add_one_kernel, out_shape=(x, x), ...)(x)
目标是能够为启动 ppermute
和等待其完成编写单独的内核,以便我们可以在两者之间插入一个常规的 x + 1
(或任何我们想要的计算)。这使得代码更具可读性、可维护性,并且不易出错。
我们如何实现已分解的 Pallas 异步操作(在 TPU 上)?#
在 Pallas 中实现已分解的异步操作时,需要解决的主要问题是它们之间传递的 future
包含什么。具体来说,它必须包含有关后台操作的一些重要状态。
如果我们查看 Pallas 代码,我们可以看到我们需要一个“描述符”来启动和等待远程复制。我们可以将此描述符引出 Pallas 内核,然后将其传递到另一个内核吗?嗯,差不多。底层的 TPU 硬件通过一对信号量跟踪异步操作进度:send_sem
使我们能够等待设备完成向其邻居发送数据,而 recv_sem
跟踪从邻居发送到设备的数据传输。如果我们想象编写一个启动内核和一个完成内核,我们只需要从启动内核传递到完成内核的就是信号量以及有关在信号量上等待多少的信息。
我们可以通过扩展 Pallas 以支持从内核返回信号量来实现这一点。
def ppermute_start_kernel(
in_ref, send_sem, recv_sem, out_ref, *, axis_name,
):
axis_size = jax.lax.psum(1, axis_name)
left_neighbor = jax.lax.rem(
jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size
)
right_neighbor = jax.lax.rem(jax.lax.axis_index(axis_name) + 1, axis_size)
barrier_sem = pltpu.get_barrier_semaphore()
pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor)
pltpu.semaphore_wait(barrier_sem, 1)
pltpu.make_async_remote_copy(
in_ref, out_ref, send_sem, recv_sem, device_id=right_neighbor
).start()
def ppermute_start(x, *, axis_name) -> tuple[Semaphore, Semaphore, Array]:
send_sem, recv_sem, out = pl.pallas_call(
functools.partial(ppermute_start_kernel, axis_name=axis_name),
out_shape=(
pltpu.SemaphoreType.DMA(()),
pltpu.SemaphoreType.DMA(()),
jax.ShapeDtypeStruct(
x.shape,
dtype=x.dtype,
),
),
in_specs=[
pl.BlockSpec(memory_space=pltpu.ANY),
],
out_specs=(
pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
pl.BlockSpec(memory_space=pltpu.ANY),
),
)(x)
return send_sem, recv_sem, out
请注意,这里发生了一些微妙的事情。Pallas 告诉 XLA 它希望某些输出是信号量(也称为同步标志),XLA 会将它们视为“保留”(例如,只要它们在 XLA 程序中处于活动状态,其他内核就无法分配这些同步标志)。它们类似于屏障信号量,屏障信号量是由 XLA 管理的保留信号量。
另一件值得注意的事情是,我们从启动内核返回输出缓冲区 out
,此时它正在被积极复制到其中。
现在我们编写执行阻塞操作的 done
内核。我们将 out
传递到内核中,以计算阻塞信号量所需的形状。
def ppermute_done_kernel(ref, send_sem, recv_sem, _):
pltpu.make_async_copy(ref, ref, send_sem).wait()
pltpu.make_async_copy(ref, ref, recv_sem).wait()
def ppermute_done(send_sem, recv_sem, out) ->Array:
out = pl.pallas_call(
ppermute_done_kernel,
out_shape=(
jax.ShapeDtypeStruct(
out.shape,
dtype=out.dtype,
),
),
in_specs=[
pl.BlockSpec(memory_space=pltpu.ANY),
pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
],
out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
input_output_aliases={0:0}
)(out, send_sem, recv_sem)
return out
注意:我们在此处对输出缓冲区进行 i/o 别名,以保证消费者在 ppermute_done
的下游。
现在我们可以实现已分解的集体置换。
def f(x):
fut = ppermute_start(x)
z = x + 1 # happens at the same time as ppermute
y = ppermute_done(fut)
return y, z
或者我们可以吗?
为什么不起作用?#
这还有三个遗留问题,其中每个问题在一定程度上都存在于 Pallas 之外。下面是它们的高层概述。
调度 - 仅仅因为我们编写了
ppermute_start
,然后是x + 1
,然后是ppermute_done
,并不能保证它们会按此顺序执行。XLA 负责调度,所以当我们编写 JAX 程序时,我们正在设置 XLA 将遵循的数据依赖关系,但 XLA 不会遵循 JAX 中编写的操作的特定顺序。生命周期 - XLA 假设一旦一个值在依赖图中超出范围,其内存就可以被释放以供其他值使用。如果我们有一个操作异步复制 x -> y,我们需要确保 x 在复制完成之前是活动的,否则我们将从垃圾内存中复制。
防御性复制 - XLA 保留创建值副本的权利。我们需要确保不引入不必要的复制,以 a) 避免不必要的运行时开销,以及 b) 确保正确性。
我们将逐一讨论这些问题并提出解决方案。
调度#
我们如何显式地强制 JAX 中的操作按特定顺序执行?请注意,这不是 Pallas 特有的问题,如果我们使用替代方法实现异步操作,仍然会遇到此问题。
一种方法是在 XLA 程序中引入优化屏障。优化屏障将阻止 XLA 移动其周围的操作。
这是我们的原始代码
def f(x):
fut = ppermute_start(x)
z = x + 1
y = ppermute_done(fut)
return y, z
XLA 可以选择在三个位置中的任何一个执行 x + 1
def f(x):
z = x + 1
fut = ppermute_start(x)
y = ppermute_done(fut)
return y, z
# OR
def f(x):
fut = ppermute_start(x)
z = x + 1
y = ppermute_done(fut)
return y, z
# OR
def f(x):
fut = ppermute_start(x)
y = ppermute_done(fut)
z = x + 1
return y, z
为了强制 x + 1
发生在 ppermute
操作之间,我们可以使用 optimization_barrier
,它在语义上是身份函数(即 lambda x: x
),但会引入值之间的显式数据依赖关系。具体来说,如果我们使 x + 1
中使用的 x
依赖于 ppermute_start
返回的 fut
,则它必须发生在 ppermute_start
之后。
我们还引入了一个依赖关系,该依赖关系强制输出值 y
依赖于 z
。
def f(x):
fut = ppermute_start(x)
x, fut = optimization_barrier((x, fut)) # x now depends on fut
z = x + 1
z, fut = optimization_barrier((z, fut)) # fut now depends on z
y = ppermute_done(fut)
return y, z
optimization_barrier
对于我们显式写出调度来说是一个足够好的工具。
生命周期#
让我们再次查看我们的原始代码,并假设操作按正确的顺序执行。
def f(x):
fut = ppermute_start(x)
z = x + 1
y = ppermute_done(fut)
return y, z
让我们看看在程序的哪个点 XLA 认为可以释放 x
的缓冲区。它将是 x
不再使用的点之后,特别是 z = x + 1
之后。
def f(x):
fut = ppermute_start(x)
z = x + 1
# XLA can free x here!
y = ppermute_done(fut)
return y, z
如果 XLA 在 z = x + 1
完成后释放 x
,我们会遇到一个非常糟糕的问题。该 ppermute
可能仍在 z = x + 1
之后积极地将 x
复制到邻居,这意味着如果 x
被释放,ppermute
将从垃圾内存中读取!
我们如何将 x
的生命周期延长到 ppermute_done
?嗯,我们可以引入一个数据依赖!我们需要稍微修改我们的内核来实现这一点。
首先,我们重写 ppermute_start
以返回 x
,并将其通过内核进行别名。
def ppermute_start_kernel(
in_ref, send_sem, recv_sem, out_ref, _, *, axis_name,
):
axis_size = jax.lax.psum(1, axis_name)
left_neighbor = jax.lax.rem(
jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size
)
right_neighbor = jax.lax.rem(jax.lax.axis_index(axis_name) + 1, axis_size)
barrier_sem = pltpu.get_barrier_semaphore()
pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor)
pltpu.semaphore_wait(barrier_sem, 1)
pltpu.make_async_remote_copy(
in_ref, out_ref, send_sem, recv_sem, device_id=right_neighbor
).start()
def ppermute_start(x, *, axis_name) -> tuple[Semaphore, Semaphore, Array, Array]:
send_sem, recv_sem, x, out = pl.pallas_call(
functools.partial(ppermute_start_kernel, axis_name=axis_name),
out_shape=(
pltpu.SemaphoreType.DMA(()),
pltpu.SemaphoreType.DMA(()),
jax.ShapeDtypeStruct(
x.shape,
dtype=x.dtype,
),
jax.ShapeDtypeStruct(
x.shape,
dtype=x.dtype,
),
),
in_specs=[
pl.BlockSpec(memory_space=pltpu.ANY),
],
out_specs=(
pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
pl.BlockSpec(memory_space=pltpu.ANY),
pl.BlockSpec(memory_space=pltpu.ANY),
),
input_output_aliases={0:2}
)(x)
return send_sem, recv_sem, x, out
然后我们让 ppermute_done
接受 x
并对其不做任何操作。
def ppermute_done_kernel(_, ref, send_sem, recv_sem, _):
pltpu.make_async_copy(ref, ref, send_sem).wait()
pltpu.make_async_copy(ref, ref, recv_sem).wait()
def ppermute_done(send_sem, recv_sem, x, out) ->Array:
out = pl.pallas_call(
ppermute_done_kernel,
out_shape=(
jax.ShapeDtypeStruct(
out.shape,
dtype=out.dtype,
),
),
in_specs=[
pl.BlockSpec(memory_space=pltpu.ANY),
pl.BlockSpec(memory_space=pltpu.ANY),
pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
],
out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
input_output_aliases={1:0}
)(x, out, send_sem, recv_sem)
return out
现在当我们写
def f(x):
*sems, x ,out = ppermute_start(x)
z = x + 1
y = ppermute_done(*sems, x, out)
return y, z
XLA 不再能释放 x
,因为它是 ppermute_done
的输入!这意味着 x
的生命周期与 ppermute
相关联,并且此代码现在是正确的。
防御性复制#
XLA 在其缓冲区分配传递中,分析哪些缓冲区相互别名,并在操作别名其中一个输入但不是该输入的最终使用者时插入复制。
背景#
这是一个简单的例子。假设我们有一个 add_one_inplace
操作,它接收一个数组并加一,但承诺就地执行。
以下代码将是合法的。
def f():
x = jnp.arange(...)
y = add_one_inplace(x) return y
但是,如果 x
也有一个单独的消费者,程序可能无法正确执行。
def f():
x = jnp.arange(...)
y = add_one_inplace(x)
return y, x * 2 # another x consumer!
这是因为 x * 2
操作的是原始 x
,而 add_one_inplace
会覆盖 x
中的值。x * 2
需要确保读取 x
的原始值,而不是我们对其加一后的值。XLA 注意到这一点并插入一个 copy
操作(在语义上是身份操作,但输入和输出缓冲区将不同)。
def f(x):
x2 = copy(x)
y = add_one_inplace(x2)
return y, x * 2
XLA 中的此传递通过强制执行原地更新操作实际上成为异地操作(通过 copy
操作)来确保正确性。
带下游操作的复制#
让我们回顾一下我们在 ppermute
时加一的例子。
def f(x):
fut = ppermute_start(x)
z = x + 1
y = ppermute_done(fut)
return y, z
如果我们解开 Future 到它的组成部分,我们会看到别名模式
def f(x):
*sems, x2, y = ppermute_start(x)
z = x + 1
y = ppermute_done((*sems, x2, y))
return y, z
我们知道 x
不会被 ppermute_start
更改(也就是说,x
与 x2
相同),但 XLA 并不知道。实际上,它看起来像我们的 add_one_inplace
示例对 XLA 而言,它保守地假设 ppermute_start
更改了 x
,而 x2
是新的别名结果。因此,当我们执行 z = x + 1
时,我们会遇到原始缓冲区的消费者。因此,XLA 会插入一个副本!
def f(x):
x2 = copy(x)
*sems, x2, y = ppermute_start(x2)
z = x + 1
y = ppermute_done((*sems, x2, y))
return y, z
这个副本是不必要的,因为我们知道 x2
与 x
相比没有改变。为了移除这个副本,我们需要一种机制来告知 XLA 我们只是转发一个值。然而,在没有这种机制的情况下,我们可以稍微重写我们的程序,明确使用 x2
而不是 x
。
def f(x):
*sems, x2, y = ppermute_start(x)
z = x2 + 1
y = ppermute_done((*sems, x2, y))
return y, z
现在,XLA 看不到 x
的单独消费者,因此不再引入复制。然而,这有一个重大的缺点,即它迫使我们解开来自 ppermute_start
的 Future。它将生命周期问题与复制问题耦合起来。
循环别名#
让我们考虑一个稍微高级的例子。我们将实现一个使用 while_loop
和 ppermute
将值发送到环形结构的函数。
def f(x):
def body(i, x):
fut = ppermute_start(x)
y = ppermute_done(fut)
return y
return fori_loop(0, 8, body, x)
fori_loop
的一个实现细节是输入和输出缓冲区会自动别名。请注意,我们在 ppermute_start
和 ppermute_done
操作中设置了一些额外的别名。让我们进行自己的“缓冲区分配”,为程序中的每个值着色,以确定我们需要多少个唯一的缓冲区。
首先,我们将解开具有别名 x
和 out
缓冲区的 fut
元组。
def f(x):
def body(i, x):
*sems, x, y = ppermute_start(x)
y = ppermute_done(*sems, x, y)
return y
return fori_loop(0, 8, body, x)
现在让我们根据分配的唯一缓冲区为每个值着色。我们有来自 fori_loop
的输入/输出别名,来自 ppermute_start
的 x
别名,以及来自 ppermute_done
的 y
别名。
def f(x):
def body(i, x):
*sems, x, y = ppermute_start(x)
y = ppermute_done((*sems, x, y))
return y
return fori_loop(0, 8, body, x)
如果你运行别名分析,你会发现所有缓冲区都被着色为相同!直观地说,这是个问题,因为如果我们进行 ppermute
的循环,我们就不能写入我们正在发送数据的同一缓冲区。我们通常需要一个额外的(即“双”)缓冲区来接收,然后在下一个迭代中切换发送/接收缓冲区。XLA 在实践中会观察到缓冲区重用并防御性地插入一个副本。
def f(x):
def body(i, x):
x = copy(x)
*sems, x, y = ppermute_start(x)
y = ppermute_done((*sems, x, y))
return y
return fori_loop(0, 8, body, x)
这个副本意味着 x
和 y
不再相互别名,程序将是正确的。但是,我们是否需要这个副本?我们如何引入双缓冲区以避免每次迭代都产生昂贵的复制?答案是展开!
我们将手动展开我们的代码。
def f(x):
def body(i, x):
*sems, x, x2 = ppermute_start(x)
x2 = ppermute_done((*sems, x, x2))
*sems, x2, y = ppermute_start(x2)
y = ppermute_done((*sems, x2, y))
return y
return fori_loop(0, 4, body, x)
现在如果我们进行相同的别名分析,我们会发现缓冲区都不再别名,并且我们不需要插入防御性副本即可保证正确性。
因此,移除这些副本的简单方法是使用 fori_loop
和 unroll >= 2
。
def f(x):
def body(i, x):
fut = ppermute_start(x)
y = ppermute_done(fut)
return y
return fori_loop(0, 8, body, x, unroll=2)
这就足以在没有额外副本的情况下实现此循环了!
跨循环边界传递 Future#
现在让我们看一个更高级的例子。我们将实现与之前相同的程序,但要错开循环,我们在循环之前的序言中启动 ppermute
,并在循环开始时等待 ppermute
。
def f(x):
fut = ppermute_start(x)
def body(i, fut):
x = ppermute_done(fut)
fut = ppermute_start(x)
return fut
fut = fori_loop(0, 7, body, fut)
return ppermute_done(fut)
在此示例中,我们传递的是 Future 值,而不是将值 x
从一个循环传递到另一个循环。
让我们再次解开 Future,看看发生了什么。
def f(x):
fut = ppermute_start(x)
def body(i, fut):
*sems, x, out = fut
x = ppermute_done((*sems, x, out))
(*sems, x, out) = ppermute_start(x)
return (*sems, x, out)
(*sems, x, out) = fori_loop(0, 7, body, x)
return ppermute_done((*sems, x, out))
所以我们正在显式地将信号量、输入缓冲区和目标输出缓冲区作为循环携带者进行传递。如果我们现在运行别名分析会怎样?嗯,我们会遇到与上一节相同的别名问题,即 x
和 out
将相互别名。XLA 将插入一个副本。
def f(x):
fut = ppermute_start(x)
def body(i, fut):
*sems, x, out = fut
out = copy(out)
x = ppermute_done((*sems, x, out))
(*sems, x, out) = ppermute_start(x)
return (*sems, x, out)
(*sems, x, out) = fori_loop(0, 7, body, x)
return ppermute_done((*sems, x, out))
在这种情况下,我们在 out
上插入了一个副本。但是,这是一个非常糟糕的情况,因为 out
正在被积极地复制到其中!即使我们在 x
上插入一个副本,我们也会遇到问题,因为那时 x
的生命周期将不会延伸到 ppermute_done
。这非常非常糟糕!我们不仅会得到副本,还会得到错误的结果!
如前所述,解决方案是通过展开来避免缓冲区别名,从而避免复制。所以,如果我们这样做
def f(x):
fut = ppermute_start(x)
def body(i, fut):
x = ppermute_done(fut)
fut = ppermute_start(x)
return fut
fut = fori_loop(0, 7, body, x, unroll=2)
return ppermute_done(fut)
我们的程序现在应该是正确的。
整合#
所以我们已经得出了一些经验法则
如果我们有依赖于
ppermute
输入值的操作,请解开 Future 以使用别名值而不是原始值。在循环体中进行
ppermute
时,使用unroll >= 2
。
让我们将所有内容组合成一个执行循环中 ppermute
并累积结果的函数。
def f(x):
out = jnp.zeros_like(x)
fut = (*sems, x, out) = ppermute_start(x)
out = out + x
def body(i, carry):
out, fut = carry
x = ppermute_done(fut)
fut = (*sems, x, out) = ppermute_start(x)
out = out + x
return out, fut
out, fut = fori_loop(0, 7, body, (out, fut), unroll=2)
return out, ppermute_done(fut)
请注意,在此示例中,我们不需要 optimization_barrier
,因为循环边界充当调度屏障,分隔 start
和 done
。
就是这样,我们完成了!这将是 Pallas 中异步操作的官方 API。感谢大家!任务完成!
或者真的完成了吗?
状态的复仇#
虽然我们似乎通过一些巧妙的技巧解决了复制和不正确性问题,但我们仍然处于尴尬的境地。这个 API 很强大,但有很多陷阱和注意事项。很可能还需要处理更多边缘情况,甚至需要深入了解 XLA 才能预测或理解。我们应该发布这样的 API 吗?还是有替代方案?
嗯,答案可能一直都在我们眼前。
让我们再来一次,但是,让我们编写有状态的版本。这意味着我们的每个自定义异步操作现在都操作 Ref
而不是值。
def ppermute_start_stateful(x_ref, y_ref) -> tuple[Semaphore, Semaphore]:
...
def ppermute_done_stateful(send_sem, recv_sem, x_ref, y_ref) -> None:
...
让我们假设我们可以在 Pallas 中实现这些,看看我们的新程序会是什么样子。让我们从一个基本的集体置换开始
def f(x):
x_ref = make_ref(x)
y_ref = make_ref(zeros_like(x))
fut = ppermute_start_stateful(x_ref, y_ref)
ppermute_done_stateful(*fut, x_ref, y_ref)
return y_ref[...]
与我们最初基于值的版本相比,它有点冗长,但有几个关键区别。第一个是,我们创建一个“空”的 Ref
来接收 ppermute
的结果,而不是基于值的版本,它会为我们创建一个值。一个巧妙之处在于 x_ref
的生命周期在这里很清楚:它一直存在直到 ppermute_done_stateful
。我们不需要像以前那样将 x
值“偷偷”放入操作中。
当我们尝试在 start/done
之间添加操作时,另一个区别变得更加清晰。
def f(x):
x_ref = make_ref(x)
y_ref = make_ref(zeros_like(x))
fut = ppermute_start_stateful(x_ref, y_ref)
x_ref[...] += 1
ppermute_done_stateful(*fut, x_ref, y_ref)
return y_ref[...]
以前,我们遇到了调度歧义,XLA 可以相对于 ppermute
重新排序加法。通过有状态语义,我们实际上增加了一个排序约束!x_ref[...] += 1
会改变 x_ref
,因此它不能相对于 ppermute_done_stateful
移动。JAX 可以将这些调度约束作为 HLO 降低的一部分注入。
当我们在循环示例中尝试时,最终的关键区别就显而易见了。
def f(x):
x_ref = make_ref(x)
y_ref = make_ref(zeros_like(x))
def body(i, _):
fut = ppermute_start_stateful(x_ref, y_ref)
ppermute_done_stateful(*fut, x_ref, y_ref)
# Now switch to y_ref -> x_ref
fut = ppermute_start_stateful(y_ref, x_ref)
ppermute_done_stateful(*fut, y_ref, x_ref)
fori_loop(0, 8 // 2, body, None)
return x_ref[...]
由于需要一个单独的缓冲区来接收 ppermute
,我们被迫以展开的方式编写代码!没有办法编写需要复制的 XLA 版本,因为那将涉及一个从 Ref
发送到自身的 ppermute
,这实际上没有意义。
为了在没有手动展开的情况下处理此问题,我们将创建一个带有前导 2
维度的暂存缓冲区,该缓冲区充当跨迭代的发送/接收目标,并在每个迭代中切换。这与我们在编写手动重叠内核时在 Pallas 内核中使用的模式相同。
这里的认识是,有状态性迫使我们更早地处理价值语义带来的许多问题。我们通过定义它们来规避它们!
调度 - 将
Ref
作为输入的有状态操作会强制我们的程序进行排序。请注意,这将按相对于彼此的顺序调度同一Ref
上的操作。我们可能还需要一个opt_barrier_stateful
来强制执行更多排序约束。生命周期 -
Ref
的生命周期可以通过run_state
来限定范围,或者可以作为有状态操作的输入。防御性复制 - 使用
Ref
迫使我们“手动”处理缓冲区分配,并且降低可以确保别名正确,从而避免任何复制。
另一个重要的基本限制是,我们最终会阶段化一个 HLO 程序,其中活动缓冲区和信号量表示为数组值类型。XLA 不保证这些中间值的缓冲区生命周期或它们所在的内存空间。因此,XLA 可能会复制数组值,即使它们正在被 Pallas 内核积极复制。 这在 HLO 中很容易验证,但它是使用自定义调用来表示 HLO 中的异步操作的一个尖锐边缘。
结论#
我们已经讨论了 Pallas 和 JAX 中异步操作的一些棘手挑战。Ref
s 似乎是表示这些操作的有前途的方式,可以绕过价值语义出现的一些问题。然而,缺点是它将有状态的 JAX 置于前台,而我们目前在 Pallas 之外还没有这样做。值得考虑我们是否应该教育用户了解有状态操作,还是提供一个更危险的 API。我们也不知道我们想要做的一切是否都可以通过 Ref
s 来表达。我们还应该集思广益,探索状态的替代方案,以充实设计空间。例如,如果 XLA 提供了一流的 Future API,能够尊重生命周期,并且可以自动执行诸如双缓冲带有 Future 的循环等操作?这可能是一个可行的替代方案,但权衡将是给予编译器更多的控制权,而不是用户显式控制。