jax.experimental.mesh_utils.create_device_mesh#

jax.experimental.mesh_utils.create_device_mesh(mesh_shape, devices=None, *, contiguous_submeshes=False, allow_split_physical_axes=False)[源代码]#

为 jax.sharding.Mesh 创建高性能设备网格。

参数:
  • mesh_shape (Sequence[int]) – 逻辑网格的形状,按网络密集程度递增排序,例如 [replica, data, mdl],其中 mdl 具有最多的网络通信需求。

  • devices (Sequence[Any] | None) – (可选) 用于构建网格的设备。 默认为 jax.devices()。

  • contiguous_submeshes (bool) – 如果为 True,此函数将尝试创建一个网格,其中每个进程的本地设备形成一个连续的子网格。 如果此函数无法生成合适的网格,将引发 ValueError。 此设置在引入 jax.Array 之前有时是必要的,以确保非锯齿状的本地数组; 如果使用 jax.Arrays,最好将其设置为 False。

  • allow_split_physical_axes (bool) – 如果为 True,如有必要,我们将拆分物理轴以生成所需的设备网格。

引发:

ValueError – 如果设备数量不等于 mesh_shape 的乘积。

返回:

一个 JAX 设备的 np.ndarray,其形状为 mesh_shape,可以馈送到 jax.sharding.Mesh 中,从而获得良好的集体性能。

返回类型:

np.ndarray