已闭合常量的处理#

“已闭合常量”是指在 JAX 函数跟踪过程中遇到的非标量数组,并且它们与函数的任何参数都没有依赖关系。JAX 操作,如 jax.numpylax,会被阶段式移除(staged out)并且不会创建已闭合常量。在以下示例中,数组 a_jax_arraynp.full 是已闭合常量,而 jnp.full 则不是。下面我们将已闭合常量简称为常量。

import numpy as np
from jax import jit
from jax import numpy as jnp

a_jax_array = jnp.ones((16,), dtype=np.float32)

@jit
def f(x):
  return x + a_jax_array + np.full((16,), 42.) + jnp.full((16,), 142.)

我们将在下面描述常量的 **未来** 内部实现细节。截至 2025 年 7 月,这还不是默认实现;它由环境变量 JAX_USE_SIMPLIFIED_JAXPR_CONSTANTS=True 启用。有关先前实现的详细信息,包括其缺点,请参阅 下文

Tracing#

当 JAX 跟踪遇到作为 JAX 原始(primitive)参数或函数返回值的常量时,它会被表示为 core.Literal,并与使用它们的原始一同嵌入到 Jaxpr 中。core.is_literalable 函数决定哪些常量被转换为 core.Literal。所有标量常量以及非标量的 np.ndarrayjax.Array 都会被转换为 core.Literal

降低(Lowering)#

在将代码降低到 HLO 时,我们可以直接为 core.Literal 发出 stablehlo.constant 操作,但这会有几个缺点:

  • 如果常量是 jax.Array(例如上面的 a_jax_array),那么在降低过程中,它将从设备拉取到主机,并在降低的模块执行时重新物化到设备上。这会增加主机内存使用量,有时甚至会急剧增加。此外,如果常量分片(sharded)到多个设备上,这种分片信息会丢失。

  • 大型常量会增加 HLO 的大小,特别是当同一个常量被多次使用时。此外,XLA 编译器会尝试对其进行常量折叠(constant-fold),这会导致警告和编译缓慢。此外,我们观察到 XLA 常量折叠有时会产生与编译代码略有不同的数值结果。另请参阅 大型已闭合常量被内联到 HLO 代码中 #29684

相反,在降低过程中,我们使用 core.jaxpr_const_args 函数来扫描 Jaxpr 并返回其中包含的常量列表,通过它们的 id 进行去重。对于调用 core.jaxpr_const_args 的每个 Jaxpr 和子 Jaxpr,都会对其进行记忆化(memoized)。

所有降低后的 HLO 函数将为出现在其相应 Jaxpr 中的每个唯一常量添加一个额外的参数。这些参数称为 const_args,它们出现在维度变量参数之后、token 参数之后,并且在实际数组参数之前。在降低过程中,我们维护一个从常量 id 到相应 const_args 的 HLO 值的映射 const_lowering: dict[int, mlir.IrValues]。此映射存储在 mlir.LoweringRuleContext 中,并由 mlir.ir_constant 使用:当遇到常量时,我们仅重用 const_lowering 中已有的降低项,而不是发出 stablehlo.constant

当我们降低 HLO 内部函数(即非 main 函数)时,我们会再次调用 core.jaxpr_const_args 来获取相应 Jaxpr 中的实际常量。这些常量应该包含在我们拥有 const_lowering 的常量之中。内部函数将获得自己的一组较小的 const_args 和自己的 const_lowering 映射,供在降低主体时使用。例如,mlir.lower_jaxpr_as_fun 函数就是其中一部分处理逻辑发生的地方。

mlir.jaxpr_subcomp 函数不创建新的 HLO 函数,而是创建当前函数内的一个块。它使用封闭函数的 const_lowering

还要注意,降低后的代码中仍会存在 stablehlo.constant,在以下三种情况下:

  • 当常量是标量时;我们希望这些常量可供 XLA 进行常量折叠。

  • 当常量未出现在跟踪的程序中,因此不在 Jaxpr 中时。这可能会发生在降低过程中产生的常量,例如,某些 PRNG 函数的降低包含常量。

  • 当我们进行导出(export)时:目前,我们在导出时不会提升(hoist)常量参数,因为导出的序列化目前不支持对数组进行序列化。我们使用 mlir.LoweringParameters.hoist_constants_as_args 参数来控制这一点。

