jax.distributed.initialize#
- jax.distributed.initialize(coordinator_address=None, num_processes=None, process_id=None, local_device_ids=None, cluster_detection_method=None, initialization_timeout=300, heartbeat_timeout_seconds=100, coordinator_bind_address=None, slice_index=None, partition_index=None)[源代码]#
初始化 JAX 分布式系统。
调用
initialize()为 JAX 在多主机 GPU 和 Cloud TPU 上执行做准备。initialize()必须在执行任何 JAX 计算之前调用。JAX 分布式系统服务于多种角色:
它允许 JAX 进程互相发现并共享拓扑信息,
它执行健康检查,确保如果任何一个进程崩溃,所有进程都会关闭,并且
它用于分布式 checkpointing。
如果您正在使用 TPU、Slurm 或 Open MPI,所有参数都是可选的:如果省略,它们将被自动选择。
可以使用
cluster_detection_method参数选择一种特定的方法来检测这些分布式参数。您可以将任何自动化的spec_detect_methods传递给此参数,尽管在 TPU、Slurm 或 Open MPI 的情况下并非必需。对于其他 MPI 安装,如果您安装了功能正常的mpi4py,您可以传递cluster_detection_method="mpi4py"来引导所需的参数。否则,您必须向
initialize()提供coordinator_address、num_processes、process_id和local_device_ids参数。当所有四个参数都提供时,将跳过集群环境的自动检测。请注意:在某些系统上,尤其是只有通过 HTTP_PROXY、HTTPS_PROXY 等代理变量才能访问外部网络的高性能计算集群上,调用
initialize()可能会超时。您可能需要在应用程序启动前取消设置这些变量。- 参数:
coordinator_address (str | None) – 进程 0 的 IP 地址以及该进程应在其上启动协调器服务的端口。端口的选择并不重要,只要该端口在协调器上可用并且所有进程都就该端口达成一致即可。仅在受支持的环境中可以为
None,在这种情况下它将被自动选择。请注意,像localhost或127.0.0.1这样的特殊地址通常意味着程序将绑定到本地接口,在多主机环境中运行不适合。num_processes (int | None) – 进程数。仅在受支持的环境中可以为
None,在这种情况下它将被自动选择。process_id (int | None) – 当前进程的 ID 号。集群中的
process_id值必须是一个紧凑的范围0、1、…、num_processes - 1。仅在受支持的环境中可以为None;如果为None,它将被自动选择。local_device_ids (int | Sequence[int] | None) – 将当前进程可见的设备限制为
local_device_ids。如果为None,则默认为进程可见的所有本地设备,除非进程通过 Slurm 和 Open MPI 在 GPU 上启动。在这种情况下,它将默认为每个进程使用一个设备。cluster_detection_method (str | None) – 一个可选字符串,用于尝试自动检测分布式运行的配置。请注意,“mpi4py”方法要求您的环境中安装了可用的
mpi4py,并使用兼容 MPI 的作业启动器(如mpiexec或mpirun)启动应用程序。旧的自动检测选项“ompi”(OMPI)和“slurm”(Slurm)仍然启用。“deactivate”会绕过自动集群检测。initialization_timeout (int) – 连接将重试的时间段(以秒为单位)。如果初始化花费的时间超过指定的超时时间,初始化将出错。默认为 300 秒,即 5 分钟。
heartbeat_timeout_seconds (int) – 如果一个进程在指定时间内没有成功发送任何心跳信号,则认为该进程已死的时间(以秒为单位)。默认为 100 秒。
coordinator_bind_address (str | None) – 进程 0 上的协调器服务应绑定到的地址和端口。如果未指定此项,则默认为绑定到
coordinator_address相同端口上的所有可用地址。在具有多个网络接口的系统上,仅让协调器服务监听一个地址/接口可能不足够。slice_index (int | None) – 已弃用:请改用
partition_index。partition_index (int | None) – 分配给该进程本地设备的分区索引。如果任何进程设置了
partition_index,则所有进程都必须设置。如果为None,则分区索引将被自动选择。
- 引发:
RuntimeError – 如果
initialize()被调用超过一次,或者在后端已初始化后被调用。
示例
假设有两个 GPU 进程,进程 0 是指定的协调器,地址为
10.0.0.1:1234。要初始化 GPU 集群,请在执行任何其他操作之前运行以下命令。在进程 0 上
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=0)
在进程 1 上
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=1)