并行编程简介#

本教程旨在介绍 JAX 中单程序多数据 (SPMD) 代码的设备并行性。SPMD 是一种并行技术,其中相同的计算(例如神经网络的前向传播)可以在不同的输入数据(例如,批次中的不同输入)上并行运行,并在不同的设备(例如多个 GPU 或 Google TPU)上执行。

本教程涵盖了三种并行计算模式

  • 通过 jax.jit() 进行自动分片:编译器选择最优的计算策略(又称“编译器接管”)。

  • 显式分片(*新特性*)与自动分片类似,都是编写全局视图程序。不同之处在于,每个数组的分片是数组 JAX 级别类型的一部分,使其成为编程模型的显式组成部分。这些分片在 JAX 级别传播,并且可以在追踪时查询。编译器仍然负责将整个数组程序转换为每个设备的程序(例如将 jnp.sum 转换为 psum),但编译器受到用户提供的分片的严格约束。

  • 使用 jax.shard_map() 进行完全手动分片和手动控制shard_map 支持每个设备的程序和显式通信集合操作。

总结表

模式

视图?

显式分片?

显式集合操作?

自动

全局

显式

全局

手动

按设备

利用这些 SPMD 思路,你可以将为一个设备编写的函数转换为可以在多个设备上并行运行的函数。

import jax

jax.config.update('jax_num_cpu_devices', 8)
jax.devices()
[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7)]

关键概念:数据分片#

下面所有分布式计算方法的关键是数据分片的概念,它描述了数据如何在可用设备上布局。

JAX 如何理解数据在设备上的布局?JAX 的数据类型 jax.Array 不可变数组数据结构表示具有跨一个或多个设备的物理存储的数组,并有助于使并行成为 JAX 的核心特性。jax.Array 对象在设计时就考虑了分布式数据和计算。每个 jax.Array 都关联一个 jax.sharding.Sharding 对象,该对象描述了每个全局设备需要全局数据的哪个分片。从头开始创建 jax.Array 时,还需要创建其 Sharding

在最简单的情况下,数组在单个设备上分片,如下所示

import numpy as np
import jax.numpy as jnp

arr = jnp.arange(32.0).reshape(4, 8)
arr.devices()
{CpuDevice(id=0)}
arr.sharding
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)

为了更直观地表示存储布局,jax.debug 模块提供了一些帮助器来可视化数组的分片。例如,jax.debug.visualize_array_sharding() 显示了数组如何存储在单个设备的内存中

jax.debug.visualize_array_sharding(arr)
                                                  
                                                  
                                                  
                                                  
                                                  
                      CPU 0                       
                                                  
                                                  
                                                  
                                                  
                                                  

要创建具有非平凡分片的数组,你可以为数组定义一个 jax.sharding 规范,并将其传递给 jax.device_put()

在这里,定义一个 NamedSharding,它指定了一个具有命名轴的 N 维设备网格,其中 jax.sharding.Mesh 允许精确的设备放置。

from jax.sharding import PartitionSpec as P

mesh = jax.make_mesh((2, 4), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
print(sharding)
NamedSharding(mesh=Mesh('x': 2, 'y': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('x', 'y'), memory_kind=unpinned_host)

将此 Sharding 对象传递给 jax.device_put(),即可获得一个分片数组。

arr_sharded = jax.device_put(arr, sharding)

print(arr_sharded)
jax.debug.visualize_array_sharding(arr_sharded)
[[ 0.  1.  2.  3.  4.  5.  6.  7.]
 [ 8.  9. 10. 11. 12. 13. 14. 15.]
 [16. 17. 18. 19. 20. 21. 22. 23.]
 [24. 25. 26. 27. 28. 29. 30. 31.]]
                                                
                                                
   CPU 0       CPU 1       CPU 2       CPU 3    
                                                
                                                
                                                
                                                
                                                
   CPU 4       CPU 5       CPU 6       CPU 7    
                                                
                                                
                                                

1. 通过 jit 实现自动并行#

一旦你有了分片数据,进行并行计算最简单的方法就是将数据传递给一个由 jax.jit() 编译的函数!在 JAX 中,你只需要指定代码的输入和输出如何分区,编译器就会搞清楚如何:1) 分区内部的所有内容;以及 2) 编译设备间通信。

jit 背后的 XLA 编译器包含用于优化跨多个设备的计算的启发式方法。在最简单的情况下,这些启发式方法归结为计算跟随数据

为了演示 JAX 中自动并行化的工作原理,下面是一个使用 jax.jit() 修饰的 staged-out 函数的示例:这是一个简单的元素级函数,其中每个分片的计算将在与该分片关联的设备上执行,并且输出以相同的方式分片。

