并行编程入门#

本教程旨在介绍 JAX 中单程序多数据(SPMD)代码的设备并行性。SPMD 是一种并行技术,其中相同的计算(例如神经网络的前向传播)可以在不同设备(例如多个 GPU 或 Google TPU)上并行运行,处理不同的输入数据(例如批次中的不同输入)。

本教程涵盖三种并行计算模式:

  • 通过 jax.jit() 实现自动分片:编译器选择最优计算策略(也称为“编译器掌舵”)。

  • 显式分片 (*新*) 与自动分片类似,因为您编写的是全局视图程序。区别在于,每个数组的分片是数组的 JAX 级别类型的一部分,使其成为编程模型中的显式部分。这些分片在 JAX 级别传播,并在跟踪时可查询。编译器仍负责将整个数组程序转换为每个设备的程序(例如,将 jnp.sum 转换为 psum),但编译器受到用户提供的分片的严格约束。

  • 使用 jax.shard_map() 完全手动分片和手动控制shard_map 支持每个设备的代码和显式通信集合。

总结表

模式

视图?

显式分片?

显式集合?

自动

全局

显式

全局

手动

每个设备

利用这些 SPMD 的思想,您可以将为单个设备编写的函数转换为可在多个设备上并行运行的函数。

import jax

jax.config.update('jax_num_cpu_devices', 8)
jax.devices()
[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7)]

核心概念:数据分片#

下面所有分布式计算方法的核心是 *数据分片* 概念,它描述了数据如何在可用设备上布局。

JAX 如何理解数据在设备之间的布局?JAX 的数据类型,即 jax.Array 不可变数组数据结构,表示跨越一个或多个设备的物理存储的数组,并有助于使并行性成为 JAX 的核心功能。jax.Array 对象在设计时就考虑了分布式数据和计算。每个 jax.Array 都有一个关联的 jax.sharding.Sharding 对象,该对象描述了每个全局设备所需的全局数据的哪个分片。当您从头开始创建 jax.Array 时,还需要创建其 Sharding

在最简单的情况下,数组在单个设备上分片,如下所示:

import numpy as np
import jax.numpy as jnp

arr = jnp.arange(32.0).reshape(4, 8)
arr.devices()
{CpuDevice(id=0)}
arr.sharding
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)

为了更直观地表示存储布局,jax.debug 模块提供了一些辅助工具来可视化数组的分片。例如,jax.debug.visualize_array_sharding() 显示数组如何在单个设备的内存中存储。

jax.debug.visualize_array_sharding(arr)
                                                  
                                                  
                                                  
                                                  
                                                  
                      CPU 0                       
                                                  
                                                  
                                                  
                                                  
                                                  

要创建具有非平凡分片的数组,您可以为数组定义一个 jax.sharding 规范,并将其传递给 jax.device_put()

在此,我们定义了一个 NamedSharding,它指定了一个具有命名轴的 N 维设备网格,其中 jax.sharding.Mesh 允许精确的设备放置。

from jax.sharding import PartitionSpec as P

mesh = jax.make_mesh((2, 4), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
print(sharding)
NamedSharding(mesh=Mesh('x': 2, 'y': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('x', 'y'), memory_kind=device)

将此 Sharding 对象传递给 jax.device_put(),您可以获得一个分片数组。

arr_sharded = jax.device_put(arr, sharding)

print(arr_sharded)
jax.debug.visualize_array_sharding(arr_sharded)
[[ 0.  1.  2.  3.  4.  5.  6.  7.]
 [ 8.  9. 10. 11. 12. 13. 14. 15.]
 [16. 17. 18. 19. 20. 21. 22. 23.]
 [24. 25. 26. 27. 28. 29. 30. 31.]]
                                                
                                                
   CPU 0       CPU 1       CPU 2       CPU 3    
                                                
                                                
                                                
                                                
                                                
   CPU 4       CPU 5       CPU 6       CPU 7    
                                                
                                                
                                                

1. 通过 jit 实现自动并行#

一旦有了分片数据,进行并行计算的最简单方法就是将数据传递给经过 jax.jit() 编译的函数!在 JAX 中,您只需指定代码的输入和输出的分区方式,编译器就会找出:1) 如何分区内部的所有内容;以及 2) 如何编译设备间的通信。

jit 背后的 XLA 编译器包含跨多个设备优化计算的启发式方法。在最简单的情况下,这些启发式方法归结为*计算跟随数据*。

为了演示 JAX 中的自动并行化是如何工作的,下面是一个使用 jax.jit() 装饰的阶段化函数的示例:这是一个简单的逐元素函数,其中每个分片的计算将在与该分片关联的设备上执行,输出的分片方式也相同。

