导出和序列化已分阶段计算#

“JIT” 编译(提前进行降低和编译)API 生成的对象可用于调试,或在同一进程中进行编译和执行。有时您希望序列化一个已降低的 JAX 函数,以便在单独的进程中,或许稍后进行编译和执行。这将允许您

  • 在另一个进程或机器上编译和执行该函数,而无需访问 JAX 程序,也无需重复分阶段和降低过程,例如在推理系统中。

  • 在没有目标加速器访问权限的机器上跟踪和降低一个函数,而您稍后将在此机器上编译和执行该函数。

  • 归档 JAX 函数的快照,例如,以便以后能够重现您的结果。注意: 查看此用例的兼容性保证

有关更多详细信息,请参阅 jax.export API 参考。

下面是一个示例

>>> import re
>>> import numpy as np
>>> import jax
>>> from jax import export

>>> def f(x): return 2 * x * x


>>> exported: export.Exported = export.export(jax.jit(f))(
...    jax.ShapeDtypeStruct((), np.float32))

>>> # You can inspect the Exported object
>>> exported.fun_name
'f'

>>> exported.in_avals
(ShapedArray(float32[]),)

>>> print(re.search(r".*@main.*", exported.mlir_module()).group(0))
  func.func public @main(%arg0: tensor<f32> loc("x")) -> (tensor<f32> {jax.result_info = "result"}) {

>>> # And you can serialize the Exported to a bytearray.
>>> serialized: bytearray = exported.serialize()

>>> # The serialized function can later be rehydrated and called from
>>> # another JAX computation, possibly in another process.
>>> rehydrated_exp: export.Exported = export.deserialize(serialized)
>>> rehydrated_exp.in_avals
(ShapedArray(float32[]),)

>>> def callee(y):
...  return 3. * rehydrated_exp.call(y * 4.)

>>> callee(1.)
Array(96., dtype=float32)

序列化分为两个阶段

  1. 导出以生成一个 jax.export.Exported 对象,该对象包含已降低函数的 StableHLO 以及从另一个 JAX 函数调用它所需的元数据。我们计划添加代码以从 TensorFlow 生成 Exported 对象,并从 TensorFlow 和 PyTorch 使用 Exported 对象。

  2. 使用 flatbuffers 格式将实际序列化为字节数组。有关序列化到 TensorFlow 图以实现与 TensorFlow 互操作的替代方案,请参阅与 TensorFlow 的互操作

支持反向模式 AD#

序列化可以选择支持高阶反向模式 AD。这是通过将原始函数的 jax.vjp() 与原始函数一起序列化来完成的,直到用户指定的阶数(默认为 0,表示反序列化的函数无法微分)。

>>> import jax
>>> from jax import export
>>> from typing import Callable

>>> def f(x): return 7 * x * x * x

>>> # Serialize 3 levels of VJP along with the primal function
>>> blob: bytearray = export.export(jax.jit(f))(1.).serialize(vjp_order=3)
>>> rehydrated_f: Callable = export.deserialize(blob).call

>>> rehydrated_f(0.1)  # 7 * 0.1^3
Array(0.007, dtype=float32)

>>> jax.grad(rehydrated_f)(0.1)  # 7*3 * 0.1^2
Array(0.21000001, dtype=float32)

>>> jax.grad(jax.grad(rehydrated_f))(0.1)  # 7*3*2 * 0.1
Array(4.2, dtype=float32)

>>> jax.grad(jax.grad(jax.grad(rehydrated_f)))(0.1)  # 7*3*2
Array(42., dtype=float32)

>>> jax.grad(jax.grad(jax.grad(jax.grad(rehydrated_f))))(0.1)  
Traceback (most recent call last):
ValueError: No VJP is available

请注意,VJP 函数是在序列化过程中惰性计算的,此时 JAX 程序仍然可用。这意味着它尊重 JAX VJP 的所有功能,例如 jax.custom_vjp()jax.remat()

请注意,反序列化的函数不支持任何其他转换,例如前向模式 AD (jvp) 或 jax.vmap()

兼容性保证#

您不应将仅从降低获得的原始 StableHLO(jax.jit(f).lower(1.).compiler_ir())用于归档和在另一个进程中编译,原因如下。

首先,编译可能使用不同版本的编译器,支持不同版本的 StableHLO。 jax.export 模块通过使用 StableHLO 的可移植工件功能来处理此问题,以应对 StableHLO 操作集可能发生的演变。

自定义调用的兼容性保证#

其次,原始 StableHLO 可能包含引用 C++ 函数的自定义调用。JAX 使用自定义调用来降低少量原始函数,例如线性代数原始函数、分片注解或 Pallas 内核。这些不在 StableHLO 的兼容性保证范围内。这些函数的 C++ 实现很少更改,但可能会更改。

jax.export 做出以下导出兼容性保证:JAX 导出工件可以被一个比用于导出的 JAX 版本*****更新** 6 个月**的编译器和 JAX 运行时系统编译和执行(我们称 JAX 导出提供 **6 个月后向兼容性**)。如果您想将导出的工件归档以供以后编译和执行,这很有用。

  • **更新** 6 个月**的编译器和 JAX 运行时系统编译和执行(我们称 JAX 导出提供 **6 个月后向兼容性**)。如果您想将导出的工件归档以供以后编译和执行,这很有用。

  • **旧** 3 个星期**的 JAX 版本(我们称 JAX 导出提供 **3 周前向兼容性**)。如果您想使用比导出更早构建和部署的消费者编译和运行导出的工件,例如在导出时已部署的推理系统,这很有用。

(特定的兼容性窗口长度与 JAX 为 jax2tf承诺的相同,并且基于TensorFlow 兼容性。术语“后向兼容性”来自消费者(例如推理系统)的角度。)

重要的是 **导出和消耗组件的构建时间**,而不是导出和编译发生的时间。对于外部 JAX 用户,可以运行不同版本的 JAX 和 jaxlib;重要的是 jaxlib 释放的构建时间。

为了降低不兼容的可能性,内部 JAX 用户应

  • 尽可能频繁地重新构建和重新部署消费者系统.

外部用户应

  • 尽可能使用**相同版本的 jaxlib** 运行导出和消费者系统,并且

  • 为归档**使用最新发布的 jaxlib 版本**进行导出。

如果您绕过 jax.export API 来获取 StableHLO 代码,则兼容性保证不适用。

为了确保前向兼容性,当我们更改 JAX 降低规则以使用新的自定义调用目标时,JAX 将在 3 周内不使用新目标。要使用最新的降低规则,您可以传递 --jax_export_ignore_forward_compatibility=1 配置标志或 JAX_EXPORT_IGNORE_FORWARD_COMPATIBILITY=1 环境变量。

只有一部分自定义调用被保证是稳定的并具有兼容性保证(查看列表)。我们不断地将更多自定义调用目标添加到允许列表中,并进行后向兼容性测试。如果您尝试序列化调用其他自定义调用目标的代码,则在导出时会收到错误。

如果您想为特定自定义调用(例如,目标为 my_target)禁用此安全检查,您可以将 export.DisabledSafetyCheck.custom_call("my_target") 添加到 export 方法的 disabled_checks 参数中,如下例所示

>>> import jax
>>> from jax import export
>>> from jax import lax
>>> from jax._src import core
>>> from jax._src.interpreters import mlir
>>> # Define a new primitive backed by a custom call
>>> new_prim = core.Primitive("new_prim")
>>> _ = new_prim.def_abstract_eval(lambda x: x)
>>> _ = mlir.register_lowering(new_prim, lambda ctx, o: mlir.custom_call("my_new_prim", operands=[o], result_types=[o.type]).results)
>>> print(jax.jit(new_prim.bind).lower(1.).compiler_ir())
module @jit_bind attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<f32>) -> (tensor<f32> {jax.result_info = "result"}) {
    %0 = stablehlo.custom_call @my_new_prim(%arg0) {api_version = 2 : i32, backend_config = ""} : (tensor<f32>) -> tensor<f32>
    return %0 : tensor<f32>
  }
}

