多控制器 JAX (又名多进程/多主机 JAX) 简介#

通过阅读本教程,您将学会如何将 JAX 计算扩展到单个主机机器无法容纳的更多设备,例如在 GPU 集群、Cloud TPU pod 或多个纯 CPU 机器上运行时。

主要思想

  • 运行多个 Python 进程,我们有时称之为“控制器”。我们可以在每台主机机器上运行一个(或多个)进程。

  • 使用 jax.distributed.initialize() 初始化集群。.

  • 一个 jax.Array 可以跨越所有进程,如果每个进程对其应用相同的 JAX 函数,则就像编程一个大型设备一样。

  • 使用与单控制器 JAX 相同的 统一分片机制 来控制数据的分布和计算的并行化。XLA 会自动利用可用时的高速网络连接,例如 TPU ICI 或主机之间的 NVLink,否则会使用可用的主机网络(例如以太网、InfiniBand)。

  • 所有进程(通常)运行相同的 Python 脚本。您编写的代码几乎与单进程版本完全相同——只需运行多个实例,JAX 会处理其余的事情。换句话说,除了数组创建之外,您可以像有一个连接了所有设备的巨大机器一样编写您的 JAX 代码。

本教程假定您已阅读 分布式数组和自动并行化,该文档介绍的是单控制器 JAX。

Illustration of a multi-host TPU pod. Each host in the pod is attached via PCI to a board of four TPU chips. The TPUs chips themselves are connected via high-speed inter-chip interconnects.

多主机 TPU pod 的说明。pod 中的每个主机(绿色)通过 PCI 连接到一块四芯片 TPU 板(蓝色)。TPU 芯片本身通过高速芯片间互连 (ICI) 连接。JAX Python 代码运行在每个主机上,例如通过 ssh。每个主机上的 JAX 进程相互之间是已知的,这使您可以协调整个 pod 芯片的计算。对于支持 JAX 的 GPU、CPU 和其他平台,原理相同!#

玩具示例#

在定义术语并详细介绍之前,这里有一个玩具示例:创建一个跨进程的 jax.Array 并对其应用 jax.numpy 函数。

# call this file toy.py, to be run in each process simultaneously

import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec as P
import numpy as np

# in this example, get multi-process parameters from sys.argv
import sys
proc_id = int(sys.argv[1])
num_procs = int(sys.argv[2])

# initialize the distributed system
jax.distributed.initialize('localhost:10000', num_procs, proc_id)

# this example assumes 8 devices total
assert jax.device_count() == 8

# make a 2D mesh that refers to devices from all processes
mesh = jax.make_mesh((4, 2), ('i', 'j'))

# create some toy data
global_data = np.arange(32).reshape((4, 8))

# make a process- and device-spanning array from our toy data
sharding = NamedSharding(mesh, P('i', 'j'))
global_array = jax.device_put(global_data, sharding)
assert global_array.shape == global_data.shape

# each process has different shards of the global array
for shard in global_array.addressable_shards:
  print(f"device {shard.device} has local data {shard.data}")

# apply a simple computation, automatically partitioned
global_result = jnp.sum(jnp.sin(global_array))
print(f'process={proc_id} got result: {global_result}')

在此,mesh 包含来自所有进程的设备。我们使用它来创建 global_array,逻辑上它是一个单一的共享数组,分布式存储在所有进程的设备上。

每个进程必须以相同的顺序对 global_array 应用相同的操作。XLA 会自动分区这些计算,例如插入通信集合来计算整个数组的 jnp.sum。我们可以打印最终结果,因为它的值在进程之间是复制的。

我们可以在本地 CPU 上运行此代码,例如使用 4 个进程,每个进程 2 个 CPU 设备。

export JAX_NUM_CPU_DEVICES=2
num_processes=4

range=$(seq 0 $(($num_processes - 1)))

for i in $range; do
  python toy.py $i $num_processes > /tmp/toy_$i.out &
done

wait

for i in $range; do
  echo "=================== process $i output ==================="
  cat /tmp/toy_$i.out
  echo
done

输出

=================== process 0 output ===================
device TFRT_CPU_0 has local data [[0 1 2 3]]
device TFRT_CPU_1 has local data [[4 5 6 7]]
process=0 got result: -0.12398731708526611

