jax.experimental.mesh_utils.create_hybrid_device_mesh#

jax.experimental.mesh_utils.create_hybrid_device_mesh(mesh_shape, dcn_mesh_shape, devices=None, *, process_is_granule=False, should_sort_granules_by_key=True, allow_split_physical_axes=False)[source]#

为混合(例如 ICI 和 DCN)并行创建设备网格。

参数:
  • mesh_shape (Sequence[int]) – 用于更快/内部网络的逻辑网格的形状,按网络强度递增排序,例如 [replica, data, mdl],其中 mdl 具有最多的网络通信需求。

  • dcn_mesh_shape (Sequence[int]) – 用于较慢/外部网络的逻辑网格的形状,顺序与 mesh_shape 相同。

  • devices (Sequence[Any] | None | None) – (可选)要为其构建网格的设备。默认为 jax.devices()。

  • process_is_granule (bool) – 如果为 True,此函数将把进程视为较慢/外部网络的单元。否则,它将查找设备上的 slice_index 属性并将切片用作单元。启用此功能旨在作为不设置 slice_index 的平台的后备方案。

  • should_sort_granules_by_key (bool) – 设备粒度是否应按粒度键(切片或进程索引,取决于 process_is_granule)排序。

  • allow_split_physical_axes (bool) – 如果为 True,我们将在必要时拆分物理轴以生成所需的设备网格。

Raises:

ValueError – 如果 devices 所属的切片数量不等于 dcn_mesh_shape 的乘积,或者任何单个切片所属的设备数量不等于 mesh_shape 的乘积。

Returns:

一个 np.ndarray 类型的 JAX 设备,其形状为 mesh_shape * dcn_mesh_shape,可以馈送到 jax.sharding.Mesh 以实现混合并行。

Return type:

np.ndarray