并行编程简介#

本教程介绍 JAX 中单程序多数据 (SPMD) 代码的设备并行性。SPMD 是一种并行技术,其中相同的计算(例如神经网络的前向传递)可以在不同设备(例如多个 GPU 或 Google TPU)上并行地在不同的输入数据(例如,批次中的不同输入)上运行。

本教程涵盖三种并行计算模式

  • 通过 jax.jit() 进行自动分片:编译器选择最佳计算策略(又名“编译器掌握方向盘”)。

  • 显式分片 (*new*) 类似于自动分片,因为您正在编写全局视图程序。不同之处在于,每个数组的分片是数组 JAX 级别类型的一部分,使其成为编程模型的显式部分。这些分片在 JAX 级别传播,并且在跟踪时可查询。将整个数组程序转换为每设备程序(例如,将 jnp.sum 转换为 psum)仍然是编译器的责任,但编译器受到用户提供的分片的严格约束。

  • 使用 jax.experimental.shard_map.shard_map() 进行完全手动分片和手动控制shard_map 启用每设备代码和显式通信集合

摘要表

模式

显式分片?

显式集合?

自动

显式(新)

手动

使用这些 SPMD 思维模式,您可以将为一台设备编写的函数转换为可以在多台设备上并行运行的函数。

import jax

jax.config.update('jax_num_cpu_devices', 8)
jax.devices()
[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7)]

关键概念:数据分片#

以下所有分布式计算方法的关键是数据分片的概念,它描述了数据如何在可用设备上布局。

JAX 如何理解数据如何在设备之间布局?JAX 的数据类型 jax.Array 不可变数组数据结构表示跨越一个或多个设备的物理存储的数组,并有助于使并行性成为 JAX 的核心功能。jax.Array 对象在设计时考虑了分布式数据和计算。每个 jax.Array 都有一个关联的 jax.sharding.Sharding 对象,它描述了每个全局设备所需的全局数据的哪个分片。当您从头开始创建 jax.Array 时,您还需要创建其 Sharding

在最简单的情况下,数组在单个设备上分片,如下所示

import numpy as np
import jax.numpy as jnp

arr = jnp.arange(32.0).reshape(4, 8)
arr.devices()
{CpuDevice(id=0)}
arr.sharding
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)

为了更直观地表示存储布局,jax.debug 模块提供了一些帮助程序来可视化数组的分片。例如,jax.debug.visualize_array_sharding() 显示数组如何在单个设备的内存中存储

jax.debug.visualize_array_sharding(arr)
                                                  
                                                  
                                                  
                                                  
                                                  
                      CPU 0                       
                                                  
                                                  
                                                  
                                                  
                                                  

要创建具有非平凡分片的数组,您可以为数组定义 jax.sharding 规范,并将其传递给 jax.device_put()

在这里,定义一个 NamedSharding,它指定具有命名轴的 N 维设备网格,其中 jax.sharding.Mesh 允许精确的设备放置

from jax.sharding import PartitionSpec as P

mesh = jax.make_mesh((2, 4), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
print(sharding)
NamedSharding(mesh=Mesh('x': 2, 'y': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('x', 'y'), memory_kind=unpinned_host)

将此 Sharding 对象传递给 jax.device_put(),您可以获得一个分片数组

arr_sharded = jax.device_put(arr, sharding)

print(arr_sharded)
jax.debug.visualize_array_sharding(arr_sharded)
[[ 0.  1.  2.  3.  4.  5.  6.  7.]
 [ 8.  9. 10. 11. 12. 13. 14. 15.]
 [16. 17. 18. 19. 20. 21. 22. 23.]
 [24. 25. 26. 27. 28. 29. 30. 31.]]
                                                
                                                
   CPU 0       CPU 1       CPU 2       CPU 3    
                                                
                                                
                                                
                                                
                                                
   CPU 4       CPU 5       CPU 6       CPU 7    
                                                
                                                
                                                

1. 通过 jit 自动并行#

一旦您拥有分片数据,进行并行计算的最简单方法是简单地将数据传递给 jax.jit() 编译的函数!在 JAX 中,您只需要指定您希望如何对代码的输入和输出进行分区,编译器将确定如何:1) 对内部的所有内容进行分区;以及 2) 编译设备间通信。

jit 背后的 XLA 编译器包含用于优化跨多个设备的计算的启发式方法。在最简单的情况下,这些启发式方法归结为计算跟随数据

为了演示自动并行化在 JAX 中如何工作,下面是一个示例,该示例使用 jax.jit() 装饰的 staged-out 函数:它是一个简单的逐元素函数,其中每个分片的计算将在与该分片关联的设备上执行,并且输出以相同的方式分片

