jax.make_mesh#
- jax.make_mesh(axis_shapes, axis_names, *, devices=None, axis_types=None)[源代码]#
创建具有指定形状和轴名称的高效网格。
此函数尝试自动计算从一组逻辑轴到物理网格的良好映射。例如,在具有 8 个设备的 TPU v3 上
>>> mesh = jax.make_mesh((8,), ('x')) >>> [d.id for d in mesh.devices.flat] [0, 1, 2, 3, 6, 7, 4, 5]
上面的排序考虑了 TPU v3 的物理拓扑结构。它将设备排序成环形,这在 TPU v3 上产生高效的 all-reduce 操作。
现在,让我们看另一个具有 16 个 TPU v3 设备的示例
>>> mesh = jax.make_mesh((2, 8), ('x', 'y')) >>> [d.id for d in mesh.devices.flat] [0, 1, 2, 3, 6, 7, 4, 5, 8, 9, 10, 11, 14, 15, 12, 13] >>> mesh = jax.make_mesh((4, 4), ('x', 'y')) >>> [d.id for d in mesh.devices.flat] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
如您所见,逻辑轴 (axis_shapes) 影响设备的排序。
如果您想使用它提供的额外参数,如 contiguous_submeshes 和 allow_split_physical_axes,您可以使用 jax.experimental.mesh_utils.create_device_mesh。
- 参数:
- 返回:
一个 jax.sharding.Mesh 对象。
- 返回类型:
mesh_lib.Mesh