=================== process 1 output ===================
device TFRT_CPU_131072 has local data [[ 8  9 10 11]]
device TFRT_CPU_131073 has local data [[12 13 14 15]]
process=1 got result: -0.12398731708526611

=================== process 2 output ===================
device TFRT_CPU_262144 has local data [[16 17 18 19]]
device TFRT_CPU_262145 has local data [[20 21 22 23]]
process=2 got result: -0.12398731708526611

=================== process 3 output ===================
device TFRT_CPU_393216 has local data [[24 25 26 27]]
device TFRT_CPU_393217 has local data [[28 29 30 31]]
process=3 got result: -0.12398731708526611

这看起来可能与单控制器 JAX 代码没有太大区别,事实上,这正是您编写同一程序的单控制器版本的方式!(我们技术上不需要为单控制器调用 jax.distributed.initialize(),但调用它也没有坏处。)让我们从单个进程运行相同的代码。

JAX_NUM_CPU_DEVICES=8 python toy.py 0 1

输出

device TFRT_CPU_0 has local data [[0 1 2 3]]
device TFRT_CPU_1 has local data [[4 5 6 7]]
device TFRT_CPU_2 has local data [[ 8  9 10 11]]
device TFRT_CPU_3 has local data [[12 13 14 15]]
device TFRT_CPU_4 has local data [[16 17 18 19]]
device TFRT_CPU_5 has local data [[20 21 22 23]]
device TFRT_CPU_6 has local data [[24 25 26 27]]
device TFRT_CPU_7 has local data [[28 29 30 31]]
process=0 got result: -0.12398731708526611

数据被分片到单个进程上的八个设备上,而不是分片到四个进程上的八个设备上,但除此之外,我们对相同的数据运行相同的操作。

术语#

值得明确一些术语。

我们有时将运行 JAX 计算的每个 Python 进程称为 **控制器**,但这两个术语基本同义。

每个进程都有一个 **本地设备** 集,这意味着它可以与这些设备的内存进行数据传输,并在这些设备上运行计算,而无需涉及任何其他进程。本地设备通常物理连接到相应主机的进程,例如通过 PCI。一个设备只能是单个进程的本地设备;也就是说,本地设备集是不相交的。可以通过评估 jax.local_devices() 来查询进程的本地设备。我们有时使用 **可寻址** 来表示与本地相同的含义。

Illustration of how a process/controller and local devices fit into a larger multi-host cluster. The "global devices" are all devices in the cluster.

说明进程/控制器和本地设备如何融入更大的多主机集群。“全局设备”是集群中的所有设备。#

所有进程中的设备称为 **全局设备**。全局设备列表通过 jax.devices() 查询。所有设备列表都是通过在所有进程上运行 jax.distributed.initialize() 填充的,该函数设置了一个连接这些进程的简单分布式系统。

我们经常使用 **全局** 和 **本地** 这两个术语来描述进程跨越和进程本地的概念。例如,“本地数组”可能是仅对单个进程可见的 numpy 数组,而 JAX 的“全局数组”在概念上对所有进程都可见。

设置多个 JAX 进程#

在实践中,设置多个 JAX 进程看起来与玩具示例略有不同,玩具示例是从单个主机机器运行的。我们通常在单独的主机上启动每个进程,或者有多台主机,每台主机上有多个进程。我们可以直接使用 ssh 来做到这一点,或者使用 Slurm 或 Kubernetes 等集群管理器。无论如何,您必须在每台主机上手动运行您的 JAX 程序! JAX 不会从单个程序调用自动启动多个进程。

无论它们如何启动,Python 进程都需要运行 jax.distributed.initialize()。当使用 Slurm、Kubernetes 或任何 Cloud TPU 部署时,我们可以不带参数地运行 jax.distributed.initialize(),因为它们会自动填充。初始化系统意味着我们可以运行 jax.devices() 来报告所有进程的所有设备。

警告