@jax.jit
def f_elementwise(x):
  return 2 * jnp.sin(x) + 1

result = f_elementwise(arr_sharded)

print("shardings match:", result.sharding == arr_sharded.sharding)
shardings match: True

随着计算变得更加复杂,编译器会决定如何最好地传播数据的分片。

在这里,您沿着 x 的前导轴进行求和,并可视化结果值如何在多个设备之间存储(使用 jax.debug.visualize_array_sharding())。

@jax.jit
def f_contract(x):
  return x.sum(axis=0)

result = f_contract(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)
 CPU 0,4  CPU 1,5  CPU 2,6  CPU 3,7 
                                    
[48. 52. 56. 60. 64. 68. 72. 76.]

结果是部分复制的:也就是说,数组的前两个元素复制在设备 04 上,第二个元素复制在 15 上,依此类推。

2. 显式分片#

显式分片(也称为“类型内分片”)的主要思想是,JAX 级别的*类型*包含关于值如何分片的信息。我们可以使用 jax.typeof 查询任何 JAX 值(或 Numpy 数组、Python 标量)的 JAX 级别类型。

some_array = np.arange(8)
print(f"JAX-level type of some_array: {jax.typeof(some_array)}")
JAX-level type of some_array: int32[8]

重要的是,我们可以在 jit 下跟踪时查询类型(JAX 级别类型几乎*定义*为“在 jit 下我们可以访问的值的信息”)。

@jax.jit
def foo(x):
  print(f"JAX-level type of x during tracing: {jax.typeof(x)}")
  return x + x

foo(some_array)
JAX-level type of x during tracing: int32[8]
Array([ 0,  2,  4,  6,  8, 10, 12, 14], dtype=int32)

为了开始在类型中看到分片,我们需要设置一个显式分片网格。

from jax.sharding import AxisType

mesh = jax.make_mesh((2, 4), ("X", "Y"),
                     axis_types=(AxisType.Explicit, AxisType.Explicit))

现在我们可以创建一些分片数组。

replicated_array = np.arange(8).reshape(4, 2)
sharded_array = jax.device_put(replicated_array, jax.NamedSharding(mesh, P("X", None)))

print(f"replicated_array type: {jax.typeof(replicated_array)}")
print(f"sharded_array type: {jax.typeof(sharded_array)}")
replicated_array type: int32[4,2]
sharded_array type: int32[4@X,2]

我们将类型 int32[4@X, 2] 读取为“一个 4x2 的 32 位整数数组,其第一维沿网格轴 'X' 分片。数组在所有其他网格轴上复制。”

与 JAX 级别类型相关联的分片会通过操作进行传播。例如:

arg0 = jax.device_put(np.arange(4).reshape(4, 1),
                      jax.NamedSharding(mesh, P("X", None)))
arg1 = jax.device_put(np.arange(8).reshape(1, 8),
                      jax.NamedSharding(mesh, P(None, "Y")))

@jax.jit
def add_arrays(x, y):
  ans = x + y
  print(f"x sharding: {jax.typeof(x)}")
  print(f"y sharding: {jax.typeof(y)}")
  print(f"ans sharding: {jax.typeof(ans)}")
  return ans

with jax.set_mesh(mesh):
  add_arrays(arg0, arg1)
x sharding: int32[4@X,1]
y sharding: int32[1,8@Y]
ans sharding: int32[4@X,8@Y]

这就是它的要点。分片在跟踪时确定性地传播,并且我们可以在跟踪时查询它们。

3. 使用 shard_map 手动并行#

在上面探索的自动并行方法中,您可以像操作整个数据集一样编写函数,jit 会将该计算分配到多个设备上。相比之下,使用 jax.shard_map(),您编写将处理单个数据分片的函数,而 shard_map 将构造完整的函数。

shard_map 的工作原理是跨特定的设备*网格*映射一个函数(shard_map 映射分片)。在下面的示例中:

  • 与之前一样,jax.sharding.Mesh 允许精确的设备放置,其中 axis names 参数用于逻辑和物理轴名称。

  • in_specs 参数决定了分片的大小。out_specs 参数标识了块如何重新组合。

注意: 如果需要,jax.shard_map() 代码可以在 jax.jit() 内部工作。

mesh = jax.make_mesh((8,), ('x',))

f_elementwise_sharded = jax.shard_map(
    f_elementwise,
    mesh=mesh,
    in_specs=P('x'),
    out_specs=P('x'))

