即时编译#

在本节中,我们将进一步探讨 JAX 的工作原理以及如何使其实现高性能。我们将讨论 jax.jit() 变换,该变换将对 JAX Python 函数执行即时(JIT)编译,使其能够在 XLA 中高效执行。

JAX 变换的工作原理#

在上一节中,我们讨论了 JAX 允许我们变换 Python 函数。JAX 通过将每个函数简化为一系列原始操作来完成此任务,每个操作代表一个基本计算单元。

查看函数背后原始操作序列的一种方法是使用jax.make_jaxpr()

import jax
import jax.numpy as jnp

global_list = []

def log2(x):
  global_list.append(x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2)(3.0))
{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0:f32[]
    d:f32[] = div b c
  in (d,) }

文档的JAX 内部机制:jaxpr 语言部分提供了有关上述输出含义的更多信息。

重要的是,请注意 jaxpr 不会捕获函数中存在的副作用:其中没有任何对应于 global_list.append(x) 的内容。这是一个特性,而不是一个错误:JAX 变换旨在理解无副作用(又称功能纯)的代码。如果纯函数副作用是不熟悉的术语,这在🔪 JAX - 棘手之处 🔪:纯函数中有更详细的解释。

不纯函数是危险的,因为在 JAX 变换下它们可能不会按预期运行;它们可能会默默失败,或者产生令人惊讶的下游错误,例如 Tracer 泄露。此外,JAX 通常无法检测何时存在副作用。(如果您需要调试打印,请使用jax.debug.print()。要以牺牲性能为代价表达一般副作用,请参阅jax.experimental.io_callback()。要以牺牲性能为代价检查 Tracer 泄露,请与jax.check_tracer_leaks()一起使用)。

在追踪时,JAX 会将每个参数包装成一个tracer对象。这些 tracer 随后会记录在函数调用期间(在常规 Python 中发生)对它们执行的所有 JAX 操作。然后,JAX 使用 tracer 记录来重建整个函数。该重建的输出就是 jaxpr。由于 tracer 不记录 Python 副作用,因此它们不会出现在 jaxpr 中。但是,副作用仍然在追踪本身期间发生。

注意:Python print() 函数不是纯函数:文本输出是函数的副作用。因此,任何 print() 调用将只在追踪期间发生,而不会出现在 jaxpr 中。

def log2_with_print(x):
  print("printed x:", x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2_with_print)(3.))
printed x: JitTracer<~float32[]>
{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0:f32[]
    d:f32[] = div b c
  in (d,) }

看到打印出的 x 是一个 Traced 对象了吗?这就是 JAX 内部机制的作用。

Python 代码至少运行一次这一事实严格来说是一个实现细节,因此不应依赖它。但是,了解它很有用,因为在调试时您可以使用它来打印计算的中间值。

要理解的一个关键点是 jaxpr 捕获了根据给定参数执行的函数。例如,如果我们有一个 Python 条件语句,jaxpr 将只知道我们所采用的分支

def log2_if_rank_2(x):
  if x.ndim == 2:
    ln_x = jnp.log(x)
    ln_2 = jnp.log(2.0)
    return ln_x / ln_2
  else:
    return x

print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1, 2, 3])))
{ lambda ; a:i32[3]. let  in (a,) }

JIT 编译函数#

如前所述,JAX 允许使用相同的代码在 CPU/GPU/TPU 上执行操作。让我们看一个计算缩放指数线性单元SELU)的例子,这是深度学习中常用的操作

import jax
import jax.numpy as jnp

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(1000000)
%timeit selu(x).block_until_ready()
4.87 ms ± 83.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

上面的代码一次将一个操作发送到加速器。这限制了 XLA 编译器优化我们函数的能力。

自然,我们想要做的是尽可能多地提供代码给 XLA 编译器,以便它可以充分优化它。为此,JAX 提供了 jax.jit() 变换,它将对 JAX 兼容函数进行 JIT 编译。下面的示例展示了如何使用 JIT 来加速之前的函数。

selu_jit = jax.jit(selu)

# Pre-compile the function before timing...
selu_jit(x).block_until_ready()

%timeit selu_jit(x).block_until_ready()
632 μs ± 2.83 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

这是刚刚发生的事情:

  1. 我们将 selu_jit 定义为 selu 的编译版本。

  2. 我们对 x 调用了一次 selu_jit。这是 JAX 进行追踪的地方——毕竟它需要一些输入来包装在 tracer 中。然后使用 XLA 将 jaxpr 编译成针对您的 GPU 或 TPU 优化的非常高效的代码。最后,执行编译后的代码以满足调用。对 selu_jit 的后续调用将直接使用编译后的代码,完全跳过 Python 实现。(如果我们不单独包含预热调用,一切仍然会正常工作,但编译时间将包含在基准测试中。它仍然会更快,因为我们在基准测试中运行了许多循环,但这不会是一个公平的比较)。

  3. 我们测量了编译版本的执行速度。(注意使用了block_until_ready(),这是由于 JAX 的异步分派所必需的)。