jax.distributed.initialize() 必须在运行 jax.devices()jax.local_devices() 或在设备上运行任何计算(例如使用 jax.numpy)之前调用。否则,JAX 进程将无法感知任何非本地设备。(使用 jax.config() 或其他不访问设备的功能是可以的。)如果您不小心在访问任何设备后调用 jax.distributed.initialize(),它将引发错误。

GPU 示例#

我们可以在 GPU 机器 集群上运行多控制器 JAX。例如,在 Google Cloud 上创建四个具有两块 GPU 的虚拟机后,我们可以在每台虚拟机上运行以下 JAX 程序。在此示例中,我们显式地为 jax.distributed.initialize() 提供了参数。协调器地址、进程 ID 和进程数从命令行读取。

# In file gpu_example.py...

import jax
import sys

# Get the coordinator_address, process_id, and num_processes from the command line.
coord_addr = sys.argv[1]
proc_id = int(sys.argv[2])
num_procs = int(sys.argv[3])

# Initialize the GPU machines.
jax.distributed.initialize(coordinator_address=coord_addr,
                           num_processes=num_procs,
                           process_id=proc_id)
print("process id =", jax.process_index())
print("global devices =", jax.devices())
print("local devices =", jax.local_devices())

例如,如果第一台虚拟机地址为 192.168.0.1,则您将在第一台虚拟机上运行 python3 gpu_example.py 192.168.0.1:8000 0 4,在第二台虚拟机上运行 python3 gpu_example.py 192.168.0.1:8000 1 4,依此类推。在所有四台虚拟机上运行 JAX 程序后,第一个进程将打印以下内容。

process id = 0
global devices = [CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3), CudaDevice(id=4), CudaDevice(id=5), CudaDevice(id=6), CudaDevice(id=7)]
local devices = [CudaDevice(id=0), CudaDevice(id=1)]

该进程成功地看到了所有八块 GPU 作为全局设备,以及其两个本地设备。同样,第二个进程打印以下内容。

process id = 1
global devices = [CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3), CudaDevice(id=4), CudaDevice(id=5), CudaDevice(id=6), CudaDevice(id=7)]
local devices = [CudaDevice(id=2), CudaDevice(id=3)]

这台虚拟机看到了相同的全局设备,但拥有不同的本地设备集。

TPU 示例#

作为另一个示例,我们可以在 Cloud TPU 上运行。创建 v5litepod-16(具有 4 台主机)后,我们可能希望测试能否连接进程并列出所有设备。

$ TPU_NAME=jax-demo
$ EXTERNAL_IPS=$(gcloud compute tpus tpu-vm describe $TPU_NAME --zone 'us-central1-a' \
                 | grep externalIp | cut -d: -f2)
$ cat << EOF > demo.py
import jax
jax.distributed.initialize()
if jax.process_index() == 0:
  print(jax.devices())
EOF
$ echo $EXTERNAL_IPS | xargs -n 1 -P 0 bash -c '
scp demo.py $0:
ssh $0 "pip -q install -U jax[tpu]"
ssh $0 "python demo.py" '

这里我们使用 xargs 并行运行多个 ssh 命令,每个命令在一台 TPU 主机上运行相同的 Python 程序。在 Python 代码中,我们使用 jax.process_index() 来仅在一个进程上打印。以下是其打印的内容。

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=2, process_index=1, coords=(2,0,0), core_on_chip=0), TpuDevice(id=3, process_index=1, coords=(3,0,0), core_on_chip=0), TpuDevice(id=6, process_index=1, coords=(2,1,0), core_on_chip=0), TpuDevice(id=7, process_index=1, coords=(3,1,0), core_on_chip=0), TpuDevice(id=8, process_index=2, coords=(0,2,0), core_on_chip=0), TpuDevice(id=9, process_index=2, coords=(1,2,0), core_on_chip=0), TpuDevice(id=12, process_index=2, coords=(0,3,0), core_on_chip=0), TpuDevice(id=13, process_index=2, coords=(1,3,0), core_on_chip=0), TpuDevice(id=10, process_index=3, coords=(2,2,0), core_on_chip=0), TpuDevice(id=11, process_index=3, coords=(3,2,0), core_on_chip=0), TpuDevice(id=14, process_index=3, coords=(2,3,0), core_on_chip=0), TpuDevice(id=15, process_index=3, coords=(3,3,0), core_on_chip=0)]

