即时编译#

在本节中,我们将进一步探讨 JAX 的工作原理,以及如何使其高性能。我们将讨论 jax.jit() 转换,它将对 JAX Python 函数执行即时 (JIT) 编译,以便可以在 XLA 中高效执行。

JAX 转换如何工作#

在上一节中,我们讨论了 JAX 允许我们转换 Python 函数。JAX 通过将每个函数简化为 原语 操作序列来实现这一点,每个操作代表一个基本的计算单元。

查看函数背后原语序列的一种方法是使用 jax.make_jaxpr()

import jax
import jax.numpy as jnp

global_list = []

def log2(x):
  global_list.append(x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2)(3.0))
{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }

文档的 JAX 内部原理:jaxpr 语言 部分提供了关于上述输出含义的更多信息。

重要的是,请注意 jaxpr 没有捕获函数中存在的副作用:其中没有任何内容对应于 global_list.append(x)。这是一个特性,而不是一个错误:JAX 转换旨在理解无副作用(又名函数式纯粹)的代码。如果纯函数副作用是不熟悉的术语,这在 🔪 JAX - 尖锐之处 🔪:纯函数 中有更详细的解释。

不纯函数是危险的,因为在 JAX 转换下,它们很可能无法按预期运行;它们可能会静默失败,或产生令人惊讶的下游错误,例如泄漏的 Tracer。此外,JAX 通常无法检测到何时存在副作用。(如果您想要调试打印,请使用 jax.debug.print()。要以性能为代价表达一般副作用,请参阅 jax.experimental.io_callback()。要以性能为代价检查 tracer 泄漏,请使用 jax.check_tracer_leaks())。

在跟踪时,JAX 将每个参数包装成一个 tracer 对象。然后,这些 tracer 记录在函数调用期间对它们执行的所有 JAX 操作(这发生在常规 Python 中)。然后,JAX 使用 tracer 记录来重建整个函数。该重建的输出是 jaxpr。由于 tracer 不记录 Python 副作用,因此它们不会出现在 jaxpr 中。但是,副作用仍然发生在跟踪本身期间。

注意:Python print() 函数不是纯函数:文本输出是函数的副作用。因此,任何 print() 调用都只会在跟踪期间发生,而不会出现在 jaxpr 中

def log2_with_print(x):
  print("printed x:", x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2_with_print)(3.))
printed x: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>
{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }

看到打印的 x 是一个 Traced 对象了吗?那是 JAX 内部机制在工作。

Python 代码至少运行一次的事实严格来说是一个实现细节,因此不应依赖它。但是,理解它很有用,因为您可以在调试时使用它来打印计算的中间值。

要理解的关键是 jaxpr 捕获了在给定参数上执行的函数。例如,如果我们有一个 Python 条件,jaxpr 将只知道我们采取的分支

def log2_if_rank_2(x):
  if x.ndim == 2:
    ln_x = jnp.log(x)
    ln_2 = jnp.log(2.0)
    return ln_x / ln_2
  else:
    return x

print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1, 2, 3])))
{ lambda ; a:i32[3]. let  in (a,) }

JIT 编译函数#

如前所述,JAX 使操作能够在 CPU/GPU/TPU 上使用相同的代码执行。让我们看一个计算缩放指数线性单元 (SELU) 的示例,这是一种常用于深度学习的操作

import jax
import jax.numpy as jnp

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(1000000)
%timeit selu(x).block_until_ready()
1.81 ms ± 138 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

上面的代码一次向加速器发送一个操作。这限制了 XLA 编译器优化我们函数的能力。

自然地,我们想要做的是给 XLA 编译器尽可能多的代码,以便它可以完全优化它。为此,JAX 提供了 jax.jit() 转换,它将 JIT 编译一个 JAX 兼容的函数。下面的示例展示了如何使用 JIT 来加速之前的函数。

selu_jit = jax.jit(selu)

# Pre-compile the function before timing...
selu_jit(x).block_until_ready()

%timeit selu_jit(x).block_until_ready()
273 μs ± 1.64 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

以下是刚刚发生的事情

  1. 我们将 selu_jit 定义为 selu 的编译版本。

  2. 我们在 x 上调用了 selu_jit 一次。这是 JAX 进行跟踪的地方——毕竟它需要一些输入来包装在 tracer 中。然后使用 XLA 将 jaxpr 编译成针对您的 GPU 或 TPU 优化的非常高效的代码。最后,执行编译后的代码以满足调用。后续对 selu_jit 的调用将直接使用编译后的代码,完全跳过 Python 实现。(如果我们不单独包含预热调用,一切仍然会正常工作,但编译时间将包含在基准测试中。它仍然会更快,因为我们在基准测试中运行了很多循环,但这将不是一个公平的比较。)

  3. 我们计时了编译版本的执行速度。(请注意 block_until_ready() 的使用,这是由于 JAX 的 异步调度 所必需的)。

为什么我们不能 JIT 编译一切?#