@jax.jit
def f_elementwise(x):
  return 2 * jnp.sin(x) + 1

result = f_elementwise(arr_sharded)

print("shardings match:", result.sharding == arr_sharded.sharding)
shardings match: True

随着计算变得更加复杂,编译器会决定如何最好地传播数据的分片。

在这里,你沿着 x 的主轴求和,并可视化结果值如何存储在多个设备上(使用 jax.debug.visualize_array_sharding())。

@jax.jit
def f_contract(x):
  return x.sum(axis=0)

result = f_contract(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)
 CPU 0,4  CPU 1,5  CPU 2,6  CPU 3,7 
                                    
[48. 52. 56. 60. 64. 68. 72. 76.]

结果是部分复制的:也就是说,数组的前两个元素在设备 04 上复制,第二个在 15 上复制,依此类推。

2. 显式分片#

显式分片(又称“类型内分片”)的主要思想是,值的 JAX 级别类型包含对该值如何分片的描述。我们可以使用 jax.typeof 查询任何 JAX 值(或 Numpy 数组,或 Python 标量)的 JAX 级别类型。

some_array = np.arange(8)
print(f"JAX-level type of some_array: {jax.typeof(some_array)}")
JAX-level type of some_array: int32[8]

重要的是,即使在 jit 下进行追踪时,我们也可以查询类型(JAX 级别类型几乎可以定义为“我们在 jit 下可以访问的有关值的信息”)。

@jax.jit
def foo(x):
  print(f"JAX-level type of x during tracing: {jax.typeof(x)}")
  return x + x

foo(some_array)
JAX-level type of x during tracing: int32[8]
Array([ 0,  2,  4,  6,  8, 10, 12, 14], dtype=int32)

为了开始在类型中看到分片,我们需要设置一个显式分片网格。

from jax.sharding import AxisType

mesh = jax.make_mesh((2, 4), ("X", "Y"),
                     axis_types=(AxisType.Explicit, AxisType.Explicit))

现在我们可以创建一些分片数组了

replicated_array = np.arange(8).reshape(4, 2)
sharded_array = jax.device_put(replicated_array, jax.NamedSharding(mesh, P("X", None)))

print(f"replicated_array type: {jax.typeof(replicated_array)}")
print(f"sharded_array type: {jax.typeof(sharded_array)}")
replicated_array type: int32[4,2]
sharded_array type: int32[4@X,2]

我们应该将类型 int32[4@X, 2] 理解为“一个 4x2 的 32 位整数数组,其第一维沿网格轴 'X' 分片。该数组沿所有其他网格轴复制”。

这些与 JAX 级别类型关联的分片会通过操作传播。例如

arg0 = jax.device_put(np.arange(4).reshape(4, 1),
                      jax.NamedSharding(mesh, P("X", None)))
arg1 = jax.device_put(np.arange(8).reshape(1, 8),
                      jax.NamedSharding(mesh, P(None, "Y")))

@jax.jit
def add_arrays(x, y):
  ans = x + y
  print(f"x sharding: {jax.typeof(x)}")
  print(f"y sharding: {jax.typeof(y)}")
  print(f"ans sharding: {jax.typeof(ans)}")
  return ans

with jax.sharding.use_mesh(mesh):
  add_arrays(arg0, arg1)
x sharding: int32[4@X,1]
y sharding: int32[1,8@Y]
ans sharding: int32[4@X,8@Y]

这就是其要旨。分片在追踪时确定性地传播,我们可以在追踪时查询它们。

3. 使用 shard_map 进行手动并行#

在上面探讨的自动并行方法中,你可以编写一个函数,就好像它操作整个数据集一样,然后 jit 会将该计算拆分到多个设备上。相比之下,使用 jax.shard_map() 时,你编写的函数将处理单个数据分片,而 shard_map 将构建完整的函数。

shard_map 通过在特定设备网格上映射函数来工作(shard_map 映射分片)。在下面的示例中

  • 与之前一样,jax.sharding.Mesh 允许精确的设备放置,并带有用于逻辑和物理轴名称的轴名称参数。

  • 参数 in_specs 决定分片大小。参数 out_specs 标识如何将这些块重新组合在一起。

注意:如果需要,jax.shard_map() 代码可以在 jax.jit() 内部工作。

mesh = jax.make_mesh((8,), ('x',))

f_elementwise_sharded = jax.shard_map(
    f_elementwise,
    mesh=mesh,
    in_specs=P('x'),
    out_specs=P('x'))

