多主机和多进程环境#

简介#

本指南解释了如何在 GPU 集群和 Cloud TPU Pod 等环境中,跨多个 CPU 主机或 JAX 进程分布加速器时使用 JAX。我们将这些环境称为“多进程”环境。

本指南特别关注如何在多进程设置中使用集体通信操作(例如 jax.lax.psum() ),尽管其他通信方法也可能根据您的用例而有用(例如 RPC,mpi4jax)。如果您还不熟悉 JAX 的集体操作,我们建议从并行编程入门部分开始。 JAX 中多进程环境的一个重要要求是加速器之间的直接通信链路,例如 Cloud TPU 的高速互连或 GPU 的 NCCL。这些链接允许集体操作在高效率下跨多个进程的加速器运行。

多进程编程模型#

核心概念

  • 每个主机必须运行至少一个 JAX 进程。

  • 您应该使用 jax.distributed.initialize() 初始化集群。

  • 每个进程都有一组不同的本地设备可以寻址。全局设备是所有进程中所有设备的集合。

  • 使用标准的 JAX 并行 API,如 jit()(参见 并行编程入门 教程)和 shard_map()jax.jit 仅接受全局形状的数组。shard_map 允许您降低到每个设备的形状。

  • 确保所有进程以相同的顺序运行相同的并行计算。

  • 确保所有进程具有相同数量的本地设备。

  • 确保所有设备都相同(例如,全部为 V100,或全部为 H100)。

启动 JAX 进程#

与其他分布式系统(单个控制器节点管理多个工作节点)不同,JAX 使用“多控制器”编程模型,其中每个 JAX Python 进程独立运行,有时称为 单程序多数据 (SPMD) 模型。通常,相同的 JAX Python 程序在每个进程中运行,每个进程的执行之间只有细微的差异(例如,不同的进程将加载不同的输入数据)。此外,您必须手动在每个主机上运行您的 JAX 程序! JAX 不会自动从单个程序调用启动多个进程。

(需要多个进程是本指南没有作为 notebook 提供的的原因——我们目前没有一个好的方法从单个 notebook 管理多个 Python 进程。)

初始化集群#

要初始化集群,您应该在每个进程的开始处调用 jax.distributed.initialize()jax.distributed.initialize() 必须在程序早期调用,在执行任何 JAX 计算之前。

API jax.distributed.initialize() 接受几个参数,即

  • coordinator_address:集群中进程 0 的 IP 地址,以及该进程上可用的端口。进程 0 将启动一个通过该 IP 地址和端口公开的 JAX 服务,集群中的其他进程将连接到该服务。

  • coordinator_bind_address:集群中进程 0 上的 JAX 服务将绑定到的 IP 地址和端口。默认情况下,它将使用与 coordinator_address 相同的端口绑定到所有可用的接口。

  • num_processes:集群中的进程数

  • process_id:此进程的 ID 号,范围为 [0 .. num_processes)

  • local_device_ids:将当前进程的可见设备限制为 local_device_ids

例如,在 GPU 上,典型的用法是

import jax

jax.distributed.initialize(coordinator_address="192.168.0.1:1234",
                           num_processes=2,
                           process_id=0)

在 Cloud TPU、Slurm 和 Open MPI 环境中,您可以直接调用不带参数的 jax.distributed.initialize()。参数的默认值将自动选择。当在带有 Slurm 和 Open MPI 的 GPU 上运行时,假定每个 GPU 启动一个进程,即每个进程将仅分配一个可见的本地设备。否则,假定每个主机启动一个进程,即每个进程将被分配所有本地设备。仅当 JAX 进程通过 mpirun/mpiexec 启动时,才使用 Open MPI 自动初始化。

import jax

jax.distributed.initialize()

在目前的 TPU 上,调用 jax.distributed.initialize() 是可选的,但建议使用,因为它启用了额外的检查点和健康检查功能。

本地设备与全局设备#

在我们开始从您的程序运行多进程计算之前,重要的是要理解本地设备和全局设备之间的区别。

进程的本地设备是它可以直接寻址并在其上启动计算的设备。 例如,在 GPU 集群上,每个主机只能在其直接连接的 GPU 上启动计算。在 Cloud TPU Pod 上,每个主机只能在其直接连接的 8 个 TPU 核心上启动计算(有关更多详细信息,请参阅 Cloud TPU 系统架构 文档)。您可以通过 jax.local_devices() 查看进程的本地设备。

全局设备是跨所有进程的设备。 计算可以跨进程的设备,并通过设备之间的直接通信链路执行集体操作,只要每个进程在其本地设备上启动计算即可。您可以通过 jax.devices() 查看所有可用的全局设备。进程的本地设备始终是全局设备的子集。

运行多进程计算#

那么,您实际上如何运行涉及跨进程通信的计算? 使用与单进程中相同的并行评估 API!

例如,shard_map() 可用于跨多个进程运行并行计算。(如果您还不熟悉如何使用 shard_map 在单个进程中的多个设备上运行,请查看并行编程入门教程。)从概念上讲,这可以被认为是跨主机分片的单个数组上运行 pmap,其中每个主机“看到”仅其本地输入和输出分片。

这是一个多进程 pmap 运行示例

# The following is run in parallel on each host on a GPU cluster or TPU pod slice.
>>> import jax
>>> jax.distributed.initialize()  # On GPU, see above for the necessary arguments.
>>> jax.device_count()  # total number of accelerator devices in the cluster
32
>>> jax.local_device_count()  # number of accelerator devices attached to this host
8
# The psum is performed over all mapped devices across the pod slice
>>> xs = jax.numpy.ones(jax.local_device_count())
>>> jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
ShardedDeviceArray([32., 32., 32., 32., 32., 32., 32., 32.], dtype=float32)

所有进程以相同的顺序运行相同的跨进程计算非常重要。 在每个进程中运行相同的 JAX Python 程序通常就足够了。以下是一些需要注意的常见陷阱,这些陷阱可能会导致尽管运行相同的程序,但计算顺序不同

  • 进程将不同形状的输入传递给相同的并行函数可能会导致挂起或错误的返回值。只要不同形状的输入导致跨进程的每个设备数据分片形状相同,不同形状的输入就是安全的;例如,为了在每个进程不同数量的本地设备上运行,传入不同的前导批大小是可以的,但是让每个进程将其批次填充到不同的最大示例长度则不行。

  • “最后一批”问题,其中并行函数在(训练)循环中被调用,并且一个或多个进程比其余进程更早退出循环。这将导致其余进程挂起,等待已完成的进程开始计算。

  • 基于集合的非确定性排序的条件可能会导致代码进程挂起。例如,即使具有相同的插入顺序,在当前 Python 版本上迭代 set 或在 Python 3.7 之前 迭代 dict 也可能导致不同进程上的排序不同。