多控制器 JAX (又名多进程/多主机 JAX) 简介#
通过阅读本教程,您将学会如何将 JAX 计算扩展到单个主机机器无法容纳的更多设备,例如在 GPU 集群、Cloud TPU pod 或多个纯 CPU 机器上运行时。
主要思想
运行多个 Python 进程,我们有时称之为“控制器”。我们可以在每台主机机器上运行一个(或多个)进程。
使用
jax.distributed.initialize()初始化集群。.一个
jax.Array可以跨越所有进程,如果每个进程对其应用相同的 JAX 函数,则就像编程一个大型设备一样。使用与单控制器 JAX 相同的 统一分片机制 来控制数据的分布和计算的并行化。XLA 会自动利用可用时的高速网络连接,例如 TPU ICI 或主机之间的 NVLink,否则会使用可用的主机网络(例如以太网、InfiniBand)。
所有进程(通常)运行相同的 Python 脚本。您编写的代码几乎与单进程版本完全相同——只需运行多个实例,JAX 会处理其余的事情。换句话说,除了数组创建之外,您可以像有一个连接了所有设备的巨大机器一样编写您的 JAX 代码。
本教程假定您已阅读 分布式数组和自动并行化,该文档介绍的是单控制器 JAX。
多主机 TPU pod 的说明。pod 中的每个主机(绿色)通过 PCI 连接到一块四芯片 TPU 板(蓝色)。TPU 芯片本身通过高速芯片间互连 (ICI) 连接。JAX Python 代码运行在每个主机上,例如通过 ssh。每个主机上的 JAX 进程相互之间是已知的,这使您可以协调整个 pod 芯片的计算。对于支持 JAX 的 GPU、CPU 和其他平台,原理相同!#
玩具示例#
在定义术语并详细介绍之前,这里有一个玩具示例:创建一个跨进程的 jax.Array 并对其应用 jax.numpy 函数。
# call this file toy.py, to be run in each process simultaneously
import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec as P
import numpy as np
# in this example, get multi-process parameters from sys.argv
import sys
proc_id = int(sys.argv[1])
num_procs = int(sys.argv[2])
# initialize the distributed system
jax.distributed.initialize('localhost:10000', num_procs, proc_id)
# this example assumes 8 devices total
assert jax.device_count() == 8
# make a 2D mesh that refers to devices from all processes
mesh = jax.make_mesh((4, 2), ('i', 'j'))
# create some toy data
global_data = np.arange(32).reshape((4, 8))
# make a process- and device-spanning array from our toy data
sharding = NamedSharding(mesh, P('i', 'j'))
global_array = jax.device_put(global_data, sharding)
assert global_array.shape == global_data.shape
# each process has different shards of the global array
for shard in global_array.addressable_shards:
print(f"device {shard.device} has local data {shard.data}")
# apply a simple computation, automatically partitioned
global_result = jnp.sum(jnp.sin(global_array))
print(f'process={proc_id} got result: {global_result}')
在此,mesh 包含来自所有进程的设备。我们使用它来创建 global_array,逻辑上它是一个单一的共享数组,分布式存储在所有进程的设备上。
每个进程必须以相同的顺序对 global_array 应用相同的操作。XLA 会自动分区这些计算,例如插入通信集合来计算整个数组的 jnp.sum。我们可以打印最终结果,因为它的值在进程之间是复制的。
我们可以在本地 CPU 上运行此代码,例如使用 4 个进程,每个进程 2 个 CPU 设备。
export JAX_NUM_CPU_DEVICES=2
num_processes=4
range=$(seq 0 $(($num_processes - 1)))
for i in $range; do
python toy.py $i $num_processes > /tmp/toy_$i.out &
done
wait
for i in $range; do
echo "=================== process $i output ==================="
cat /tmp/toy_$i.out
echo
done
输出
=================== process 0 output ===================
device TFRT_CPU_0 has local data [[0 1 2 3]]
device TFRT_CPU_1 has local data [[4 5 6 7]]
process=0 got result: -0.12398731708526611
=================== process 1 output ===================
device TFRT_CPU_131072 has local data [[ 8 9 10 11]]
device TFRT_CPU_131073 has local data [[12 13 14 15]]
process=1 got result: -0.12398731708526611
=================== process 2 output ===================
device TFRT_CPU_262144 has local data [[16 17 18 19]]
device TFRT_CPU_262145 has local data [[20 21 22 23]]
process=2 got result: -0.12398731708526611
=================== process 3 output ===================
device TFRT_CPU_393216 has local data [[24 25 26 27]]
device TFRT_CPU_393217 has local data [[28 29 30 31]]
process=3 got result: -0.12398731708526611
这看起来可能与单控制器 JAX 代码没有太大区别,事实上,这正是您编写同一程序的单控制器版本的方式!(我们技术上不需要为单控制器调用 jax.distributed.initialize(),但调用它也没有坏处。)让我们从单个进程运行相同的代码。
JAX_NUM_CPU_DEVICES=8 python toy.py 0 1
输出
device TFRT_CPU_0 has local data [[0 1 2 3]]
device TFRT_CPU_1 has local data [[4 5 6 7]]
device TFRT_CPU_2 has local data [[ 8 9 10 11]]
device TFRT_CPU_3 has local data [[12 13 14 15]]
device TFRT_CPU_4 has local data [[16 17 18 19]]
device TFRT_CPU_5 has local data [[20 21 22 23]]
device TFRT_CPU_6 has local data [[24 25 26 27]]
device TFRT_CPU_7 has local data [[28 29 30 31]]
process=0 got result: -0.12398731708526611
数据被分片到单个进程上的八个设备上,而不是分片到四个进程上的八个设备上,但除此之外,我们对相同的数据运行相同的操作。
术语#
值得明确一些术语。
我们有时将运行 JAX 计算的每个 Python 进程称为 **控制器**,但这两个术语基本同义。
每个进程都有一个 **本地设备** 集,这意味着它可以与这些设备的内存进行数据传输,并在这些设备上运行计算,而无需涉及任何其他进程。本地设备通常物理连接到相应主机的进程,例如通过 PCI。一个设备只能是单个进程的本地设备;也就是说,本地设备集是不相交的。可以通过评估 jax.local_devices() 来查询进程的本地设备。我们有时使用 **可寻址** 来表示与本地相同的含义。
说明进程/控制器和本地设备如何融入更大的多主机集群。“全局设备”是集群中的所有设备。#
所有进程中的设备称为 **全局设备**。全局设备列表通过 jax.devices() 查询。所有设备列表都是通过在所有进程上运行 jax.distributed.initialize() 填充的,该函数设置了一个连接这些进程的简单分布式系统。
我们经常使用 **全局** 和 **本地** 这两个术语来描述进程跨越和进程本地的概念。例如,“本地数组”可能是仅对单个进程可见的 numpy 数组,而 JAX 的“全局数组”在概念上对所有进程都可见。
设置多个 JAX 进程#
在实践中,设置多个 JAX 进程看起来与玩具示例略有不同,玩具示例是从单个主机机器运行的。我们通常在单独的主机上启动每个进程,或者有多台主机,每台主机上有多个进程。我们可以直接使用 ssh 来做到这一点,或者使用 Slurm 或 Kubernetes 等集群管理器。无论如何,您必须在每台主机上手动运行您的 JAX 程序! JAX 不会从单个程序调用自动启动多个进程。
无论它们如何启动,Python 进程都需要运行 jax.distributed.initialize()。当使用 Slurm、Kubernetes 或任何 Cloud TPU 部署时,我们可以不带参数地运行 jax.distributed.initialize(),因为它们会自动填充。初始化系统意味着我们可以运行 jax.devices() 来报告所有进程的所有设备。
警告
jax.distributed.initialize() 必须在运行 jax.devices()、jax.local_devices() 或在设备上运行任何计算(例如使用 jax.numpy)之前调用。否则,JAX 进程将无法感知任何非本地设备。(使用 jax.config() 或其他不访问设备的功能是可以的。)如果您不小心在访问任何设备后调用 jax.distributed.initialize(),它将引发错误。
GPU 示例#
我们可以在 GPU 机器 集群上运行多控制器 JAX。例如,在 Google Cloud 上创建四个具有两块 GPU 的虚拟机后,我们可以在每台虚拟机上运行以下 JAX 程序。在此示例中,我们显式地为 jax.distributed.initialize() 提供了参数。协调器地址、进程 ID 和进程数从命令行读取。
# In file gpu_example.py...
import jax
import sys
# Get the coordinator_address, process_id, and num_processes from the command line.
coord_addr = sys.argv[1]
proc_id = int(sys.argv[2])
num_procs = int(sys.argv[3])
# Initialize the GPU machines.
jax.distributed.initialize(coordinator_address=coord_addr,
num_processes=num_procs,
process_id=proc_id)
print("process id =", jax.process_index())
print("global devices =", jax.devices())
print("local devices =", jax.local_devices())
例如,如果第一台虚拟机地址为 192.168.0.1,则您将在第一台虚拟机上运行 python3 gpu_example.py 192.168.0.1:8000 0 4,在第二台虚拟机上运行 python3 gpu_example.py 192.168.0.1:8000 1 4,依此类推。在所有四台虚拟机上运行 JAX 程序后,第一个进程将打印以下内容。
process id = 0
global devices = [CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3), CudaDevice(id=4), CudaDevice(id=5), CudaDevice(id=6), CudaDevice(id=7)]
local devices = [CudaDevice(id=0), CudaDevice(id=1)]
该进程成功地看到了所有八块 GPU 作为全局设备,以及其两个本地设备。同样,第二个进程打印以下内容。
process id = 1
global devices = [CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3), CudaDevice(id=4), CudaDevice(id=5), CudaDevice(id=6), CudaDevice(id=7)]
local devices = [CudaDevice(id=2), CudaDevice(id=3)]
这台虚拟机看到了相同的全局设备,但拥有不同的本地设备集。
TPU 示例#
作为另一个示例,我们可以在 Cloud TPU 上运行。创建 v5litepod-16(具有 4 台主机)后,我们可能希望测试能否连接进程并列出所有设备。
$ TPU_NAME=jax-demo
$ EXTERNAL_IPS=$(gcloud compute tpus tpu-vm describe $TPU_NAME --zone 'us-central1-a' \
| grep externalIp | cut -d: -f2)
$ cat << EOF > demo.py
import jax
jax.distributed.initialize()
if jax.process_index() == 0:
print(jax.devices())
EOF
$ echo $EXTERNAL_IPS | xargs -n 1 -P 0 bash -c '
scp demo.py $0:
ssh $0 "pip -q install -U jax[tpu]"
ssh $0 "python demo.py" '
这里我们使用 xargs 并行运行多个 ssh 命令,每个命令在一台 TPU 主机上运行相同的 Python 程序。在 Python 代码中,我们使用 jax.process_index() 来仅在一个进程上打印。以下是其打印的内容。
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=2, process_index=1, coords=(2,0,0), core_on_chip=0), TpuDevice(id=3, process_index=1, coords=(3,0,0), core_on_chip=0), TpuDevice(id=6, process_index=1, coords=(2,1,0), core_on_chip=0), TpuDevice(id=7, process_index=1, coords=(3,1,0), core_on_chip=0), TpuDevice(id=8, process_index=2, coords=(0,2,0), core_on_chip=0), TpuDevice(id=9, process_index=2, coords=(1,2,0), core_on_chip=0), TpuDevice(id=12, process_index=2, coords=(0,3,0), core_on_chip=0), TpuDevice(id=13, process_index=2, coords=(1,3,0), core_on_chip=0), TpuDevice(id=10, process_index=3, coords=(2,2,0), core_on_chip=0), TpuDevice(id=11, process_index=3, coords=(3,2,0), core_on_chip=0), TpuDevice(id=14, process_index=3, coords=(2,3,0), core_on_chip=0), TpuDevice(id=15, process_index=3, coords=(3,3,0), core_on_chip=0)]
太棒了,看这些 TPU 核心!
Kubernetes 示例#
在 Kubernetes 集群上运行多控制器 JAX 的思想与上述 GPU 和 TPU 示例几乎相同:每个 pod 运行相同的 Python 程序,JAX 会发现其对等节点,集群就像一台巨大的机器一样运行。
容器镜像 - 从启用 JAX 的镜像开始,例如 Google Artifact Registry 上的公共 JAX AI 镜像(TPU / GPU)或 NVIDIA(NGC / JAX-Toolbox)。
服务帐户 - JAX 需要权限来列出属于该作业的 pod,以便进程发现其对等节点。在 examples/k8s/svc-acct.yaml 中提供了一个最小的 RBAC 设置。
下面是一个 最小 JobSet,它启动了两个副本。将占位符(镜像、GPU 数量和任何私有注册表密钥)替换为与您的环境匹配的值。
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: jaxjob
spec:
replicatedJobs:
- name: workers
template:
spec:
parallelism: 2
completions: 2
backoffLimit: 0
template:
spec:
serviceAccountName: jax-job-sa # kubectl apply -f svc-acct.yaml
restartPolicy: Never
imagePullSecrets:
# https://k8s.io/docs/tasks/configure-pod-container/pull-image-private-registry/
- name: null
containers:
- name: main
image: null # e.g. ghcr.io/nvidia/jax:jax
imagePullPolicy: Always
resources:
limits:
cpu: 1
# https://k8s.io/docs/tasks/manage-gpus/scheduling-gpus/
nvidia.com/gpu: null
command:
- python
args:
- -c
- |
import jax
jax.distributed.initialize()
print(jax.devices())
print(jax.local_devices())
assert jax.process_count() > 1
assert len(jax.devices()) > len(jax.local_devices())
应用清单并观察 pod 完成。
$ kubectl apply -f example.yaml
$ kubectl get pods -l jobset.sigs.k8s.io/jobset-name=jaxjob
NAME READY STATUS RESTARTS AGE
jaxjob-workers-0-0-xpx8l 0/1 Completed 0 8m32s
jaxjob-workers-0-1-ddkq8 0/1 Completed 0 8m32s
作业完成后,检查日志以确认每个进程都看到了所有加速器。
$ kubectl logs -l jobset.sigs.k8s.io/jobset-name=jaxjob
[CudaDevice(id=0), CudaDevice(id=1)]
[CudaDevice(id=0)]
[CudaDevice(id=0), CudaDevice(id=1)]
[CudaDevice(id=1)]
每个 pod 都应该具有相同的全局设备集和不同的本地设备集。此时,您可以将内联脚本替换为您的实际 JAX 程序。
一旦进程设置好,我们就可以开始构建全局 jax.Array 并运行计算。本教程中其余的 Python 代码示例旨在同时在所有进程上运行,在运行 jax.distributed.initialize() 之后。
Mesh、分片和计算可以跨越进程和主机#
从 JAX 编程多个进程通常看起来就像编程单个进程一样,只是有更多的设备!主要的例外是关于进出 JAX 的数据,例如从外部数据源加载时。我们首先在这里介绍多进程计算的基础知识,这些基础知识在很大程度上与其单进程对应物看起来相同。下一节将介绍一些数据加载的基础知识,即如何从非 JAX 源创建 JAX 数组。
回想一下,jax.sharding.Mesh 将 jax.Device 数组与一组名称配对,每个数组轴对应一个名称。通过使用来自多个进程的设备创建 Mesh,然后在 jax.sharding.Sharding 中使用该 mesh,我们可以构建跨越多个进程的设备的 jax.Array。
这是一个直接使用 jax.devices() 获取所有进程设备来构造 Mesh 的示例。
from jax.sharding import Mesh
mesh = Mesh(jax.devices(), ('a',))
# in this case, the same as
mesh = jax.make_mesh((jax.device_count(),), ('a',)) # use this in practice
在实践中,您可能应该使用 jax.make_mesh() 辅助函数,这不仅因为它更简单,而且因为它还可以自动选择更高效的设备排序,但我们在这里将其拼写出来。默认情况下,它包含所有进程的设备,就像 jax.devices() 一样。
一旦我们有了 mesh,我们就可以对其进行分片数组。有几种有效构建跨进程数组的方法,将在下一节详细介绍,但现在我们为了简单起见,将坚持使用 jax.device_put。
arr = jax.device_put(jnp.ones((32, 32)), NamedSharding(mesh, P('a')))
if jax.process_index() == 0:
jax.debug.visualize_array_sharding(arr)
在进程 0 上,打印如下:
┌───────────────────────┐
│ TPU 0 │
├───────────────────────┤
│ TPU 1 │
├───────────────────────┤
│ TPU 4 │
├───────────────────────┤
│ TPU 5 │
├───────────────────────┤
│ TPU 2 │
├───────────────────────┤
│ TPU 3 │
├───────────────────────┤
│ TPU 6 │
├───────────────────────┤
│ TPU 7 │
├───────────────────────┤
│ TPU 8 │
├───────────────────────┤
│ TPU 9 │
├───────────────────────┤
│ TPU 12 │
├───────────────────────┤
│ TPU 13 │
├───────────────────────┤
│ TPU 10 │
├───────────────────────┤
│ TPU 11 │
├───────────────────────┤
│ TPU 14 │
├───────────────────────┤
│ TPU 15 │
└───────────────────────┘
让我们尝试一个更有趣的计算!
mesh = jax.make_mesh((jax.device_count() // 2, 2), ('a', 'b'))
def device_put(x, spec):
return jax.device_put(x, NamedSharding(mesh, spec))
# construct global arrays by sharding over the global mesh
x = device_put(jnp.ones((4096, 2048)), P('a', 'b'))
y = device_put(jnp.ones((2048, 4096)), P('b', None))
# run a distributed matmul
z = jax.nn.relu(x @ y)
# inspect the sharding of the result
if jax.process_index() == 0:
jax.debug.visualize_array_sharding(z)
print()
print(z.sharding)
在进程 0 上,打印如下:
┌───────────────────────┐
│ TPU 0,1 │
├───────────────────────┤
│ TPU 4,5 │
├───────────────────────┤
│ TPU 8,9 │
├───────────────────────┤
│ TPU 12,13 │
├───────────────────────┤
│ TPU 2,3 │
├───────────────────────┤
│ TPU 6,7 │
├───────────────────────┤
│ TPU 10,11 │
├───────────────────────┤
│ TPU 14,15 │
└───────────────────────┘
NamedSharding(mesh=Mesh('a': 8, 'b': 2), spec=PartitionSpec('a',), memory_kind=device)
在这里,仅仅通过在所有进程上评估 x @ y,XLA 就会自动生成并运行分布式矩阵乘法。结果被分片到 mesh 上,如 P('a', None),因为在这种情况下,matmul 包含了关于 'b' 轴的 psum。
警告
当对跨进程数组应用 JAX 计算时,为了避免死锁和挂起,所有具有参与设备的进程以相同的顺序运行相同的计算至关重要。这是因为计算可能涉及集体通信屏障。如果分片数组的设备没有参与集体通信,因为其控制器没有发出相同的计算,那么其他设备就会被挂起。例如,如果只有前三个进程评估了 x @ y,而最后一个进程评估了 y @ x,计算很可能会无限期挂起。这个假设,即在跨进程数组上的计算以相同的顺序在所有参与的进程上运行,基本上是没有经过检查的。
因此,在多进程 JAX 中避免死锁的最简单方法是在每个进程上运行相同的 Python 代码,并注意任何依赖于 jax.process_index() 并包含通信的控制流。
如果一个跨进程数组被分片到不同进程的设备上,那么对需要该数据在进程本地可用的数组执行操作(如打印)是错误的。例如,如果我们打印前面的示例中的 print(z),我们会看到:
RuntimeError: Fetching value for `jax.Array` that spans non-addressable (non process local) devices is not possible. You can use `jax.experimental.multihost_utils.process_allgather` to print the global array or use `.addressable_shards` method of jax.Array to inspect the addressable (process local) shards.
要打印完整的数组值,我们必须首先确保它已复制到所有进程(但不一定复制到每个进程的本地设备),例如使用 jax.device_put。在上面的示例中,我们可以在末尾写:
w = device_put(z, P(None, None))
if jax.process_index() == 0:
print(w)
请注意不要在 if process_index() == 0 下写入 jax.device_put(),因为这会导致死锁,因为只有进程 0 发起了集体通信并无限期等待其他进程。jax.experimental.multihost_utils 模块包含一些函数,可以更轻松地处理全局 jax.Array(例如,jax.experimental.multihost_utils.process_allgather())。
或者,要仅在进程本地数据上打印或执行其他 Python 操作,我们可以访问 z.addressable_shards。访问该属性不需要任何通信,因此任何子集进程都可以执行它而无需其他进程。该属性在 jax.jit() 下不可用。
从外部数据创建跨进程数组#
从外部数据源(例如数据加载器的 numpy 数组)创建跨进程 jax.Array 有三种主要方法:
在所有进程上创建或加载完整数组,然后使用
jax.device_put()分片到设备;在每个进程上创建或加载一个数组,该数组仅代表将在本地分片并存储在该进程设备上的数据,然后使用
jax.make_array_from_process_local_data()分片到设备;在每个进程的设备上创建或加载单独的数组,每个数组代表将在该设备上存储的数据,然后使用
jax.make_array_from_single_device_arrays()在没有任何数据移动的情况下将它们组装起来。
后两种方法在实践中最为常用,因为将全部全局数据物化到每个进程中通常成本过高。
上面的玩具示例使用了 jax.device_put()。
jax.make_array_from_process_local_data() 通常用于分布式数据加载。它不像 jax.make_array_from_single_device_arrays() 那样通用,因为它不直接指定进程本地数据的哪个切片会放在每个本地设备上。这在加载数据并行批次时很方便,因为哪个微批次放在哪个设备上并不重要。例如:
# target (micro)batch size across the whole cluster
batch_size = 1024
# how many examples each process should load per batch
per_process_batch_size = batch_size // jax.process_count()
# how many examples each device will process per batch
per_device_batch_size = batch_size // jax.device_count()
# make a data-parallel mesh and sharding
mesh = jax.make_mesh((jax.device_count(),), ('batch'))
sharding = NamedSharding(mesh, P('batch'))
# our "data loader". each process loads a different set of "examples".
process_batch = np.random.rand(per_process_batch_size, 2048, 42)
# assemble a global array containing the per-process batches from all processes
global_batch = jax.make_array_from_process_local_data(sharding, process_batch)
# sanity check that everything got sharded correctly
assert global_batch.shape[0] == batch_size
assert process_batch.shape[0] == per_process_batch_size
assert global_batch.addressable_shards[0].data.shape[0] == per_device_batch_size
jax.make_array_from_single_device_arrays() 是构建跨进程数组的最通用方法。它通常在执行 jax.device_put() 以将每个设备所需数据发送给它之后使用。这是最低级别的选项,因为所有数据移动都是手动执行的(例如通过 jax.device_put())。这是一个例子:
shape = (jax.process_count(), jax.local_device_count())
mesh = jax.make_mesh(shape, ('i', 'j'))
sharding = NamedSharding(mesh, P('i', 'j'))
# manually create per-device data equivalent to np.arange(jax.device_count())
# i.e. each device will get a single scalar value from 0..N
local_arrays = [
jax.device_put(
jnp.array([[jax.process_index() * jax.local_device_count() + i]]),
device)
for i, device in enumerate(jax.local_devices())
]
# assemble a global array from the local_arrays across all processes
global_array = jax.make_array_from_single_device_arrays(
shape=shape,
sharding=sharding,
arrays=local_arrays)
# sanity check
assert (np.all(
jax.experimental.multihost_utils.process_allgather(global_array) ==
np.arange(jax.device_count()).reshape(global_array.shape)))