使用 shard_map
进行手动并行#
概述#
shard_map
是一种单程序多数据 (SPMD) 多设备并行 API,用于将函数映射到数据的分片上。映射的函数应用程序或实例通过显式集合通信操作相互通信。
shard_map
是 jit
中内置的基于编译器的自动并行化的补充,并且可以与其组合。使用 jit
,您可以编写代码,就好像它是针对单个设备一样,并且编译器可以自动将计算分区到多个设备上,在后台生成每个设备的代码和通信集合。使用 shard_map
,您可以控制,编写自己的分区代码和显式集合。或者您可以两者都做一点:跨设备组进行手动控制,同时将组内设备分区留给编译器。这两种方法可以根据需要混合、匹配和组合。
如果您熟悉 pmap
,请将 shard_map
视为一种演变。它更具表现力、性能更好,并且可以与其他 JAX API 组合。它甚至可以急切地工作,以便于调试!(有关更多信息,请参见与 pmap
的详细比较。)
通过阅读本教程,您将学习如何使用 shard_map
来完全控制您的多设备代码。您将详细了解它如何与 jax.jit
的自动并行化和 jax.grad
的自动微分组合。我们还将提供一些神经网络并行化策略的基本示例。
我们将假设本教程在具有八个设备的环境中运行
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices
那么,让我们看看 shard_map
吧!#
事不宜迟,这是一个玩具示例
from functools import partial
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
mesh = jax.make_mesh((4, 2), ('x', 'y'))
a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 * 4.).reshape(16, 4)
@partial(jax.shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),
out_specs=P('x', None))
def matmul_basic(a_block, b_block):
# a_block: f32[2, 8]
# b_block: f32[8, 4]
c_partialsum = jnp.dot(a_block, b_block)
c_block = jax.lax.psum(c_partialsum, 'y')
# c_block: f32[2, 4]
return c_block
c = matmul_basic(a, b) # c: f32[8, 4]
此函数通过执行本地块矩阵乘法,然后执行集合求和运算来并行计算矩阵乘法。我们可以检查结果是否正确
from jax.tree_util import tree_map, tree_all
def allclose(a, b):
return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))
allclose(c, jnp.dot(a, b))
True
结果沿其行分片
jax.debug.visualize_array_sharding(c)
CPU 0,1 CPU 2,3 CPU 4,5 CPU 6,7
在高层,shard_map
有点像 vmap
或 pmap
,因为我们将函数映射到数组数据的各个部分,但请注意
shard_map
将输入切片成块(并且输出是通过串联结果块形成的),保持秩不变,而vmap
会通过映射掉一个轴来降低秩;mesh
参数使我们可以精确控制计算和结果的设备放置;我们一次映射多个数据轴,并为集合设置多个轴名称(此处为
'x'
和'y'
);由于我们还没有使用
jax.jit
,因此所有内容都会被急切地评估,我们甚至可以print
中间值进行调试。
上面的代码执行的计算与此 jax.jit
自动并行化代码相同
from jax.sharding import NamedSharding
a = jax.device_put(a, NamedSharding(mesh, P('x', 'y')))
b = jax.device_put(b, NamedSharding(mesh, P('y', None)))
@jax.jit
def matmul_reference(a, b):
c = jnp.dot(a, b)
return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None)))
c_ref = matmul_reference(a, b)
allclose(c_ref, jnp.dot(a, b))
True
我们可以将 shard_map
视为根据其 mesh
和 in_specs
参数在其输入上执行 device_put
或 with_sharding_constraint
,因此 matmul_basic
运行的块与 matmul_reference
中的块相同
print('a blocks:'); jax.debug.visualize_array_sharding(a)
print('b blocks:'); jax.debug.visualize_array_sharding(b)
print('c blocks:'); jax.debug.visualize_array_sharding(c)
a blocks:
b blocks:
c blocks:
CPU 0 CPU 1 CPU 2 CPU 3 CPU 4 CPU 5 CPU 6 CPU 7
CPU 0,2,4,6 CPU 1,3,5,7
CPU 0,1 CPU 2,3 CPU 4,5 CPU 6,7
慢一点,从基础开始!#
降秩映射与保秩映射#
我们可以将 vmap
和 pmap
视为沿轴解堆叠每个数组输入(例如,将二维矩阵解包为其一维行),将其主体函数应用于每个部分,并将结果堆叠在一起,至少在不涉及集合时
def check_vmap(f, xs):
ans = jax.vmap(f, in_axes=(0,), out_axes=0)(xs)
expected = jnp.stack([f(x) for x in xs]) # vmap reference semantics
print(allclose(ans, expected))
check_vmap(lambda x: x @ x, jnp.arange(12).reshape(4, 3))
True
例如,如果 xs
的形状为 f32[8,5]
,则每个 x
的形状为 f32[5]
,并且如果每个 f(x)
的形状为 f32[3,7]
,则最终堆叠的结果 vmap(f)(xs)
的形状为 f32[8,3,7]
。也就是说,主体函数 f
的每个应用程序都以比 vmap(f)
的相应参数少一个轴的输入作为参数。我们可以说这些是降秩映射,具有输入/输出的解堆叠/堆叠。
f
的逻辑应用程序的数量,或 f
的实例,由映射的输入轴的大小确定:例如,如果我们映射大小为 8 的输入轴,则在语义上我们获得该函数的 8 个逻辑应用程序。
相比之下,shard_map
没有这种降秩行为。相反,我们可以将其视为沿输入轴切片(或“非串联”)成块,应用主体函数,然后将结果串联在一起(同样在不涉及集合时)
import numpy as np
devices = np.array(jax.devices()[:4])
mesh = Mesh(devices, ('i',)) # mesh.shape['i'] = 4
def check_shmap(f, y):
ans = jax.shard_map(f, mesh=mesh, in_specs=P('i'), out_specs=P('i'))(y)
expected = jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, mesh.shape['i'])])
print(allclose(ans, expected))
check_shmap(lambda x: x.T @ x, jnp.arange(32).reshape(8, 4))
True
回想一下,jnp.split
将其输入切片成具有相同秩的等大小块,因此如果在上面的示例中 y
的形状为 f32[8,5]
,则每个 y_blk
的形状为 f32[2,5]
,并且如果每个 f(y_blk)
的形状为 f32[3,7]
,则最终串联的结果 shard_map(f, ...)(y)
的形状为 f32[12,7]
。因此 shard_map
映射其输入的分片或块。我们可以说它是一种保秩映射,具有其输入的非串联/串联。
f
的逻辑应用程序的数量由网格大小确定,而不是由任何输入轴大小确定:例如,如果我们有一个总大小为 4 的网格(即,在 4 个设备上),则在语义上我们获得该函数的 4 个逻辑应用程序,对应于实际计算它们的 4 个设备。
使用 in_specs
控制每个输入如何分割(非串联)和切片#
每个 in_specs
通过使用 PartitionSpec
按名称识别一些相应输入数组的轴与网格轴,表示如何将该输入分割(或非串联)到主体函数应用到的块中。该识别确定分片大小;当输入轴被识别为网格轴时,输入沿该逻辑轴分割(非串联)成等于相应网格轴大小的块数。(如果相应的网格轴大小不能均匀地划分输入数组轴大小,则会出错。)如果输入的 pspec 未提及网格轴名称,则不会在该网格轴上进行分割。例如
mesh = jax.make_mesh((4, 2), ('i', 'j'))
@partial(jax.shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))
def f1(x_block):
print(x_block.shape) # prints (3, 12)
return x_block
x1 = jnp.arange(12 * 12).reshape(12, 12)
y = f1(x1)
(3, 12)
在这里,因为输入 pspec 没有提及网格轴名称 'j'
,所以没有输入数组轴在该网格轴上分割;类似地,因为输入数组的第二个轴未被识别为(因此未在该轴上分割)任何网格轴,所以 f1
的应用程序沿该轴获取输入的完整视图。
当输入 pspec 中未提及网格轴时,我们可以始终重写为效率较低的程序,其中提及了所有网格轴,但调用方执行 jnp.tile
,例如
@partial(jax.shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', 'j'))
def f2(x_block):
print(x_block.shape)
return x_block
x = jnp.arange(12 * 12).reshape(12, 12)
x_ = jnp.tile(x, (1, mesh.shape['j'])) # x_ has shape (12, 24)
y = f2(x_) # prints (3,12), and f1(x) == f2(x_)
(3, 12)
换句话说,因为每个输入 pspec 可以零次或一次提及每个网格轴名称,而不是必须准确提及每个名称一次,所以我们可以说,除了输入中内置的 jnp.split
之外,shard_map
在其输入中还内置了一个 jnp.tile
,至少在逻辑上是这样(尽管根据参数的物理分片布局,可能不需要物理执行切片)。要使用的切片不是唯一的;我们也可以沿第一个轴进行切片,并使用 pspec P(('j', 'i'), None)
。
物理数据移动可能在输入上进行,因为每个设备都需要具有适当的数据副本。
使用 out_specs
控制每个输出如何通过串联、块转置和反切片来组装#
与输入端类似,每个 out_specs
通过名称识别一些相应输出数组的轴与网格轴,表示应如何将输出块(每个主体函数的应用程序一个,或等效地每个物理设备一个)组装在一起以形成最终输出值。例如,在上面的 f1
和 f2
示例中,out_specs
都指示我们应通过沿两个轴串联块结果来形成最终输出,从而在两种情况下都产生形状为 (12, 24)
的数组 y
。(如果主体函数的输出形状(即输出块形状)的秩对于相应输出 pspec 描述的串联而言太小,则会出错。)
当输出 pspec 中未提及网格轴名称时,它表示反切片:当用户编写未提及网格轴名称之一的输出 pspec 时,他们承诺输出块沿该网格轴相等,因此输出中仅使用沿该轴的一个块(而不是沿该网格轴将所有块串联在一起)。例如,使用与上面相同的网格
x = jnp.array([[3.]])
z = jax.shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))()
print(z) # prints the same as jnp.tile(x, (4, 2))
z = jax.shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', None))()
print(z) # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))
z = jax.shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P(None, None))()
print(z) # prints the same as jnp.tile(x, (1, 1)), or just x
[[3. 3.]
[3. 3.]
[3. 3.]
[3. 3.]]
[[3.]
[3.]
[3.]
[3.]]
[[3.]]
主体函数关闭数组值等效于将其作为具有 P(None, None) 的相应输入 pspec 的参数传递。作为另一个示例,更紧密地遵循上面的其他示例
@partial(jax.shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None))
def f3(x_block):
return jax.lax.psum(x_block, 'j')
x = jnp.arange(12 * 12).reshape(12, 12)
y3 = f3(x)
print(y3.shape)
(12, 6)
结果的第二个轴大小为 6,是输入的第二个轴大小的一半。在这种情况下,由于集合 psum
,通过未在输出 pspec 中提及网格轴名称 'j'
表达的反切片是安全的,这确保每个输出块沿相应的网格轴相等。以下是另外两个示例,我们改变了在输出 pspec 中提及的网格轴
@partial(jax.shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f4(x_block):
return jax.lax.psum(x_block, 'i')
x = jnp.arange(12 * 12).reshape(12, 12)
y4 = f4(x)
print(y4.shape) # (3,12)
@partial(jax.shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, None))
def f5(x_block):
return jax.lax.psum(x_block, ('i', 'j'))
y5 = f5(x)
print(y5.shape) # (3,6)
(3, 12)
(3, 6)
在物理端,未在输出 pspec 中提及网格轴名称会从设备缓冲区组装一个 Array
,并且沿该网格轴复制了布局。
没有运行时检查输出块实际上是否沿要反切片的网格轴相等,或者等效地,相应的物理缓冲区是否具有相等的值,因此可以解释为单个逻辑数组的复制布局。但是我们可以提供一种静态检查机制,该机制会在所有可能不正确的程序上引发错误。
因为 out_specs
可以零次或一次提及网格轴名称,并且因为可以以任何顺序提及它们,所以我们可以说,除了输出中内置的 jnp.concatenate
之外,shard_map
在其输出中还内置了反切片和块转置。
物理数据移动在输出上是不可能的,无论输出 pspec 如何。相反,out_specs
仅编码如何将块输出组装成 Array
,或者在物理上如何将设备上的缓冲区解释为单个逻辑 Array
的物理布局。
跟踪值在手动网格轴上的变化方式,以及 check_vma=True
#
在 shard_map
下,值可以在函数实例之间变化,也可以相同。例如,当我们使用 in_specs
在网格轴上分割参数时,沿该网格轴的每个函数实例都会获得不同的值
mesh = jax.make_mesh((2,), ('i',))
@partial(jax.shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
def f(x):
print(x)
return 2 * x
x = jnp.arange(6.)
f(x)
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[0. 1. 2.]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[3. 4. 5.]
Array([ 0., 2., 4., 6., 8., 10.], dtype=float32)
如果 in_specs
不在网格轴上分割参数,则该值对于沿该轴的每个函数实例都是相同的
@partial(jax.shard_map, mesh=mesh, in_specs=P(), out_specs=P())
def f(x):
print(x)
return 2 * x
x = jnp.arange(6.)
f(x)
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[0. 1. 2. 3. 4. 5.]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[0. 1. 2. 3. 4. 5.]
Array([ 0., 2., 4., 6., 8., 10.], dtype=float32)
集合的输出可能具有与其输入不同的方差。例如,应用 psum
会在沿轴的每个函数实例上产生相同的输出
@partial(jax.shard_map, mesh=mesh, in_specs=P('i'), out_specs=P())
def f(x):
y = jax.lax.psum(x, 'i')
print(y)
return y
x = jnp.arange(6.)
f(x)
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3. 5. 7.]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[3. 5. 7.]
Array([3., 5., 7.], dtype=float32)
通常,shard_map
中的每个中间值可以是跨每个手动网格轴不变的,也可以是可能变化的。该信息可以在 JAX 类型系统中跟踪,由 shard_map
的 check_vma=True
参数启用
@partial(jax.shard_map, mesh=mesh, in_specs=P('i'), out_specs=P())
def f(x):
print(jax.typeof(x)) # f32[3]{i}
y = jax.lax.psum(x, 'i')
print(jax.typeof(y)) # f32[3]
return y
x = jnp.arange(6.)
f(x)
float32[3]{i}
float32[3]
Array([3., 5., 7.], dtype=float32)
此处,类型 f32[3]{i}
表示 x
的值在网格轴 'i'
上变化。打印为 f32[3]
的 y
的类型指示它在所有网格轴上都不变;也就是说,不打印空集。我们将类型的这部分称为可变手动轴 (VMA),可以通过 jax.typeof(x).vma
访问它。
通常,值的 VMA 类型可以包括 shard_map
运行在其上的手动网格轴的任何子集
mesh = jax.make_mesh((4, 2), ('i', 'j'))
@partial(jax.shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i'))
def f(x):
print(jax.typeof(x)) # f32[2,2]{i,j}
y = jax.lax.psum(x, 'j')
assert jax.typeof(y).vma == {'i'}
print(jax.typeof(y)) # f32[2,2]{i}
return y
x = jnp.arange(8 * 4.).reshape(8, 4)
f(x)
float32[2,2]{i,j}
float32[2,2]{i}
Array([[ 2., 4.],
[10., 12.],
[18., 20.],
[26., 28.],
[34., 36.],
[42., 44.],
[50., 52.],
[58., 60.]], dtype=float32)
跟踪可变手动轴可能很有用
您的代码可以包括关于值是否在预期网格轴上变化的打印、断言或条件;
它启用了不需要防御性
psum
的高效反向模式自动微分(请参见JEP);可以检查
out_specs
的正确性,从而排除以下潜在的错误示例。
例如,此 out_specs
错误通过 check_vma=True
捕获,但在没有它的情况下未捕获
mesh = jax.make_mesh((2,), ('i',))
x = jnp.arange(6.)
try:
y = jax.shard_map(lambda x: x, mesh=mesh, in_specs=P('i'), out_specs=P())(x)
except Exception as e:
print(e)
shard_map applied to the function '<lambda>' was given out_specs which require replication which can't be statically inferred given the mesh:
The mesh given has shape (2,) with corresponding axis names ('i',).
out_specs is PartitionSpec() which implies that the corresponding output value is replicated across mesh axis 'i', but could not infer replication over any axes
Check if these output values are meant to be replicated over those mesh axes. If not, consider revising the corresponding out_specs entries. If so, consider disabling the check by passing the check_vma=False argument to `jax.shard_map`.
此处,out_specs
错误地承诺沿网格轴 'i'
的每个函数实例产生相同的值,因此我们可以仅选择其中一个。使用 check_vma=True
(默认值)会引发异常,而使用 check_vma=False
则不会引发异常,而是获得静默的未定义行为。
有时我们想将网格轴上不变的值视为在该网格轴上变化。这就是 jax.lax.pvary
所做的
@partial(jax.shard_map, mesh=mesh, in_specs=P(), out_specs=None)
def f(x):
print(jax.typeof(x)) # f32[6]
y = jax.lax.pvary(x, 'i')
print(jax.typeof(y)) # f32[6]{i}
x = jnp.arange(6.)
f(x)
float32[6]
float32[6]{i}
将 jax.lax.pvary
视为应用类型转换:它在运行时是一个空操作,但在反向模式自动微分下,它会转置为 jax.lax.psum
(请参见JEP)。这是有道理的,因为它们对 VMA 做的事情相反:如果 y: f32[3]{i} = jax.lax.pvary(x: f32[3], 'i')
,则我们相应地有 x_grad: f32[3] = jax.lax.psum(y_grad: f32[3]{i}, 'i')
。
JAX 在许多情况下隐式插入 jax.lax.pvary
调用,尤其是对于二进制运算
@partial(jax.shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))
def f(x, y):
return x * y
x = jnp.arange(6.)
y = jnp.arange(3.)
print(jax.make_jaxpr(f)(x, y))
{ lambda ; a:f32[6] b:f32[3]. let
c:f32[6] = shard_map[
check_vma=True
in_specs=(PartitionSpec('i',), PartitionSpec())
jaxpr={ lambda ; d:f32[3]{i} e:f32[3]. let
f:f32[3]{i} = pvary[axes=('i',) axis_index_groups=None] e
g:f32[3]{i} = mul d f
in (g,) }
manual_axes=frozenset({'i'})
mesh=Mesh('i': 2, axis_types=(Auto,))
out_specs=(PartitionSpec('i',),)
] a b
in (c,) }
在 jaxpr 中,乘法运算要求其参数的 VMA 类型匹配,但为方便起见,jax.numpy
和 jax.lax
API 会自动应用 jax.lax.pvary
以使参数 VMA 类型一致。
在某些情况下,例如使用 jax.lax.scan
时,您可能需要自己应用 jax.lax.pvary
以确保 VMA 类型按要求匹配。例如,此代码引发错误
mesh = jax.make_mesh((2,), ('i',))
@partial(jax.shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))
def f(x, y):
def body(carry, _):
c1, c2 = carry
return (c2, c1), () # swap the carry
(x_, y_), _ = jax.lax.scan(body, (x, y), (), length=2)
return x_, y_
x = jnp.arange(6.)
y = jnp.arange(3.)
try:
f(x, y)
except Exception as e:
print(e)
scan body function carry input and carry output must have equal types, but they differ:
* the input carry component carry[0] has type float32[3]{i} but the corresponding output carry component has type float32[3], so the varying manual axes do not match;
* the input carry component carry[1] has type float32[3] but the corresponding output carry component has type float32[3]{i}, so the varying manual axes do not match.
This might be fixed by applying `jax.lax.pvary(..., ('i',))` to the initial carry value corresponding to the input carry component carry[1].
See https://jax.net.cn/en/latest/notebooks/shard_map.html#scan-vma for more information.
Revise the function so that all output types match the corresponding input types.
为了使类型匹配,我们需要将 jax.lax.pvary
应用于 scan
的某些参数
mesh = jax.make_mesh((2,), ('i',))
@partial(jax.shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))
def f(x, y):
def body(carry, _):
c1, c2 = carry
return (c2, c1), () # swap the carry
y = jax.lax.pvary(y, 'i') # apply pvary to fix the error
(x_, y_), _ = jax.lax.scan(body, (x, y), (), length=2)
return x_, y_
x = jnp.arange(6.)
y = jnp.arange(3.)
f(x, y)
(Array([0., 1., 2., 3., 4., 5.], dtype=float32),
Array([0., 1., 2., 0., 1., 2.], dtype=float32))
以下是集合原语及其如何影响可变手动轴类型的摘要
名称 |
设备方差类型 |
示例 |
降低为 HLO |
转置 |
---|---|---|---|---|
|
|
|
|
|
|
|
|
no-op(无通信) |
|
|
|
|
|
|
|
|
|
|
不适用 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
关于该表的一些说明
函数
jax.lax.psum
是psum_invariant
的便捷包装器。令人惊讶的是
all_gather
是变化 -> 变化
,但那是因为它实际上是psum_scatter
的转置,后者是变化 -> 变化
。在撰写本文时,
pscatter
和all_gather_invariant
都没有用户 API,但此处为完整起见进行了描述。
API 规范#
from jax.sharding import Mesh
Specs = PyTree[PartitionSpec]
def shard_map(
f: Callable, /, *, out_specs: Specs, mesh: Mesh | None = None,
in_specs: Specs | None = None,
axis_names: collections.abc.Set[AxisName] = set(),
check_vma: bool = True,
) -> Callable:
...
其中
诸如
psum
之类的通信集合在f
的主体中可以提及mesh
的轴名称;mesh
编码排列成数组并具有相关轴名称的设备,就像sharding.NamedSharding
一样;如果为 None,则网格将从上下文中推断,该上下文可以通过jax.sharding.use_mesh
上下文管理器设置。in_specs
是PartitionSpec
,它可以零次或一次提及来自mesh
的轴名称,以分别表达输入的切片/非连接,未提及的名称对应于复制和非平铺(断言已复制,所以给我一份副本)。如果为 None,则所有网格轴必须是Explicit
类型,在这种情况下,in_specs 从参数类型推断;out_specs
是PartitionSpec
,它可以零次或一次提及来自mesh
的轴名称,以表达输出的连接,未提及的名称对应于复制和非平铺(断言已复制,所以给我一份副本);axis_names
是一个可选的轴名称集合,对应于mesh
名称的子集,用于处理主体中的 manual。如果为空,则f
是对网格所有轴的 manual。check_vma
是一个可选的布尔值,指示是否静态检查out_specs
中的任何复制错误,以及是否启用相关的自动微分优化(参见 JEP)。
传递给 f
的参数的形状与传递给 shard_map
-of-f
的参数的秩相同,并且 f
的参数的形状是根据 shard_map
-of-f
的对应参数的形状 shape
和对应的 PartitionSpec
spec
大致计算为 tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))
。
集合通信教程#
shard_map
不一定是纯映射:函数应用程序可以使用 集合通信 进行相互通信,使用在 mesh
参数中定义的轴名称。
回想一下,shard_map
将函数映射到输入数据的分片或块上,因此这
mesh = Mesh(jax.devices(), ('i',))
x = jnp.arange(16.)
f_shmapped = jax.shard_map(f, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
y = f_shmapped(x)
计算相同的值,评估 f
对相同参数值的应用,作为此引用函数
def f_shmapped_ref(x):
x_blocks = jnp.array_split(x, mesh.shape['i'])
y_blocks = [f(x_blk) for x_blk in x_blocks]
return jnp.concatenate(y_blocks)
我们将 f
对不同参数分片的这些应用称为 函数实例。每个函数实例都在不同的设备(或设备的子集)上执行。
当 f
中没有通信集合通信时,这些引用语义有效。但是,如果我们希望函数实例进行通信,对应于具有跨设备通信的情况呢?也就是说,当 f
包含集合通信时,引用语义是什么?假设 f
只有一个集合通信,并且形式为
def f(x_blk):
z_blk = f_part1(x_blk)
u_blk = collective(z_blk, axis_name)
v_blk = f_part2(x_blk, z_blk, u_blk)
return v_blk
我们假设只有一个网格轴我们正在映射,并且 axis_name
是它的对应名称。那么引用语义看起来更像是
def f_shmapped_ref(x):
x_blocks = jnp.array_split(x, mesh.shape[0])
z_blocks = [f_part1(x_blk) for x_blk in x_blocks]
u_blocks = [collective_ref(i, z_blocks) for i in range(len(z_blocks))]
v_blocks = [f_part2(x_blk, z_blk, u_blk) for x_blk, z_blk, u_blk
in zip(x_blocks, z_blocks, u_blocks)]
return jnp.concatenate(v_blocks)
请注意,collective_ref
可能依赖于所有 z_blocks
。也就是说,虽然 f_part1
和 f_part2
独立地映射在块上,但集合通信引入了一定数量的跨块依赖性。物理上,这意味着跨设备的通信。确切发生的通信以及计算出的值取决于集合通信。
psum
#
最简单的集合通信可能是 jax.lax.psum
,它计算沿设备网格轴(或多个轴)的全归约求和。这是一个玩具示例
import jax
import jax.numpy as jnp
from jax import lax
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
mesh1d = Mesh(jax.devices()[:4], ('i',))
@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P(None))
def f1(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.psum(x_block, 'i')
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f1(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22 20 12 17]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[22 20 12 17]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[22 20 12 17]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[22 20 12 17]
FINAL RESULT:
[22 20 12 17]
打印输出显示每个函数应用程序都以其自己的参数值 x_block
块开始。在 psum
之后,每个函数应用程序都具有相同的 y_block
值,该值是通过将应用程序的 x_block
值相加来计算的。
如果计算中只有一个轴名称,我们可以说 psum
的 collective_ref
引用实现是
def psum_ref(_, x_blocks):
tot = sum(x_blocks)
return [tot] * len(x_blocks)
另请注意,由于 f1
返回 y_block
,即 'i'
上的 psum
的结果,因此我们可以使用 out_specs=P()
,以便调用方获得结果值的单个逻辑副本,而不是平铺的结果。
当有多个网格轴时,我们可以分别在每个网格轴上执行 psum
,或者一次在多个轴上执行
mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j'))
@partial(jax.shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f2(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.psum(x_block, 'i')
print('AFTER:\n', y_block)
return y_block
y = f2(jnp.arange(16).reshape(4, 4))
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[0 1]
[4 5]]
On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[2 3]
[6 7]]
On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8 9]
[12 13]]
On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[10 11]
[14 15]]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[ 8 10]
[16 18]]
On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[12 14]
[20 22]]
On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8 10]
[16 18]]
On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[12 14]
[20 22]]
FINAL RESULT:
[[ 8 10 12 14]
[16 18 20 22]]
通过在网格轴 'i'
上应用 psum
,我们得到沿轴 'i'
相等但沿轴 'j'
不等的 y_block
值。(因此我们可以使用 out_specs=P(None, 'j')
沿该轴获得单个逻辑结果。)
如果我们对两个轴都应用 psum
,则 y_block
值在两个轴上都相等
@partial(jax.shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None))
def f3(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.psum(x_block, ('i', 'j'))
print('AFTER:\n', y_block)
return y_block
y = f3(jnp.arange(16).reshape(4, 4))
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[0 1]
[4 5]]
On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[2 3]
[6 7]]
On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8 9]
[12 13]]
On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[10 11]
[14 15]]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[20 24]
[36 40]]
On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[20 24]
[36 40]]
On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[20 24]
[36 40]]
On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[20 24]
[36 40]]
FINAL RESULT:
[[20 24]
[36 40]]
在机器学习中,我们经常使用 psum
来计算总损失,或者,当我们在 shard_map
映射的函数体内部有 grad
时,计算总梯度。
在续集中,我们将看到如何根据其他原语实现 psum
,这提供了一些关于其通信成本的直觉。
all_gather
#
另一个基本操作是沿轴收集数组分片,以便每个函数应用程序都具有沿该轴的完整数据副本
@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f4(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.all_gather(x_block, 'i', tiled=True)
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 9, 5, 2])
y = f4(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[9]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 9 5 2]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[3 9 5 2]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[3 9 5 2]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[3 9 5 2]
FINAL RESULT:
[3 9 5 2 3 9 5 2 3 9 5 2 3 9 5 2]
打印输出显示每个函数应用程序再次以其自己的参数值 x_block
块开始。在 all_gather
之后,它们具有一个公共值,该值是通过连接 x_block
的值来计算的。
(请注意,我们实际上不能在这里设置 out_specs=P()
。出于与自动微分相关的技术原因,我们认为 all_gather
的输出不能保证跨设备不变。如果我们希望保证它是不变的,我们可以使用 jax.lax.all_gather_invariant
,或者在这种情况下,我们可以避免在函数体中进行 all_gather
,而只需使用 out_specs=P('i')
来执行连接。)
当 tiled=False
(默认值)时,结果沿新轴堆叠而不是连接
@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f5(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.all_gather(x_block, 'i', tiled=False)
print('AFTER:\n', y_block)
return y_block
y = f5(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[9]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[[3]
[9]
[5]
[2]]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[[3]
[9]
[5]
[2]]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[[3]
[9]
[5]
[2]]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[[3]
[9]
[5]
[2]]
FINAL RESULT:
[[3]
[9]
[5]
[2]
[3]
[9]
[5]
[2]
[3]
[9]
[5]
[2]
[3]
[9]
[5]
[2]]
我们可以将 all_gather
的 collective_ref
引用语义函数写为
def all_gather_ref(_, x_blocks, *, tiled=False):
combine = jnp.concatenate if tiled else jnp.stack
return [combine(x_blocks)] * len(x_blocks)
在深度学习中,我们可能会在完全分片数据并行性 (FSDP) 中对参数使用 all_gather
。
psum_scatter
#
jax.lax.psum_scatter
集合通信有点不太直观。它类似于 psum
,只是每个函数实例只获得结果的一个分片
@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f6(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True)
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f6(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[20]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[12]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[17]
FINAL RESULT:
[22 20 12 17]
如打印输出所示,与 psum
不同,每个结果 y_block
的大小都小于参数 x_block
。此外,与 psum
相比,这里的每个 y_block
仅表示函数实例之间 x_block
的总和的切片。(即使每个函数实例只获得总和的一个分片,最终输出 y
与 psum
示例中的相同,因为这里我们使用 out_specs=P('i')
来连接每个函数实例的输出。)
就计算的值而言,collective_ref
引用实现可能如下所示
def psum_scatter_ref(i, x_blocks, *, tiled=False):
axis_size = len(x_blocks)
tot = sum(x_blocks)
if tiled:
tot = tot.reshape(axis_size, -1, *tot.shape[1:]) # split leading axis
return [tot[i] for i in range(tot.shape[0])]
语义引用实现中未捕获,但 psum_scatter
很有用,因为与完全 psum
相比,这些结果可以更有效地计算,并且通信量更少。实际上,考虑 psum_scatter
的一种方法是将其视为“在 all_gather
之前,psum
的前半部分”。也就是说,实现 psum
的一种方法是
def psum(x, axis_name):
summed_chunk = jax.lax.psum_scatter(x, axis_name)
return jax.lax.all_gather(summed_chunk, axis_name)
实际上,此实现通常在 TPU 和 GPU 上使用!
在 ppermute
部分中说明了 psum_scatter
可能需要大约一半通信量作为完全 psum
的原因。
另一种直觉是,我们可以使用 psum_scatter
来实现分布式矩阵乘法,其中输入和输出在同一轴上分片。在机器学习中,psum_scatter
可用于张量并行矩阵乘法或完全分片的数据并行梯度累积,如以下示例所示。
ppermute
#
jax.lax.ppermute
集合通信为函数实例提供了相互发送数据的最直接方式。给定一个网格轴和一个表示沿该网格轴的索引的 (source_index, destination_index)
对列表,ppermute
将其参数值从每个源函数实例发送到每个目标
@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f7(x_block):
sz = jax.lax.axis_size('i')
print('BEFORE:\n', x_block)
y_block = jax.lax.ppermute(x_block, 'i', [(i, (i + 1) % sz) for i in range(sz)])
print('AFTER:\n', y_block)
return y_block
y = f7(jnp.arange(8))
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[0 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[2 3]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[4 5]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[6 7]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[6 7]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[0 1]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[2 3]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[4 5]
FINAL RESULT:
[6 7 0 1 2 3 4 5]
在这种情况下,只有两个函数实例,每个实例的 y_block
值是另一个实例的 x_block
值。
源索引和目标索引不能重复。如果索引未显示为目标,则相应函数实例结果的值为零数组。
collective_ref
引用实现可能如下所示
def ppermute_ref(i, x_blocks, perm):
results = [jnp.zeros_like(x_blocks[0])] * len(x_blocks)
for src, dst in perm:
results[dst] = x_blocks[src]
return results
可以使用 ppermute
有效地实现其他集合通信(就总通信量而言),其中每个函数仅将数据传递给其相邻函数。例如,我们可以使用一系列 ppermute
和本地加法来以这种方式实现 psum_scatter
或者,使用数值示例
直观地说,在每次迭代中,每个函数实例都“向上”发送其在前一次迭代中收到的值,并减少(添加)其在此迭代中收到的值。在代码中,它可能看起来像这样
def psum_scatter(x, axis_name, *, tiled=False):
size = jax.lax.axis_size(axis_name)
idx = jax.lax.axis_index(axis_name) # function instance index along axis_name
if tiled:
x = x.reshape(size, -1, *x.shape[1:]) # split leading axis
shift = partial(jax.lax.ppermute, axis_name=axis_name,
perm=[(i, (i - 1) % size) for i in range(size)])
for i in range(1, size):
update = shift(x[(idx + i) % size])
x = x.at[(idx + i + 1) % size].add(update)
return x[idx]
@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f8(x_block):
print('BEFORE:\n', x_block)
y_block = psum_scatter(x_block, 'i', tiled=True)
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f8(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[20]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[12]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[17]
FINAL RESULT:
[22 20 12 17]
在 TPU 上,有此算法的更高维变体,可以利用多个双向物理网格轴。
请注意,psum_scatter
是 all_gather
的转置。实际上,根据 ppermute
实现 all_gather
的一种方法看起来与上述过程相反
在深度学习中,当我们实现 SPMD 管道并行性时,可能会使用 ppermute
,其中我们将网络沿其深度划分为多个阶段,并并行评估阶段的应用。或者,我们可能会使用 ppermute
来并行化卷积层的评估,其中我们分片空间轴,因此设备必须相互通信“光环”。或者,它可能在张量并行矩阵乘法中在后台使用。
all_to_all
#
最后一个集合通信是 all_to_all
,它本质上是在一个位置轴和一个跨设备轴上运行的块矩阵转置
@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f9(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.all_to_all(x_block, 'i', split_axis=0, concat_axis=0,
tiled=True)
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f9(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 5 5 9]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[1 9 3 7]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[4 2 5 1]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[1 6 8 2]
FINAL RESULT:
[3 5 5 9 1 9 3 7 4 2 5 1 1 6 8 2]
split_axis
参数指示应分片哪个位置轴并在网格轴上分区。concat_axis
参数指示应沿哪个轴连接或堆叠通信的结果。
当 tiled=False
(默认值)时,split_axis
轴大小必须等于名为 axis_name
的网格轴的大小,并且在该大小的新轴将在 concat_axis
位置创建以用于堆叠结果。当 tiled=True
时,split_axis
轴大小只需要可以被网格轴大小整除,并且结果将沿现有轴 concat_axis
连接。
当 split_axis=0
和 concat_axis=0
时,collective_ref
引用语义可能如下所示
def all_to_all_ref(_, x_blocks, *, tiled=False):
axis_size = len(x_blocks)
if tiled:
splits = [jnp.array_split(x, axis_size) for x in x_blocks]
return [jnp.concatenate(s) for s in zip(*splits)]
else:
splits = [list(x) for x in x_blocks]
return [jnp.stack(s) for s in zip(*splits)]
在深度学习中,我们可能会在专家混合路由中使用 all_to_all
,其中我们首先根据应将示例发送到哪个专家对本地示例批次进行排序,然后应用 all_to_all
将示例重新分发给专家。
玩具示例#
我们如何在实践中使用 shard_map
和集合通信?这些示例虽然简单,但提供了一些想法。
矩阵乘法#
并行化矩阵乘法对于扩展深度学习模型(无论是训练还是推理)至关重要。当 jax.jit
自动并行化矩阵乘法时,它可以使用几种不同的策略之一,具体取决于矩阵大小、硬件详细信息和其他因素。我们如何使用 shard_map
更明确地编写其中一些并行化例程?我们如何优化它们以获得更好的计算/通信重叠,从而提高 FLOP 利用率?
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
mesh = Mesh(jax.devices()[:4], ('i',))
def device_put(x, pspec):
return jax.device_put(x, NamedSharding(mesh, pspec))
示例 1:一侧的 all-gather
#
考虑执行矩阵乘法,其中我们在左侧参数(可以认为是参数)的前导(非收缩)维度上对其进行分片
lhs_spec = P('i', None)
lhs = device_put(jax.random.normal(jax.random.key(0), (8, 8)), lhs_spec)
我们在右侧参数(可以认为是激活)的收缩维度上对其进行分片,输出的片与此类似
rhs_spec = P('i', None)
rhs = device_put(jax.random.normal(jax.random.key(1), (8, 4)), rhs_spec)
要执行此矩阵乘法,我们可以首先 all-gather 右侧,然后对分片的左侧执行局部矩阵乘法
@jax.jit
@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_allgather(lhs_block, rhs_block):
rhs = jax.lax.all_gather(rhs_block, 'i', tiled=True)
return lhs_block @ rhs
out = matmul_allgather(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
这很好,但是我们在这里没有获得任何计算/通信重叠:在我们开始 matmul 之前,我们需要 all_gather
完成。以下是在较大示例形状(lhs
为 (8192, 8192)
,rhs
为 (8192, 1024)
)上使用相同代码的配置文件
如果我们在 ppermute
中内联我们上述 all_gather
的实现,而不是调用 all_gather
,然后将 gather 排列的步骤与局部矩阵乘法交错,则可以获得计算/通信重叠
@jax.jit
@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_allgather_overlapped(lhs_block, rhs_block):
size = jax.lax.axis_size('i')
idx = jax.lax.axis_index('i')
shift = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i + 1) % size) for i in range(size)])
B = lhs_block.shape[1] // size
lhs_blocks = lambda i: lax.dynamic_slice_in_dim(lhs_block, i * B, B, 1)
out_block = lhs_blocks(idx) @ rhs_block
for i in range(1, size):
rhs_block = shift(rhs_block)
out_block += lhs_blocks((idx - i) % size) @ rhs_block
return out_block
out = matmul_allgather_overlapped(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
此实现允许通信和计算之间的重叠,并且还可以避免在每个设备上收集大的中间结果。但是在 TPU 上,它仅使用一半的互连带宽,因为仅沿环在一个方向上排列。要双向排列,我们只需将块分成两半,并在每个方向上发送一半
@jax.jit
@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_allgather_overlapped_bidi(lhs_block, rhs_block):
size = jax.lax.axis_size('i')
idx = jax.lax.axis_index('i')
shift_up = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i + 1) % size) for i in range(size)])
shift_dn = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i - 1) % size) for i in range(size)])
B = lhs_block.shape[1] // size // 2 # half-size blocks
lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 1)
rhs_block_lo, rhs_block_hi = jnp.split(rhs_block, 2, axis=0)
out_block = lhs_blocks(idx, 0) @ rhs_block_lo
out_block += lhs_blocks(idx, 1) @ rhs_block_hi
for i in range(1, size):
rhs_block_lo = shift_up(rhs_block_lo)
rhs_block_hi = shift_dn(rhs_block_hi)
out_block += lhs_blocks((idx - i) % size, 0) @ rhs_block_lo
out_block += lhs_blocks((idx + i) % size, 1) @ rhs_block_hi
return out_block
out = matmul_allgather_overlapped_bidi(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
实际上,为了减少编译时间,我们可能会将其滚动到 jax.lax.fori_loop
中。我们可能还会涉及额外的并行轴。
示例 2:psum_scatter
结果#
我们可能开始的另一个分片是在收缩维度上对 lhs
和 rhs
进行分片,输出再次像 rhs
一样分片
lhs_spec = P(None, 'i')
lhs = device_put(lhs, lhs_spec)
rhs_spec = P('i', None)
rhs = device_put(rhs, rhs_spec)
在这里,我们可以使用 reduce_scatter
来执行分片上的收缩求和
@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_psumscatter(lhs_block, rhs_block):
out_summand = lhs_block @ rhs_block
return jax.lax.psum_scatter(out_summand, 'i', tiled=True)
out = matmul_psumscatter(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
但是在开始之前,散射通信必须等待整个局部矩阵乘法完成。要获得通信/计算重叠,我们可以内联 ppermute
中 psum_scatter
的实现,然后将通信步骤与局部矩阵乘法交错
@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_psumscatter_overlapped(lhs_block, rhs_block):
size = jax.lax.axis_size('i')
idx = jax.lax.axis_index('i')
shift = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i - 1) % size) for i in range(size)])
lhs_block = lhs_block.reshape(size, -1, lhs_block.shape[1]) # split 1st axis
out_summand = lhs_block[(idx + 1) % size] @ rhs_block
for i in range(1, size):
out_summand = shift(out_summand)
out_summand += lhs_block[(idx + i + 1) % size] @ rhs_block
return out_summand
out = matmul_psumscatter_overlapped(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
与前面的示例一样,为了充分利用 TPU 上的互连,我们将运行双向版本
@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_psumscatter_overlapped_bidi(lhs_block, rhs_block):
size = jax.lax.axis_size('i')
idx = jax.lax.axis_index('i')
shift_up = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i + 1) % size) for i in range(size)])
shift_dn = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i - 1) % size) for i in range(size)])
B = lhs_block.shape[0] // size // 2 # half-size blocks
lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 0)
out_summand_lo = lhs_blocks((idx - 1) % size, 0) @ rhs_block
out_summand_hi = lhs_blocks((idx + 1) % size, 1) @ rhs_block
for i in range(1, size):
out_summand_lo = shift_up(out_summand_lo)
out_summand_hi = shift_dn(out_summand_hi)
out_summand_lo += lhs_blocks((idx - i - 1) % size, 0) @ rhs_block
out_summand_hi += lhs_blocks((idx + i + 1) % size, 1) @ rhs_block
return jnp.concatenate([out_summand_lo, out_summand_hi])
out = matmul_psumscatter_overlapped_bidi(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
神经网络#
我们可以使用 shard_map
来并行化神经网络中的计算,无论是单独使用还是与 jax.jit
中的自动分区结合使用。本节有一些基于此玩具神经网络和随机数据的示例
import jax
import jax.numpy as jnp
def predict(params, inputs):
for W, b in params:
outputs = jnp.dot(inputs, W) + b
inputs = jax.nn.relu(outputs)
return outputs
def loss(params, batch):
inputs, targets = batch
predictions = predict(params, inputs)
return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
def init_layer(key, n_in, n_out):
k1, k2 = jax.random.split(key)
W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
b = jax.random.normal(k2, (n_out,))
return W, b
def init(key, layer_sizes, batch_size):
key, *keys = jax.random.split(key, len(layer_sizes))
params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))
key, *keys = jax.random.split(key, 3)
inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))
return params, (inputs, targets)
layer_sizes = [784, 128, 128, 128, 128, 128, 8]
batch_size = 32
params, batch = init(jax.random.key(0), layer_sizes, batch_size)
将这些示例与纯 “分布式数组和自动分区”文档中的自动分区示例进行比较。虽然在这些自动分区示例中,我们不需要编辑模型函数即可使用不同的并行化策略,但对于 shard_map
,我们通常需要这样做。
8 向批数据并行性#
最简单的多设备并行性策略是在多个设备上分片输入和目标的批次,在这些设备上复制参数,并并行地将模型应用于这些数据分片。为了评估总损失,设备只需要在末尾使用标量大小的全归约求和进行通信。(要评估损失的梯度,设备必须在后向传递中执行参数梯度的全归约求和。)
from functools import partial
from jax.sharding import NamedSharding, Mesh, PartitionSpec as P
mesh = jax.make_mesh((8,), ('batch',))
# replicate initial params on all devices, shard data batch over devices
batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
params = jax.device_put(params, NamedSharding(mesh, P()))
# adapt the loss function to sum the losses across devices
def loss_dp(params, batch):
@partial(jax.shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P())
def loss_spmd(local_batch):
inputs, targets = local_batch
predictions = predict(params, inputs) # use reference 'predict`
local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
return jax.lax.pmean(local_loss, 'batch')
return loss_spmd(batch)
我们可以检查损失及其梯度是否与参考(基本)模型匹配
print(jax.jit(loss)(params, batch))
print(jax.jit(loss_dp)(params, batch))
11.9203
11.9203
def allclose(a, b):
return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))
print(allclose(jax.jit(jax.grad(loss))(params, batch),
jax.jit(jax.grad(loss_dp))(params, batch)))
True
我们可以打印编译器 IR 以检查梯度计算,并验证集合通信全归约求和运算是否在我们期望的位置发生:在前向传递的末尾计算损失值,在后向传递中计算总参数梯度。
8 向完全分片数据并行性 (FSDP)#
另一种策略是在设备上额外分片参数,在需要完整值以用于 jnp.dot
或偏差加法时,all-gathering 每个参数。由于我们一次只有一个完整参数在本地设备内存中,而不是像前面的 DP 示例中那样将所有参数保留在所有设备内存中,因此我们释放了大量内存,可以将其用于更大的模型或更大的批次大小。并且由于 XLA 将重叠计算和设备间通信,因此挂钟时间不会受到影响。
因此,现在我们需要两个地方的集合通信:模型预测函数 predict
需要 all-gather 参数,然后才能使用它们,并且与 DP 情况一样,损失函数需要对局部损失求和才能计算总损失。
我们需要另一个要素:我们不想存储来自前向传递的完全收集的参数,以供后向传递使用。相反,我们希望在后向传递中再次收集它们。我们可以通过使用具有自定义策略(或 custom_vjp
)的 jax.remat
来表达,尽管 XLA 通常会自动执行该重新物化。
这种通用的 FSDP 方法类似于 权重更新分片 (WUS) 和 ZeRO-3。
# shard data batch *and params* over devices
mesh = Mesh(devices, ('batch',))
batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
params = jax.device_put(params, NamedSharding(mesh, P('batch')))
# adapt the prediction function to gather weights just before their use,
# and to re-gather them on the backward pass (rather than saving them)
@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')
def predict_fsdp(params_frag, inputs):
for W_frag, b_frag in params_frag:
W = jax.lax.all_gather(W_frag, 'batch', tiled=True)
b = jax.lax.all_gather(b_frag, 'batch', tiled=True)
outputs = jnp.dot(inputs, W) + b
inputs = jax.nn.relu(outputs)
return outputs
def loss_fsdp(params, batch):
@partial(jax.shard_map, mesh=mesh, in_specs=P('batch'), out_specs=P())
def loss_spmd(local_params, local_batch):
inputs, targets = local_batch
predictions = predict_fsdp(local_params, inputs)
local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
return jax.lax.pmean(local_loss, 'batch')
return loss_spmd(params, batch)
我们可以再次检查损失及其梯度是否与参考模型匹配
print(jax.jit(loss)(params, batch))
print(jax.jit(loss_fsdp)(params, batch))
print(allclose(jax.jit(jax.grad(loss))(params, batch),
jax.jit(jax.grad(loss_fsdp))(params, batch)))
11.920298
11.920298
True
8 向张量并行性 (TP)#
通常我们不单独使用张量模型并行性,但是隔离地看到它是在并行矩阵乘法方面的一个很好的热身。这也是在较大的基于 jit
的计算中调用的库函数中使用 shard_map
的一个很好的示例。
并行化思想是我们将在其特征轴(而不是其批次轴)上保留数据/激活分片,并且类似地,我们将在其输入特征轴上分片权重矩阵(并在其特征轴上分片偏差)。然后,为了执行并行矩阵乘法,我们将执行局部矩阵乘法,然后执行 psum_scatter
以对局部结果求和并有效地散射结果的分片。
mesh = jax.make_mesh((8,), ('feats',))
batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats')))
params = jax.device_put(params, NamedSharding(mesh, P('feats')))
def predict_tp(params, inputs):
for W, b in params:
outputs = gemm_tp(inputs, W, b)
inputs = jax.nn.relu(outputs)
return outputs
@partial(jax.shard_map, mesh=mesh,
in_specs=(P(None, 'feats'), P('feats', None), P('feats')),
out_specs=P(None, 'feats'))
def gemm_tp(inputs, W, b):
block_result = jnp.dot(inputs, W)
return jax.lax.psum_scatter(block_result, 'feats',
scatter_dimension=1, tiled=True) + b
def loss_tp(params, batch):
inputs, targets = batch
predictions = predict_tp(params, inputs)
return jnp.mean(jnp.sum((predictions - targets) ** 2, axis=-1)) # NOTE psum!
FSDP + TP,在顶层使用 shard_map
#
我们可以将这些策略组合在一起,使用多个并行轴。
mesh = jax.make_mesh((4, 2), ('batch', 'feats'))
batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats')))
params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats'))))
# mostly same as previous predict_fsdp definition, except we call gemm_tp
@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')
def predict_fsdp_tp(params_frag, inputs):
for W_frag, b_frag in params_frag:
W = jax.lax.all_gather(W_frag, 'batch', tiled=True)
b = jax.lax.all_gather(b_frag, 'batch', tiled=True)
block_result = jnp.dot(inputs, W)
outputs = jax.lax.psum_scatter(block_result, 'feats',
scatter_dimension=1, tiled=True) + b
inputs = jax.nn.relu(outputs)
return outputs
@partial(jax.shard_map, mesh=mesh,
in_specs=(P(('feats', 'batch')), P('batch', 'feats')),
out_specs=P())
def loss_fsdp_tp(local_params, local_batch):
inputs, targets = local_batch
predictions = predict_fsdp_tp(local_params, inputs)
sq_err = jax.lax.psum(jnp.sum((predictions - targets)**2, axis=-1), 'feats')
return jax.lax.pmean(jnp.mean(sq_err), 'batch')
请注意我们必须执行两个集合通信归约:一个在 'feats'
上,另一个在 'batch'
上。在纯 TP 示例中,我们没有显式编写 'feats'
归约,因为我们仅在 gemm_tp
中使用了 shard_map
;在调用方 loss_tp
中,编译器根据 predict_tp
返回的分片结果的需要,自动将我们对 jnp.sum
的使用转换为执行 psum
。
print(jax.jit(loss)(params, batch))
print(jax.jit(loss_fsdp_tp)(params_, batch_))
print(allclose(jax.jit(jax.grad(loss))(params, batch),
jax.jit(jax.grad(loss_fsdp_tp))(params, batch)))
11.9203
11.920298
True
SPMD 管道并行性 (PP)#
通过流水线并行,我们的目标是并行化网络中不同深度的层的评估。例如,一个设备可能计算第一层的应用,而另一个设备计算第二层的应用;当它们完成后,第一个设备将其结果传递给第二个设备,而第二个设备将其结果传递给负责第三层的设备,并且该过程重复进行。一般来说,流水线阶段的数量可能与层的数量不同,因为每个阶段可能负责多个层。
通过 SPMD 流水线,我们利用了网络中大多数层都应用相同的计算,只是参数值不同这一事实。特别地,我们可以将除第一层和最后一层之外的所有参数堆叠在一起,然后使用 shard_map
来映射这些层参数的块,其中每个参数块对应于一个流水线阶段。然后我们使用 jax.lax.ppermute
集体操作来将数据向下移动并行流水线。
这种特殊的流水线策略本质上是 GPipe 策略。 有几种变体以及完全不同的策略,哪种策略合适取决于阶段之间的网络速度和批处理大小。 但是在本教程中,我们将只关注一种策略。
首先,我们选择一些流水线参数
L = len(params) - 2 # num layers, excluding first and last
N = batch_size # batch size
F = params[0][0].shape[1] # num features
# choose some pipeline parameters
S = 2 # number of stages
B = 8 # size of each microbatch
assert L % S == 0, "S (number of stages) must divide L (number of inner layers)"
# compute some useful quantities
M, ragged = divmod(N, B) # M is number of microbatches
assert not ragged, "B (size of each microbatch) must divide total batch size"
K, ragged = divmod(M, S) # K is microbatches per stage
assert not ragged, "S (number of stages) must divide number of microbatches"
print(f'{S} stages, {L // S} layer(s) per stage, {L} pipelined layers total')
print(f'{B} examples per microbatch, {M} microbatches total')
2 stages, 2 layer(s) per stage, 4 pipelined layers total
8 examples per microbatch, 4 microbatches total
mesh = Mesh(jax.devices()[:S], ('stages',))
def predict_pp(params, inputs):
(W_first, b_first), inner_params, (W_last, b_last) = params
inputs = jax.nn.relu(jnp.dot(inputs, W_first) + b_first)
inputs = spmd_pipeline(lambda Wb, x: jax.nn.relu(x @ Wb[0] + Wb[1]),
inner_params, inputs)
outputs = jnp.dot(inputs, W_last) + b_last
return outputs
@partial(jax.shard_map, mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')),
out_specs=P())
def loss_pp(params, batch):
inputs, targets = batch
predictions = predict_pp(params, inputs.reshape(K, B, -1)).reshape(K * B, -1)
local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
return jax.lax.pmean(local_loss, 'stages')
def spmd_pipeline(fn, stage_params, inputs):
stage = jax.lax.axis_index('stages')
outputs = jnp.zeros_like(inputs) * jnp.nan
state = jnp.zeros((L // S, B, F)) * jnp.nan
for i in range(M+L-1):
state = state.at[0].set(jnp.where(stage == 0, inputs[i % K], state[0]))
state = jax.vmap(fn)(stage_params, state)
outputs = outputs.at[(i-L+1) % K].set(jnp.where(stage == S-1, state[-1], outputs[(i-L+1) % K]))
state, inputs, outputs = shift(i, state, inputs, outputs)
outputs = jax.lax.ppermute(outputs, 'stages', [(i, (i+1) % S) for i in range(S)])
return outputs
def shift(i, state, inputs, outputs):
sh = lambda x, d: jax.lax.ppermute(x, 'stages', [(i, (i+d) % S) for i in range(S)])
state = jnp.roll(state, +1, axis=0).at[0].set(sh(state[-1], +1))
if (i % K) == (-1 % K):
inputs = sh(inputs, +1)
if ((i-L+1) % K) == (-1 % K):
outputs = sh(outputs, +1)
return state, inputs, outputs
first_params, *inner_params, last_params = params
Ws, bs = zip(*inner_params)
params_stacked = jnp.stack(Ws), jnp.stack(bs)
first_params = jax.device_put(first_params, NamedSharding(mesh, P()))
params_stacked = jax.device_put(params_stacked, NamedSharding(mesh, P('stages')))
last_params = jax.device_put(last_params, NamedSharding(mesh, P()))
params_ = first_params, params_stacked, last_params
batch_ = jax.device_put(batch, NamedSharding(mesh, P('stages')))
print(jax.jit(loss)(params, batch))
print(jax.jit(loss_pp)(params_, batch_))
11.9203
11.920299
_ = jax.jit(jax.grad(loss_pp))(params_, batch_) # don't crash