在经历了上面的示例之后,您可能会想知道我们是否应该简单地将 jax.jit() 应用于每个函数。为了理解为什么不是这种情况,以及我们何时应该/不应该应用 jit,让我们首先检查一些 JIT 不起作用的情况。

# Condition on value of x.

def f(x):
  if x > 0:
    return x
  else:
    return 2 * x

jax.jit(f)(10)  # Raises an error
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /tmp/ipykernel_1157/2956679937.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError
# While loop conditioned on x and n.

def g(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i

jax.jit(g)(10, 20)  # Raises an error
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function g at /tmp/ipykernel_1157/722961019.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument n.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError

这两种情况下的问题是,我们试图使用运行时值来调节程序的跟踪时流程。JIT 中的跟踪值,如这里的 xn,只能通过其静态属性(例如 shapedtype)而不是通过其值来影响控制流。有关 Python 控制流和 JAX 之间交互的更多详细信息,请参阅 使用 JIT 的控制流和逻辑运算符

解决此问题的一种方法是重写代码以避免值条件。另一种方法是使用特殊的 控制流运算符,如 jax.lax.cond()。但是,有时这是不可能或不切实际的。在这种情况下,您可以考虑仅 JIT 编译函数的一部分。例如,如果函数中最计算密集的部分在循环内部,我们可以仅 JIT 编译内部部分(尽管请务必查看下一节关于缓存的内容,以避免搬起石头砸自己的脚)

# While loop conditioned on x and n with a jitted body.

@jax.jit
def loop_body(prev_i):
  return prev_i + 1

def g_inner_jitted(x, n):
  i = 0
  while i < n:
    i = loop_body(i)
  return x + i

g_inner_jitted(10, 20)
Array(30, dtype=int32, weak_type=True)

将参数标记为静态#

如果我们真的需要 JIT 编译一个对输入值有条件的函数,我们可以通过指定 static_argnumsstatic_argnames 来告诉 JAX 为特定输入使用不太抽象的 tracer。这样做的代价是,生成的 jaxpr 和编译后的工件取决于传递的特定值,因此 JAX 将不得不为指定的静态输入的每个新值重新编译该函数。只有当函数保证看到有限的静态值集时,这才是好的策略。

f_jit_correct = jax.jit(f, static_argnums=0)
print(f_jit_correct(10))
10
g_jit_correct = jax.jit(g, static_argnames=['n'])
print(g_jit_correct(10, 20))
30

在使用 jit 作为装饰器时指定此类参数,常见的模式是使用 python 的 functools.partial()

from functools import partial

@partial(jax.jit, static_argnames=['n'])
def g_jit_decorated(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i

print(g_jit_decorated(10, 20))
30

JIT 和缓存#

考虑到第一次 JIT 调用的编译开销,理解 jax.jit() 如何以及何时缓存之前的编译是有效使用它的关键。

假设我们定义 f = jax.jit(g)。当我们第一次调用 f 时,它将被编译,并且生成的 XLA 代码将被缓存。后续对 f 的调用将重用缓存的代码。这就是 jax.jit 如何弥补编译的前期成本。

如果我们指定 static_argnums,则缓存的代码将仅用于标记为静态的参数的相同值。如果其中任何一个发生变化,则会发生重新编译。如果有许多值,那么您的程序可能会花费更多时间编译,而不是按顺序执行操作。

避免在循环或其他 Python 作用域内定义的临时函数上调用 jax.jit()。在大多数情况下,JAX 将能够在后续调用 jax.jit() 时使用编译后的缓存函数。但是,由于缓存依赖于函数的哈希值,因此当等效函数被重新定义时,它会变得有问题。这将导致循环中每次都不必要的编译

from functools import partial

def unjitted_loop_body(prev_i):
  return prev_i + 1

def g_inner_jitted_partial(x, n):
  i = 0
  while i < n:
    # Don't do this! each time the partial returns
    # a function with different hash
    i = jax.jit(partial(unjitted_loop_body))(i)
  return x + i

def g_inner_jitted_lambda(x, n):
  i = 0
  while i < n:
    # Don't do this!, lambda will also return
    # a function with a different hash
    i = jax.jit(lambda x: unjitted_loop_body(x))(i)
  return x + i

def g_inner_jitted_normal(x, n):
  i = 0
  while i < n:
    # this is OK, since JAX can find the
    # cached, compiled function
    i = jax.jit(unjitted_loop_body)(i)
  return x + i

print("jit called in a loop with partials:")
%timeit g_inner_jitted_partial(10, 20).block_until_ready()

print("jit called in a loop with lambdas:")
%timeit g_inner_jitted_lambda(10, 20).block_until_ready()

print("jit called in a loop with caching:")
%timeit g_inner_jitted_normal(10, 20).block_until_ready()
jit called in a loop with partials:
173 ms ± 248 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
jit called in a loop with lambdas:
173 ms ± 108 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
jit called in a loop with caching:
1.29 ms ± 4.73 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)