太棒了,看这些 TPU 核心!

Kubernetes 示例#

在 Kubernetes 集群上运行多控制器 JAX 的思想与上述 GPU 和 TPU 示例几乎相同:每个 pod 运行相同的 Python 程序,JAX 会发现其对等节点,集群就像一台巨大的机器一样运行。

  1. 容器镜像 - 从启用 JAX 的镜像开始,例如 Google Artifact Registry 上的公共 JAX AI 镜像(TPU / GPU)或 NVIDIA(NGC / JAX-Toolbox)。

  2. 工作负载类型 - 使用 JobSet索引 Job。每个副本对应一个 JAX 进程。

  3. 服务帐户 - JAX 需要权限来列出属于该作业的 pod,以便进程发现其对等节点。在 examples/k8s/svc-acct.yaml 中提供了一个最小的 RBAC 设置。

下面是一个 最小 JobSet,它启动了两个副本。将占位符(镜像、GPU 数量和任何私有注册表密钥)替换为与您的环境匹配的值。

apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
  name: jaxjob
spec:
  replicatedJobs:
  - name: workers
    template:
      spec:
        parallelism: 2
        completions: 2
        backoffLimit: 0
        template:
          spec:
            serviceAccountName: jax-job-sa  # kubectl apply -f svc-acct.yaml
            restartPolicy: Never
            imagePullSecrets:
              # https://k8s.io/docs/tasks/configure-pod-container/pull-image-private-registry/
            - name: null
            containers:
            - name: main
              image: null  # e.g. ghcr.io/nvidia/jax:jax
              imagePullPolicy: Always
              resources:
                limits:
                  cpu: 1
                  # https://k8s.io/docs/tasks/manage-gpus/scheduling-gpus/
                  nvidia.com/gpu: null
              command: 
                - python
              args:
                - -c
                - |
                  import jax
                  jax.distributed.initialize()
                  print(jax.devices())
                  print(jax.local_devices())
                  assert jax.process_count() > 1
                  assert len(jax.devices()) > len(jax.local_devices())

应用清单并观察 pod 完成。

$ kubectl apply -f example.yaml
$ kubectl get pods -l jobset.sigs.k8s.io/jobset-name=jaxjob
NAME                       READY   STATUS      RESTARTS   AGE
jaxjob-workers-0-0-xpx8l   0/1     Completed   0          8m32s
jaxjob-workers-0-1-ddkq8   0/1     Completed   0          8m32s

作业完成后,检查日志以确认每个进程都看到了所有加速器。

$ kubectl logs -l jobset.sigs.k8s.io/jobset-name=jaxjob
[CudaDevice(id=0), CudaDevice(id=1)]
[CudaDevice(id=0)]
[CudaDevice(id=0), CudaDevice(id=1)]
[CudaDevice(id=1)]

每个 pod 都应该具有相同的全局设备集和不同的本地设备集。此时,您可以将内联脚本替换为您的实际 JAX 程序。

一旦进程设置好,我们就可以开始构建全局 jax.Array 并运行计算。本教程中其余的 Python 代码示例旨在同时在所有进程上运行,在运行 jax.distributed.initialize() 之后。

Mesh、分片和计算可以跨越进程和主机#

从 JAX 编程多个进程通常看起来就像编程单个进程一样,只是有更多的设备!主要的例外是关于进出 JAX 的数据,例如从外部数据源加载时。我们首先在这里介绍多进程计算的基础知识,这些基础知识在很大程度上与其单进程对应物看起来相同。下一节将介绍一些数据加载的基础知识,即如何从非 JAX 源创建 JAX 数组。

回想一下,jax.sharding.Meshjax.Device 数组与一组名称配对,每个数组轴对应一个名称。通过使用来自多个进程的设备创建 Mesh,然后在 jax.sharding.Sharding 中使用该 mesh,我们可以构建跨越多个进程的设备的 jax.Array

这是一个直接使用 jax.devices() 获取所有进程设备来构造 Mesh 的示例。

from jax.sharding import Mesh
mesh = Mesh(jax.devices(), ('a',))