>>> # If we try to export, we get an error
>>> export.export(jax.jit(new_prim.bind))(1.)  
Traceback (most recent call last):
ValueError: Cannot serialize code with custom calls whose targets have no compatibility guarantees: my_new_bind

>>> # We can avoid the error if we pass a `DisabledSafetyCheck.custom_call`
>>> exp = export.export(
...    jax.jit(new_prim.bind),
...    disabled_checks=[export.DisabledSafetyCheck.custom_call("my_new_prim")])(1.)

有关确保兼容性的开发者信息,请参阅 确保前向和后向兼容性

跨平台和多平台导出#

对于少数 JAX 原始函数,JAX 降低是平台特定的。默认情况下,代码会为导出机上的加速器进行降低和导出

>>> from jax import export
>>> export.default_export_platform()
'cpu'

有一个安全检查,当尝试在没有为代码导出的加速器的机器上编译 Exported 对象时会引发错误。

您可以显式指定代码应该导出到哪些平台。这允许您指定一个不同于导出时可用的加速器,甚至允许您指定多平台导出以获得一个可以跨多个平台编译和执行的 Exported 对象。

>>> import jax
>>> from jax import export
>>> from jax import lax

>>> # You can specify the export platform, e.g., `tpu`, `cpu`, `cuda`, `rocm`
>>> # even if the current machine does not have that accelerator.
>>> exp = export.export(jax.jit(lax.cos), platforms=['tpu'])(1.)

