提前编译#

JAX 的 jax.jit 转换返回一个函数,该函数在被调用时会编译计算并将其运行在加速器(或 CPU)上。正如 JIT 首字母缩略词所示,所有编译都发生在执行的即时(just-in-time)。

在某些情况下,需要提前(ahead-of-time, AOT)编译。当您希望在执行时间之前完全编译,或者希望控制编译过程不同部分的执行时间时,JAX 为您提供了一些选项。

首先,让我们回顾一下编译的阶段。假设 f 是由 jax.jit() 输出的函数/可调用对象,例如,对于某个输入的可调用对象 F,有 f = jax.jit(F)。当使用参数调用它时,例如 f(x, y),其中 xy 是数组,JAX 按顺序执行以下操作:

  1. 将原始 Python 可调用对象 F 的特化版本暂存到一个内部表示中。特化反映了 F 对根据参数 xy 的属性(通常是它们的形状和元素类型)推断出的输入类型的限制。JAX 通过一个称为追踪(tracing)的过程来实现此特化。在追踪期间,JAX 将 F 的特化暂存到 jaxpr,这是一个 Jaxpr 中间语言中的函数。

  2. 将此特化、暂存的计算降低(lower)到 XLA 编译器的输入语言 StableHLO。

  3. 编译(compile)降低的 HLO 程序,以生成目标设备(CPU、GPU 或 TPU)的优化可执行文件。

  4. 使用数组 xy 作为参数执行(execute)编译好的可执行文件。

JAX 的 AOT API 让您可以直接控制这些步骤中的每一个,以及沿途的其他一些功能。举个例子

>>> import jax

>>> def f(x, y): return 2 * x + y
>>> x, y = 3, 4

>>> traced = jax.jit(f).trace(x, y)

>>> # Print the specialized, staged-out representation (as Jaxpr IR)
>>> print(traced.jaxpr)
{ lambda ; a:i32[] b:i32[]. let
    c:i32[] = mul 2:i32[] a
    d:i32[] = add c b
  in (d,) }

>>> lowered = traced.lower()

>>> # Print lowered HLO
>>> print(lowered.as_text())
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32> {jax.result_info = "result"}) {
    %c = stablehlo.constant dense<2> : tensor<i32>
    %0 = stablehlo.multiply %c, %arg0 : tensor<i32>
    %1 = stablehlo.add %0, %arg1 : tensor<i32>
    return %1 : tensor<i32>
  }
}

>>> compiled = lowered.compile()

>>> # Query for cost analysis, print FLOP estimate
>>> compiled.cost_analysis()['flops']
2.0

>>> # Execute the compiled function!
>>> compiled(x, y)
Array(10, dtype=int32, weak_type=True)

请注意,降低的对象只能在其中被降低的同一进程中使用。有关导出用例,请参阅 导出和序列化 API。

有关降低(lowering)和编译(compiled)函数提供功能的更多详细信息,请参阅 jax.stages 文档。

jit 的所有可选参数 — 例如 static_argnums — 在相应的追踪、降低、编译和执行中都会得到尊重。

在上面的示例中,我们可以用任何具有 shapedtype 属性的对象替换传递给 trace 的参数。

>>> i32_scalar = jax.ShapeDtypeStruct((), jnp.dtype('int32'))
>>> jax.jit(f).trace(i32_scalar, i32_scalar).lower().compile()(x, y)
Array(10, dtype=int32)

更普遍地说,trace 只需要其参数在结构上提供 JAX 为特化和降低所必需的信息。对于典型的数组参数(如上述参数),这意味着 shapedtype 字段。相比之下,对于静态参数,JAX 需要实际的数组值(关于这一点,请参见 下方)。

使用与 AOT 编译函数不兼容的参数调用 AOT 编译的函数会引发错误。

>>> x_1d = y_1d = jnp.arange(3)
>>> jax.jit(f).trace(i32_scalar, i32_scalar).lower().compile()(x_1d, y_1d)  
...
Traceback (most recent call last):
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with int32[3]
Argument 'y' compiled with int32[] and called with int32[3]

>>> x_f = y_f = jnp.float32(72.)
>>> jax.jit(f).trace(i32_scalar, i32_scalar).lower().compile()(x_f, y_f)  
...
Traceback (most recent call last):
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with float32[]
Argument 'y' compiled with int32[] and called with float32[]

同样,AOT 编译的函数无法被 JAX 的即时转换(just-in-time transformations)转换,例如 jax.jitjax.grad()jax.vmap()

