jax.distributed.initialize#

jax.distributed.initialize(coordinator_address=None, num_processes=None, process_id=None, local_device_ids=None, cluster_detection_method=None, initialization_timeout=300, heartbeat_timeout_seconds=100, coordinator_bind_address=None, slice_index=None)[源代码]#

初始化 JAX 分布式系统。

调用 initialize() 会为 JAX 在多主机 GPU 和 Cloud TPU 上的执行做准备。initialize() 必须在执行任何 JAX 计算之前调用。

JAX 分布式系统有以下作用:

  • 它允许 JAX 进程相互发现并共享拓扑信息,

  • 它执行健康检查,确保如果任何进程死亡,所有进程都会关闭,并且

  • 它用于分布式检查点。

如果您正在使用 TPU、Slurm 或 Open MPI,所有参数都是可选的:如果省略,它们将自动选择。

可以使用 cluster_detection_method 选择检测这些分布式参数的特定方法。您可以将任何自动 spec_detect_methods 传递给此参数,但在 TPU、Slurm 或 Open MPI 的情况下并非必要。对于其他 MPI 安装,如果您已安装可用的 mpi4py,则可以传递 cluster_detection_method="mpi4py" 来引导所需的参数。

否则,您必须向 initialize() 提供 coordinator_addressnum_processesprocess_idlocal_device_ids 参数。当提供所有这四个参数时,将跳过集群环境的自动检测。

请注意:在某些系统上,特别是那些仅通过代理变量(如 HTTP_PROXY、HTTPS_PROXY 等)访问外部网络的 HPC 集群上,调用 initialize() 可能会超时。您可能需要在应用程序启动之前取消设置这些变量。

参数:
  • coordinator_address (str | None) – 进程 0 的 IP 地址和该进程应启动协调器服务的端口。只要该端口在协调器上可用且所有进程都同意使用该端口,端口的选择就无关紧要。仅在支持的环境下才可以为 None,在这种情况下它将自动选择。请注意,像 localhost127.0.0.1 这样的特殊地址通常意味着程序将绑定到本地接口,不适合在多主机环境中运行。

  • num_processes (int | None) – 进程数量。仅在支持的环境下才可以为 None,在这种情况下它将自动选择。

  • process_id (int | None) – 当前进程的 ID 号。集群中的 process_id 值必须是连续的范围:01、…、num_processes - 1。仅在支持的环境下才可以为 None;如果为 None,它将自动选择。

  • local_device_ids (int | Sequence[int] | None) – 将当前进程可见的设备限制为 local_device_ids。如果为 None,则默认所有本地设备都对进程可见,除非进程是通过 Slurm 和 Open MPI 在 GPU 上启动的。在这种情况下,它将默认为每个进程一个设备。

  • cluster_detection_method (str | None) – 用于尝试自动检测分布式运行配置的可选字符串。请注意,“mpi4py” 方法要求您的环境中安装了可用的 mpi4py,并使用与 MPI 兼容的作业启动器(例如 mpiexecmpirun)启动应用程序。旧版自动检测选项“ompi”(OMPI)和“slurm”(Slurm)仍然启用。“deactivate”禁用自动集群检测。

  • initialization_timeout (int) – 连接重试的时间段(秒)。如果初始化时间超过指定的超时,初始化将报错。默认为 300 秒,即 5 分钟。

  • heartbeat_timeout_seconds (int) – 进程在多长时间(秒)未成功发送心跳后被视为死亡。默认为 100 秒。

  • coordinator_bind_address (str | None) – 进程 0 上的协调器服务应绑定的地址和端口。如果未指定,默认是绑定到与 coordinator_address 相同的端口上的所有可用地址。在每个节点有多个网络接口的系统上,仅让协调器服务在一个地址/接口上监听可能不足。

  • slice_index (int | None) – 分配给此进程本地设备的切片索引。如果任何进程设置了 slice_index,则所有进程都必须这样做。如果为 None,则切片索引将自动选择。

引发:

RuntimeError – 如果 initialize() 被多次调用,或者在后端已经初始化之后调用。

示例

假设有两个 GPU 进程,进程 0 是指定的协调器,地址为 10.0.0.1:1234。要初始化 GPU 集群,请在执行其他任何操作之前运行以下命令。

在进程 0 上

>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=0)  

在进程 1 上

>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=1)