一个额外的复杂性是,一些内部降低函数需要获取参数的 avals,有时还需要获取参数的 shardingslayouts。此外,所有参数(包括 const_args)的 avalsshardingslayout 在降低之后也仍然需要使用。因此,在调用堆栈的较高层级计算这些(例如,在 pxla.lower_sharding_computations 中)并将它们向下传递会很方便。

例如,函数 mlir.lower_jaxpr_to_modulepjit._pjit_cached_lower_jaxpr_to_funmlir.lower_jaxpr_to_fun 接受 in_avalsin_shardingsin_layouts,这些参数同时包含了 const_args 和常规参数(对应 Jaxpr.invars)的 avals。它们还接受一个 num_const_args 参数。

编译与执行#

降低后的 MLIR 模块包含 const_args 的参数,因此编译后的可执行文件需要传入 const_args。选择正确的位置来预置(prepend)const_args 非常重要。例如,在以下代码中,对 jitted 函数 f 的第二次调用预计会命中 C++ jit 缓存,而无需执行任何 Python 代码。

const = jnp.array([42.])
f = jax.jit(lambda: const)

f()
f()

(待办:yashk2810 计划撰写关于 jit 缓存工作原理的描述。)这意味着 const 将必须在 C++ 中传递给可执行文件(因此存储在 pxla.MeshExecutableFastpathData 中),并且因此 C++ 缓存未命中函数(例如 pjit._cpp_pjit.cache_misspxla.MeshExecutable.create_cpp_call 中的 aot_cache_miss)不会将 const_args 作为参数。相反,这些缓存未命中函数将不得不预置 const_args

C++ 快速路径(fast path)从 jaxlib 0.7.1 开始支持 const_args。在之前的版本中,当存在 const_args 时,快速路径会被禁用。

为了实现此方案,我们在 stages.Loweringstages.Loweredstages.CompiledCallParams 中保留 const_args

有趣的是,当我们序列化可执行文件(例如用于编译缓存)时,我们不需要序列化已闭合的常量。可执行文件本身不包含它们,并且需要将它们作为 const_args 传入。任何将要反序列化缓存的可执行文件的人都必须传入 const_args

在 AOT 模式下,降低和执行可能会使用不同的 jax_enable_x64 配置值。如果常量是 64 位 ndarray,我们必须在降低和执行时使用相同的 jax_enable_x64 值。

先前实现#

这描述了截至 2025 年 7 月(只要 JAX_USE_SIMPLIFIED_CONSTANTS=False)我们处理已闭合常量的方式。

当 JAX 将函数跟踪到 Jaxpr 时,它会将已闭合的值收集到一个常量集中,并向 Jaxpr 添加相应的 constvars 集(实际参数由 invars 表示)。大多数跟踪函数,例如 trace_to_jaxpr_dynamic,会同时返回 Jaxpr 和常量。

在代码的许多地方,我们使用 core.ClosedJaxpr 类,它包含一个 Jaxpr 和对应于 Jaxpr.constvarsconsts

ClosedJaxpr 有几个问题:

  • ClosedJaxprconsts 的降低会导致内联 stablehlo.constant,并带有上述所有问题。

  • JaxprClosedJaxpr 在 JAX 中被广泛使用,通常使用通用名称 jaxpr,因此很难区分我们拥有的是哪种 Jaxpr。我们已开始添加类型声明,但在某些地方的代码使用 isinstance 条件语句来同时处理两者。

  • 由于 Jaxpr 和 ClosedJaxpr 有时用作缓存键,并且它们是通过 id 进行哈希计算的,因此我们希望记忆化它们的构造。例如,函数 pe.closed_jaxpr 记忆化了 ClosedJaxpr 的构造,但仅限于 consts 为空的情况。这是因为有时 consts 是不可哈希的。

  • 处理 ClosedJaxpr 中的常量需要一些额外的注意。例如,在 Mosaic 降低的某些地方,我们尚未实现对带有非空常量的 ClosedJaxpr 的处理(例如 此处)。

  • 当我们把已闭合常量转换为输入时,在转换过程中必须小心处理这些辅助输入。