提前降低和编译#

JAX 提供了几种转换,例如 jax.jitjax.pmap,它们返回一个已编译并在加速器或 CPU 上运行的函数。正如 JIT 的缩写所表明的那样,所有编译都发生在执行的即时

某些情况下需要提前 (AOT) 编译。 当您希望在执行时间之前完全编译,或者您希望控制编译过程的不同部分何时发生时,JAX 为您提供了一些选项。

首先,让我们回顾一下编译的各个阶段。 假设 fjax.jit() 输出的函数/可调用对象,例如对于某些输入可调用对象 Ff = jax.jit(F)。 当它使用参数调用时,例如 f(x, y),其中 xy 是数组,JAX 会按顺序执行以下操作

  1. 阶段化输出原始 Python 可调用对象 F 的专门版本为内部表示。 这种专门化反映了对 F 的限制,即限制为从参数 xy 的属性(通常是它们的形状和元素类型)推断出的输入类型。

  2. 降低这种专门化的、阶段性输出的计算到 XLA 编译器的输入语言 StableHLO。

  3. 编译降低后的 HLO 程序,为目标设备(CPU、GPU 或 TPU)生成优化的可执行文件。

  4. 执行使用数组 xy 作为参数的已编译可执行文件。

JAX 的 AOT API 使您可以直接控制步骤 #2、#3 和 #4(但不是 #1),以及沿途的其他一些功能。 一个例子

>>> import jax

>>> def f(x, y): return 2 * x + y
>>> x, y = 3, 4

>>> lowered = jax.jit(f).lower(x, y)

>>> # Print lowered HLO
>>> print(lowered.as_text())
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32> {jax.result_info = ""}) {
    %c = stablehlo.constant dense<2> : tensor<i32>
    %0 = stablehlo.multiply %c, %arg0 : tensor<i32>
    %1 = stablehlo.add %0, %arg1 : tensor<i32>
    return %1 : tensor<i32>
  }
}

>>> compiled = lowered.compile()

>>> # Query for cost analysis, print FLOP estimate
>>> compiled.cost_analysis()['flops']
2.0

>>> # Execute the compiled function!
>>> compiled(x, y)
Array(10, dtype=int32, weak_type=True)

请注意,降低后的对象只能在降低它们的同一进程中使用。 对于导出用例,请参阅导出和序列化 API。

有关降低和编译后的函数提供的更多功能细节,请参阅 jax.stages 文档。

jit 的所有可选参数(例如 static_argnums)在相应的降低、编译和执行中都得到尊重。

在上面的示例中,我们可以使用任何具有 shapedtype 属性的对象替换 lower 的参数

>>> i32_scalar = jax.ShapeDtypeStruct((), jnp.dtype('int32'))
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x, y)
Array(10, dtype=int32)

更一般地说,lower 只需要其参数在结构上提供 JAX 必须知道的用于专门化和降低的信息。 对于像上面这样的典型数组参数,这意味着 shapedtype 字段。 相比之下,对于静态参数,JAX 需要实际的数组值(更多信息请参见下文)。

使用与其降低不兼容的参数调用 AOT 编译的函数会引发错误

>>> x_1d = y_1d = jnp.arange(3)
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_1d, y_1d)  
...
Traceback (most recent call last):
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with int32[3]
Argument 'y' compiled with int32[] and called with int32[3]

>>> x_f = y_f = jnp.float32(72.)
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_f, y_f)  
...
Traceback (most recent call last):
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with float32[]
Argument 'y' compiled with int32[] and called with float32[]

相关地,AOT 编译的函数不能被 JAX 的即时转换(例如 jax.jitjax.grad()jax.vmap())转换。

使用静态参数降低#

使用静态参数降低突出了传递给 jax.jit 的选项、传递给 lower 的参数以及调用生成的编译函数所需的参数之间的交互。 继续我们上面的示例

>>> lowered_with_x = jax.jit(f, static_argnums=0).lower(7, 8)

>>> # Lowered HLO, specialized to the *value* of the first argument (7)
>>> print(lowered_with_x.as_text())
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32>) -> (tensor<i32> {jax.result_info = ""}) {
    %c = stablehlo.constant dense<14> : tensor<i32>
    %0 = stablehlo.add %c, %arg0 : tensor<i32>
    return %0 : tensor<i32>
  }
}

>>> lowered_with_x.compile()(5)
Array(19, dtype=int32, weak_type=True)

lower 的结果不安全,不能直接序列化以在不同的进程中使用。 有关此目的的其他 API,请参阅 导出和序列化

请注意,这里的 lower 像往常一样采用两个参数,但随后的编译函数仅接受剩余的非静态第二个参数。 静态第一个参数(值 7)在降低时被视为常量,并构建到降低后的计算中,其中可能会与其他常量折叠在一起。 在这种情况下,它乘以 2 被简化,导致常量 14。

尽管上面的 lower 的第二个参数可以替换为空心的形状/dtype 结构,但静态第一个参数必须是具体值。 否则,降低会出错

>>> jax.jit(f, static_argnums=0).lower(i32_scalar, i32_scalar)  
Traceback (most recent call last):
TypeError: unsupported operand type(s) for *: 'int' and 'ShapeDtypeStruct'

>>> jax.jit(f, static_argnums=0).lower(10, i32_scalar).compile()(5)
Array(25, dtype=int32)

AOT 编译的函数不能被转换#

编译后的函数专门用于特定的一组参数“类型”,例如我们正在运行的示例中具有特定形状和元素类型的数组。 从 JAX 的内部角度来看,诸如 jax.vmap() 之类的转换会以使编译后的类型签名失效的方式更改函数的类型签名。 作为一项策略,JAX 只是不允许编译后的函数参与转换。 例子

>>> def g(x):
...   assert x.shape == (3, 2)
...   return x @ jnp.ones(2)

>>> def make_z(*shape):
...   return jnp.arange(np.prod(shape)).reshape(shape)

>>> z, zs = make_z(3, 2), make_z(4, 3, 2)

>>> g_jit = jax.jit(g)
>>> g_aot = jax.jit(g).lower(z).compile()

>>> jax.vmap(g_jit)(zs)
Array([[ 1.,  5.,  9.],
       [13., 17., 21.],
       [25., 29., 33.],
       [37., 41., 45.]], dtype=float32)

>>> jax.vmap(g_aot)(zs)  
Traceback (most recent call last):
TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type <class 'jax._src.interpreters.batching.BatchTracer'>

g_aot 参与自动微分(例如 jax.grad())时,会引发类似的错误。 为了保持一致性,也禁止使用 jax.jit 进行转换,即使 jit 不会明显修改其参数的类型签名。

调试信息和分析(如果可用)#

除了主要的 AOT 功能(单独且显式的降低、编译和执行)之外,JAX 的各种 AOT 阶段还提供了一些额外的功能,以帮助进行调试和收集编译器反馈。

例如,正如上面的初始示例所示,降低后的函数通常提供文本表示形式。 编译后的函数也这样做,并且还提供来自编译器的成本和内存分析。 所有这些都是通过 jax.stages.Loweredjax.stages.Compiled 对象上的方法提供的(例如,上面的 lowered.as_text()compiled.cost_analysis())。 您可以使用 debug_info 参数到 lowered.as_text() 来获取更多调试信息,例如源代码位置。

这些方法旨在作为手动检查和调试的辅助工具,而不是可靠的可编程 API。 它们的可用性和输出因编译器、平台和运行时而异。 这带来了两个重要的注意事项

  1. 如果 JAX 的当前后端上某些功能不可用,那么它的方法会返回一些微不足道的东西(并且类似于 False)。 例如,如果 JAX 的底层编译器不提供成本分析,则 compiled.cost_analysis() 将为 None

  2. 如果某些功能可用,那么对于相应的方法提供的内容仍然只有非常有限的保证。 并不要求返回值在 JAX 配置、后端/平台、版本甚至方法的调用之间在类型、结构或值上保持一致。 JAX 不能保证在某一天 compiled.cost_analysis() 的输出在第二天保持不变。

如有疑问,请参阅 jax.stages 的包 API 文档。

检查阶段性计算#

此注释顶部的列表中阶段 #1 提到了在降低之前的专门化和阶段化。 JAX 内部对专门化为其参数类型的函数的概念并不总是内存中的具体化数据结构。 要显式构造 JAX 函数在内部 Jaxpr 中间语言中专门化的视图,请参阅 jax.make_jaxpr()