>>> # But you will get an error if you try to compile `exp`
>>> # on a machine that does not have TPUs.
>>> exp.call(1.)  
Traceback (most recent call last):
ValueError: Function 'cos' was lowered for platforms '('tpu',)' but it is used on '('cpu',)'.

>>> # We can avoid the error if we pass a `DisabledSafetyCheck.platform`
>>> # parameter to `export`, e.g., because you have reasons to believe
>>> # that the code lowered will run adequately on the current
>>> # compilation platform (which is the case for `cos` in this
>>> # example):
>>> exp_unsafe = export.export(jax.jit(lax.cos),
...    platforms=['tpu'],
...    disabled_checks=[export.DisabledSafetyCheck.platform()])(1.)

>>> exp_unsafe.call(1.)
Array(0.5403023, dtype=float32, weak_type=True)

# and similarly with multi-platform lowering
>>> exp_multi = export.export(jax.jit(lax.cos),
...    platforms=['tpu', 'cpu', 'cuda'])(1.)
>>> exp_multi.call(1.)
Array(0.5403023, dtype=float32, weak_type=True)

对于多平台导出,StableHLO 将包含多个降低结果,但仅针对需要它的原始函数,因此生成的模块大小应仅比具有单平台导出的模块大小略大。作为极端情况,当序列化不包含任何平台特定降低的原始函数的模块时,您将获得与单平台导出相同的 StableHLO。

>>> import jax
>>> from jax import export
>>> from jax import lax
>>> # A largish function
>>> def f(x):
...   for i in range(1000):
...     x = jnp.cos(x)
...   return x

>>> exp_single = export.export(jax.jit(f))(1.)
>>> len(exp_single.mlir_module_serialized)  
9220

>>> exp_multi = export.export(jax.jit(f),
...                           platforms=["cpu", "tpu", "cuda"])(1.)
>>> len(exp_multi.mlir_module_serialized)  
9282

形状多态导出#

在 JIT 模式下使用时,JAX 会为输入形状的每种组合分别跟踪和降低函数。导出时,在某些情况下可以使用维度变量来表示某些输入维度,以获得可用于多种输入形状组合的导出工件。

请参阅 形状多态 文档。

设备多态导出#

导出的工件可能包含输入、输出和某些中间结果的分片注解,但这些注解不直接引用导出时的实际物理设备。相反,分片注解引用逻辑设备。这意味着您可以在导出时使用的不同物理设备上编译和运行导出的工件。

实现设备多态导出的最清晰方法是使用由 jax.sharding.AbstractMesh 构建的分片,该分片仅包含网格形状和轴名称。但是,如果您使用为具有具体设备的网格构建的分片,也可以获得相同的结果,因为在跟踪和降低过程中会忽略网格中的实际设备。

>>> import jax
>>> from jax import export
>>> from jax.sharding import AbstractMesh, Mesh, NamedSharding
>>> from jax.sharding import PartitionSpec as P
>>>
>>> # Use an AbstractMesh for exporting
>>> export_mesh = AbstractMesh((4,), ("a",))

