jax.make_mesh#

jax.make_mesh(axis_shapes, axis_names, axis_types=None, *, devices=None)[源码]#

创建一个具有指定形状和轴名称的高效 Mesh。

此函数尝试自动计算从一组逻辑轴到物理 Mesh 的良好映射。例如,在具有 8 个设备的 TPU v3 上

>>> mesh = jax.make_mesh((8,), ('x'))  
>>> [d.id for d in mesh.devices.flat]  
[0, 1, 2, 3, 6, 7, 4, 5]

上述排序考虑了 TPU v3 的物理拓扑。它将设备排序成一个环,这会在 TPU v3 上产生高效的 all-reduces。

现在,让我们来看一个具有 16 个 TPU v3 设备的另一个示例

>>> mesh = jax.make_mesh((2, 8), ('x', 'y'))  
>>> [d.id for d in mesh.devices.flat]  
[0, 1, 2, 3, 6, 7, 4, 5, 8, 9, 10, 11, 14, 15, 12, 13]
>>> mesh = jax.make_mesh((4, 4), ('x', 'y'))  
>>> [d.id for d in mesh.devices.flat]  
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

正如您所见,逻辑轴(axis_shapes)会影响设备的排序。

如果您想使用 jax.experimental.mesh_utils.create_device_mesh 提供的额外参数,例如 contiguous_submeshesallow_split_physical_axes,可以使用它。

参数:
  • axis_shapes (Sequence[int]) – Mesh 的形状。例如,axis_shape=(4, 2)

  • axis_names (Sequence[str]) – Mesh 轴的名称。例如,axis_names=(‘x’, ‘y’)

  • axis_types (tuple[mesh_lib.AxisType, ...] | None) – 与 axis_names 对应的 jax.sharding.AxisType 条目的可选元组。有关更多信息,请参阅 显式分片

  • devices (Sequence[xc.Device] | None) – 可选的仅关键字参数,允许您指定要创建 Mesh 的设备。

返回:

一个 jax.sharding.Mesh 对象。

返回类型:

mesh_lib.Mesh