jax.lax.with_sharding_constraint#

jax.lax.with_sharding_constraint(x, shardings)[source]#

在 jitted 计算中约束 Array 分片的机制

这是对 GSPMD 分片器的一个严格约束,而非提示。有关如何使用此函数的示例,请参阅分布式数组和自动并行化

在 jitted 计算内部,with_sharding_constraint 可以约束中间值为不均匀分片。然而,如果 jitted 计算输出这样的不均匀分片值,它将以完全复制的形式出现,无论给定的分片注解如何。

参数:
  • x – jax.Array 的 PyTree,其分片将被约束

  • shardings – 分片规范的 PyTree。有效值与jax.experimental.pjit()in_shardings参数相同。

返回:

具有指定分片约束的 jax.Array 的 PyTree。

返回类型:

x_with_shardings