JAX 中的副作用排序#

sharadmv@ 2022年5月9日

概述#

当我们编写 JAX 代码时,通常可以假装我们正在编写单线程、急切执行的 Python 代码,尽管在底层,JAX 及其运行时可能会在后台异步执行它。只要我们编写纯(无副作用)代码,这些性能优化通常对我们是不可见的,并且不会干扰我们的单线程心智模型。异步执行非常棒——我们无需考虑即可获得高性能的并行代码!

然而,在存在副作用的情况下,这种假象开始破裂,我们心智模型中的裂痕也开始显现。具体来说,当我们考虑副作用发生的*顺序*时,这些差异就会出现。

在本设计说明中,我们将探讨 JAX 的执行模型与副作用排序之间的相互作用。我们还将提供一种强制“单线程”效应排序的方法。

背景#

当我们编写以下 Python 代码时

def f():
  print("hello")
  return 2
def g():
  print("world")
  return 3
f()
g()

我们期望 "hello""world" 之前打印。这看起来可能很明显,但请考虑以下 JAX 代码

@partial(jax.jit, device=<device 0>)
def f():
  return 2

@partial(jax.jit, device=<device 1>)
def g():
  return 3
f()
g()

在许多情况下,JAX 会*并行*执行 fg,将计算分派到不同的线程——g 实际上可能在 f 之前执行。并行执行是一种很好的性能优化,特别是当设备之间的数据复制开销很大时(更多详细信息请参阅异步分派说明)。然而在实践中,我们通常不需要考虑异步分派,因为我们编写的是纯函数,只关心函数的输入和输出——我们会自然地阻塞未来的值。

然而,现在假设我们有一个在 JIT 编译的 JAX 函数内部工作的 jax.print 函数(host_callback.id_print 是一个例子)。让我们回到之前的例子,但这次加入了打印。

@partial(jax.jit, device=<device 0>)
def f():
  jax.print("hello")
  return 2

@partial(jax.jit, device=<device 1>)
def g():
  jax.print("world")
  return 3
f()
g()

得益于异步分派,我们实际上可能会看到 "world""hello" 之前打印。打印副作用的重新排序打破了单线程执行模型的假象。

副作用可能“揭示”乱序执行的另一个例子是当我们编译 JAX 程序时。考虑以下 JAX 代码

@jax.jit
def f(x):
  jax.print("hello")
  jax.print("world")
  return x

尽管在 Python 中,我们先写了打印 "hello",后写了打印 "world",但像 XLA 这样的编译器可以自由地重新排序它们,因为这些打印之间没有显式的数据依赖。

动机#

我们希望支持“有序”效应。当我们说有序时,我们的意思是效应发生的顺序与我们执行单线程 Python 程序时的顺序相同。这是我们的主要目标。在存在显式并行性(如 pmap 或用户线程)的情况下,我们不需要维持这种行为,但至少如果用户没有显式请求并行性,我们希望保留单线程排序。

在我们深入探讨之前,让我们首先退一步问自己:为了性能,重新排序效应是否可以接受?反过来说,我们是否需要强制执行效应的顺序?在某些情况下,我们不需要排序。也许某些副作用不应该对 JAX 程序的性能产生不利影响。然而,对于其他副作用,我们可能希望强制执行单线程程序顺序,以免用户遇到反直觉的行为。考虑一个日志记录效应。

@jax.jit
def f(x, y):
  log_value(x)
  log_value(y)
f(1, 2)

如果 log 正在修改一个全局列表,我们可能会期望在添加 y 之前先添加 x。对于更严格的效应,我们可能希望选择对效应进行排序。

强制有序效应#

我们强制计算顺序的主要工具是*数据依赖*。简而言之,如果函数 g 的一个输入是函数 f 的输出,那么 f 必须在 g 之前执行。

然而,我们可能会遇到像打印这样的副作用,它们根本没有输入,因此我们无法天真地对它们进行排序。因此,我们使用*令牌(token)*作为向计算中注入人工数据依赖的方法。

