提前编译降级与编译#

JAX 的 jax.jit 转换会返回一个函数,该函数在调用时会编译一个计算并在加速器(或 CPU)上运行它。正如 JIT 首字母缩写所示,所有编译都在执行时即时发生。

有些情况需要提前(AOT)编译。当您希望在执行时间之前完全编译,或者您希望控制编译过程的不同部分何时发生时,JAX 会提供一些选项。

首先,让我们回顾一下编译的阶段。假设 f 是由 jax.jit() 输出的函数/可调用对象,例如对于某个输入可调用对象 F,有 f = jax.jit(F)。当它与参数一起调用时,例如 f(x, y),其中 xy 是数组,JAX 按以下顺序执行:

  1. 将原始 Python 可调用对象 F 的专用版本阶段化输出为内部表示。此专用化反映了 F 针对从参数 xy 的属性(通常是它们的形状和元素类型)推断出的输入类型的限制。JAX 通过我们称之为追踪的过程执行此专用化。在追踪过程中,JAX 将 F 的专用化阶段化为一个 jaxpr,它是在 Jaxpr 中间语言中的函数。

  2. 将这个专用化、阶段化输出的计算降级到 XLA 编译器的输入语言 StableHLO。

  3. 编译降级后的 HLO 程序,为目标设备(CPU、GPU 或 TPU)生成优化后的可执行文件。

  4. 使用数组 xy 作为参数执行编译后的可执行文件。

JAX 的 AOT API 允许您直接控制这些步骤中的每一个,以及沿途的一些其他功能。一个例子:

>>> import jax

>>> def f(x, y): return 2 * x + y
>>> x, y = 3, 4

>>> traced = jax.jit(f).trace(x, y)

>>> # Print the specialized, staged-out representation (as Jaxpr IR)
>>> print(traced.jaxpr)
{ lambda ; a:i32[] b:i32[]. let
    c:i32[] = mul 2:i32[] a
    d:i32[] = add c b
  in (d,) }

>>> lowered = traced.lower()

>>> # Print lowered HLO
>>> print(lowered.as_text())
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32> {jax.result_info = "result"}) {
    %c = stablehlo.constant dense<2> : tensor<i32>
    %0 = stablehlo.multiply %c, %arg0 : tensor<i32>
    %1 = stablehlo.add %0, %arg1 : tensor<i32>
    return %1 : tensor<i32>
  }
}

>>> compiled = lowered.compile()

>>> # Query for cost analysis, print FLOP estimate
>>> compiled.cost_analysis()['flops']
2.0

>>> # Execute the compiled function!
>>> compiled(x, y)
Array(10, dtype=int32, weak_type=True)

请注意,降级后的对象只能在进行降级的同一进程中使用。对于导出用例,请参阅导出与序列化 API。

有关降级和编译函数提供哪些功能的更多详细信息,请参阅 jax.stages 文档。

jit 的所有可选参数(例如 static_argnums)在相应的追踪、降级、编译和执行中都受到尊重。

在上面的例子中,我们可以用任何具有 shapedtype 属性的对象替换 trace 的参数:

>>> i32_scalar = jax.ShapeDtypeStruct((), jnp.dtype('int32'))
>>> jax.jit(f).trace(i32_scalar, i32_scalar).lower().compile()(x, y)
Array(10, dtype=int32)

更一般地,trace 只需要其参数在结构上提供 JAX 进行专用化和降级所需的信息。对于像上面这样的典型数组参数,这意味着 shapedtype 字段。相比之下,对于静态参数,JAX 需要实际的数组值(更多信息请参阅下文)。

用与追踪不兼容的参数调用 AOT 编译函数会引发错误:

>>> x_1d = y_1d = jnp.arange(3)
>>> jax.jit(f).trace(i32_scalar, i32_scalar).lower().compile()(x_1d, y_1d)  
...
Traceback (most recent call last):
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with int32[3]
Argument 'y' compiled with int32[] and called with int32[3]

>>> x_f = y_f = jnp.float32(72.)
>>> jax.jit(f).trace(i32_scalar, i32_scalar).lower().compile()(x_f, y_f)  
...
Traceback (most recent call last):
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with float32[]
Argument 'y' compiled with int32[] and called with float32[]

与此相关的是,AOT 编译函数不能被 JAX 的即时转换(例如 jax.jitjax.grad()jax.vmap())转换。

带静态参数的追踪#