@jax.jit
def f_elementwise(x):
  return 2 * jnp.sin(x) + 1

result = f_elementwise(arr_sharded)

print("shardings match:", result.sharding == arr_sharded.sharding)
shardings match: True

随着计算变得更加复杂,编译器会决定如何最好地传播数据的分片。

在这里,您沿着 x 的前导轴求和,并可视化结果值如何在多个设备上存储(使用 jax.debug.visualize_array_sharding()

@jax.jit
def f_contract(x):
  return x.sum(axis=0)

result = f_contract(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)
 CPU 0,4  CPU 1,5  CPU 2,6  CPU 3,7 
                                    
[48. 52. 56. 60. 64. 68. 72. 76.]

结果是部分复制的:也就是说,数组的前两个元素在设备 06 上复制,第二个元素在 17 上复制,依此类推。

2. 显式分片#

显式分片(又名类型内分片)背后的主要思想是,JAX 级别的值的类型包括如何分片值的描述。我们可以使用 jax.typeof 查询任何 JAX 值(或 Numpy 数组或 Python 标量)的 JAX 级别类型

some_array = np.arange(8)
print(f"JAX-level type of some_array: {jax.typeof(some_array)}")
JAX-level type of some_array: ShapedArray(int32[8])

重要的是,即使在 jit 下跟踪时,我们也可以查询类型(JAX 级别类型几乎定义为“我们在 jit 下可以访问的关于值的信息)。

@jax.jit
def foo(x):
  print(f"JAX-level type of x during tracing: {jax.typeof(x)}")
  return x + x

foo(some_array)
JAX-level type of x during tracing: ShapedArray(int32[8])
Array([ 0,  2,  4,  6,  8, 10, 12, 14], dtype=int32)

要开始在类型中看到分片,我们需要设置一个显式分片网格。

from jax.sharding import AxisType

mesh = jax.make_mesh((2, 4), ("X", "Y"),
                     axis_types=(AxisType.Explicit, AxisType.Explicit))

现在我们可以创建一些分片数组

replicated_array = np.arange(8).reshape(4, 2)
sharded_array = jax.device_put(replicated_array, jax.NamedSharding(mesh, P("X", None)))

print(f"replicated_array type: {jax.typeof(replicated_array)}")
print(f"sharded_array type: {jax.typeof(sharded_array)}")
replicated_array type: ShapedArray(int32[4,2])
sharded_array type: ShapedArray(int32[4@X,2])

我们应该将类型 f32[4@X, 2] 读取为“一个 4x2 的 32 位浮点数数组,其第一个维度沿网格轴 ‘X’ 分片。该数组沿所有其他网格轴复制”

与 JAX 级别类型关联的这些分片通过操作传播。例如

arg0 = jax.device_put(np.arange(4).reshape(4, 1),
                      jax.NamedSharding(mesh, P("X", None)))
arg1 = jax.device_put(np.arange(8).reshape(1, 8),
                      jax.NamedSharding(mesh, P(None, "Y")))

@jax.jit
def add_arrays(x, y):
  ans = x + y
  print(f"x sharding: {jax.typeof(x)}")
  print(f"y sharding: {jax.typeof(y)}")
  print(f"ans sharding: {jax.typeof(ans)}")
  return ans

with jax.sharding.use_mesh(mesh):
  add_arrays(arg0, arg1)
x sharding: ShapedArray(int32[4@X,1])
y sharding: ShapedArray(int32[1,8@Y])
ans sharding: ShapedArray(int32[4@X,8@Y])

这就是它的要点。分片在跟踪时确定性地传播,我们可以在跟踪时查询它们。

3. 使用 shard_map 的手动并行#

在上面探索的自动并行化方法中,您可以编写一个函数,就好像您正在对完整数据集进行操作一样,并且 jit 会将该计算拆分到多个设备上。相比之下,使用 jax.experimental.shard_map.shard_map() 您编写将处理单个数据分片的函数,并且 shard_map 将构造完整函数。

shard_map 的工作原理是将函数映射到特定的设备网格上(shard_map 映射到分片上)。在下面的示例中

  • 和以前一样,jax.sharding.Mesh 允许精确的设备放置,具有用于逻辑和物理轴名称的轴名称参数。

  • in_specs 参数确定分片大小。out_specs 参数标识如何将块组装在一起。

注意: 如果您需要,jax.experimental.shard_map.shard_map() 代码可以在 jax.jit() 内部工作。

from jax.experimental.shard_map import shard_map
mesh = jax.make_mesh((8,), ('x',))

f_elementwise_sharded = shard_map(
    f_elementwise,
    mesh=mesh,
    in_specs=P('x'),
    out_specs=P('x'))

arr = jnp.arange(32)
f_elementwise_sharded(arr)
Array([ 1.        ,  2.682942  ,  2.818595  ,  1.28224   , -0.513605  ,
       -0.9178486 ,  0.44116902,  2.3139732 ,  2.9787164 ,  1.824237  ,
       -0.08804226, -0.99998045, -0.07314587,  1.840334  ,  2.9812148 ,
        2.3005757 ,  0.42419338, -0.92279494, -0.50197446,  1.2997544 ,
        2.8258905 ,  2.6733112 ,  0.98229736, -0.69244087, -0.81115675,
        0.7352965 ,  2.525117  ,  2.912752  ,  1.5418116 , -0.32726777,
       -0.97606325,  0.19192469], dtype=float32)

您编写的函数仅“看到”数据的单个批次,您可以通过打印设备本地形状来检查

x = jnp.arange(32)
print(f"global shape: {x.shape=}")

def f(x):
  print(f"device local shape: {x.shape=}")
  return x * 2

y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
global shape: x.shape=(32,)
device local shape: x.shape=(4,)

因为您的每个函数仅“看到”数据的设备本地部分,这意味着类似聚合的函数需要一些额外的思考。

例如,这是一个 jax.numpy.sum()shard_map 的外观

def f(x):
  return jnp.sum(x, keepdims=True)

shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
Array([  6,  22,  38,  54,  70,  86, 102, 118], dtype=int32)

您的函数 f 在每个分片上单独操作,并且得到的求和反映了这一点。

如果您想跨分片求和,您需要使用像 jax.lax.psum() 这样的集合操作显式请求它

def f(x):
  sum_in_shard = x.sum()
  return jax.lax.psum(sum_in_shard, 'x')

shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x)
Array(496, dtype=int32)

