JAX 中的序列化副作用#
sharadmv@ 2022 年 5 月 9 日
概述#
当我们编写 JAX 代码时,通常可以假装我们正在编写单线程、即时执行的 Python 代码,尽管在底层,JAX 及其运行时可能会在后台异步执行它。只要我们编写的是纯粹的(无副作用)代码,这些性能优化通常对我们来说是不可见的,也不会干扰我们的单线程心智模型。异步执行很棒——我们无需费心就能获得高性能的并行代码!
然而,在存在副作用的情况下,这种“幻觉”开始破裂,我们的心智模型中的裂痕也开始显现。具体来说,这些差异在我们考虑副作用发生的*顺序*时会显现出来。
在本设计说明中,我们将探讨 JAX 的执行模型与副作用的排序之间的交互。我们还将提供一种强制执行“单线程”效果排序的方法。
背景#
当我们编写以下 Python 代码时
def f():
print("hello")
return 2
def g():
print("world")
return 3
f()
g()
我们期望 "hello" 在 "world" 之前打印。这似乎很明显,但请考虑以下 JAX 代码
@partial(jax.jit, device=<device 0>)
def f():
return 2
@partial(jax.jit, device=<device 1>)
def g():
return 3
f()
g()
在许多情况下,JAX 会*并行*执行 f 和 g,将计算分派到不同的线程——g 实际上可能在 f 之前执行。并行执行是一种很好的性能优化,尤其是当设备之间的复制成本很高时(有关更多详细信息,请参阅 异步分派说明)。然而,在实践中,我们通常不需要考虑异步分派,因为我们编写的是纯函数,只关心函数的输入和输出——我们会自然地阻塞在未来的值上。
但是,现在想象一下我们有一个 jax.print 函数,它在 JIT 编译的 JAX 函数中工作(host_callback.id_print 就是一个例子)。让我们回到之前的例子,只是在其中混合了打印语句。
@partial(jax.jit, device=<device 0>)
def f():
jax.print("hello")
return 2
@partial(jax.jit, device=<device 1>)
def g():
jax.print("world")
return 3
f()
g()
由于异步分派,我们实际上可能会看到 "world" 在 "hello" 之前打印。打印副作用的重新排序破坏了单线程执行模型的“幻觉”。
当编译 JAX 程序时,副作用会“暴露”乱序执行的另一个例子。考虑以下 JAX 代码
@jax.jit
def f(x):
jax.print("hello")
jax.print("world")
return x
尽管我们在 Python 中先写了 "hello" 打印,后写 "world" 打印,但像 XLA 这样的编译器可以自由地重新排序它们,因为打印之间没有显式的数据依赖关系。
动机#
我们希望支持“有序”副作用。当我们说有序时,我们的意思是副作用的发生顺序与我们在执行单线程 Python 程序时的顺序相同。这是我们的主要需求。在显式并行(如 pmap 或用户线程)存在的情况下,我们不需要维持这种行为,但至少如果用户没有明确请求并行,我们希望保留单线程排序。
在我们深入研究之前,让我们先退一步,问问自己,是否可以为了性能而重新排序副作用,反之,我们是否需要对副作用强制执行排序?在某些情况下,我们不需要排序。也许一些副作用不应该对 JAX 程序的性能产生不利影响。然而,对于其他副作用,我们可能希望强制执行单线程程序顺序,这样用户就不会遇到反直觉的行为。考虑一个日志记录副作用。
@jax.jit
def f(x, y):
log_value(x)
log_value(y)
f(1, 2)
如果 log 在修改一个全局列表,我们可能期望 x 在 y 之前被添加。对于更严格的副作用,我们可能希望可以选择对副作用进行排序。
强制执行有序副作用#
我们用来强制执行计算顺序的主要工具是*数据依赖*。简而言之,如果函数 g 的一个输入是函数 f 的输出,那么 f 必须在 g 之前执行。
然而,我们可能有像打印这样的副作用,它们没有任何输入,所以我们不能简单地对它们进行排序。因此,我们使用*Token*作为注入计算的*人工数据依赖*的手段。
什么是 Token?Token 只是一个可以传入和传出计算的虚拟值。通过将同一个 Token 传入和传出多个计算,我们可以强制它们按特定顺序发生。让我们以前面的打印示例为例,看看在混合 Token 时会是什么样子
@jax.jit
def f(token, x):
token = jax.print(token, "hello")
token = jax.print(token, "world")
return token, x
如果我们重写 jax.print 以接受和返回一个 Token,我们就已经对这两个打印语句进行了排序,因为第二个打印语句的输入依赖于第一个打印语句的输出。Token 的实际值可以是任何东西,但在实践中,Token 对用户来说是不可见的。
运行时 Token 与编译器 Token#
这里我们将开始讨论实现细节。在实践中,我们需要两种不同的 Token 来序列化副作用:一种用于上述的两种重新排序来源。我们需要*运行时 Token*来序列化异步分派的副作用计算,并且我们需要*编译器 Token*来序列化计算内的副作用。
在实践中,我们的计算将被重写成这样
@jax.jit
def f(runtime_token, x):
compiler_token = new_compiler_token()
compiler_token = jax.print(compiler_token, "hello")
compiler_token = jax.print(compiler_token, "world")
return runtime_token, x
请注意,运行时 Token 只在 JIT 边界使用,而编译器 Token 只在编译后的代码内部使用。编译器 Token 在“降低”(我们将 Python 代码转换为较低级别的表示,如 HLO 或 StableHLO)过程中创建,但运行时 Token 需要在 Python 中进行管理,因为它们被传入和传出 JIT 编译的函数。
此外,请注意,运行时 Token 与编译器 Token 是“断开连接”的,这意味着它们之间没有数据依赖。这可能很危险,因为如果我们丢失了两个分派函数调用体之间的数据依赖。然而,如果我们假设“严格执行”——即一个分派函数将在其所有输入都准备好后才开始执行,并且其所有输出将在同一时间准备好——那么我们可以创建一个新的编译器 Token 并返回一个不依赖于输出的运行时 Token。
管理运行时 Token#
为了代表用户管理运行时 Token,我们需要挂钩 JAX 的分派机制。每当我们调用一个 JIT 编译的函数时,最终都会在一个如下的函数中结束
def _execute(compiled_computation, *args):
outputs = compiled_computation.execute(*args)
return outputs
此时,我们需要将运行时 Token“注入”到计算中,并从计算的输出中“提取”它们
def _execute(compiled_computation, *args):
runtime_token = get_runtime_token() # Grab global token
runtime_token, *outputs = compiled_computation.execute(runtime_token, *args)
update_runtime_token(runtime_token) # Update global token
return outputs
那么 runtime_token 究竟是什么?我们需要能够将其传递给一个 compiled_computation,这意味着它需要某种形式的数组(目前如此,因为在已编译的 JAX 代码内外没有共享的 Token 表示)。在实践中,我们可以使用一个形状为 (0,) 的数组来最大限度地减少开销。
我们还需要考虑多设备用例,例如第一个例子,我们首先在设备 0 上调用一个 JIT 编译的函数,然后在设备 1 上调用另一个。在这种情况下,我们还需要将第一个计算返回的运行时 Token(位于设备 0 上)*复制*到设备 1,以便我们可以将其传递给第二个计算。如果两个连续的计算共享同一个设备,则不需要此复制。
添加编译器 Token#
当我们把 Python 代码降低到 HLO 或 StableHLO 时,我们需要在计算开始时创建一个 Token,并确保在需要排序的副作用计算时它们可用。副作用计算将把 Token 作为输入,并将其作为输出返回。
这种 Token 传递的实现涉及升级 JAX 降低机制以自动执行此簿记。主要的挑战在于处理高阶原语,如调用原语和控制流原语。我们不会在本设计说明中详细介绍如何处理这些。
阻塞输出 Token#
为副作用计算添加对运行时和编译器 Token 的支持对于序列化很重要,但还有一个微妙的用例是阻塞副作用计算。即使我们不希望副作用计算被*排序*,我们也可能仍想等待其完成。目前我们有 jax.block_until_ready,它会等待直到未来值准备好其结果。然而,对于副作用计算,我们可能有不返回值的函数,但它们仍在执行副作用。看这里简单的例子
@jax.jit
def f():
jax.print("hello world")
return
f() # Executed asynchronously
这个编译后的计算不接受显式输入,也没有显式输出。如果它是一个有序的打印效果,我们可以阻塞在返回的运行时 Token 上。然而,当这是一个无序计算时,我们不进行任何 Token 传递。当我们没有任何输出值可以调用 block_until_ready 时,如何等待 f() 完成执行?嗯,我们可以应用相同的 Token 策略,只是我们只返回运行时 Token 并且不将它们作为输入。这将为我们提供一个可以阻塞的值,该值将在 f() 执行完毕后才准备好。我们将这些 Token 称为*输出 Token*。最终我们会得到一个如下的函数
@jax.jit
def f():
jax.print("hello world")
return new_runtime_token()
f() # Executed asynchronously
在底层,我们将以与管理运行时 Token 相同的方式管理输出 Token,但提供一个用户阻塞在当前输出 Token 集上的方法。与运行时 Token 不同,输出 Token 需要是*设备特定的*。考虑单设备用例
@jax.jit
def f():
jax.print("hello")
@jax.jit
def g():
jax.print("world")
f()
g()
由于 f() 和 g() 在同一设备上执行,阻塞在 g() 的输出 Token 上实际上会阻塞 f(),因为(目前!)JAX 运行时不会交错在同一设备上执行的计算。如果这种情况发生变化,我们当然需要修改整个设计。
然而,考虑双设备用例
@partial(jax.jit, device=<device 0>)
def f():
jax.print("hello")
@partial(jax.jit, device=<device 1>)
def g():
jax.print("world")
f()
g()
这里我们不想显式排序 f() 和 g(),但想等待它们都完成。我们将需要一个用于 f() 的输出 Token 和一个用于 g() 的输出 Token,然后我们将阻塞在这两个 Token 上
@partial(jax.jit, device=<device 0>)
def f():
jax.print("hello")
return new_runtime_token()
@partial(jax.jit, device=<device 1>)
def g():
jax.print("world")
return new_runtime_token()
t0 = f()
t1 = g()
block_until_ready((t0, t1))
因此,我们需要一个每设备输出 Token,以便我们可以避免对不同设备上的计算进行排序,同时提供阻塞副作用计算的能力。最终我们会对 JAX 分派机制进行以下(近似)更改
def _execute(compiled_computation, *args):
output_token, *outputs = compiled_computation.execute(runtime_token, *args)
update_output_token(output_token, compiled_computation.device)
return outputs
我们还需要公开一个函数来阻塞输出 Token
def effects_barrier():
output_token.block_until_ready()
请注意,阻塞输出 Token 可能不那么常见,因为大多数 JAX 计算都会返回一个值来阻塞。然而,输出 Token 对于测试和剖析很有用,并且很好地支持,以便我们拥有一个一致且凝聚的副作用系统。
更多细节#
所有上述 Token 管理基础设施都将是*线程本地*的。这意味着每个用户线程将拥有自己独立的运行时 Token 流。排序仅在用户线程级别得到保证。
在实践中,我们为每个副作用有一个运行时 Token。该副作用的不同实例将被排序。这是为了避免对可能彼此无关的副作用计算进行排序。严格来说,这违背了我们最初强制执行单线程 Python 程序排序的目标,但这是一个可以调整的权衡,可以通过同时拥有“副作用特定” Token 和“全局” Token 来实现。