使用静态参数进行追踪#

使用静态参数进行追踪突显了传递给 jax.jit 的选项、传递给 trace 的参数以及调用结果编译函数所需的参数之间的交互。继续我们上面的例子:

>>> lowered_with_x = jax.jit(f, static_argnums=0).trace(7, 8).lower()

>>> # Lowered HLO, specialized to the *value* of the first argument (7)
>>> print(lowered_with_x.as_text())
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32>) -> (tensor<i32> {jax.result_info = "result"}) {
    %c = stablehlo.constant dense<14> : tensor<i32>
    %0 = stablehlo.add %c, %arg0 : tensor<i32>
    return %0 : tensor<i32>
  }
}

>>> lowered_with_x.compile()(5)
Array(19, dtype=int32, weak_type=True)

请注意,这里的 trace 像往常一样接受两个参数,但后续的编译函数只接受剩余的非静态第二个参数。静态第一个参数(值为 7)在降低时被视为一个常量,并被构建到降低的计算中,在那里它可能与其他常量合并。在这种情况下,它乘以 2 的运算被简化,结果是常量 14。

尽管上面 trace 的第二个参数可以用一个空的形状/dtype 结构替换,但静态第一个参数必须是一个具体值。否则,追踪会出错。

>>> jax.jit(f, static_argnums=0).trace(i32_scalar, i32_scalar)  
Traceback (most recent call last):
TypeError: unsupported operand type(s) for *: 'int' and 'ShapeDtypeStruct'

>>> jax.jit(f, static_argnums=0).trace(10, i32_scalar).lower().compile()(5)
Array(25, dtype=int32)

tracelower 的结果不安全直接序列化以在不同进程中使用。有关此目的的其他 API,请参阅 导出和序列化

AOT 编译的函数无法进行转换#

编译的函数被特化为特定的参数“类型”集,例如我们示例中的具有特定形状和元素类型的数组。从 JAX 的内部观点来看,jax.vmap() 等转换会以一种使编译的类型签名无效的方式改变函数的类型签名。作为一项策略,JAX 根本不允许编译的函数参与转换。例如:

>>> def g(x):
...   assert x.shape == (3, 2)
...   return x @ jnp.ones(2)

>>> def make_z(*shape):
...   return jnp.arange(np.prod(shape)).reshape(shape)

>>> z, zs = make_z(3, 2), make_z(4, 3, 2)

>>> g_jit = jax.jit(g)
>>> g_aot = jax.jit(g).trace(z).lower().compile()

>>> jax.vmap(g_jit)(zs)
Array([[ 1.,  5.,  9.],
       [13., 17., 21.],
       [25., 29., 33.],
       [37., 41., 45.]], dtype=float32)

>>> jax.vmap(g_aot)(zs)  
Traceback (most recent call last):
TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type <class 'jax._src.interpreters.batching.BatchTracer'>

g_aot 参与自动微分(例如 jax.grad())时,会引发类似的错误。为了保持一致性,jax.jit 的转换也被禁止,尽管 jit 并不会有意义地修改其参数的类型签名。

调试信息和分析(如果可用)#

除了主要的 AOT 功能(独立的显式降低、编译和执行)之外,JAX 的各种 AOT 阶段还提供了一些附加功能,以帮助进行调试和收集编译器反馈。

例如,如上面第一个示例所示,降低的函数通常提供文本表示。编译的函数也提供文本表示,并且还提供来自编译器的成本和内存分析。所有这些都通过 jax.stages.Loweredjax.stages.Compiled 对象上的方法提供(例如,上面示例中的 lowered.as_text()compiled.cost_analysis())。您可以通过使用 lowered.as_text()debug_info 参数来获取更多调试信息,例如源位置。

这些方法旨在帮助手动检查和调试,而不是作为可靠的可编程 API。它们的可用性和输出因编译器、平台和运行时而异。这带来了两个重要的注意事项:

  1. 如果 JAX 当前后端上的某些功能不可用,那么对应的方法将返回一些琐碎的内容(且为 False 类的返回值)。例如,如果 JAX 底层的编译器不提供成本分析,那么 compiled.cost_analysis() 将为 None

  2. 即使某些功能可用,对应方法提供的内容仍然有非常有限的保证。返回值不必在 JAX 配置、后端/平台、版本甚至方法调用之间保持一致(在类型、结构或值方面)。JAX 不能保证 compiled.cost_analysis() 的输出在某一天与第二天相同。

如有疑问,请参阅 jax.stages 的包 API 文档。