# in this case, the same as
mesh = jax.make_mesh((jax.device_count(),), ('a',))  # use this in practice

在实践中,您可能应该使用 jax.make_mesh() 辅助函数,这不仅因为它更简单,而且因为它还可以自动选择更高效的设备排序,但我们在这里将其拼写出来。默认情况下,它包含所有进程的设备,就像 jax.devices() 一样。

一旦我们有了 mesh,我们就可以对其进行分片数组。有几种有效构建跨进程数组的方法,将在下一节详细介绍,但现在我们为了简单起见,将坚持使用 jax.device_put

arr = jax.device_put(jnp.ones((32, 32)), NamedSharding(mesh, P('a')))
if jax.process_index() == 0:
  jax.debug.visualize_array_sharding(arr)

在进程 0 上,打印如下:

┌───────────────────────┐
│         TPU 0         │
├───────────────────────┤
│         TPU 1         │
├───────────────────────┤
│         TPU 4         │
├───────────────────────┤
│         TPU 5         │
├───────────────────────┤
│         TPU 2         │
├───────────────────────┤
│         TPU 3         │
├───────────────────────┤
│         TPU 6         │
├───────────────────────┤
│         TPU 7         │
├───────────────────────┤
│         TPU 8         │
├───────────────────────┤
│         TPU 9         │
├───────────────────────┤
│        TPU 12         │
├───────────────────────┤
│        TPU 13         │
├───────────────────────┤
│        TPU 10         │
├───────────────────────┤
│        TPU 11         │
├───────────────────────┤
│        TPU 14         │
├───────────────────────┤
│        TPU 15         │
└───────────────────────┘

让我们尝试一个更有趣的计算!