>>> def f(x):
...   return x.T

>>> exp = export.export(jax.jit(f))(
...    jax.ShapeDtypeStruct((32,), dtype=np.int32,
...                         sharding=NamedSharding(export_mesh, P("a"))))

>>> # `exp` knows for how many devices it was exported.
>>> exp.nr_devices
4

>>> # and it knows the shardings for the inputs. These will be applied
>>> # when the exported is called.
>>> exp.in_shardings_hlo
({devices=[4]<=[4]},)

>>> # You can also use a concrete set of devices for exporting
>>> concrete_devices = jax.local_devices()[:4]
>>> concrete_mesh = Mesh(concrete_devices, ("a",))
>>> exp2 = export.export(jax.jit(f))(
...    jax.ShapeDtypeStruct((32,), dtype=np.int32,
...                         sharding=NamedSharding(concrete_mesh, P("a"))))

>>> # You can expect the same results
>>> assert exp.in_shardings_hlo == exp2.in_shardings_hlo

>>> # When you call an Exported, you must use a concrete set of devices
>>> arg = jnp.arange(8 * 4)
>>> res1 = exp.call(jax.device_put(arg,
...                                NamedSharding(concrete_mesh, P("a"))))

>>> # Check out the first 2 shards of the result
>>> [f"device={s.device} index={s.index}" for s in res1.addressable_shards[:2]]
['device=TFRT_CPU_0 index=(slice(0, 8, None),)',
 'device=TFRT_CPU_1 index=(slice(8, 16, None),)']

>>> # We can call `exp` with some other 4 devices and another
>>> # mesh with a different shape, as long as the number of devices is
>>> # the same.
>>> other_mesh = Mesh(np.array(jax.local_devices()[2:6]).reshape((2, 2)), ("b", "c"))
>>> res2 = exp.call(jax.device_put(arg,
...                                NamedSharding(other_mesh, P("b"))))

>>> # Check out the first 2 shards of the result. Notice that the output is
>>> # sharded similarly; this means that the input was resharded according to the
>>> # exp.in_shardings.
>>> [f"device={s.device} index={s.index}" for s in res2.addressable_shards[:2]]
['device=TFRT_CPU_2 index=(slice(0, 8, None),)',
 'device=TFRT_CPU_3 index=(slice(8, 16, None),)']

尝试使用与导出时不同的设备数量调用导出的工件是错误的。

>>> import jax
>>> from jax import export
>>> from jax.sharding import Mesh, NamedSharding
>>> from jax.sharding import PartitionSpec as P

>>> export_devices = jax.local_devices()
>>> export_mesh = Mesh(np.array(export_devices), ("a",))
>>> def f(x):
...   return x.T

>>> exp = export.export(jax.jit(f))(
...    jax.ShapeDtypeStruct((4 * len(export_devices),), dtype=np.int32,
...                         sharding=NamedSharding(export_mesh, P("a"))))

>>> arg = jnp.arange(4 * len(export_devices))
>>> exp.call(arg)  
Traceback (most recent call last):
ValueError: Exported module f was lowered for 8 devices and is called in a context with 1 devices. This is disallowed because: the module was lowered for more than 1 device.

有一些辅助函数可以在调用站点使用新网格为调用导出的工件分片。

>>> import jax
>>> from jax import export
>>> from jax.sharding import Mesh, NamedSharding
>>> from jax.sharding import PartitionSpec as P

>>> export_devices = jax.local_devices()
>>> export_mesh = Mesh(np.array(export_devices), ("a",))
>>> def f(x):
...   return x.T


>>> exp = export.export(jax.jit(f))(
...    jax.ShapeDtypeStruct((4 * len(export_devices),), dtype=np.int32,
...                         sharding=NamedSharding(export_mesh, P("a"))))

>>> # Prepare the mesh for calling `exp`.
>>> calling_mesh = Mesh(np.array(export_devices[::-1]), ("b",))

>>> # Shard the arg according to what `exp` expects.
>>> arg = jnp.arange(4 * len(export_devices))
>>> sharded_arg = jax.device_put(arg, exp.in_shardings_jax(calling_mesh)[0])
>>> res = exp.call(sharded_arg)