arr = jnp.arange(32)
f_elementwise_sharded(arr)
Array([ 1.        ,  2.682942  ,  2.818595  ,  1.28224   , -0.513605  ,
       -0.9178486 ,  0.44116902,  2.3139732 ,  2.9787164 ,  1.824237  ,
       -0.08804226, -0.99998045, -0.07314587,  1.840334  ,  2.9812148 ,
        2.3005757 ,  0.42419338, -0.92279494, -0.50197446,  1.2997544 ,
        2.8258905 ,  2.6733112 ,  0.98229736, -0.69244087, -0.81115675,
        0.7352965 ,  2.525117  ,  2.912752  ,  1.5418116 , -0.32726777,
       -0.97606325,  0.19192469], dtype=float32)

你编写的函数只“看到”单个批次的数据,你可以通过打印设备局部形状来验证。

x = jnp.arange(32)
print(f"global shape: {x.shape=}")

def f(x):
  print(f"device local shape: {x.shape=}")
  return x * 2

y = jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
global shape: x.shape=(32,)
device local shape: x.shape=(4,)

因为你的每个函数只“看到”数据的设备局部部分,这意味着聚合类函数需要一些额外的思考。

例如,以下是 jax.numpy.sum()shard_map 示例

def f(x):
  return jnp.sum(x, keepdims=True)

jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
Array([  6,  22,  38,  54,  70,  86, 102, 118], dtype=int32)

你的函数 f 在每个分片上独立操作,结果求和反映了这一点。

如果你想跨分片求和,你需要使用集合操作(例如 jax.lax.psum())显式请求。

def f(x):
  sum_in_shard = x.sum()
  return jax.lax.psum(sum_in_shard, 'x')

jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x)
Array(496, dtype=int32)

由于输出不再具有分片维度,请设置 out_specs=P()(回想一下,out_specs 参数标识了在 shard_map 中如何将这些块重新组合在一起)。

比较这三种方法#

牢记这些概念,让我们比较一下简单神经网络层的三种方法。

首先定义你的规范函数,如下所示

@jax.jit
def layer(x, weights, bias):
  return jax.nn.sigmoid(x @ weights + bias)
import numpy as np
rng = np.random.default_rng(0)

x = rng.normal(size=(32,))
weights = rng.normal(size=(32, 4))
bias = rng.normal(size=(4,))

layer(x, weights, bias)
Array([0.02138916, 0.8931118 , 0.5989196 , 0.9774251 ], dtype=float32)

你可以使用 jax.jit() 并传递适当分片的数据,以分布式方式自动运行此函数。

如果你对 x 的主轴进行分片,并使 weights 完全复制,那么矩阵乘法将自动并行发生。

mesh = jax.make_mesh((8,), ('x',))
x_sharded = jax.device_put(x, jax.NamedSharding(mesh, P('x')))
weights_sharded = jax.device_put(weights, jax.NamedSharding(mesh, P()))

layer(x_sharded, weights_sharded, bias)
Array([0.02138916, 0.8931118 , 0.5989196 , 0.9774251 ], dtype=float32)

或者,你也可以使用显式分片模式

explicit_mesh = jax.make_mesh((8,), ('X',), axis_types=(AxisType.Explicit,))

x_sharded = jax.device_put(x, jax.NamedSharding(explicit_mesh, P('X')))
weights_sharded = jax.device_put(weights, jax.NamedSharding(explicit_mesh, P()))

@jax.jit
def layer_auto(x, weights, bias):
  print(f"x sharding: {jax.typeof(x)}")
  print(f"weights sharding: {jax.typeof(weights)}")
  print(f"bias sharding: {jax.typeof(bias)}")
  out = layer(x, weights, bias)
  print(f"out sharding: {jax.typeof(out)}")
  return out

with jax.sharding.use_mesh(explicit_mesh):
  layer_auto(x_sharded, weights_sharded, bias)
x sharding: float32[32@X]
weights sharding: float32[32,4]
bias sharding: float32[4]
out sharding: float32[4]

最后,你可以使用 shard_map 完成相同的事情,使用 jax.lax.psum() 来指示矩阵乘法所需的跨分片集合操作。

from functools import partial

@jax.jit
@partial(jax.shard_map, mesh=mesh,
         in_specs=(P('x'), P('x', None), P(None)),
         out_specs=P(None))
def layer_sharded(x, weights, bias):
  return jax.nn.sigmoid(jax.lax.psum(x @ weights, 'x') + bias)

layer_sharded(x, weights, bias)
Array([0.02138916, 0.8931118 , 0.5989196 , 0.9774251 ], dtype=float32)

下一步#

本教程简要介绍了 JAX 中的分片和并行计算。

要深入了解每种 SPMD 方法,请查阅以下文档