因为输出不再具有分片维度,所以设置 out_specs=P()(回想一下,out_specs 参数标识如何在 shard_map 中将块组装在一起)。

比较三种方法#

有了这些概念,让我们比较一下简单神经网络层的三种方法。

首先像这样定义您的规范函数

@jax.jit
def layer(x, weights, bias):
  return jax.nn.sigmoid(x @ weights + bias)
import numpy as np
rng = np.random.default_rng(0)

x = rng.normal(size=(32,))
weights = rng.normal(size=(32, 4))
bias = rng.normal(size=(4,))

layer(x, weights, bias)
Array([0.02138916, 0.8931118 , 0.5989196 , 0.9774251 ], dtype=float32)

您可以使用 jax.jit() 并传递适当分片的数据来自动以分布式方式运行它。

如果您对 x 的前导轴进行分片并使 weights 完全复制,则矩阵乘法将自动并行发生

mesh = jax.make_mesh((8,), ('x',))
x_sharded = jax.device_put(x, jax.NamedSharding(mesh, P('x')))
weights_sharded = jax.device_put(weights, jax.NamedSharding(mesh, P()))

layer(x_sharded, weights_sharded, bias)
Array([0.02138916, 0.8931118 , 0.5989196 , 0.9774251 ], dtype=float32)

或者,您也可以使用显式分片模式

explicit_mesh = jax.make_mesh((8,), ('X',), axis_types=(AxisType.Explicit,))

x_sharded = jax.device_put(x, jax.NamedSharding(explicit_mesh, P('X')))
weights_sharded = jax.device_put(weights, jax.NamedSharding(explicit_mesh, P()))

@jax.jit
def layer_auto(x, weights, bias):
  print(f"x sharding: {jax.typeof(x)}")
  print(f"weights sharding: {jax.typeof(weights)}")
  print(f"bias sharding: {jax.typeof(bias)}")
  out = layer(x, weights, bias)
  print(f"out sharding: {jax.typeof(out)}")
  return out

with jax.sharding.use_mesh(explicit_mesh):
  layer_auto(x_sharded, weights_sharded, bias)
x sharding: ShapedArray(float32[32@X])
weights sharding: ShapedArray(float32[32,4])
bias sharding: ShapedArray(float32[4])
out sharding: ShapedArray(float32[4])

最后,您可以使用 shard_map 做同样的事情,使用 jax.lax.psum() 来指示矩阵乘积所需的跨分片集合

from functools import partial

@jax.jit
@partial(shard_map, mesh=mesh,
         in_specs=(P('x'), P('x', None), P(None)),
         out_specs=P(None))
def layer_sharded(x, weights, bias):
  return jax.nn.sigmoid(jax.lax.psum(x @ weights, 'x') + bias)

layer_sharded(x, weights, bias)
Array([0.02138916, 0.8931118 , 0.5989196 , 0.9774251 ], dtype=float32)

下一步#

本教程简要介绍了 JAX 中的分片和并行计算。

要深入了解每种 SPMD 方法,请查看以下文档