Shardy JAX 迁移#

要点总结#

发生了什么?#

Shardy 是一个新的分区系统,由 GDM 模型缩放(PartIR 的作者)和 XLA/CoreML 团队(GSPMD 的作者)共同开发。Shardy 旨在为用户提供更好的可用性和控制,并将逐步取代 GSPMD 和 PartIR。

在 2026 年 3 月完成迁移后,Shardy 将成为 JAX 中唯一的分区器。

在此之前,作为解决任何问题的临时方案,可以禁用 Shardy。如果您遇到任何问题,请提交一个 JAX issue

如何知道 Shardy 是否破坏了我的代码?#

判断 Shardy 是否导致任何问题的最简单方法是禁用 Shardy,看看问题是否消失。请参阅下面的 启用 Shardy 后会出现什么问题?部分。

您可以通过查找 Using Shardy for XLA SPMD propagation in the logs 日志来判断是否启用了 Shardy。

如何暂时禁用 Shardy?#

在 2026 年 3 月之前,可以通过以下方式临时禁用 Shardy:

  • 将 shell 环境变量 JAX_USE_SHARDY_PARTITIONER 设置为类似 false 的值(例如,0);

  • 如果您的代码使用 absl 解析标志,则将布尔标志 jax_use_shardy_partitioner 设置为类似 false 的值;

  • 在您的主文件中或调用 jax.jit 之前的任何位置使用此语句

    import jax
    jax.config.update('jax_use_shardy_partitioner', False)
    

要在启用 Shardy 的情况下调试分区,您可以按如下方式启用 MLIR 转储

--xla_dump_hlo_pass_re=shardy --xla_dump_to=<some_directory>

注意:请尽可能禁用未按预期工作的特定用例,并提交一个包含重现步骤的 bug,以便我们尽快解决它并重新启用 Shardy。

JAX 导出向后兼容性#

默认情况下在 JAX 中启用 Shardy 正在维护 6 个月的向后兼容性保证。这意味着您将能够加载在禁用 Shardy 的情况下导出的模型,至少在 Shardy 为您的模型启用后的 6 个月内。旧的 checkpointed 模型将与 GSPMD 一起运行,只有在重新导出模型时,它才会开始与 Shardy 一起运行。

但是,如果您仍然遇到加载旧检查点的问题,请联系我们或提交一个 bug

注意:不支持启用 Shardy 导出模型,然后禁用 Shardy 加载它,并且将会失败。

我该如何准备 Shardy 在 2026 年 3 月永久启用?#

由于我们将为任何 JAX 导出检查点回退到 GSPMD 6 个月,为了帮助找到任何潜在问题,请使用启用的 Shardy 重新导出您拥有的任何模型。然后您可以看看它是否运行良好,或者是否存在我们需要修复的任何错误。

启用 Shardy 后会出现什么问题?#

性能下降或 OOM(内存溢出)#

虽然 Shardy 改进了现有的分片传播系统(GSPMD 和 PartIR),但由于不同的传播顺序或冲突解决启发式,它有时会输出略有不同的结果。

这不一定意味着 Shardy 做了错误的事情,而是可能程序中没有足够的分片约束,因此传播顺序的微小变化会影响最终结果。它还可以暗示现有的分片约束过度拟合到 GSPMD,并且需要使用 Shardy 进行微调。

因此,启用 Shardy 可能会导致某些模型出现性能下降或 OOM(特别是如果模型已经接近内存容量)。但是,我们已经迁移了 Alphabet 中的许多用例,并且观察到与 GSPMD 相当或更好的性能。

要解决此类问题,用户可以:

  1. 暂时禁用 Shardy 并打开一个包含重现步骤的 bug

  2. 添加额外的分片约束以确保 Shardy 执行所需的操作。

编译失败#

我们已经在许多 JAX 模型上进行了广泛的测试。但是,可能存在我们不支持/处理的某些边缘情况或情况(因为我们不知道我们需要这样做)。

这意味着,虽然很少见,但您可能会遇到编译失败,例如段错误、硬检查、python 值错误等。

在这种情况下,请暂时禁用 Shardy 并打开一个包含重现步骤的 bug

use Shardy 标志的不一致值#

如果 Shardy 在您的代码中的某个位置被禁用,但仍然有路径使用 JAX 标志的默认值,这可能会导致问题。例如,启用 Shardy 导出模型,然后禁用 Shardy 加载它是不受支持的,并且将会失败(另一种方式支持 向后兼容性)。

此类问题的症状可能是 JAX 或 XLA/Shardy 中的错误,或者只是未定义的行为。您可以尝试在 JAX config 中全局禁用 Shardy,看看问题是否消失。

注意:如果需要,请确保始终如一地禁用 Shardy,或者删除对该标志的任何显式修改,以便在整个过程中应用默认值。

使用 JAX jax.experimental.custom_partitioning API 的新方法#

如果您使用此 API,您可能会看到错误

Shardy is used, but sharding propagation callbacks instead of sharding_rule are
provided. Need to provide sharding_rule to migrate to Shardy.

定义 infer_sharding_from_operandspropagate_user_sharding 回调,定义一个 jax.experimental.SdyShardingRule,它指定传播期间维度之间的类似 einsum 的关系。有关如何定义分片规则的更多信息,请参阅 custom_partitioning doc

jax.export 要求所有输入和输出都具有相同的 mesh#

作为 Shardy 迁移的一部分,jax.export 现在要求所有输入/输出分片都位于同一 mesh 上 - 相同的轴名称和大小。