什么是令牌?令牌只是一个可以传入传出计算的虚拟值。通过在多个计算中传入传出同一个令牌,我们强制它们以特定的顺序发生。让我们以上面的打印示例为例,看看它在混合了令牌之后会是什么样子

@jax.jit
def f(token, x):
  token = jax.print(token, "hello")
  token = jax.print(token, "world")
  return token, x

如果我们将 jax.print 重写为接受并返回一个令牌,那么我们现在已经对这两个打印进行了排序,因为第二个打印的输入依赖于第一个打印的输出。实际上,token 的实际值可以是任何东西,但我们将在实践中看到令牌对用户是不可见的。

运行时令牌与编译器令牌#

在这里,我们将开始讨论实现细节。在实践中,我们需要两种不同类型的令牌来对效应进行排序:每种用于上述重排序来源之一。我们将需要*运行时令牌*来对异步分派的副作用计算进行排序,我们将需要*编译器令牌*来对计算内部的效应进行排序。

实际上,我们的计算将被重写为如下所示

@jax.jit
def f(runtime_token, x):
  compiler_token = new_compiler_token()
  compiler_token = jax.print(compiler_token, "hello")
  compiler_token = jax.print(compiler_token, "world")
  return runtime_token, x

请注意,运行时令牌仅在 JIT 边界处使用,而编译器令牌仅在编译后的代码内部使用。编译器令牌是在“降低”(我们将 Python 代码转换为 HLO 或 StableHLO 等较低级别表示)期间创建的,但运行时令牌需要在 Python 中管理,因为它们在 JIT 编译的函数中被传入传出。

此外,请注意运行时令牌与编译器令牌是“断开连接”的,这意味着它们之间没有数据依赖。这可能很危险,因为我们可能会失去两个分派函数调用主体之间的数据依赖。然而,如果我们假设“严格执行”——即分派函数仅在其所有输入都准备就绪时才开始执行,并且其所有输出将同时准备就绪——那么我们可以安全地创建一个新的编译器令牌并返回一个不依赖于输出的运行时令牌。

管理运行时令牌#

为了代表用户管理运行时令牌,我们需要接入 JAX 的分派机制。每当我们调用一个 JIT 编译的函数时,最终都会落入一个像这样的函数中

def _execute(compiled_computation, *args):
  outputs = compiled_computation.execute(*args)
  return outputs

此时,我们需要将运行时令牌“注入”到计算中,并从计算的输出中“提取”它们。

def _execute(compiled_computation, *args):
  runtime_token = get_runtime_token() # Grab global token
  runtime_token, *outputs = compiled_computation.execute(runtime_token, *args)
  update_runtime_token(runtime_token) # Update global token
  return outputs

那么 runtime_token 到底是什么?我们需要能够将其传递给 compiled_computation,这意味着它需要是某种数组(目前,因为编译后的 JAX 代码内部和外部没有共享的令牌表示)。实际上,我们可以使用一个 (0,) 形状的数组来最小化开销。

我们还需要考虑多设备使用情况,例如第一个例子中,我们首先在设备 0 上调用一个 JIT 编译的函数,然后又在设备 1 上调用一个。在这种情况下,我们还需要将第一个计算返回的运行时令牌(它位于设备 0 上)*复制*到设备 1,以便我们可以将其传递给第二个计算。如果两个后续计算共享同一设备,则无需进行此复制。

添加编译器令牌#

当我们将 Python 代码降低到 HLO 或 StableHLO 时,我们需要在计算开始时创建一个令牌,并确保在需要排序的副作用计算中可以使用这些令牌。副作用计算将接受令牌作为输入并将其作为输出返回。

这种令牌传递的实现涉及升级 JAX 的降低机制以自动进行此记账。主要挑战包括处理像调用原语和控制流原语这样的高阶原语。我们不会在本设计说明中详细介绍如何处理这些问题。