带静态参数的追踪强调了传递给 jax.jit 的选项、传递给 trace 的参数以及调用生成的编译函数所需的参数之间的交互。继续上面的例子:

>>> lowered_with_x = jax.jit(f, static_argnums=0).trace(7, 8).lower()

>>> # Lowered HLO, specialized to the *value* of the first argument (7)
>>> print(lowered_with_x.as_text())
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32>) -> (tensor<i32> {jax.result_info = "result"}) {
    %c = stablehlo.constant dense<14> : tensor<i32>
    %0 = stablehlo.add %c, %arg0 : tensor<i32>
    return %0 : tensor<i32>
  }
}

>>> lowered_with_x.compile()(5)
Array(19, dtype=int32, weak_type=True)

请注意,这里的 trace 像往常一样接受两个参数,但随后的编译函数只接受剩余的非静态第二个参数。静态的第一个参数(值 7)在降级时被视为常量并构建到降级计算中,在那里它可能与其他常量一起折叠。在这种情况下,它与 2 的乘法被简化,得到常量 14。

尽管上面 trace 的第二个参数可以替换为空洞的 shape/dtype 结构,但静态的第一个参数必须是一个具体的值。否则,追踪会出错:

>>> jax.jit(f, static_argnums=0).trace(i32_scalar, i32_scalar)  
Traceback (most recent call last):
TypeError: unsupported operand type(s) for *: 'int' and 'ShapeDtypeStruct'

>>> jax.jit(f, static_argnums=0).trace(10, i32_scalar).lower().compile()(5)
Array(25, dtype=int32)

tracelower 的结果不能直接安全地序列化以用于不同的进程。有关此目的的其他 API,请参阅导出与序列化

AOT 编译函数无法进行转换#

编译函数专门针对一组特定的参数“类型”进行优化,例如我们当前示例中具有特定形状和元素类型的数组。从 JAX 的内部角度来看,像 jax.vmap() 这样的转换会改变函数的类型签名,从而使编译过的类型签名失效。作为一项策略,JAX 简单地禁止编译函数参与转换。示例:

>>> def g(x):
...   assert x.shape == (3, 2)
...   return x @ jnp.ones(2)

>>> def make_z(*shape):
...   return jnp.arange(np.prod(shape)).reshape(shape)

>>> z, zs = make_z(3, 2), make_z(4, 3, 2)

>>> g_jit = jax.jit(g)
>>> g_aot = jax.jit(g).trace(z).lower().compile()

>>> jax.vmap(g_jit)(zs)
Array([[ 1.,  5.,  9.],
       [13., 17., 21.],
       [25., 29., 33.],
       [37., 41., 45.]], dtype=float32)

>>> jax.vmap(g_aot)(zs)  
Traceback (most recent call last):
TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type <class 'jax._src.interpreters.batching.BatchTracer'>

g_aot 参与自动微分(例如 jax.grad())时,也会引发类似的错误。为保持一致性,即使 jit 没有有意义地修改其参数的类型签名,也禁止通过 jax.jit 进行转换。

调试信息和分析(如果可用)#

除了主要的 AOT 功能(独立且显式的降级、编译和执行)之外,JAX 的各种 AOT 阶段还提供了一些额外功能,以帮助调试和收集编译器反馈。

例如,如上面的初始示例所示,降级函数通常提供文本表示。编译函数也提供相同的功能,并且还提供来自编译器的成本和内存分析。所有这些都通过 jax.stages.Loweredjax.stages.Compiled 对象上的方法提供(例如,上面的 lowered.as_text()compiled.cost_analysis())。您可以通过使用 debug_info 参数调用 lowered.as_text() 来获取更多调试信息,例如源代码位置。

这些方法旨在帮助手动检查和调试,而不是作为可靠可编程的 API。它们的可用性和输出因编译器、平台和运行时而异。这带来了两个重要的注意事项:

  1. 如果 JAX 当前后端上某些功能不可用,则其对应的方法将返回一些微不足道(且类似于 False)的值。例如,如果 JAX 底层编译器不提供成本分析,则 compiled.cost_analysis() 将为 None

  2. 如果某些功能可用,对于相应方法提供的内容仍然只有非常有限的保证。返回值不要求在 JAX 配置、后端/平台、版本甚至方法调用之间保持一致(类型、结构或值)。JAX 不能保证 compiled.cost_analysis() 在某天输出的结果在第二天仍然相同。

如有疑问,请参阅 jax.stages 的包 API 文档。