即时(Just-in-time)编译#
在本节中,我们将进一步探索 JAX 的工作原理以及如何使其具备高性能。我们将讨论 jax.jit() 变换,它将对 JAX Python 函数执行即时(JIT)编译,以便能够在 XLA 中高效执行。
JAX 变换的工作原理#
在上一节中,我们讨论了 JAX 允许我们变换 Python 函数。JAX 通过将每个函数简化为一系列 primitive(原语)操作来实现这一点,每个原语代表一个基本的计算单元。
查看函数背后原语序列的一种方法是使用 jax.make_jaxpr()
import jax
import jax.numpy as jnp
global_list = []
def log2(x):
global_list.append(x)
ln_x = jnp.log(x)
ln_2 = jnp.log(2.0)
return ln_x / ln_2
print(jax.make_jaxpr(log2)(3.0))
{ lambda ; a:f32[]. let
b:f32[] = log a
c:f32[] = log 2.0:f32[]
d:f32[] = div b c
in (d,) }
文档的 JAX 内部机制:jaxpr 语言 一节提供了有关上述输出含义的更多信息。
重要的是,请注意 jaxpr 并未捕获函数中存在的副作用:其中没有任何内容对应 global_list.append(x)。这是一个特性,而非缺陷:JAX 变换旨在理解无副作用(即函数式纯)的代码。如果纯函数和副作用这些术语还不熟悉,可以在 🔪 JAX 的陷阱 🔪:纯函数 中了解更多详细信息。
非纯函数很危险,因为在 JAX 变换下,它们可能不会按预期运行;它们可能会静默失败,或产生令人惊讶的后续错误,例如泄漏的 Tracer(追踪器)。此外,JAX 通常无法检测到副作用何时出现。(如果您需要调试打印,请使用 jax.debug.print()。如需以性能为代价来表达通用副作用,请参阅 jax.experimental.io_callback()。如需以性能为代价检查追踪器泄漏,请使用 jax.check_tracer_leaks())。
在追踪时,JAX 会用 tracer(追踪器)对象包装每个参数。这些追踪器记录在函数调用期间(发生在常规 Python 中)对它们执行的所有 JAX 操作。然后,JAX 使用追踪器记录来重建整个函数。该重建的输出就是 jaxpr。由于追踪器不会记录 Python 副作用,因此它们不会出现在 jaxpr 中。但是,副作用在追踪期间仍然会发生。
注意:Python 的 print() 函数不是纯函数:文本输出是该函数的副作用。因此,任何 print() 调用仅在追踪期间发生,不会出现在 jaxpr 中。
def log2_with_print(x):
print("printed x:", x)
ln_x = jnp.log(x)
ln_2 = jnp.log(2.0)
return ln_x / ln_2
print(jax.make_jaxpr(log2_with_print)(3.))
printed x: JitTracer(~float32[])
{ lambda ; a:f32[]. let
b:f32[] = log a
c:f32[] = log 2.0:f32[]
d:f32[] = div b c
in (d,) }
看到打印出的 x 是一个 Traced 对象了吗?这就是 JAX 内部机制在起作用。
Python 代码至少运行一次的事实完全是一个实现细节,因此不应依赖它。不过,理解这一点很有用,因为您可以在调试时利用它来打印计算的中间值。
理解的关键点在于,jaxpr 捕获的是函数在给定参数下执行的情况。例如,如果我们有一个 Python 条件判断,jaxpr 只会了解我们所采取的分支。
def log2_if_rank_2(x):
if x.ndim == 2:
ln_x = jnp.log(x)
ln_2 = jnp.log(2.0)
return ln_x / ln_2
else:
return x
print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1, 2, 3])))
{ lambda ; a:i32[3]. let in (a,) }
函数 JIT 编译#
如前所述,JAX 使得操作能够使用相同的代码在 CPU/GPU/TPU 上执行。让我们看一个计算缩放指数线性单元(SELU)的示例,这是深度学习中常用的一种操作。
import jax
import jax.numpy as jnp
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = jnp.arange(1000000)
%timeit selu(x).block_until_ready()
3.8 ms ± 151 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
上面的代码是一次向加速器发送一个操作。这限制了 XLA 编译器优化我们函数的能力。
当然,我们想要做的是尽可能多地向 XLA 编译器提供代码,以便它可以对其进行全面优化。为此,JAX 提供了 jax.jit() 变换,它将对兼容 JAX 的函数进行 JIT 编译。下面的示例展示了如何使用 JIT 来加速之前的函数。
selu_jit = jax.jit(selu)
# Pre-compile the function before timing...
selu_jit(x).block_until_ready()
%timeit selu_jit(x).block_until_ready()
265 μs ± 874 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
刚刚发生了什么:
我们将
selu_jit定义为selu的编译版本。我们在
x上调用了一次selu_jit。这就是 JAX 进行追踪的地方——毕竟它需要一些输入来包装在追踪器中。随后,jaxpr 使用 XLA 编译成针对您的 GPU 或 TPU 优化的非常高效的代码。最后,执行编译后的代码以满足调用需求。后续对selu_jit的调用将直接使用编译后的代码,完全跳过 Python 实现。(如果我们没有单独包含预热调用,一切仍然有效,但编译时间会被计入基准测试中。它仍然会更快,因为我们在基准测试中运行许多循环,但这不会是一个公平的比较。)我们测定了编译版本的执行速度。(注意使用了
block_until_ready(),这是由于 JAX 的 异步调度 所必需的)。
为什么我们不能对一切进行 JIT 编译?#
看完上面的示例,您可能会想,我们是否应该简单地将 jax.jit() 应用于每个函数。为了理解事实并非如此,以及我们应该/不应该使用 jit 的情况,让我们先看看 JIT 不起作用的一些情况。
# Condition on value of x.
def f(x):
if x > 0:
return x
else:
return 2 * x
jax.jit(f)(10) # Raises an error
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /tmp/ipykernel_2133/2956679937.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError
# While loop conditioned on x and n.
def g(x, n):
i = 0
while i < n:
i += 1
return x + i
jax.jit(g)(10, 20) # Raises an error
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function g at /tmp/ipykernel_2133/722961019.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument n.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError
这两个问题的症结在于,我们试图使用运行时值来调节程序追踪时的流程。JIT 中的追踪值(如这里的 x 和 n)只能通过其静态属性(如 shape 或 dtype)来影响控制流,而不能通过其值来影响。有关 Python 控制流与 JAX 之间交互的更多详细信息,请参阅 JIT 下的控制流与逻辑运算符。
处理此问题的一种方法是重写代码以避免基于值的条件判断。另一种方法是使用特殊的 控制流运算符,例如 jax.lax.cond()。但是,有时这并不可能或不切实际。在这种情况下,您可以考虑仅对函数的一部分进行 JIT 编译。例如,如果函数中计算最昂贵的部分在循环内部,我们可以只 JIT 编译内部部分(不过请务必查看下一节关于缓存的内容,以避免弄巧成拙)。
# While loop conditioned on x and n with a jitted body.
@jax.jit
def loop_body(prev_i):
return prev_i + 1
def g_inner_jitted(x, n):
i = 0
while i < n:
i = loop_body(i)
return x + i
g_inner_jitted(10, 20)
Array(30, dtype=int32, weak_type=True)
将参数标记为静态#
如果我们确实需要对一个基于输入值进行条件判断的函数进行 JIT 编译,我们可以通过指定 static_argnums 或 static_argnames 来告诉 JAX 为特定输入使用非抽象的追踪器。这样做代价是生成的 jaxpr 和编译产物取决于所传递的特定值,因此 JAX 必须为每个新的静态输入值重新编译函数。只有当函数确保只处理有限的一组静态值时,这才是好策略。
f_jit_correct = jax.jit(f, static_argnums=0)
print(f_jit_correct(10))
10
g_jit_correct = jax.jit(g, static_argnames=['n'])
print(g_jit_correct(10, 20))
30
在使用 jit 作为装饰器时要指定此类参数,一种常见的模式是使用 Python 的 functools.partial()。
from functools import partial
@partial(jax.jit, static_argnames=['n'])
def g_jit_decorated(x, n):
i = 0
while i < n:
i += 1
return x + i
print(g_jit_decorated(10, 20))
30
JIT 与缓存#
鉴于首次 JIT 调用存在编译开销,理解 jax.jit() 如何以及何时缓存以前的编译结果,是有效使用它的关键。
假设我们定义了 f = jax.jit(g)。当我们第一次调用 f 时,它会被编译,产生的 XLA 代码会被缓存起来。后续对 f 的调用将复用缓存的代码。这就是 jax.jit 弥补前期编译成本的方式。
如果我们指定了 static_argnums,那么只有当标记为静态的参数值相同时,才会使用缓存的代码。如果其中任何一个发生变化,就会发生重新编译。如果数值很多,那么您的程序在编译上花费的时间可能会超过逐个执行操作所需的时间。
避免在循环或其他 Python 作用域内定义的临时函数上调用 jax.jit()。在大多数情况下,JAX 将能够利用在后续调用 jax.jit() 时编译并缓存的函数。然而,由于缓存依赖于函数的哈希值,当等效函数被重新定义时就会出现问题。这将导致每次在循环中都进行不必要的编译。
from functools import partial
def unjitted_loop_body(prev_i):
return prev_i + 1
def g_inner_jitted_partial(x, n):
i = 0
while i < n:
# Don't do this! each time the partial returns
# a function with different hash
i = jax.jit(partial(unjitted_loop_body))(i)
return x + i
def g_inner_jitted_lambda(x, n):
i = 0
while i < n:
# Don't do this!, lambda will also return
# a function with a different hash
i = jax.jit(lambda x: unjitted_loop_body(x))(i)
return x + i
def g_inner_jitted_normal(x, n):
i = 0
while i < n:
# this is OK, since JAX can find the
# cached, compiled function
i = jax.jit(unjitted_loop_body)(i)
return x + i
print("jit called in a loop with partials:")
%timeit g_inner_jitted_partial(10, 20).block_until_ready()
print("jit called in a loop with lambdas:")
%timeit g_inner_jitted_lambda(10, 20).block_until_ready()
print("jit called in a loop with caching:")
%timeit g_inner_jitted_normal(10, 20).block_until_ready()
jit called in a loop with partials:
373 ms ± 5.99 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with lambdas:
377 ms ± 6.56 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with caching:
1.56 ms ± 6.1 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)