已闭合常量的处理#
“已闭合常量”是指在 JAX 函数跟踪过程中遇到的非标量数组,并且它们与函数的任何参数都没有依赖关系。JAX 操作,如 jax.numpy
和 lax
,会被阶段式移除(staged out)并且不会创建已闭合常量。在以下示例中,数组 a_jax_array
和 np.full
是已闭合常量,而 jnp.full
则不是。下面我们将已闭合常量简称为常量。
import numpy as np
from jax import jit
from jax import numpy as jnp
a_jax_array = jnp.ones((16,), dtype=np.float32)
@jit
def f(x):
return x + a_jax_array + np.full((16,), 42.) + jnp.full((16,), 142.)
我们将在下面描述常量的 **未来** 内部实现细节。截至 2025 年 7 月,这还不是默认实现;它由环境变量 JAX_USE_SIMPLIFIED_JAXPR_CONSTANTS=True
启用。有关先前实现的详细信息,包括其缺点,请参阅 下文。
Tracing#
当 JAX 跟踪遇到作为 JAX 原始(primitive)参数或函数返回值的常量时,它会被表示为 core.Literal
,并与使用它们的原始一同嵌入到 Jaxpr
中。core.is_literalable
函数决定哪些常量被转换为 core.Literal
。所有标量常量以及非标量的 np.ndarray
和 jax.Array
都会被转换为 core.Literal
。
降低(Lowering)#
在将代码降低到 HLO 时,我们可以直接为 core.Literal
发出 stablehlo.constant
操作,但这会有几个缺点:
如果常量是
jax.Array
(例如上面的a_jax_array
),那么在降低过程中,它将从设备拉取到主机,并在降低的模块执行时重新物化到设备上。这会增加主机内存使用量,有时甚至会急剧增加。此外,如果常量分片(sharded)到多个设备上,这种分片信息会丢失。大型常量会增加 HLO 的大小,特别是当同一个常量被多次使用时。此外,XLA 编译器会尝试对其进行常量折叠(constant-fold),这会导致警告和编译缓慢。此外,我们观察到 XLA 常量折叠有时会产生与编译代码略有不同的数值结果。另请参阅 大型已闭合常量被内联到 HLO 代码中 #29684。
相反,在降低过程中,我们使用 core.jaxpr_const_args
函数来扫描 Jaxpr
并返回其中包含的常量列表,通过它们的 id
进行去重。对于调用 core.jaxpr_const_args
的每个 Jaxpr
和子 Jaxpr
,都会对其进行记忆化(memoized)。
所有降低后的 HLO 函数将为出现在其相应 Jaxpr
中的每个唯一常量添加一个额外的参数。这些参数称为 const_args
,它们出现在维度变量参数之后、token 参数之后,并且在实际数组参数之前。在降低过程中,我们维护一个从常量 id
到相应 const_args
的 HLO 值的映射 const_lowering: dict[int, mlir.IrValues]
。此映射存储在 mlir.LoweringRuleContext
中,并由 mlir.ir_constant
使用:当遇到常量时,我们仅重用 const_lowering
中已有的降低项,而不是发出 stablehlo.constant
。
当我们降低 HLO 内部函数(即非 main
函数)时,我们会再次调用 core.jaxpr_const_args
来获取相应 Jaxpr
中的实际常量。这些常量应该包含在我们拥有 const_lowering
的常量之中。内部函数将获得自己的一组较小的 const_args
和自己的 const_lowering
映射,供在降低主体时使用。例如,mlir.lower_jaxpr_as_fun
函数就是其中一部分处理逻辑发生的地方。
mlir.jaxpr_subcomp
函数不创建新的 HLO 函数,而是创建当前函数内的一个块。它使用封闭函数的 const_lowering
。
还要注意,降低后的代码中仍会存在 stablehlo.constant
,在以下三种情况下:
当常量是标量时;我们希望这些常量可供 XLA 进行常量折叠。
当常量未出现在跟踪的程序中,因此不在
Jaxpr
中时。这可能会发生在降低过程中产生的常量,例如,某些 PRNG 函数的降低包含常量。当我们进行导出(export)时:目前,我们在导出时不会提升(hoist)常量参数,因为导出的序列化目前不支持对数组进行序列化。我们使用
mlir.LoweringParameters.hoist_constants_as_args
参数来控制这一点。
一个额外的复杂性是,一些内部降低函数需要获取参数的 avals
,有时还需要获取参数的 shardings
和 layouts
。此外,所有参数(包括 const_args
)的 avals
、shardings
和 layout
在降低之后也仍然需要使用。因此,在调用堆栈的较高层级计算这些(例如,在 pxla.lower_sharding_computations
中)并将它们向下传递会很方便。
例如,函数 mlir.lower_jaxpr_to_module
、pjit._pjit_cached_lower_jaxpr_to_fun
和 mlir.lower_jaxpr_to_fun
接受 in_avals
、in_shardings
和 in_layouts
,这些参数同时包含了 const_args
和常规参数(对应 Jaxpr.invars
)的 avals
。它们还接受一个 num_const_args
参数。
编译与执行#
降低后的 MLIR 模块包含 const_args
的参数,因此编译后的可执行文件需要传入 const_args
。选择正确的位置来预置(prepend)const_args
非常重要。例如,在以下代码中,对 jitted 函数 f
的第二次调用预计会命中 C++ jit 缓存,而无需执行任何 Python 代码。
const = jnp.array([42.])
f = jax.jit(lambda: const)
f()
f()
(待办:yashk2810 计划撰写关于 jit 缓存工作原理的描述。)这意味着 const
将必须在 C++ 中传递给可执行文件(因此存储在 pxla.MeshExecutableFastpathData
中),并且因此 C++ 缓存未命中函数(例如 pjit._cpp_pjit.cache_miss
或 pxla.MeshExecutable.create_cpp_call
中的 aot_cache_miss
)不会将 const_args
作为参数。相反,这些缓存未命中函数将不得不预置 const_args
。
C++ 快速路径(fast path)从 jaxlib 0.7.1 开始支持 const_args
。在之前的版本中,当存在 const_args
时,快速路径会被禁用。
为了实现此方案,我们在 stages.Lowering
、stages.Lowered
和 stages.CompiledCallParams
中保留 const_args
。
有趣的是,当我们序列化可执行文件(例如用于编译缓存)时,我们不需要序列化已闭合的常量。可执行文件本身不包含它们,并且需要将它们作为 const_args
传入。任何将要反序列化缓存的可执行文件的人都必须传入 const_args
。
在 AOT 模式下,降低和执行可能会使用不同的 jax_enable_x64
配置值。如果常量是 64 位 ndarray
,我们必须在降低和执行时使用相同的 jax_enable_x64
值。
先前实现#
这描述了截至 2025 年 7 月(只要 JAX_USE_SIMPLIFIED_CONSTANTS=False
)我们处理已闭合常量的方式。
当 JAX 将函数跟踪到 Jaxpr
时,它会将已闭合的值收集到一个常量集中,并向 Jaxpr 添加相应的 constvars
集(实际参数由 invars
表示)。大多数跟踪函数,例如 trace_to_jaxpr_dynamic
,会同时返回 Jaxpr
和常量。
在代码的许多地方,我们使用 core.ClosedJaxpr
类,它包含一个 Jaxpr
和对应于 Jaxpr.constvars
的 consts
。
ClosedJaxpr
有几个问题:
ClosedJaxpr
中consts
的降低会导致内联stablehlo.constant
,并带有上述所有问题。Jaxpr
和ClosedJaxpr
在 JAX 中被广泛使用,通常使用通用名称jaxpr
,因此很难区分我们拥有的是哪种Jaxpr
。我们已开始添加类型声明,但在某些地方的代码使用isinstance
条件语句来同时处理两者。由于 Jaxpr 和 ClosedJaxpr 有时用作缓存键,并且它们是通过
id
进行哈希计算的,因此我们希望记忆化它们的构造。例如,函数 pe.closed_jaxpr 记忆化了ClosedJaxpr
的构造,但仅限于consts
为空的情况。这是因为有时consts
是不可哈希的。处理 ClosedJaxpr 中的常量需要一些额外的注意。例如,在 Mosaic 降低的某些地方,我们尚未实现对带有非空常量的 ClosedJaxpr 的处理(例如 此处)。
当我们把已闭合常量转换为输入时,在转换过程中必须小心处理这些辅助输入。