jax.experimental.mesh_utils 模块#

用于构建设备网格的实用工具。

API#

create_device_mesh(mesh_shape[, devices, ...])

为 jax.sharding.Mesh 创建高性能设备网格。

create_hybrid_device_mesh(mesh_shape, ...[, ...])

为混合(例如,ICI 和 DCN)并行创建设备网格。