作为一项特殊功能,如果一个函数为 1 个设备导出,并且不包含任何分片注解,那么它可以被具有相同形状但分片到多个设备的参数调用,并且编译器将适当地分片该函数。

```python
>>> import jax
>>> from jax import export
>>> from jax.sharding import Mesh, NamedSharding
>>> from jax.sharding import PartitionSpec as P

>>> def f(x):
...   return jnp.cos(x)

>>> arg = jnp.arange(4)
>>> exp = export.export(jax.jit(f))(arg)
>>> exp.in_avals
(ShapedArray(int32[4]),)

>>> exp.nr_devices
1

>>> # Prepare the mesh for calling `exp`.
>>> calling_mesh = Mesh(jax.local_devices()[:4], ("b",))

>>> # Shard the arg according to what `exp` expects.
>>> sharded_arg = jax.device_put(arg,
...                              NamedSharding(calling_mesh, P("b")))
>>> res = exp.call(sharded_arg)

调用约定版本#

JAX 导出支持随着时间的推移而演进,例如为了支持效应。为了支持兼容性(请参阅兼容性保证),我们为每个 Exported 对象维护一个调用约定版本。截至 2024 年 6 月,所有使用版本 9(最新版本,请参阅所有调用约定版本)导出的函数

>>> from jax import export
>>> exp: export.Exported = export.export(jnp.cos)(1.)
>>> exp.calling_convention_version
10

在任何给定时间,导出 API 可能支持一系列调用约定版本。您可以使用 --jax_export_calling_convention_version 标志或 JAX_EXPORT_CALLING_CONVENTION_VERSION 环境变量来控制使用哪个调用约定版本。

>>> from jax import export
>>> (export.minimum_supported_calling_convention_version, export.maximum_supported_calling_convention_version)
(9, 10)

>>> from jax._src import config
>>> with config.jax_export_calling_convention_version(10):
...  exp = export.export(jnp.cos)(1.)
...  exp.calling_convention_version
10

我们保留删除支持生成或消耗 6 个月以上调用约定版本的权利。

模块调用约定#

Exported.mlir_module 有一个 main 函数,如果模块支持多个平台(len(platforms) > 1),它会接受一个可选的第一个平台索引参数,然后是与有序效应对应的 token 参数,然后是保留的数组参数(对应于 module_kept_var_idxin_avals)。平台索引是一个 i32 或 i64 标量,用于编码当前编译平台在 platforms 序列中的索引。

内部函数使用不同的调用约定:一个可选的平台索引参数,可选的维度变量参数(i32 或 i64 类型的标量张量),然后是可选的 token 参数(当存在有序效应时),然后是常规数组参数。维度参数对应于 args_avals 中出现的维度变量,按其名称的排序顺序。

考虑一个具有类型为 f32[w, 2 * h] 的数组参数的函数的降低,其中 wh 是两个维度变量。假设我们使用多平台降低,并且有一个有序效应。 main 函数将如下所示

      func public main(
            platform_index: i32 {jax.global_constant="_platform_index"},
            token_in: token,
            arg: f32[?, ?]) {
         arg_w = hlo.get_dimension_size(arg, 0)
         dim1 = hlo.get_dimension_size(arg, 1)
         arg_h = hlo.floordiv(dim1, 2)
         call _check_shape_assertions(arg)  # See below
         token = new_token()
         token_out, res = call _wrapped_jax_export_main(platform_index,
                                                        arg_h,
                                                        arg_w,
                                                        token_in,
                                                        arg)
         return token_out, res
      }

实际计算在 _wrapped_jax_export_main 中,它还接受维度变量 hw 的值。

_wrapped_jax_export_main 的签名是

      func private _wrapped_jax_export_main(
          platform_index: i32 {jax.global_constant="_platform_index"},
          arg_h: i32 {jax.global_constant="h"},
          arg_w: i32 {jax.global_constant="w"},
          arg_token: stablehlo.token {jax.token=True},
          arg: f32[?, ?]) -> (stablehlo.token, ...)

在调用约定版本 9 之前,效应的调用约定有所不同:main 函数不接受或返回 token。相反,该函数创建类型为 i1[0] 的虚拟 token 并将其传递给 _wrapped_jax_export_main_wrapped_jax_export_main 接受类型为 i1[0] 的虚拟 token,并将在内部创建实际 token 以传递给内部函数。内部函数使用实际 token(在调用约定版本 9 之前和之后)。

同样,从调用约定版本 9 开始,包含平台索引或维度变量值的函数参数具有一个 jax.config.use_shardy_partitioner 字符串属性,其值是全局常量的名称,即 _platform_index 或维度变量名称。如果全局常量名称未知,则可以为空。某些全局常量计算使用内部函数,例如 floor_divide。这些函数的参数具有jax.global_constant 属性,意味着函数的返回值也是一个全局常量。

请注意,main 包含对 _check_shape_assertions 的调用。JAX 跟踪假设 arg.shape[1] 是偶数,并且 wh 的值都大于等于 1。我们在调用模块时必须检查这些约束。我们使用一个特殊的自定义调用 @shape_assertion,它接受一个布尔值作为第一个操作数,一个可能包含格式说明符 {0}{1}、... 的字符串 error_message 属性,以及一个与格式说明符对应的变长整数标量操作数。

       func private _check_shape_assertions(arg: f32[?, ?]) {
         # Check that w is >= 1
         arg_w = hlo.get_dimension_size(arg, 0)
         custom_call @shape_assertion(arg_w >= 1, arg_w,
            error_message="Dimension variable 'w' must have integer value >= 1. Found {0}")
         # Check that dim1 is even
         dim1 = hlo.get_dimension_size(arg, 1)
         custom_call @shape_assertion(dim1 % 2 == 0, dim1 % 2,
            error_message="Division had remainder {0} when computing the value of 'h')
         # Check that h >= 1
         arg_h = hlo.floordiv(dim1, 2)
         custom_call @shape_assertion(arg_h >= 1, arg_h,
            error_message=""Dimension variable 'h' must have integer value >= 1. Found {0}")

调用约定版本#

我们在此列出调用约定版本号的历史记录

  • 版本 1 使用 MHLO & CHLO 序列化代码,不再支持。

  • 版本 2 支持 StableHLO & CHLO。使用自 2022 年 10 月起。不再支持。

  • 版本 3 支持平台检查和多平台。使用自 2023 年 2 月起。不再支持。

  • 版本 4 支持具有兼容性保证的 StableHLO。这是 JAX 原生序列化启动时的最早版本。在 JAX 中使用自 2023 年 3 月 15 日(cl/516885716)。从 2023 年 3 月 28 日开始,我们停止使用 dim_args_spec(cl/520033493)。此版本支持于 2023 年 10 月 17 日(cl/573858283)被删除。

  • 版本 5 增加了对 call_tf_graph 的支持。这目前用于一些专门的用例。在 JAX 中使用自 2023 年 5 月 3 日(cl/529106145)。

  • 版本 6 增加了对 disabled_checks 属性的支持。此版本要求 platforms 属性非空。自 2023 年 6 月 7 日起受 XlaCallModule 支持,并自 2023 年 6 月 13 日起(JAX 0.4.13)在 JAX 中可用。

  • 版本 7 支持 stablehlo.shape_assertion 操作以及 disabled_checks 中指定的 shape_assertions。请参阅形状多态性下的错误。自 2023 年 7 月 12 日起受 XlaCallModule 支持(cl/547482522),自 2023 年 7 月 20 日起(JAX 0.4.14)在 JAX 序列化中可用,并且自 2023 年 8 月 12 日起(JAX 0.4.15)成为默认值。

  • 版本 8 增加了对 jax.uses_shape_polymorphism 模块属性的支持,并且仅当该属性存在时才启用形状精炼传递。自 2023 年 7 月 21 日起受 XlaCallModule 支持(cl/549973693),自 2023 年 7 月 26 日起(JAX 0.4.14)在 JAX 中可用,并且自 2023 年 10 月 21 日起(JAX 0.4.20)成为默认值。

  • 版本 9 增加了对效应的支持。有关精确的调用约定,请参阅 export.Exported 的文档字符串。在此调用约定版本中,我们还使用 jax.global_constant 属性标记平台索引和维度变量参数。自 2023 年 10 月 27 日起受 XlaCallModule 支持,自 2023 年 10 月 20 日起(JAX 0.4.20)在 JAX 中可用,并且自 2024 年 2 月 1 日起(JAX 0.4.24)成为默认值。截至 2024 年 3 月 27 日,这是唯一支持的版本。

  • 版本 10 将 jax.config.use_shardy_partitioner 的值传播到 XlaCallModule。自 2025 年 5 月 20 日起受 XlaCallModule 支持,并自 2025 年 7 月 14 日起(JAX 0.7.0)成为 JAX 中的默认值。

开发者文档#

调试#

您可以使用 OSS 和 Google 中略有不同的标志来记录导出的模块。在 OSS 中,您可以这样做

# Log from python
python tests/export_test.py JaxExportTest.test_basic -v=3
# Or, log from pytest to /tmp/mylog.txt
pytest tests/export_test.py -k test_basic --log-level=3 --log-file=/tmp/mylog.txt

您将看到类似以下的日志行

I0619 10:54:18.978733 8299482112 _export.py:606] Exported JAX function: fun_name=sin version=9 lowering_platforms=('cpu',) disabled_checks=()
I0619 10:54:18.978767 8299482112 _export.py:607] Define JAX_DUMP_IR_TO to dump the module.

如果将环境变量 JAX_DUMP_IR_TO 设置为目录,导出的(以及 JIT 编译的)HLO 模块将保存在那里。

JAX_DUMP_IR_TO=/tmp/export.dumps pytest tests/export_test.py -k test_basic --log-level=3 --log-file=/tmp/mylog.txt
INFO     absl:_export.py:606 Exported JAX function: fun_name=sin version=9 lowering_platforms=('cpu',) disabled_checks=()
INFO     absl:_export.py:607 The module was dumped to jax_ir0_jit_sin_export.mlir.

您将看到导出的模块(命名为 ..._export.mlir)和 JIT 编译的模块(命名为 ..._compile.mlir)。

$ ls -l /tmp/export.dumps/
total 32
-rw-rw-r--@ 1 necula  wheel  2316 Jun 19 11:04 jax_ir0_jit_sin_export.mlir
-rw-rw-r--@ 1 necula  wheel  2279 Jun 19 11:04 jax_ir1_jit_sin_compile.mlir
-rw-rw-r--@ 1 necula  wheel  3377 Jun 19 11:04 jax_ir2_jit_call_exported_compile.mlir
-rw-rw-r--@ 1 necula  wheel  2333 Jun 19 11:04 jax_ir3_jit_my_fun_export.mlir

设置 JAX_DEBUG_LOG_MODULES=jax._src.export 以启用额外的调试日志记录。

确保前向和后向兼容性#

本节讨论 JAX 开发者应如何确保兼容性保证

一个复杂之处在于,外部用户单独安装 JAX 和 jaxlib 包,并且用户经常最终使用比 JAX 更旧版本的 jaxlib。我们观察到自定义调用位于 jaxlib 中,并且只有 jaxlib 对导出工件的消费者是相关的。为了简化流程,我们向外部用户传达的期望是,兼容性窗口是根据 jaxlib 版本定义的,并且由他们负责确保使用新 jaxlib 进行导出,即使 JAX 能够与旧版本一起工作。

因此,我们只关心 jaxlib 版本。当我们发布 jaxlib 版本时,我们可以启动后向兼容性弃用时钟,即使我们不强制它成为最低允许版本。

假设我们需要添加、删除或更改 JAX 降低规则使用的自定义调用目标 T 的语义。以下是可能的时序(针对更改位于 jaxlib 中的自定义调用目标)。

  1. “D - 1” 天,在更改之前。假设活动内部 JAX 版本为 0.4.31(下一个 JAX 和 jaxlib 版本的版本)。JAX 降低规则使用自定义调用 T

  2. “D” 天,我们添加新的自定义调用目标 T_NEW。我们应该创建一个新的自定义调用目标,并在大约 6 个月后清理旧目标,而不是就地更新 T

    • 请参阅实现以下步骤的示例PR #20997

    • 我们添加自定义调用目标 T_NEW

    • 我们将之前使用 T 的 JAX 降低规则更改为使用 T_NEW,条件如下:

    from jax._src import config
    from jax._src.lib import version as jaxlib_version
    
    def my_lowering_rule(ctx: LoweringRuleContext, ...):
      if ctx.is_forward_compat() or jaxlib_version < (0, 4, 31):
        # this is the old lowering, using target T, while we
        # are in forward compatibility mode for T, or we
        # are in OSS and are using an old jaxlib.
        return hlo.custom_call("T", ...)
      else:
        # This is the new lowering, using target T_NEW, for
        # when we use a jaxlib with version `>= (0, 4, 31)`
        # (or when this is internal usage), and also we are
        # in JIT mode.
        return hlo.custom_call("T_NEW", ...)
    
    • 请注意,前向兼容性模式在 JIT 模式下始终为 false,或者如果用户传递了 --jax_export_ignore_forward_compatibility=true

    • 请注意,此时导出仍不会使用 T_NEW

  3. 这可以在上一步之后的任何时间,以及在下一步之前完成:为 T_NEW 添加后向兼容性测试,并将 T_NEW 添加到 _export.py 中的 _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE 列表。

    • 添加后向兼容性测试的说明位于 export_back_compat_test_util.py 的顶部。

    • 示例在PR #29488中。

    • 请注意,如果您在下一步之前执行此操作,导出仍然不会使用 T_NEW 降低,并且您必须在 self.run_one_test 调用周围添加 with config.export_ignore_forward_compatibility(True):。这可以在您实际执行到步骤 4 时删除。

    • 您可能还需要仅为新版本的 jaxlib 启用测试。

  4. “D + 21” 天(前向兼容性窗口结束;可以晚于 21 天):我们从降低代码中删除 forward_compat_mode,因此现在导出将开始使用新的自定义调用目标 T_NEW,只要我们使用的是新 jaxlib

  5. “RELEASE > D” 天(在 D 之后的第一个 JAX 发布日期,当我们发布版本 0.4.31 时):我们为 6 个月的后向兼容性启动时钟。请注意,这仅在 T 是我们已保证稳定性的自定义调用目标之一时才相关,即列在_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE 中。

    • 如果 RELEASE 在前向兼容性窗口 [D, D + 21] 内,并且如果我们使 RELEASE 成为最低允许的 jaxlib 版本,那么我们可以删除 JIT 分支中的 jaxlib_version < (0, 4, 31) 条件。

  6. “RELEASE + 180” 天(后向兼容性窗口结束,可以晚于 180 天):此时,我们必须将最低 jaxlib 提升,以便降低条件 jaxlib_version < (0, 4, 31) 已经被删除,JAX 降低无法生成到 T 的自定义调用。

    • 我们删除旧自定义调用目标 T 的 C++ 实现。

    • 我们还删除 T 的后向兼容性测试。

从 jax.experimental.export 迁移指南#

2024 年 6 月 18 日(JAX 版本 0.4.30),我们弃用了 jax.experimental.export API,转而使用 jax.export API。发生了一些小的更改。

  • jax.experimental.export.export:

    • 旧函数曾允许任何 Python 可调用对象,或 jax.jit 的结果。现在只接受后者。您必须在调用 export 之前手动应用 jax.jit 到要导出的函数。

    • 旧的 lowering_parameters keyword 参数现在命名为 platforms

  • jax.experimental.export.default_lowering_platform() 现在位于 jax.export.default_export_platform()

  • jax.experimental.export.call 现在是 jax.export.Exported 对象的 a 方法。您应该使用 exp.call 而不是 export.call(exp)

  • jax.experimental.export.serialize 现在是 jax.export.Exported 对象的 a 方法。您应该使用 exp.serialize() 而不是 export.serialize(exp)

  • 配置标志 --jax-serialization-version 已弃用。使用 --jax-export-calling-convention-version

  • jax.experimental.export.minimum_supported_serialization_version 的值现在是 jax.export.minimum_supported_calling_convention_version

  • 的字段jax.export.Exported 已重命名

    • uses_shape_polymorphism 现在是 uses_global_constants

    • mlir_module_serialization_version 现在是 calling_convention_version

    • lowering_platforms 现在是 platforms