迁移到新的 jax.pmap#

发生了什么?#

从 JAX 0.8.0 开始,jax.pmap 的默认实现将基于 jax.jitjax.shard_map。新实现**不**是原始实现的完美替代品,本文档为遇到麻烦的用户提供了指导。

此更改使 jax.pmap 与 JAX 分片(shardings)良好集成,并简化了实现。

救命!立即修复!#

重要提示:此选项不是永久性修复。在 2026 年 1 月 15 日之前,可以通过以下任一方式暂时使用旧版本的 jax.pmap

  • 将 shell 环境变量 JAX_PMAP_SHMAP_MERGE 设置为类似 false 的值(例如,0);

  • 如果你的代码使用 absl-py 解析标志,则将布尔标志 --jax_pmap_shmap_merge 设置为类似 false 的值。

  • 在主文件或调用 jax.pmap 之前的任何位置使用此语句

    import jax
    jax.config.update("jax_pmap_shmap_merge", False)
    

注意:请提交一个带有可复现示例的 bug 报告,并 tag @danielsuo,以便我们能在新的 jax.pmap 下尽快解决问题。

如何修复我的代码以适应新的 jax.pmap#

以下是我们正在收集的常见错误和修复建议。这比设置 jax_pmap_shmap_merge=False 要麻烦,但更具长期解决方案的性质。但是,我们仍然建议将新代码或重要代码迁移到 jax.shard_map

ValueError: Received incompatible devices ...#

示例#

ValueError: Received incompatible devices for jitted computation. Got argument a
of allclose with shape float32[100] and device ids [0] on platform TPU and
argument b of allclose with shape float32[100] and device ids [0, 1] on platform
TPU

这可能如何发生#

  • 正如 jax.jitjax.shard_map 的行为那样,jax.pmap 不再静默地重分片(reshard)输入。因此,如果输入的重分片方式与你的 jax.pmap 预期不符,将会引发错误。

如何修复#

  • 使用合适的 jax.NamedSharding 调用 jax.device_put 来显式重分片任何有问题的输入。

  • 或者,使用合适的 in_axesbackend 和/或 devices 关键字重新定义你的 jax.pmap,以确保 jax.pmap 的 mesh 和预期的输入分片与你的操作数匹配。

ValueError: The context mesh ... should match the mesh passed to shard_map#

示例#

ValueError: The context mesh AbstractMesh('x': 1, axis_types=(Manual,),
device_kind=TPU v3, num_cores=1) should match the mesh passed to shard_map
Mesh('y': 4, axis_types=(Auto,))

这可能如何发生#

  • 当嵌套多个 jax.pmap 时,可能会出现此错误。由于 jax.pmap API 对内部的 jax.pmap 调用一无所知,因此也不知道内部 mesh 轴,所以这种行为不再受支持。

如何修复#

  • 迁移到 jax.shard_map。单个 jax.shard_map 可以沿着输入的多个轴并行化,并且每个轴都分配给设备 mesh 的相关轴。

  • 或者,你可以嵌套 jax.shard_map 调用,或者使用 jax.smap,它使得一次一个 mesh 轴地进入手动并行模式变得更容易。这种方法极大地简化了嵌套并行。

性能影响#

主机本地数组到全局数组的往返转换#

在多进程 JAX 程序中(即 jax.process_count() > 1),数组可能不是完全可寻址的(即“主机本地”),因此新的 jax.pmap 会将主机本地数组重分片为全局数组,然后再传递给 jax.jitjax.shard_map,并在返回给用户代码时再次转换回主机本地数组。

这种往返转换无法避免,因此如果性能损失过大,我们建议将代码迁移到 jax.shard_map

int 数组索引#

使用 int 为分片数组建立索引(例如 arr[0])现在可能会执行秩(rank)归约计算。根据你的用例,可能存在变通方法。

  1. 在典型的训练循环中,我们可能会使用一个 jax.pmaped 的更新函数来操作/携带训练状态,并从第一个 jax.pmaped 设备获取结果指标用于日志记录。在这种情况下,可以使用 None 作为传递给 jax.pmap 的相关 in_axesout_axes。这允许 jax.pmap 处理复制,并返回一个形状合适的、看起来像是来自单个设备的结果,用于记录指标等。

  2. 更一般地说,你可以通过 arr[0:1]arr.addressable_shards[0].data 来获取数据的第一个分片,而无需重塑。请注意,这将有一个前导的 (1,) 维度,你的代码需要进行处理。

迁移到 jax.shard_map#

在许多情况下,用户可以通过在输入上调用 jax.make_array_from_process_local_data 并将其传递给 jax.jit(jax.shard_map) 来从 jax.pmap 迁移到 jax.jit(jax.shard_map)。虽然从全局数组转换的性能损失仍然存在,但它不再像 jax.pmapjax.shard_map 实现那样在分派路径中,并且通常可以与计算重叠,或者可以不频繁调用(即,在训练循环之前,以及偶尔获取指标时)。