为什么不能将所有代码都 JIT 编译?#

看完上面的例子,你可能会想我们是否应该简单地将jax.jit()应用于每个函数。要理解为什么不是这样,以及何时应该/不应该应用jit,我们首先来看一些 JIT 不起作用的情况。

# Condition on value of x.

def f(x):
  if x > 0:
    return x
  else:
    return 2 * x

jax.jit(f)(10)  # Raises an error
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /tmp/ipykernel_1862/2956679937.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError
# While loop conditioned on x and n.

def g(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i

jax.jit(g)(10, 20)  # Raises an error
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function g at /tmp/ipykernel_1862/722961019.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument n.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError

这两种情况的问题在于,我们试图使用运行时值来控制程序追踪时的流程。JIT 中的追踪值,例如这里的 xn,只能通过它们的静态属性(例如 shapedtype)影响控制流,而不能通过它们的值。有关 Python 控制流和 JAX 之间交互的更多详细信息,请参阅JIT 的控制流和逻辑运算符

解决此问题的一种方法是重写代码以避免基于值的条件。另一种方法是使用特殊的控制流操作符,如jax.lax.cond()。然而,有时这不可能或不切实际。在这种情况下,您可以考虑仅 JIT 编译函数的一部分。例如,如果函数中最耗计算的部分在循环内部,我们可以只 JIT 编译该内部部分(但请务必查看下一节关于缓存的内容,以避免适得其反)

# While loop conditioned on x and n with a jitted body.

@jax.jit
def loop_body(prev_i):
  return prev_i + 1

def g_inner_jitted(x, n):
  i = 0
  while i < n:
    i = loop_body(i)
  return x + i

g_inner_jitted(10, 20)
Array(30, dtype=int32, weak_type=True)

将参数标记为静态#

如果确实需要 JIT 编译一个基于输入值有条件的函数,我们可以通过指定 static_argnumsstatic_argnames 来告诉 JAX 为特定输入使用一个抽象程度较低的 tracer。这样做的代价是,生成的 jaxpr 和编译后的构件将依赖于传入的特定值,因此 JAX 将不得不为每个指定静态输入的新值重新编译函数。只有当函数保证只看到有限的静态值集合时,这才是好的策略。

f_jit_correct = jax.jit(f, static_argnums=0)
print(f_jit_correct(10))
10
g_jit_correct = jax.jit(g, static_argnames=['n'])
print(g_jit_correct(10, 20))
30

在使用 jit 作为装饰器时指定此类参数,一种常见模式是使用 Python 的functools.partial()

from functools import partial

@partial(jax.jit, static_argnames=['n'])
def g_jit_decorated(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i

print(g_jit_decorated(10, 20))
30

JIT 和缓存#

考虑到首次 JIT 调用时的编译开销,理解 jax.jit() 如何以及何时缓存之前的编译结果是有效使用它的关键。

假设我们定义了 f = jax.jit(g)。当我们第一次调用 f 时,它将被编译,并且生成的 XLA 代码将被缓存。后续对 f 的调用将重用缓存的代码。这就是 jax.jit 弥补前期编译成本的方式。

如果我们指定了 static_argnums,则缓存的代码将仅用于标记为静态的参数的相同值。如果其中任何一个发生变化,则会重新编译。如果值很多,那么您的程序可能会花费更多时间在编译上,而不是逐个执行操作。

避免在循环或其他 Python 作用域内定义的临时函数上调用 jax.jit()。在大多数情况下,JAX 将能够在后续调用 jax.jit() 时使用已编译、已缓存的函数。但是,由于缓存依赖于函数的哈希值,当等效函数被重新定义时,它就会出现问题。这将导致循环中每次都进行不必要的编译。

from functools import partial

def unjitted_loop_body(prev_i):
  return prev_i + 1

def g_inner_jitted_partial(x, n):
  i = 0
  while i < n:
    # Don't do this! each time the partial returns
    # a function with different hash
    i = jax.jit(partial(unjitted_loop_body))(i)
  return x + i

def g_inner_jitted_lambda(x, n):
  i = 0
  while i < n:
    # Don't do this!, lambda will also return
    # a function with a different hash
    i = jax.jit(lambda x: unjitted_loop_body(x))(i)
  return x + i

def g_inner_jitted_normal(x, n):
  i = 0
  while i < n:
    # this is OK, since JAX can find the
    # cached, compiled function
    i = jax.jit(unjitted_loop_body)(i)
  return x + i

print("jit called in a loop with partials:")
%timeit g_inner_jitted_partial(10, 20).block_until_ready()

print("jit called in a loop with lambdas:")
%timeit g_inner_jitted_lambda(10, 20).block_until_ready()

print("jit called in a loop with caching:")
%timeit g_inner_jitted_normal(10, 20).block_until_ready()
jit called in a loop with partials:
239 ms ± 6.22 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with lambdas:
245 ms ± 11.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with caching:
2.51 ms ± 8.23 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)