mesh = jax.make_mesh((jax.device_count() // 2, 2), ('a', 'b'))

def device_put(x, spec):
  return jax.device_put(x, NamedSharding(mesh, spec))

# construct global arrays by sharding over the global mesh
x = device_put(jnp.ones((4096, 2048)), P('a', 'b'))
y = device_put(jnp.ones((2048, 4096)), P('b', None))

# run a distributed matmul
z = jax.nn.relu(x @ y)

# inspect the sharding of the result
if jax.process_index() == 0:
  jax.debug.visualize_array_sharding(z)
  print()
  print(z.sharding)

在进程 0 上,打印如下:

┌───────────────────────┐
│        TPU 0,1        │
├───────────────────────┤
│        TPU 4,5        │
├───────────────────────┤
│        TPU 8,9        │
├───────────────────────┤
│       TPU 12,13       │
├───────────────────────┤
│        TPU 2,3        │
├───────────────────────┤
│        TPU 6,7        │
├───────────────────────┤
│       TPU 10,11       │
├───────────────────────┤
│       TPU 14,15       │
└───────────────────────┘

NamedSharding(mesh=Mesh('a': 8, 'b': 2), spec=PartitionSpec('a',), memory_kind=device)

在这里,仅仅通过在所有进程上评估 x @ y,XLA 就会自动生成并运行分布式矩阵乘法。结果被分片到 mesh 上,如 P('a', None),因为在这种情况下,matmul 包含了关于 'b' 轴的 psum

警告

当对跨进程数组应用 JAX 计算时,为了避免死锁和挂起,所有具有参与设备的进程以相同的顺序运行相同的计算至关重要。这是因为计算可能涉及集体通信屏障。如果分片数组的设备没有参与集体通信,因为其控制器没有发出相同的计算,那么其他设备就会被挂起。例如,如果只有前三个进程评估了 x @ y,而最后一个进程评估了 y @ x,计算很可能会无限期挂起。这个假设,即在跨进程数组上的计算以相同的顺序在所有参与的进程上运行,基本上是没有经过检查的。

因此,在多进程 JAX 中避免死锁的最简单方法是在每个进程上运行相同的 Python 代码,并注意任何依赖于 jax.process_index() 并包含通信的控制流。

如果一个跨进程数组被分片到不同进程的设备上,那么对需要该数据在进程本地可用的数组执行操作(如打印)是错误的。例如,如果我们打印前面的示例中的 print(z),我们会看到:

RuntimeError: Fetching value for `jax.Array` that spans non-addressable (non process local) devices is not possible. You can use `jax.experimental.multihost_utils.process_allgather` to print the global array or use `.addressable_shards` method of jax.Array to inspect the addressable (process local) shards.

要打印完整的数组值,我们必须首先确保它已复制到所有进程(但不一定复制到每个进程的本地设备),例如使用 jax.device_put。在上面的示例中,我们可以在末尾写:

w = device_put(z, P(None, None))
if jax.process_index() == 0:
  print(w)

请注意不要在 if process_index() == 0 下写入 jax.device_put(),因为这会导致死锁,因为只有进程 0 发起了集体通信并无限期等待其他进程。jax.experimental.multihost_utils 模块包含一些函数,可以更轻松地处理全局 jax.Array(例如,jax.experimental.multihost_utils.process_allgather())。

或者,要仅在进程本地数据上打印或执行其他 Python 操作,我们可以访问 z.addressable_shards。访问该属性不需要任何通信,因此任何子集进程都可以执行它而无需其他进程。该属性在 jax.jit() 下不可用。

从外部数据创建跨进程数组#

从外部数据源(例如数据加载器的 numpy 数组)创建跨进程 jax.Array 有三种主要方法:

  1. 在所有进程上创建或加载完整数组,然后使用 jax.device_put() 分片到设备;

  2. 在每个进程上创建或加载一个数组,该数组仅代表将在本地分片并存储在该进程设备上的数据,然后使用 jax.make_array_from_process_local_data() 分片到设备;

  3. 在每个进程的设备上创建或加载单独的数组,每个数组代表将在该设备上存储的数据,然后使用 jax.make_array_from_single_device_arrays() 在没有任何数据移动的情况下将它们组装起来。

后两种方法在实践中最为常用,因为将全部全局数据物化到每个进程中通常成本过高。

上面的玩具示例使用了 jax.device_put()

jax.make_array_from_process_local_data() 通常用于分布式数据加载。它不像 jax.make_array_from_single_device_arrays() 那样通用,因为它不直接指定进程本地数据的哪个切片会放在每个本地设备上。这在加载数据并行批次时很方便,因为哪个微批次放在哪个设备上并不重要。例如:

# target (micro)batch size across the whole cluster
batch_size = 1024
# how many examples each process should load per batch
per_process_batch_size = batch_size // jax.process_count()
# how many examples each device will process per batch
per_device_batch_size = batch_size // jax.device_count()

# make a data-parallel mesh and sharding
mesh = jax.make_mesh((jax.device_count(),), ('batch'))
sharding = NamedSharding(mesh, P('batch'))

# our "data loader". each process loads a different set of "examples".
process_batch = np.random.rand(per_process_batch_size, 2048, 42)

# assemble a global array containing the per-process batches from all processes
global_batch = jax.make_array_from_process_local_data(sharding, process_batch)

# sanity check that everything got sharded correctly
assert global_batch.shape[0] == batch_size
assert process_batch.shape[0] == per_process_batch_size
assert global_batch.addressable_shards[0].data.shape[0] == per_device_batch_size

jax.make_array_from_single_device_arrays() 是构建跨进程数组的最通用方法。它通常在执行 jax.device_put() 以将每个设备所需数据发送给它之后使用。这是最低级别的选项,因为所有数据移动都是手动执行的(例如通过 jax.device_put())。这是一个例子:

shape = (jax.process_count(), jax.local_device_count())
mesh = jax.make_mesh(shape, ('i', 'j'))
sharding = NamedSharding(mesh, P('i', 'j'))

# manually create per-device data equivalent to np.arange(jax.device_count())
# i.e. each device will get a single scalar value from 0..N
local_arrays = [
    jax.device_put(
        jnp.array([[jax.process_index() * jax.local_device_count() + i]]),
        device)
    for i, device in enumerate(jax.local_devices())
]

# assemble a global array from the local_arrays across all processes
global_array = jax.make_array_from_single_device_arrays(
    shape=shape,
    sharding=sharding,
    arrays=local_arrays)

# sanity check
assert (np.all(
    jax.experimental.multihost_utils.process_allgather(global_array) ==
    np.arange(jax.device_count()).reshape(global_array.shape)))