阻塞输出令牌#

为副作用计算添加运行时和编译器令牌支持对于排序很重要,但令牌还有另一个微妙的用例,即阻塞副作用计算。即使我们不希望副作用计算被*排序*,我们可能仍然希望等待它完成。目前我们有 jax.block_until_ready,它会等待直到未来值的结果准备就绪。然而,对于副作用计算,我们可能有一些函数没有返回值但仍在执行副作用。以这个简单的例子为例

@jax.jit
def f():
  jax.print("hello world")
  return
f() # Executed asynchronously

这个编译后的计算没有显式输入,也没有显式输出。如果它是一个有序的打印效应,我们可以阻塞返回的运行时令牌。然而,当这是一个无序计算时,我们不进行任何令牌传递。当没有输出值可以调用 block_until_ready 时,我们如何等待 f() 执行完成?嗯,我们可以应用相同的令牌策略,只是我们只返回运行时令牌而不将其作为输入。这将为我们提供一个可以阻塞的值,该值只会在 f() 执行完成后才准备就绪。我们将这些令牌称为*输出令牌*。最终我们得到一个看起来像这样的函数

@jax.jit
def f():
  jax.print("hello world")
  return new_runtime_token()
f() # Executed asynchronously

在底层,我们将以管理运行时令牌的相同方式管理输出令牌,但会提供一种方法供用户阻塞当前的一组输出令牌。与运行时令牌不同,输出令牌需要是*设备特定的*。考虑一个单设备用例

@jax.jit
def f():
  jax.print("hello")

@jax.jit
def g():
  jax.print("world")

f()
g()

由于 f()g() 在同一设备上执行,阻塞 g() 的输出令牌实际上也阻塞了 f(),因为(截至目前!)JAX 运行时不会交错在同一设备上执行的计算。当然,如果这种情况发生变化,我们将不得不修改整个设计。

然而,考虑双设备用例

@partial(jax.jit, device=<device 0>)
def f():
  jax.print("hello")

@partial(jax.jit, device=<device 1>)
def g():
  jax.print("world")

f()
g()

在这里,我们不想显式地对 f()g() 进行排序,但希望等待它们都完成。我们将需要一个用于 f() 的输出令牌和一个用于 g() 的输出令牌,并且我们将阻塞这两个令牌。

@partial(jax.jit, device=<device 0>)
def f():
  jax.print("hello")
  return new_runtime_token()

@partial(jax.jit, device=<device 1>)
def g():
  jax.print("world")
  return new_runtime_token()

t0 = f()
t1 = g()
block_until_ready((t0, t1))

因此,我们将需要一个每个设备的输出令牌,这样我们就可以避免对不同设备上的计算进行排序,同时提供阻塞副作用计算的能力。我们最终对 JAX 分派机制进行了以下(近似)更改

def _execute(compiled_computation, *args):
  output_token, *outputs = compiled_computation.execute(runtime_token, *args)
  update_output_token(output_token, compiled_computation.device)
  return outputs

我们还需要公开一个函数,用于阻塞输出令牌。

def effects_barrier():
  output_token.block_until_ready()

请注意,阻塞输出令牌可能不是非常常见,因为大多数 JAX 计算都会返回一个值来阻塞。然而,输出令牌对于测试和性能分析很有帮助,并且值得支持,以便我们拥有一个一致且有凝聚力的效应系统。

更多细节#

  • 所有上述令牌管理基础设施都将是*线程局部*的。这意味着每个用户线程都将拥有自己独立的运行时令牌流。排序仅在用户线程级别上得到保证。

  • 在实践中,我们每个效应都有一个运行时令牌。该效应的不同实例将被排序。这是为了避免对可能彼此没有关系的效应计算进行排序。从技术上讲,这确实违背了我们最初强制单线程 Python 程序顺序的目标,但这是一个权衡,可以通过同时拥有“效应”特定的令牌和“全局”令牌来调节。