arr = jnp.arange(32)
f_elementwise_sharded(arr)
Array([ 1.        ,  2.682942  ,  2.818595  ,  1.28224   , -0.513605  ,
       -0.9178486 ,  0.44116902,  2.3139732 ,  2.9787164 ,  1.824237  ,
       -0.08804226, -0.99998045, -0.07314587,  1.840334  ,  2.9812148 ,
        2.3005757 ,  0.42419338, -0.92279494, -0.50197446,  1.2997544 ,
        2.8258905 ,  2.6733112 ,  0.98229736, -0.69244087, -0.81115675,
        0.7352965 ,  2.525117  ,  2.912752  ,  1.5418116 , -0.32726777,
       -0.97606325,  0.19192469], dtype=float32)

您编写的函数只“看到”数据的一个分片,您可以通过打印设备局部形状来检查这一点。

x = jnp.arange(32)
print(f"global shape: {x.shape=}")

def f(x):
  print(f"device local shape: {x.shape=}")
  return x * 2

y = jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
global shape: x.shape=(32,)
device local shape: x.shape=(4,)

因为您的每个函数只“看到”数据中设备局部的部分,这意味着像聚合这样的函数需要额外考虑。

例如,这里是 shard_mapjax.numpy.sum() 的样子:

def f(x):
  return jnp.sum(x, keepdims=True)

jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
Array([  6,  22,  38,  54,  70,  86, 102, 118], dtype=int32)

您的函数 f 在每个分片上独立操作,最终的求和反映了这一点。

如果您想跨分片求和,您需要使用 jax.lax.psum() 等集合操作显式请求它。

def f(x):
  sum_in_shard = x.sum()
  return jax.lax.psum(sum_in_shard, 'x')

jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x)
Array(496, dtype=int32)

由于输出不再具有分片维度,因此将 out_specs=P()(回想一下 out_specs 参数标识了在 shard_map 中块如何重新组合)。

三种方法的比较#

带着这些概念,让我们通过一个简单的神经网络层来比较这三种方法。

首先,像这样定义您的规范函数:

@jax.jit
def layer(x, weights, bias):
  return jax.nn.sigmoid(x @ weights + bias)
import numpy as np
rng = np.random.default_rng(0)

x = rng.normal(size=(32,))
weights = rng.normal(size=(32, 4))
bias = rng.normal(size=(4,))

layer(x, weights, bias)
Array([0.02138916, 0.8931118 , 0.5989196 , 0.9774251 ], dtype=float32)

您可以使用 jax.jit() 并传递适当分片的数据,以自动以分布式方式运行此函数。

如果您对 x 的前导轴进行分片,并使 weights 完全复制,那么矩阵乘法将自动并行进行。

mesh = jax.make_mesh((8,), ('x',))
x_sharded = jax.device_put(x, jax.NamedSharding(mesh, P('x')))
weights_sharded = jax.device_put(weights, jax.NamedSharding(mesh, P()))

layer(x_sharded, weights_sharded, bias)
Array([0.02138916, 0.8931118 , 0.5989196 , 0.9774251 ], dtype=float32)

或者,您也可以使用显式分片模式:

explicit_mesh = jax.make_mesh((8,), ('X',), axis_types=(AxisType.Explicit,))

x_sharded = jax.device_put(x, jax.NamedSharding(explicit_mesh, P('X')))
weights_sharded = jax.device_put(weights, jax.NamedSharding(explicit_mesh, P()))

@jax.jit
def layer_auto(x, weights, bias):
  print(f"x sharding: {jax.typeof(x)}")
  print(f"weights sharding: {jax.typeof(weights)}")
  print(f"bias sharding: {jax.typeof(bias)}")
  out = layer(x, weights, bias)
  print(f"out sharding: {jax.typeof(out)}")
  return out

with jax.set_mesh(explicit_mesh):
  layer_auto(x_sharded, weights_sharded, bias)
x sharding: float32[32@X]
weights sharding: float32[32,4]
bias sharding: float32[4]
out sharding: float32[4]

最后,您也可以使用 shard_map 来完成同样的事情,使用 jax.lax.psum() 来指示矩阵乘积所需的跨分片集合。

from functools import partial

@jax.jit
@partial(jax.shard_map, mesh=mesh,
         in_specs=(P('x'), P('x', None), P(None)),
         out_specs=P(None))
def layer_sharded(x, weights, bias):
  return jax.nn.sigmoid(jax.lax.psum(x @ weights, 'x') + bias)

layer_sharded(x, weights, bias)
Array([0.02138916, 0.8931118 , 0.5989196 , 0.9774251 ], dtype=float32)

下一步#

本教程简要介绍了 JAX 中的分片和并行计算。

要深入了解每种 SPMD 方法,请参阅以下文档: