提前降低和编译#
JAX 的 jax.jit
转换返回一个函数,当调用该函数时,它会编译计算并在加速器(或 CPU)上运行。正如 JIT 首字母缩写所表明的那样,所有编译都及时发生在执行之前。
某些情况下需要提前 (AOT) 编译。当您希望在执行时间之前完全编译,或者您希望控制编译过程的不同部分何时发生时,JAX 为您提供了一些选项。
首先,让我们回顾一下编译的阶段。假设 f
是由 jax.jit()
输出的函数/可调用对象,例如对于某些输入可调用对象 F
,f = jax.jit(F)
。当使用参数调用它时,例如 f(x, y)
,其中 x
和 y
是数组,JAX 按顺序执行以下操作
分段输出 原始 Python 可调用对象
F
的专门版本到内部表示。这种专门化反映了将F
限制为从参数x
和y
的属性(通常是它们的形状和元素类型)推断出的输入类型。JAX 通过我们称之为追踪的过程执行这种专门化。在追踪期间,JAX 将F
的专门化分段到 jaxpr,这是 Jaxpr 中间语言中的一个函数。降低 这种专门化的、分段输出的计算到 XLA 编译器的输入语言 StableHLO。
编译 降低的 HLO 程序,为目标设备(CPU、GPU 或 TPU)生成优化的可执行文件。
执行 使用数组
x
和y
作为参数执行编译后的可执行文件。
JAX 的 AOT API 使您可以直接控制这些步骤中的每一个,以及沿途的其他一些功能。例如
>>> import jax
>>> def f(x, y): return 2 * x + y
>>> x, y = 3, 4
>>> traced = jax.jit(f).trace(x, y)
>>> # Print the specialized, staged-out representation (as Jaxpr IR)
>>> print(traced.jaxpr)
{ lambda ; a:i32[] b:i32[]. let c:i32[] = mul 2 a; d:i32[] = add c b in (d,) }
>>> lowered = traced.lower()
>>> # Print lowered HLO
>>> print(lowered.as_text())
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32> {jax.result_info = "result"}) {
%c = stablehlo.constant dense<2> : tensor<i32>
%0 = stablehlo.multiply %c, %arg0 : tensor<i32>
%1 = stablehlo.add %0, %arg1 : tensor<i32>
return %1 : tensor<i32>
}
}
>>> compiled = lowered.compile()
>>> # Query for cost analysis, print FLOP estimate
>>> compiled.cost_analysis()['flops']
2.0
>>> # Execute the compiled function!
>>> compiled(x, y)
Array(10, dtype=int32, weak_type=True)
请注意,降低的对象只能在降低它们的同一进程中使用。对于导出用例,请参阅 导出和序列化 API。
有关降低和编译函数提供的更多功能细节,请参阅 jax.stages
文档。
jit
的所有可选参数(例如 static_argnums
)在相应的追踪、降低、编译和执行中都得到尊重。
在上面的示例中,我们可以用任何具有 shape
和 dtype
属性的对象替换 trace
的参数
>>> i32_scalar = jax.ShapeDtypeStruct((), jnp.dtype('int32'))
>>> jax.jit(f).trace(i32_scalar, i32_scalar).lower().compile()(x, y)
Array(10, dtype=int32)
更一般地,trace
只需要其参数在结构上提供 JAX 必须知道的用于专门化和降低的信息。对于像上面那样的典型数组参数,这意味着 shape
和 dtype
字段。相比之下,对于静态参数,JAX 需要实际的数组值(更多内容请参见下面)。
使用与其追踪不兼容的参数调用 AOT 编译的函数会引发错误
>>> x_1d = y_1d = jnp.arange(3)
>>> jax.jit(f).trace(i32_scalar, i32_scalar).lower().compile()(x_1d, y_1d)
...
Traceback (most recent call last):
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with int32[3]
Argument 'y' compiled with int32[] and called with int32[3]
>>> x_f = y_f = jnp.float32(72.)
>>> jax.jit(f).trace(i32_scalar, i32_scalar).lower().compile()(x_f, y_f)
...
Traceback (most recent call last):
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with float32[]
Argument 'y' compiled with int32[] and called with float32[]
类似地,AOT 编译的函数不能被 JAX 的即时转换(如 jax.jit
、jax.grad()
和 jax.vmap()
)转换。
使用静态参数进行追踪#
使用静态参数进行追踪强调了传递给 jax.jit
的选项、传递给 trace
的参数以及调用生成的编译函数所需的参数之间的交互。继续上面的示例
>>> lowered_with_x = jax.jit(f, static_argnums=0).trace(7, 8).lower()
>>> # Lowered HLO, specialized to the *value* of the first argument (7)
>>> print(lowered_with_x.as_text())
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<i32>) -> (tensor<i32> {jax.result_info = "result"}) {
%c = stablehlo.constant dense<14> : tensor<i32>
%0 = stablehlo.add %c, %arg0 : tensor<i32>
return %0 : tensor<i32>
}
}
>>> lowered_with_x.compile()(5)
Array(19, dtype=int32, weak_type=True)
请注意,此处的 trace
像往常一样接受两个参数,但随后的编译函数仅接受剩余的非静态第二个参数。静态的第一个参数(值 7)在降低时被视为常量,并构建到降低的计算中,在那里它可能与其他常量折叠在一起。在这种情况下,它与 2 的乘法被简化,结果为常量 14。
虽然上面 trace
的第二个参数可以替换为空心的 shape/dtype 结构,但静态的第一个参数必须是具体的值。否则,追踪会出错
>>> jax.jit(f, static_argnums=0).trace(i32_scalar, i32_scalar)
Traceback (most recent call last):
TypeError: unsupported operand type(s) for *: 'int' and 'ShapeDtypeStruct'
>>> jax.jit(f, static_argnums=0).trace(10, i32_scalar).lower().compile()(5)
Array(25, dtype=int32)
trace
和 lower
的结果直接序列化以在不同的进程中使用是不安全的。有关此目的的其他 API,请参阅 导出和序列化。
AOT 编译的函数不能被转换#
编译函数专门用于一组特定的参数“类型”,例如在我们正在运行的示例中具有特定形状和元素类型的数组。从 JAX 的内部观点来看,诸如 jax.vmap()
之类的转换会以使编译类型的签名无效的方式更改函数的类型签名。作为一项策略,JAX 只是不允许编译函数参与转换。示例
>>> def g(x):
... assert x.shape == (3, 2)
... return x @ jnp.ones(2)
>>> def make_z(*shape):
... return jnp.arange(np.prod(shape)).reshape(shape)
>>> z, zs = make_z(3, 2), make_z(4, 3, 2)
>>> g_jit = jax.jit(g)
>>> g_aot = jax.jit(g).trace(z).lower().compile()
>>> jax.vmap(g_jit)(zs)
Array([[ 1., 5., 9.],
[13., 17., 21.],
[25., 29., 33.],
[37., 41., 45.]], dtype=float32)
>>> jax.vmap(g_aot)(zs)
Traceback (most recent call last):
TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type <class 'jax._src.interpreters.batching.BatchTracer'>
当 g_aot
参与自动微分时(例如 jax.grad()
),也会引发类似的错误。为了保持一致性,也不允许通过 jax.jit
进行转换,即使 jit
不会显着修改其参数的类型签名。
调试信息和分析(如果可用)#
除了主要的 AOT 功能(单独且显式的降低、编译和执行)之外,JAX 的各种 AOT 阶段还提供了一些额外的功能,以帮助进行调试和收集编译器反馈。
例如,正如上面的初始示例所示,降低的函数通常提供文本表示。编译后的函数也这样做,并且还提供来自编译器的成本和内存分析。所有这些都通过 jax.stages.Lowered
和 jax.stages.Compiled
对象上的方法提供(例如,上面的 lowered.as_text()
和 compiled.cost_analysis()
)。您可以通过使用 debug_info
参数调用 lowered.as_text()
来获得更多调试信息,例如,源代码位置。
这些方法旨在作为手动检查和调试的辅助工具,而不是作为可靠的可编程 API。它们的可用性和输出因编译器、平台和运行时而异。这导致了两个重要的注意事项
如果某些功能在 JAX 当前的后端不可用,则该功能的方法会返回一些微不足道的东西(并且类似
False
)。例如,如果 JAX 的底层编译器不提供成本分析,则compiled.cost_analysis()
将为None
。如果某些功能可用,那么对于相应方法提供的内容仍然只有非常有限的保证。不要求返回值在 JAX 配置、后端/平台、版本甚至方法的调用之间在类型、结构或值上保持一致。JAX 无法保证
compiled.cost_analysis()
在一天的输出在第二天仍然相同。
如有疑问,请参阅 jax.stages
的包 API 文档。