迁移到新的 jax.pmap
#
发生了什么?#
从 JAX 0.8.0 开始,jax.pmap
的默认实现将基于 jax.jit
和 jax.shard_map
。新实现**不**是原始实现的完美替代品,本文档为遇到麻烦的用户提供了指导。
此更改使 jax.pmap
与 JAX 分片(shardings)良好集成,并简化了实现。
救命!立即修复!#
重要提示:此选项不是永久性修复。在 2026 年 1 月 15 日之前,可以通过以下任一方式暂时使用旧版本的 jax.pmap
:
将 shell 环境变量
JAX_PMAP_SHMAP_MERGE
设置为类似 false 的值(例如,0);如果你的代码使用
absl-py
解析标志,则将布尔标志--jax_pmap_shmap_merge
设置为类似 false 的值。在主文件或调用
jax.pmap
之前的任何位置使用此语句import jax jax.config.update("jax_pmap_shmap_merge", False)
注意:请提交一个带有可复现示例的 bug 报告,并 tag @danielsuo,以便我们能在新的 jax.pmap
下尽快解决问题。
如何修复我的代码以适应新的 jax.pmap
?#
以下是我们正在收集的常见错误和修复建议。这比设置 jax_pmap_shmap_merge=False
要麻烦,但更具长期解决方案的性质。但是,我们仍然建议将新代码或重要代码迁移到 jax.shard_map
。
ValueError: Received incompatible devices ...
#
示例#
ValueError: Received incompatible devices for jitted computation. Got argument a
of allclose with shape float32[100] and device ids [0] on platform TPU and
argument b of allclose with shape float32[100] and device ids [0, 1] on platform
TPU
这可能如何发生#
正如
jax.jit
和jax.shard_map
的行为那样,jax.pmap
不再静默地重分片(reshard)输入。因此,如果输入的重分片方式与你的jax.pmap
预期不符,将会引发错误。
如何修复#
使用合适的
jax.NamedSharding
调用jax.device_put
来显式重分片任何有问题的输入。或者,使用合适的
in_axes
、backend
和/或devices
关键字重新定义你的jax.pmap
,以确保jax.pmap
的 mesh 和预期的输入分片与你的操作数匹配。
ValueError: The context mesh ... should match the mesh passed to shard_map
#
示例#
ValueError: The context mesh AbstractMesh('x': 1, axis_types=(Manual,),
device_kind=TPU v3, num_cores=1) should match the mesh passed to shard_map
Mesh('y': 4, axis_types=(Auto,))
这可能如何发生#
当嵌套多个
jax.pmap
时,可能会出现此错误。由于jax.pmap
API 对内部的jax.pmap
调用一无所知,因此也不知道内部 mesh 轴,所以这种行为不再受支持。
如何修复#
迁移到
jax.shard_map
。单个jax.shard_map
可以沿着输入的多个轴并行化,并且每个轴都分配给设备 mesh 的相关轴。或者,你可以嵌套
jax.shard_map
调用,或者使用jax.smap
,它使得一次一个 mesh 轴地进入手动并行模式变得更容易。这种方法极大地简化了嵌套并行。
性能影响#
主机本地数组到全局数组的往返转换#
在多进程 JAX 程序中(即 jax.process_count() > 1
),数组可能不是完全可寻址的(即“主机本地”),因此新的 jax.pmap
会将主机本地数组重分片为全局数组,然后再传递给 jax.jit
的 jax.shard_map
,并在返回给用户代码时再次转换回主机本地数组。
这种往返转换无法避免,因此如果性能损失过大,我们建议将代码迁移到 jax.shard_map
。
int
数组索引#
使用 int 为分片数组建立索引(例如 arr[0]
)现在可能会执行秩(rank)归约计算。根据你的用例,可能存在变通方法。
在典型的训练循环中,我们可能会使用一个
jax.pmap
ed 的更新函数来操作/携带训练状态,并从第一个jax.pmap
ed 设备获取结果指标用于日志记录。在这种情况下,可以使用None
作为传递给jax.pmap
的相关in_axes
和out_axes
。这允许jax.pmap
处理复制,并返回一个形状合适的、看起来像是来自单个设备的结果,用于记录指标等。更一般地说,你可以通过
arr[0:1]
或arr.addressable_shards[0].data
来获取数据的第一个分片,而无需重塑。请注意,这将有一个前导的(1,)
维度,你的代码需要进行处理。
迁移到 jax.shard_map
#
在许多情况下,用户可以通过在输入上调用 jax.make_array_from_process_local_data
并将其传递给 jax.jit(jax.shard_map)
来从 jax.pmap
迁移到 jax.jit(jax.shard_map)
。虽然从全局数组转换的性能损失仍然存在,但它不再像 jax.pmap
的 jax.shard_map
实现那样在分派路径中,并且通常可以与计算重叠,或者可以不频繁调用(即,在训练循环之前,以及偶尔获取指标时)。