使用 shard_map
的手动并行#
概述#
shard_map
是一种单程序多数据 (SPMD) 多设备并行 API,用于将函数映射到数据分片上。映射的函数应用或实例通过显式集合通信操作相互通信。
shard_map
是对内置于 jit
中的基于编译器的自动并行化的补充,并且可以与之组合。使用 jit
,您编写代码时就像是针对单个设备一样,并且编译器可以自动将计算划分为多个设备,在后台生成每个设备的代码和通信集合。使用 shard_map
,您可以掌控一切,编写自己的分区代码和显式集合。或者您可以两者兼顾:在设备组之间进行手动控制,同时将组内设备分区留给编译器。这两种方法可以根据需要混合、匹配和组合。
如果您熟悉 pmap
,可以将 shard_map
视为一种演进。它更具表现力、性能更高,并且可以与其他 JAX API 组合。它甚至可以急切地工作,以便于调试!(有关更多信息,请参阅与 pmap
的详细比较。)
通过阅读本教程,您将学习如何使用 shard_map
来完全控制您的多设备代码。您将详细了解它如何与 jax.jit
的自动并行化和 jax.grad
的自动微分相结合。我们还将提供一些神经网络并行化策略的基本示例。
我们假设本教程在具有八个设备的环境中运行
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices
那么,让我们看看 shard_map
!#
事不宜迟,这是一个玩具示例
from functools import partial
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map
mesh = jax.make_mesh((4, 2), ('x', 'y'))
a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 * 4.).reshape(16, 4)
@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),
out_specs=P('x', None))
def matmul_basic(a_block, b_block):
# a_block: f32[2, 8]
# b_block: f32[8, 4]
c_partialsum = jnp.dot(a_block, b_block)
c_block = jax.lax.psum(c_partialsum, 'y')
# c_block: f32[2, 4]
return c_block
c = matmul_basic(a, b) # c: f32[8, 4]
此函数通过执行本地块矩阵乘法,然后执行集合求和运算来并行计算矩阵乘法。我们可以检查结果是否正确
from jax.tree_util import tree_map, tree_all
def allclose(a, b):
return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))
allclose(c, jnp.dot(a, b))
True
结果沿其行分片
jax.debug.visualize_array_sharding(c)
CPU 0,1 CPU 2,3 CPU 4,5 CPU 6,7
在高层次上,shard_map
有点像 vmap
或 pmap
,因为我们正在将函数映射到数组数据的片段上,但请注意
shard_map
将输入切片成块(输出通过串联结果块形成),保持秩不变,而vmap
会通过映射掉一个轴来降低秩;mesh
参数使我们可以精确控制计算和结果的设备放置;我们一次映射多个数据轴,并为集合运算设置多个轴名称(此处为
'x'
和'y'
);由于我们尚未使用
jax.jit
,因此一切都是急切求值的,我们甚至可以print
中间值以进行调试。
上面的代码执行的计算与此 jax.jit
自动并行化代码相同
from jax.sharding import NamedSharding
a = jax.device_put(a, NamedSharding(mesh, P('x', 'y')))
b = jax.device_put(b, NamedSharding(mesh, P('y', None)))
@jax.jit
def matmul_reference(a, b):
c = jnp.dot(a, b)
return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None)))
c_ref = matmul_reference(a, b)
allclose(c_ref, jnp.dot(a, b))
True
我们可以将 shard_map
视为根据其 mesh
和 in_specs
参数对其输入执行 device_put
或 with_sharding_constraint
,因此 matmul_basic
操作的块与 matmul_reference
中的块相同
print('a blocks:'); jax.debug.visualize_array_sharding(a)
print('b blocks:'); jax.debug.visualize_array_sharding(b)
print('c blocks:'); jax.debug.visualize_array_sharding(c)
a blocks:
b blocks:
c blocks:
CPU 0 CPU 1 CPU 2 CPU 3 CPU 4 CPU 5 CPU 6 CPU 7
CPU 0,2,4,6 CPU 1,3,5,7
CPU 0,1 CPU 2,3 CPU 4,5 CPU 6,7
慢一点,从基础开始!#
降秩与保秩映射#
我们可以将 vmap
和 pmap
视为沿轴解堆叠每个数组输入(例如,将 2D 矩阵解包成其 1D 行),将主体函数应用于每个片段,并将结果堆叠在一起,至少在不涉及集合运算时是这样
def check_vmap(f, xs):
ans = jax.vmap(f, in_axes=(0,), out_axes=0)(xs)
expected = jnp.stack([f(x) for x in xs]) # vmap reference semantics
print(allclose(ans, expected))
check_vmap(lambda x: x @ x, jnp.arange(12).reshape(4, 3))
True
例如,如果 xs
的形状为 f32[8,5]
,则每个 x
的形状为 f32[5]
,并且如果每个 f(x)
的形状为 f32[3,7]
,则最终堆叠的结果 vmap(f)(xs)
的形状为 f32[8,3,7]
。也就是说,主体函数 f
的每次应用都以轴数比 vmap(f)
的相应参数少一个的输入作为参数。我们可以说这些是降秩映射,输入/输出解堆叠/堆叠。
f
的逻辑应用数,或 f
的实例数,由要映射的输入轴的大小决定:例如,如果我们映射大小为 8 的输入轴,则语义上我们获得该函数的 8 个逻辑应用。
相比之下,shard_map
没有这种降秩行为。相反,我们可以将其视为沿输入轴切片(或“解串联”)成块,应用主体函数,并将结果串联在一起(同样在不涉及集合运算时)
import numpy as np
devices = np.array(jax.devices()[:4])
mesh = Mesh(devices, ('i',)) # mesh.shape['i'] = 4
def check_shmap(f, y):
ans = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))(y)
expected = jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, mesh.shape['i'])])
print(allclose(ans, expected))
check_shmap(lambda x: x.T @ x, jnp.arange(32).reshape(8, 4))
True
回想一下,jnp.split
将其输入切片成具有相同秩的等大小块,因此如果在上面的示例中 y
的形状为 f32[8,5]
,则每个 y_blk
的形状为 f32[2,5]
,并且如果每个 f(y_blk)
的形状为 f32[3,7]
,则最终串联的结果 shard_map(f, ...)(y)
的形状为 f32[12,7]
。因此,shard_map
映射到其输入的分片或块。我们可以说它是一个保秩映射,输入/输出解串联/串联。
f
的逻辑应用数由网格大小决定,而不是由任何输入轴大小决定:例如,如果我们有一个总大小为 4 的网格(即在 4 个设备上),那么在语义上我们获得该函数的 4 个逻辑应用,对应于物理计算它们的 4 个设备。
控制如何使用 in_specs
拆分(解串联)和平铺每个输入#
每个 in_specs
通过使用 PartitionSpec
按名称标识一些相应输入数组的轴与网格轴,表示如何将该输入拆分(或解串联)为应用主体函数的块。该标识确定分片大小;当输入轴与网格轴标识时,输入沿该逻辑轴拆分(解串联)为多个片段,片段数量等于相应的网格轴大小。(如果对应的网格轴大小不能均匀地划分输入数组轴大小,则会出错。)如果输入的 pspec 没有提及网格轴名称,则不会在该网格轴上进行拆分。例如
mesh = jax.make_mesh((4, 2), ('i', 'j'))
@partial(shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))
def f1(x_block):
print(x_block.shape) # prints (3, 12)
return x_block
x1 = jnp.arange(12 * 12).reshape(12, 12)
y = f1(x1)
(3, 12)
在此处,由于输入 pspec 未提及网格轴名称 'j'
,因此没有输入数组轴在该网格轴上拆分;类似地,由于输入数组的第二个轴未与任何网格轴标识(因此未在其上拆分),因此 f1
的应用获得了沿该轴的输入的完整视图。
当输入 pspec 中未提及网格轴时,我们始终可以重写为效率较低的程序,其中提及了所有网格轴,但调用者执行 jnp.tile
,例如
@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', 'j'))
def f2(x_block):
print(x_block.shape)
return x_block
x = jnp.arange(12 * 12).reshape(12, 12)
x_ = jnp.tile(x, (1, mesh.shape['j'])) # x_ has shape (12, 24)
y = f2(x_) # prints (3,12), and f1(x) == f2(x_)
(3, 12)
换句话说,由于每个输入 pspec 可以提及每个网格轴名称零次或一次,而不是必须精确地提及每个名称一次,因此我们可以说,除了内置于其输入中的 jnp.split
之外,shard_map
还具有内置于其输入中的 jnp.tile
,至少在逻辑上是这样(尽管平铺可能不需要物理执行,具体取决于参数的物理分片布局)。要使用的平铺不是唯一的;我们也可以沿第一个轴平铺,并使用 pspec P(('j', 'i'), None)
。
输入端可能发生物理数据移动,因为每个设备都需要拥有适当数据的副本。
控制如何使用 out_specs
通过串联、块转置和反平铺组装每个输出#
与输入端类似,每个 out_specs
通过名称标识一些相应输出数组的轴与网格轴,表示应如何将输出块(主体函数的每次应用一个,或等效地,每个物理设备一个)组装在一起以形成最终输出值。例如,在上面的 f1
和 f2
示例中,out_specs
指示我们应该通过沿两个轴串联块结果来形成最终输出,从而在两种情况下都得到形状为 (12, 24)
的数组 y
。(如果主体函数的输出形状(即输出块形状)的秩太小,无法进行相应输出 pspec 描述的串联,则会出错。)
当输出 pspec 中未提及网格轴名称时,它表示反平铺:当用户编写一个未提及其中一个网格轴名称的输出 pspec 时,他们承诺输出块沿该网格轴相等,因此输出中仅使用沿该轴的一个块(而不是沿该网格轴将所有块串联在一起)。例如,使用与上面相同的网格
x = jnp.array([[3.]])
z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))()
print(z) # prints the same as jnp.tile(x, (4, 2))
z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', None))()
print(z) # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))
z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P(None, None))()
print(z) # prints the same as jnp.tile(x, (1, 1)), or just x
[[3. 3.]
[3. 3.]
[3. 3.]
[3. 3.]]
[[3.]
[3.]
[3.]
[3.]]
[[3.]]
主体函数闭包一个数组值等效于将其作为参数传递,并具有相应的 P(None, None) 输入 pspec。作为另一个示例,更紧密地遵循上面的其他示例
@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None))
def f3(x_block):
return jax.lax.psum(x_block, 'j')
x = jnp.arange(12 * 12).reshape(12, 12)
y3 = f3(x)
print(y3.shape)
(12, 6)
结果的第二个轴大小为 6,是输入第二个轴大小的一半。在这种情况下,通过在输出 pspec 中不提及网格轴名称 'j'
表示的反平铺是安全的,因为集合运算 psum
确保每个输出块沿相应的网格轴相等。以下是另外两个示例,我们在其中更改了输出 pspec 中提及的网格轴
@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f4(x_block):
return jax.lax.psum(x_block, 'i')
x = jnp.arange(12 * 12).reshape(12, 12)
y4 = f4(x)
print(y4.shape) # (3,12)
@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, None))
def f5(x_block):
return jax.lax.psum(x_block, ('i', 'j'))
y5 = f5(x)
print(y5.shape) # (3,6)
(3, 12)
(3, 6)
在物理方面,在输出 pspec 中不提及网格轴名称会从输出设备缓冲区组装一个 Array
,该缓冲区具有沿该网格轴复制的布局。
没有运行时检查来验证输出块是否真的沿要反平铺的网格轴相等,或者等效地,相应的物理缓冲区是否具有相等的值,因此可以解释为单个逻辑数组的复制布局。但是我们可以提供一种静态检查机制,该机制会在所有可能不正确的程序上引发错误。
由于 out_specs
可以提及网格轴名称零次或一次,并且由于可以以任何顺序提及它们,因此我们可以说,除了内置于其输出中的 jnp.concatenate
之外,shard_map
还具有内置于其输出中的反平铺和块转置。
无论输出 pspec 如何,输出端都不可能发生物理数据移动。相反,out_specs
只是编码如何将块输出组装成 Array
,或者物理上如何将跨设备的缓冲区解释为单个逻辑 Array
的物理布局。
跟踪值如何在手动网格轴上变化,以及 check_rep=True
#
在 shard_map
下,值可能在函数实例之间变化,也可能相同。例如,当我们使用 in_specs
将参数拆分到网格轴上时,沿该网格轴的每个函数实例都会获得不同的值
mesh = jax.make_mesh((2,), ('i',))
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
def f(x):
print(x)
return 2 * x
x = jnp.arange(6.)
f(x)
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[0. 1. 2.]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[3. 4. 5.]
Array([ 0., 2., 4., 6., 8., 10.], dtype=float32)
如果 in_specs
没有将参数拆分到网格轴上,则沿该轴的每个函数实例的值都相同
@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P())
def f(x):
print(x)
return 2 * x
x = jnp.arange(6.)
f(x)
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[0. 1. 2. 3. 4. 5.]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[0. 1. 2. 3. 4. 5.]
Array([ 0., 2., 4., 6., 8., 10.], dtype=float32)
集合运算的输出可能与其输入的方差不同。例如,应用 psum
会在沿轴的每个函数实例上产生相同的输出
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P())
def f(x):
y = jax.lax.psum(x, 'i')
print(y)
return y
x = jnp.arange(6.)
f(x)
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3. 5. 7.]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[3. 5. 7.]
Array([3., 5., 7.], dtype=float32)
一般来说,shard_map
中的每个中间值在每个手动网格轴上可以是恒定不变的,也可以是可能变化的。该信息可以在 JAX 类型系统中跟踪,通过 shard_map
的 check_rep=True
参数启用
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P())
def f(x):
print(jax.typeof(x)) # f32[3]{i}
y = jax.lax.psum(x, 'i')
print(jax.typeof(y)) # f32[3]
return y
x = jnp.arange(6.)
f(x)
ShapedArray(float32[3]{i})
ShapedArray(float32[3])
Array([3., 5., 7.], dtype=float32)
在此处,类型 f32[3]{i}
表示 x
的值在网格轴 'i'
上变化。y
的类型打印为 f32[3]
,表示它在所有网格轴上都是不变的;也就是说,不打印空集。我们将类型的这部分称为变化的的手动轴 (VMA),可以通过 jax.typeof(x).vma
访问它。
一般来说,值的 VMA 类型可以包括 shard_map
作用于其上的手动网格轴的任何子集
mesh = jax.make_mesh((4, 2), ('i', 'j'))
@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i'))
def f(x):
print(jax.typeof(x)) # f32[2,2]{i,j}
y = jax.lax.psum(x, 'j')
assert jax.typeof(y).vma == {'i'}
print(jax.typeof(y)) # f32[2,2]{i}
return y
x = jnp.arange(8 * 4.).reshape(8, 4)
f(x)
ShapedArray(float32[2,2]{j,i})
ShapedArray(float32[2,2]{i})
Array([[ 2., 4.],
[10., 12.],
[18., 20.],
[26., 28.],
[34., 36.],
[42., 44.],
[50., 52.],
[58., 60.]], dtype=float32)
跟踪变化的手动轴可能很有用
您的代码可以包含关于值是否在预期网格轴上变化的打印、断言或条件语句;
它可以实现高效的反向模式自动微分,而无需防御性的
psum
(请参阅 JEP);可以检查
out_specs
的正确性,从而排除下面潜在的错误示例。
例如,这个 out_specs
错误可以通过 check_rep=True
捕获,但在不使用它时则无法捕获
mesh = jax.make_mesh((2,), ('i',))
x = jnp.arange(6.)
try:
y = shard_map(lambda x: x, mesh, in_specs=P('i'), out_specs=P())(x)
except Exception as e:
print(e)
shard_map applied to the function '<lambda>' was given out_specs which require replication which can't be statically inferred given the mesh:
The mesh given has shape (2,) with corresponding axis names ('i',).
out_specs is PartitionSpec() which implies that the corresponding output value is replicated across mesh axis 'i', but could not infer replication over any axes
Check if these output values are meant to be replicated over those mesh axes. If not, consider revising the corresponding out_specs entries. If so, consider disabling the check by passing the check_rep=False argument to shard_map.
这里 out_specs
错误地承诺沿网格轴 'i'
的每个函数实例产生相同的值,因此我们可以只选择其中一个。使用 check_rep=True
(默认值) 会引发异常,而使用 check_rep=False
则不会引发异常,而是会得到静默的未定义行为。
有时我们希望将网格轴上不变的值视为在该网格轴上变化的值。这就是 jax.lax.pvary
的作用
@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=None)
def f(x):
print(jax.typeof(x)) # f32[6]
y = jax.lax.pvary(x, 'i')
print(jax.typeof(y)) # f32[6]{i}
x = jnp.arange(6.)
f(x)
ShapedArray(float32[6])
ShapedArray(float32[6]{i})
可以将 jax.lax.pvary
视为应用类型转换:它在运行时是空操作,但在反向模式自动微分下,它转置为 jax.lax.psum
(参见 JEP)。这是有道理的,因为它们对 VMA 执行相反的操作:当 y: f32[3]{i} = jax.lax.pvary(x: f32[3], 'i')
时,我们相应地有 x_grad: f32[3] = jax.lax.psum(y_grad: f32[3]{i}, 'i')
。
JAX 在许多情况下隐式插入 jax.lax.pvary
调用,特别是对于二元运算
@partial(shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))
def f(x, y):
return x * y
x = jnp.arange(6.)
y = jnp.arange(3.)
print(jax.make_jaxpr(f)(x, y))
{ lambda ; a:f32[6] b:f32[3]. let
c:f32[6] = shard_map[
auto=frozenset()
check_rep=True
in_names=({0: ('i',)}, {})
jaxpr={ lambda ; d:f32[3]{i} e:f32[3]. let
f:f32[3]{i} = pvary[axes=('i',) axis_index_groups=None] e
g:f32[3]{i} = mul d f
in (g,) }
mesh=Mesh('i': 2, axis_types=(Auto,))
out_names=({0: ('i',)},)
] a b
in (c,) }
在 jaxpr 中,乘法运算要求其参数的 VMA 类型匹配,但为了方便起见,jax.numpy
和 jax.lax
API 会自动应用 jax.lax.pvary
以使参数 VMA 类型一致。
在某些情况下,例如使用 jax.lax.scan
,您可能需要自己应用 jax.lax.pvary
以确保 VMA 类型按要求匹配。例如,以下代码会引发错误
mesh = jax.make_mesh((2,), ('i',))
@partial(shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))
def f(x, y):
def body(carry, _):
c1, c2 = carry
return (c2, c1), () # swap the carry
(x_, y_), _ = jax.lax.scan(body, (x, y), (), length=2)
return x_, y_
x = jnp.arange(6.)
y = jnp.arange(3.)
try:
f(x, y)
except Exception as e:
print(e)
scan body function carry input and carry output must have equal types, but they differ:
* the input carry component carry[0] has type float32[3]{i} but the corresponding output carry component has type float32[3], so the varying manual axes do not match;
* the input carry component carry[1] has type float32[3] but the corresponding output carry component has type float32[3]{i}, so the varying manual axes do not match.
This might be fixed by applying `jax.lax.pvary(..., ('i',))` to the initial carry value corresponding to the input carry component carry[1].
See https://jax.net.cn/en/latest/notebooks/shard_map.html#scan-vma for more information.
Revise the function so that all output types match the corresponding input types.
为了使类型匹配,我们需要将 jax.lax.pvary
应用于 scan
的某些参数
mesh = jax.make_mesh((2,), ('i',))
@partial(shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))
def f(x, y):
def body(carry, _):
c1, c2 = carry
return (c2, c1), () # swap the carry
y = jax.lax.pvary(y, 'i') # apply pvary to fix the error
(x_, y_), _ = jax.lax.scan(body, (x, y), (), length=2)
return x_, y_
x = jnp.arange(6.)
y = jnp.arange(3.)
f(x, y)
(Array([0., 1., 2., 3., 4., 5.], dtype=float32),
Array([0., 1., 2., 0., 1., 2.], dtype=float32))
以下是关于 collective 原语及其如何影响 varying manual axis 类型的摘要
名称 |
设备方差类型 |
示例 |
降低到 HLO |
转置 |
---|---|---|---|---|
|
|
|
|
|
|
|
|
no-op (无通信) |
|
|
|
|
|
|
|
|
|
|
n/a |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
关于此表的几点说明
函数
jax.lax.psum
是psum_invariant
的便捷包装器。令人惊讶的是
all_gather
是Varying -> Varying
,但这是因为它实际上是psum_scatter
的转置,而psum_scatter
是Varying -> Varying
。pscatter
和all_gather_invariant
在撰写本文时都没有用户 API,但此处为完整性起见进行了描述。
API 规范#
from jax.sharding import Mesh
Specs = PyTree[PartitionSpec]
def shard_map(
f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs,
auto: collections.abc.Set[AxisName] = frozenset([]),
check_rep: bool = True,
) -> Callable:
...
其中
类似
psum
的通信 collective 可以提及mesh
的轴名称;mesh
编码排列在数组中并具有关联轴名称的设备,就像它对sharding.NamedSharding
所做的那样;in_specs
和out_specs
是PartitionSpec
,可以仿射地提及来自mesh
的轴名称,以分别表示输入和输出的切片/解串联和串联,未提及的名称分别对应于复制和 untiling (assert-replicated-so-give-me-one-copy);auto
是一个可选的轴名称集合,对应于mesh
的名称的子集,用于在主体中自动处理,就像在调用者中一样,而不是手动处理;check_rep
是一个可选的布尔值,指示是否静态检查out_specs
中的任何复制错误,以及是否启用相关的自动微分优化 (参见 JEP)。
传递给 f
的参数的形状与传递给 shard_map
-of-f
的参数的秩相同,并且 f
的参数的形状是从 shard_map
-of-f
的对应参数的形状 shape
和对应的 PartitionSpec
spec
计算出来的,大致为 tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))
。
Collectives 教程#
shard_map
不需要是纯映射:函数应用可以通过 collectives 进行相互通信,使用在 mesh
参数中定义的轴名称。
回想一下,shard_map
将函数映射到输入数据的分片或块上,因此这个
mesh = Mesh(jax.devices(), ('i',))
x = jnp.arange(16.)
f_shmapped = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))
y = f_shmapped(x)
计算相同的值,评估将 f
应用于相同参数值的应用程序,如下参考函数所示
def f_shmapped_ref(x):
x_blocks = jnp.array_split(x, mesh.shape['i'])
y_blocks = [f(x_blk) for x_blk in x_blocks]
return jnp.concatenate(y_blocks)
我们将 f
应用于不同参数分片的这些应用程序称为函数实例。每个函数实例都在不同的设备(或设备子集)上执行。
当 f
中没有通信 collective 时,这些参考语义有效。但是,如果我们希望函数实例进行通信,从而实现跨设备通信,该怎么办?也就是说,当 f
包含 collective 时,参考语义是什么?假设 f
只有一个 collective,并且形式如下
def f(x_blk):
z_blk = f_part1(x_blk)
u_blk = collective(z_blk, axis_name)
v_blk = f_part2(x_blk, z_blk, u_blk)
return v_blk
其中我们假设只有一个网格轴我们要映射,而 axis_name
是其对应的名称。那么参考语义看起来更像
def f_shmapped_ref(x):
x_blocks = jnp.array_split(x, mesh.shape[0])
z_blocks = [f_part1(x_blk) for x_blk in x_blocks]
u_blocks = [collective_ref(i, z_blocks) for i in range(len(z_blocks))]
v_blocks = [f_part2(x_blk, z_blk, u_blk) for x_blk, z_blk, u_blk
in zip(x_blocks, z_blocks, u_blocks)]
return jnp.concatenate(v_blocks)
请注意,collective_ref
可能取决于所有 z_blocks
。也就是说,虽然 f_part1
和 f_part2
是独立地映射到块上的,但 collective 引入了一些跨块依赖性。物理上,这意味着跨设备的通信。究竟发生什么通信,以及计算出什么值,取决于 collective。
psum
#
最简单的 collective 可能是 jax.lax.psum
,它沿设备网格轴(或多个轴)计算 all-reduce-sum。这是一个玩具示例
import jax
import jax.numpy as jnp
from jax import lax
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental.shard_map import shard_map
mesh1d = Mesh(jax.devices()[:4], ('i',))
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P(None))
def f1(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.psum(x_block, 'i')
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f1(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22 20 12 17]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[22 20 12 17]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[22 20 12 17]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[22 20 12 17]
FINAL RESULT:
[22 20 12 17]
打印结果表明,每个函数应用程序都以其自己的参数值 x_block
块开始。在 psum
之后,每个函数应用程序都具有相同的 y_block
值,该值是通过将应用程序的 x_block
值加在一起计算得出的。
在计算中只有一个轴名称的情况下,我们可以说 psum
的 collective_ref
参考实现是
def psum_ref(_, x_blocks):
tot = sum(x_blocks)
return [tot] * len(x_blocks)
另请注意,由于 f1
返回 y_block
,即 psum
在 'i'
上的结果,我们可以使用 out_specs=P()
,以便调用者获得结果值的单个逻辑副本,而不是平铺的结果。
当存在多个网格轴时,我们可以分别对每个轴或同时对多个轴执行 psum
mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j'))
@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f2(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.psum(x_block, 'i')
print('AFTER:\n', y_block)
return y_block
y = f2(jnp.arange(16).reshape(4, 4))
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[0 1]
[4 5]]
On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[2 3]
[6 7]]
On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8 9]
[12 13]]
On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[10 11]
[14 15]]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[ 8 10]
[16 18]]
On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[12 14]
[20 22]]
On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8 10]
[16 18]]
On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[12 14]
[20 22]]
FINAL RESULT:
[[ 8 10 12 14]
[16 18 20 22]]
通过对网格轴 'i'
应用 psum
,我们得到沿轴 'i'
相等的 y_block
值,但不沿轴 'j'
相等。(因此,我们可以使用 out_specs=P(None, 'j')
来获取沿该轴的单个逻辑结果。)
如果我们对两个轴都应用 psum
,则 y_block
值在两个轴上都相等
@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None))
def f3(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.psum(x_block, ('i', 'j'))
print('AFTER:\n', y_block)
return y_block
y = f3(jnp.arange(16).reshape(4, 4))
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[0 1]
[4 5]]
On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[2 3]
[6 7]]
On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8 9]
[12 13]]
On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[10 11]
[14 15]]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[20 24]
[36 40]]
On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[20 24]
[36 40]]
On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[20 24]
[36 40]]
On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[20 24]
[36 40]]
FINAL RESULT:
[[20 24]
[36 40]]
在机器学习中,我们经常使用 psum
来计算总损失,或者当我们在 shard_map
映射的函数体内部有 grad
时,计算总梯度。
在续集中,我们将看到 psum
如何用其他原语来实现,这给出了一些关于其通信成本的直觉。
all_gather
#
另一个基本操作是沿轴收集数组分片,以便每个函数应用程序沿该轴具有数据的完整副本
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f4(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.all_gather(x_block, 'i', tiled=True)
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 9, 5, 2])
y = f4(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[9]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 9 5 2]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[3 9 5 2]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[3 9 5 2]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[3 9 5 2]
FINAL RESULT:
[3 9 5 2 3 9 5 2 3 9 5 2 3 9 5 2]
打印结果表明,每个函数应用程序再次以其自己的参数值 x_block
块开始。在 all_gather
之后,它们具有一个公共值,该值是通过连接 x_block
的值来计算的。
(请注意,我们实际上不能在此处设置 out_specs=P()
。出于与自动微分相关的技术原因,我们认为 all_gather
的输出不能保证跨设备不变。如果我们希望它保证不变,我们可以使用 jax.lax.all_gather_invariant
,或者在这种情况下,我们可以避免在函数体中执行 all_gather
,而只是使用 out_specs=P('i')
来执行连接。)
当 tiled=False
(默认值) 时,结果沿新轴堆叠而不是连接
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f5(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.all_gather(x_block, 'i', tiled=False)
print('AFTER:\n', y_block)
return y_block
y = f5(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[9]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[[3]
[9]
[5]
[2]]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[[3]
[9]
[5]
[2]]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[[3]
[9]
[5]
[2]]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[[3]
[9]
[5]
[2]]
FINAL RESULT:
[[3]
[9]
[5]
[2]
[3]
[9]
[5]
[2]
[3]
[9]
[5]
[2]
[3]
[9]
[5]
[2]]
我们可以将 all_gather
的 collective_ref
参考语义函数编写为
def all_gather_ref(_, x_blocks, *, tiled=False):
combine = jnp.concatenate if tiled else jnp.stack
return [combine(x_blocks)] * len(x_blocks)
在深度学习中,我们可能会在完全分片数据并行 (FSDP) 中对参数使用 all_gather
。
psum_scatter
#
jax.lax.psum_scatter
collective 有点不太直观。它类似于 psum
,只是每个函数实例仅获得结果的一个分片
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f6(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True)
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f6(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[20]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[12]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[17]
FINAL RESULT:
[22 20 12 17]
如打印结果所示,每个结果 y_block
的大小都小于参数 x_block
,这与 psum
不同。此外,与 psum
相比,这里的每个 y_block
仅表示跨函数实例的 x_block
总和的一个切片。(即使每个函数实例仅获得总和的一个分片,最终输出 y
也与 psum
示例中的相同,因为此处我们使用 out_specs=P('i')
来连接每个函数实例的输出。)
就计算的值而言,collective_ref
参考实现可能看起来像
def psum_scatter_ref(i, x_blocks, *, tiled=False):
axis_size = len(x_blocks)
tot = sum(x_blocks)
if tiled:
tot = tot.reshape(axis_size, -1, *tot.shape[1:]) # split leading axis
return [tot[i] for i in range(tot.shape[0])]
语义参考实现中未捕获的是,psum_scatter
非常有用,因为与完整的 psum
相比,这些结果可以更有效地计算,通信更少。实际上,思考 psum_scatter
的一种方式是将其视为“psum
的前半部分,在 all_gather
之前”。也就是说,实现 psum
的一种方法是
def psum(x, axis_name):
summed_chunk = jax.lax.psum_scatter(x, axis_name)
return jax.lax.all_gather(summed_chunk, axis_name)
实际上,此实现通常在 TPU 和 GPU 上都使用!
psum_scatter
可能需要大约一半的通信量(如完整的 psum
)的原因在 ppermute
部分中进行了说明。
另一种直觉是,我们可以使用 psum_scatter
来实现分布式矩阵乘法,其输入和输出在同一轴上分片。在机器学习中,psum_scatter
可用于张量并行矩阵乘法或完全分片数据并行梯度累积,如后续示例所示。
ppermute
#
jax.lax.ppermute
collective 为函数实例相互发送数据提供了最直接的方式。给定一个网格轴和一个表示沿该网格轴的索引的 (source_index, destination_index)
对列表,ppermute
将其参数值从每个源函数实例发送到每个目标
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f7(x_block):
sz = jax.lax.axis_size('i')
print('BEFORE:\n', x_block)
y_block = jax.lax.ppermute(x_block, 'i', [(i, (i + 1) % sz) for i in range(sz)])
print('AFTER:\n', y_block)
return y_block
y = f7(jnp.arange(8))
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[0 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[2 3]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[4 5]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[6 7]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[6 7]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[0 1]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[2 3]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[4 5]
FINAL RESULT:
[6 7 0 1 2 3 4 5]
在这种情况下,只有两个函数实例,每个实例的 y_block
值都是另一个实例的 x_block
值。
源索引和目标索引不能重复。如果索引未作为目标出现,则相应函数实例结果的值为零数组。
collective_ref
参考实现可能看起来像
def ppermute_ref(i, x_blocks, perm):
results = [jnp.zeros_like(x_blocks[0])] * len(x_blocks)
for src, dst in perm:
results[dst] = x_blocks[src]
return results
可以使用 ppermute
有效地实现其他 collectives,就总通信量而言,其中每个函数仅将数据传递给其邻居。例如,我们可以使用一系列 ppermute
和本地加法来实现 psum_scatter
,如下所示
或者,使用数值示例
直观上,在每次迭代中,每个函数实例都“向上”发送其在上一次迭代中接收到的值,并减少(添加)其在此次迭代中接收到的值。在代码中,它可能看起来像这样
def psum_scatter(x, axis_name, *, tiled=False):
size = jax.lax.axis_size(axis_name)
idx = jax.lax.axis_index(axis_name) # function instance index along axis_name
if tiled:
x = x.reshape(size, -1, *x.shape[1:]) # split leading axis
shift = partial(jax.lax.ppermute, axis_name=axis_name,
perm=[(i, (i - 1) % size) for i in range(size)])
for i in range(1, size):
update = shift(x[(idx + i) % size])
x = x.at[(idx + i + 1) % size].add(update)
return x[idx]
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f8(x_block):
print('BEFORE:\n', x_block)
y_block = psum_scatter(x_block, 'i', tiled=True)
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f8(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[20]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[12]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[17]
FINAL RESULT:
[22 20 12 17]
在 TPU 上,有更高维度的此算法变体,以利用多个双向物理网格轴。
请注意,psum_scatter
是 all_gather
的转置。实际上,用 ppermute
实现 all_gather
的一种方法看起来像是上述过程的逆过程
在深度学习中,我们可能会在实现 SPMD 流水线并行时使用 ppermute
,其中我们将网络沿其深度划分为多个阶段,并并行评估阶段的应用。或者,我们可能会在并行化卷积层的评估时使用 ppermute
,其中我们跨空间轴进行分片,因此设备必须相互通信“光环”。或者,它可以在张量并行矩阵乘法中在后台使用。
all_to_all
#
最后一个 collective 是 all_to_all
,它本质上是沿一个位置轴和一个跨设备轴操作的块矩阵转置
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f9(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.all_to_all(x_block, 'i', split_axis=0, concat_axis=0,
tiled=True)
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f9(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 5 5 9]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[1 9 3 7]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[4 2 5 1]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[1 6 8 2]
FINAL RESULT:
[3 5 5 9 1 9 3 7 4 2 5 1 1 6 8 2]
split_axis
参数指示应跨网格轴分片和分区的哪个位置轴。concat_axis
参数指示应沿哪个轴连接或堆叠通信结果。
当 tiled=False
(默认值) 时,split_axis
轴大小必须等于名为 axis_name
的网格轴的大小,并且在该大小的新轴在位置 concat_axis
处创建,用于堆叠结果。当 tiled=True
时,split_axis
轴大小仅需能被网格轴的大小均匀整除,并且结果沿现有轴 concat_axis
连接。
当 split_axis=0
且 concat_axis=0
时,collective_ref
参考语义可能看起来像
def all_to_all_ref(_, x_blocks, *, tiled=False):
axis_size = len(x_blocks)
if tiled:
splits = [jnp.array_split(x, axis_size) for x in x_blocks]
return [jnp.concatenate(s) for s in zip(*splits)]
else:
splits = [list(x) for x in x_blocks]
return [jnp.stack(s) for s in zip(*splits)]
在深度学习中,我们可能会在混合专家路由中使用 all_to_all
,其中我们首先根据示例应发送给哪个专家对本地批次的示例进行排序,然后应用 all_to_all
将示例重新分发给专家。
玩具示例#
我们如何在实践中使用 shard_map
和 collective 通信?这些示例虽然很简单,但给出了一些想法。
矩阵乘法#
并行化矩阵乘法是扩展深度学习模型的中心,无论是用于训练还是推理。当 jax.jit
自动并行化矩阵乘法时,它可以根据矩阵大小、硬件详细信息和其他因素使用几种不同的策略之一。我们如何使用 shard_map
更明确地编写其中一些并行化例程?我们如何优化它们以获得更好的计算/通信重叠,从而提高 FLOP 利用率?
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental.shard_map import shard_map
mesh = Mesh(jax.devices()[:4], ('i',))
def device_put(x, pspec):
return jax.device_put(x, NamedSharding(mesh, pspec))
示例 1:一侧的 all-gather
#
考虑执行矩阵乘法,其中我们在左侧参数 (可以认为是:参数) 的前导 (非收缩) 维度上对其进行分片
lhs_spec = P('i', None)
lhs = device_put(jax.random.normal(jax.random.key(0), (8, 8)), lhs_spec)
我们在右侧参数 (可以认为是:激活) 的收缩维度上对其进行分片,输出也进行类似的分片
rhs_spec = P('i', None)
rhs = device_put(jax.random.normal(jax.random.key(1), (8, 4)), rhs_spec)
为了执行此矩阵乘法,我们可以首先 all-gather 右侧,然后针对分片的左侧执行本地矩阵乘法
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_allgather(lhs_block, rhs_block):
rhs = jax.lax.all_gather(rhs_block, 'i', tiled=True)
return lhs_block @ rhs
out = matmul_allgather(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
这很棒,但我们没有获得任何计算/通信重叠:在我们开始 matmul 之前,我们需要 all_gather
完成。以下是使用相同代码但在更大示例形状 ((8192, 8192)
用于 lhs
,(8192, 1024)
用于 rhs
) 下的配置文件
如果我们将 all_gather
的上述实现基本上以内联方式 (根据 ppermute
) 放在内部,然后将 gather 置换步骤与本地矩阵乘法交错,则可以获得计算/通信重叠
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_allgather_overlapped(lhs_block, rhs_block):
size = jax.lax.axis_size('i')
idx = jax.lax.axis_index('i')
shift = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i + 1) % size) for i in range(size)])
B = lhs_block.shape[1] // size
lhs_blocks = lambda i: lax.dynamic_slice_in_dim(lhs_block, i * B, B, 1)
out_block = lhs_blocks(idx) @ rhs_block
for i in range(1, size):
rhs_block = shift(rhs_block)
out_block += lhs_blocks((idx - i) % size) @ rhs_block
return out_block
out = matmul_allgather_overlapped(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
此实现允许通信和计算之间重叠,并且还避免了在每个设备上收集大型中间结果。但在 TPU 上,它仅使用一半的互连带宽,方法是仅沿环在一个方向上置换。为了双向置换,我们只需将块分成两半,然后在每个方向上发送一半
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_allgather_overlapped_bidi(lhs_block, rhs_block):
size = jax.lax.axis_size('i')
idx = jax.lax.axis_index('i')
shift_up = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i + 1) % size) for i in range(size)])
shift_dn = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i - 1) % size) for i in range(size)])
B = lhs_block.shape[1] // size // 2 # half-size blocks
lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 1)
rhs_block_lo, rhs_block_hi = jnp.split(rhs_block, 2, axis=0)
out_block = lhs_blocks(idx, 0) @ rhs_block_lo
out_block += lhs_blocks(idx, 1) @ rhs_block_hi
for i in range(1, size):
rhs_block_lo = shift_up(rhs_block_lo)
rhs_block_hi = shift_dn(rhs_block_hi)
out_block += lhs_blocks((idx - i) % size, 0) @ rhs_block_lo
out_block += lhs_blocks((idx + i) % size, 1) @ rhs_block_hi
return out_block
out = matmul_allgather_overlapped_bidi(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
在实践中,为了减少编译时间,我们可能会将其滚动到 jax.lax.fori_loop
中。我们可能还会涉及其他并行轴。
示例 2:psum_scatter
结果#
我们可能开始使用的另一种分片方式是在 lhs
和 rhs
的收缩维度上对其进行分片,输出再次像 rhs
一样分片
lhs_spec = P(None, 'i')
lhs = device_put(lhs, lhs_spec)
rhs_spec = P('i', None)
rhs = device_put(rhs, rhs_spec)
在这里,我们可以使用 reduce_scatter
来执行分片上的收缩求和
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_psumscatter(lhs_block, rhs_block):
out_summand = lhs_block @ rhs_block
return jax.lax.psum_scatter(out_summand, 'i', tiled=True)
out = matmul_psumscatter(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
但 scattering 通信必须等待整个本地矩阵乘法完成才能开始。为了获得通信/计算重叠,我们可以以内联方式 (根据 ppermute
) 实现 psum_scatter
,然后将通信步骤与本地矩阵乘法交错
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_psumscatter_overlapped(lhs_block, rhs_block):
size = jax.lax.axis_size('i')
idx = jax.lax.axis_index('i')
shift = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i - 1) % size) for i in range(size)])
lhs_block = lhs_block.reshape(size, -1, lhs_block.shape[1]) # split 1st axis
out_summand = lhs_block[(idx + 1) % size] @ rhs_block
for i in range(1, size):
out_summand = shift(out_summand)
out_summand += lhs_block[(idx + i + 1) % size] @ rhs_block
return out_summand
out = matmul_psumscatter_overlapped(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
与上一个示例中一样,为了充分利用 TPU 上的互连,我们将运行双向版本
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_psumscatter_overlapped_bidi(lhs_block, rhs_block):
size = jax.lax.axis_size('i')
idx = jax.lax.axis_index('i')
shift_up = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i + 1) % size) for i in range(size)])
shift_dn = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i - 1) % size) for i in range(size)])
B = lhs_block.shape[0] // size // 2 # half-size blocks
lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 0)
out_summand_lo = lhs_blocks((idx - 1) % size, 0) @ rhs_block
out_summand_hi = lhs_blocks((idx + 1) % size, 1) @ rhs_block
for i in range(1, size):
out_summand_lo = shift_up(out_summand_lo)
out_summand_hi = shift_dn(out_summand_hi)
out_summand_lo += lhs_blocks((idx - i - 1) % size, 0) @ rhs_block
out_summand_hi += lhs_blocks((idx + i + 1) % size, 1) @ rhs_block
return jnp.concatenate([out_summand_lo, out_summand_hi])
out = matmul_psumscatter_overlapped_bidi(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
神经网络#
我们可以使用 shard_map
来并行化神经网络中的计算,可以单独使用,也可以与 jax.jit
中的自动分区结合使用。本节有一些基于此玩具神经网络和随机数据的示例
import jax
import jax.numpy as jnp
def predict(params, inputs):
for W, b in params:
outputs = jnp.dot(inputs, W) + b
inputs = jax.nn.relu(outputs)
return outputs
def loss(params, batch):
inputs, targets = batch
predictions = predict(params, inputs)
return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
def init_layer(key, n_in, n_out):
k1, k2 = jax.random.split(key)
W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
b = jax.random.normal(k2, (n_out,))
return W, b
def init(key, layer_sizes, batch_size):
key, *keys = jax.random.split(key, len(layer_sizes))
params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))
key, *keys = jax.random.split(key, 3)
inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))
return params, (inputs, targets)
layer_sizes = [784, 128, 128, 128, 128, 128, 8]
batch_size = 32
params, batch = init(jax.random.key(0), layer_sizes, batch_size)
将这些示例与 “分布式数组和自动并行化”文档中的纯 自动分区示例进行比较。虽然在那些自动分区示例中,我们不需要编辑模型函数即可使用不同的并行化策略,但在使用 shard_map
时,我们通常需要这样做。
8 路批数据并行#
最简单的多设备并行策略是在多个设备上分片输入和目标的批次,在这些设备上复制参数,并将模型并行应用于这些数据分片。为了评估总损失,设备只需要在最后与标量大小的 all-reduce-sum 进行通信。(为了评估损失的梯度,设备必须在反向传播中执行参数梯度的 all-reduce-sum。)
from functools import partial
from jax.sharding import NamedSharding, Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map
mesh = jax.make_mesh((8,), ('batch',))
# replicate initial params on all devices, shard data batch over devices
batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
params = jax.device_put(params, NamedSharding(mesh, P()))
# adapt the loss function to sum the losses across devices
def loss_dp(params, batch):
@partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P())
def loss_spmd(local_batch):
inputs, targets = local_batch
predictions = predict(params, inputs) # use reference 'predict`
local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
return jax.lax.pmean(local_loss, 'batch')
return loss_spmd(batch)
我们可以检查损失及其梯度是否与参考 (基础) 模型匹配
print(jax.jit(loss)(params, batch))
print(jax.jit(loss_dp)(params, batch))
11.920298
11.920298
def allclose(a, b):
return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))
print(allclose(jax.jit(jax.grad(loss))(params, batch),
jax.jit(jax.grad(loss_dp))(params, batch)))
True
我们可以打印编译器 IR 以检查梯度计算,并验证 collective all-reduce-sum 操作是否发生在我们期望的位置:在正向传播结束时计算损失值,在反向传播中计算总参数梯度。
8 路完全分片数据并行 (FSDP)#
另一种策略是在设备上额外分片参数,在 jnp.dot
或偏差加法需要完整值时 all-gather 每个参数。由于我们一次本地设备内存中只有一个完整参数,而不是像前面的 DP 示例中那样在所有设备内存中保留所有参数,因此我们释放了大量内存,可以用于更大的模型或更大的批次大小。并且由于 XLA 将重叠计算和设备间通信,因此挂钟时间不会受到影响。
因此,现在我们需要两个地方的 collectives:模型预测函数 predict
需要在参数使用之前 all-gather 参数,并且与 DP 情况一样,损失函数需要对本地损失求和以计算总损失。
我们还需要另一个要素:我们不希望存储来自正向传播的完全收集的参数,以用于反向传播。相反,我们希望在反向传播中再次收集它们。我们可以通过将 jax.remat
与 自定义策略 (或 custom_vjp
) 一起使用来表达这一点,尽管 XLA 通常会自动执行该重物化。
这种通用的 FSDP 方法类似于 权重更新分片 (WUS) 和 ZeRO-3。
# shard data batch *and params* over devices
mesh = Mesh(devices, ('batch',))
batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
params = jax.device_put(params, NamedSharding(mesh, P('batch')))
# adapt the prediction function to gather weights just before their use,
# and to re-gather them on the backward pass (rather than saving them)
@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')
def predict_fsdp(params_frag, inputs):
for W_frag, b_frag in params_frag:
W = jax.lax.all_gather(W_frag, 'batch', tiled=True)
b = jax.lax.all_gather(b_frag, 'batch', tiled=True)
outputs = jnp.dot(inputs, W) + b
inputs = jax.nn.relu(outputs)
return outputs
def loss_fsdp(params, batch):
@partial(shard_map, mesh=mesh, in_specs=P('batch'), out_specs=P())
def loss_spmd(local_params, local_batch):
inputs, targets = local_batch
predictions = predict_fsdp(local_params, inputs)
local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
return jax.lax.pmean(local_loss, 'batch')
return loss_spmd(params, batch)
再次,我们可以检查损失及其梯度是否与参考模型匹配
print(jax.jit(loss)(params, batch))
print(jax.jit(loss_fsdp)(params, batch))
print(allclose(jax.jit(jax.grad(loss))(params, batch),
jax.jit(jax.grad(loss_fsdp))(params, batch)))
11.920298
11.920298
True
8路张量并行 (TP)#
通常我们不单独使用张量模型并行,但单独来看它是并行矩阵乘法的一个很好的预热。 这也是在库函数中使用 shard_map
的一个很好的例子,它在一个更大的基于 jit
的计算中被调用。
并行化的思想是我们将保持数据/激活在其特征轴上分片(而不是在其批次轴上),并且我们将类似地在权重矩阵的输入特征轴上(以及偏差在其特征轴上)分片。 然后为了执行并行矩阵乘法,我们将执行本地矩阵乘法,然后执行 psum_scatter
以对本地结果求和并有效地分散结果的分片。
mesh = jax.make_mesh((8,), ('feats',))
batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats')))
params = jax.device_put(params, NamedSharding(mesh, P('feats')))
def predict_tp(params, inputs):
for W, b in params:
outputs = gemm_tp(inputs, W, b)
inputs = jax.nn.relu(outputs)
return outputs
@partial(shard_map, mesh=mesh,
in_specs=(P(None, 'feats'), P('feats', None), P('feats')),
out_specs=P(None, 'feats'))
def gemm_tp(inputs, W, b):
block_result = jnp.dot(inputs, W)
return jax.lax.psum_scatter(block_result, 'feats',
scatter_dimension=1, tiled=True) + b
def loss_tp(params, batch):
inputs, targets = batch
predictions = predict_tp(params, inputs)
return jnp.mean(jnp.sum((predictions - targets) ** 2, axis=-1)) # NOTE psum!
FSDP + TP,在顶层使用 shard_map
#
我们可以将这些策略组合在一起,使用多个并行轴。
mesh = jax.make_mesh((4, 2), ('batch', 'feats'))
batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats')))
params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats'))))
# mostly same as previous predict_fsdp definition, except we call gemm_tp
@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')
def predict_fsdp_tp(params_frag, inputs):
for W_frag, b_frag in params_frag:
W = jax.lax.all_gather(W_frag, 'batch', tiled=True)
b = jax.lax.all_gather(b_frag, 'batch', tiled=True)
block_result = jnp.dot(inputs, W)
outputs = jax.lax.psum_scatter(block_result, 'feats',
scatter_dimension=1, tiled=True) + b
inputs = jax.nn.relu(outputs)
return outputs
@partial(shard_map, mesh=mesh,
in_specs=(P(('feats', 'batch')), P('batch', 'feats')),
out_specs=P())
def loss_fsdp_tp(local_params, local_batch):
inputs, targets = local_batch
predictions = predict_fsdp_tp(local_params, inputs)
sq_err = jax.lax.psum(jnp.sum((predictions - targets)**2, axis=-1), 'feats')
return jax.lax.pmean(jnp.mean(sq_err), 'batch')
请注意,我们必须进行两次集合规约:一次在 'feats'
上,一次在 'batch'
上。 在纯 TP 示例中,我们没有显式地编写 'feats'
规约,因为我们只在 gemm_tp
中使用了 shard_map
;在调用者 loss_tp
中,编译器根据 predict_tp
返回的分片结果,自动将我们对 jnp.sum
的使用转换为执行所需的 psum
。
print(jax.jit(loss)(params, batch))
print(jax.jit(loss_fsdp_tp)(params_, batch_))
print(allclose(jax.jit(jax.grad(loss))(params, batch),
jax.jit(jax.grad(loss_fsdp_tp))(params, batch)))
11.9203
11.920298
True
SPMD 流水线并行 (PP)#
通过流水线并行,我们的目标是并行化网络中不同深度的层的评估。 例如,一个设备可能计算第一层的应用,而另一个设备计算第二层的应用;当它们完成时,第一个设备将其结果传递给第二个设备,而第二个设备将其结果传递给负责第三层的设备,并且该过程重复进行。 通常,流水线阶段的数量可能与层的数量不同,因为每个阶段可能负责多个层。
通过 SPMD 流水线,我们利用了网络中大多数层都应用计算,只是参数值不同的事实。 特别是,我们可以将除第一层和最后一层之外的所有参数堆叠在一起,然后使用 shard_map
来映射这些层参数的块,其中每个参数块对应于一个流水线阶段。 然后,我们使用 jax.lax.ppermute
集体操作将数据向下移动到并行流水线中。
这种特定的流水线策略本质上是 GPipe 策略。 有几种变体以及完全不同的策略,哪种策略合适可能取决于阶段之间网络的速度和批次大小。 但在本教程中,我们将只关注一种策略。
首先,我们选择一些流水线参数
L = len(params) - 2 # num layers, excluding first and last
N = batch_size # batch size
F = params[0][0].shape[1] # num features
# choose some pipeline parameters
S = 2 # number of stages
B = 8 # size of each microbatch
assert L % S == 0, "S (number of stages) must divide L (number of inner layers)"
# compute some useful quantities
M, ragged = divmod(N, B) # M is number of microbatches
assert not ragged, "B (size of each microbatch) must divide total batch size"
K, ragged = divmod(M, S) # K is microbatches per stage
assert not ragged, "S (number of stages) must divide number of microbatches"
print(f'{S} stages, {L // S} layer(s) per stage, {L} pipelined layers total')
print(f'{B} examples per microbatch, {M} microbatches total')
2 stages, 2 layer(s) per stage, 4 pipelined layers total
8 examples per microbatch, 4 microbatches total
mesh = Mesh(jax.devices()[:S], ('stages',))
def predict_pp(params, inputs):
(W_first, b_first), inner_params, (W_last, b_last) = params
inputs = jax.nn.relu(jnp.dot(inputs, W_first) + b_first)
inputs = spmd_pipeline(lambda Wb, x: jax.nn.relu(x @ Wb[0] + Wb[1]),
inner_params, inputs)
outputs = jnp.dot(inputs, W_last) + b_last
return outputs
@partial(shard_map, mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')),
out_specs=P())
def loss_pp(params, batch):
inputs, targets = batch
predictions = predict_pp(params, inputs.reshape(K, B, -1)).reshape(K * B, -1)
local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
return jax.lax.pmean(local_loss, 'stages')
def spmd_pipeline(fn, stage_params, inputs):
stage = jax.lax.axis_index('stages')
outputs = jnp.zeros_like(inputs) * jnp.nan
state = jnp.zeros((L // S, B, F)) * jnp.nan
for i in range(M+L-1):
state = state.at[0].set(jnp.where(stage == 0, inputs[i % K], state[0]))
state = jax.vmap(fn)(stage_params, state)
outputs = outputs.at[(i-L+1) % K].set(jnp.where(stage == S-1, state[-1], outputs[(i-L+1) % K]))
state, inputs, outputs = shift(i, state, inputs, outputs)
outputs = jax.lax.ppermute(outputs, 'stages', [(i, (i+1) % S) for i in range(S)])
return outputs
def shift(i, state, inputs, outputs):
sh = lambda x, d: jax.lax.ppermute(x, 'stages', [(i, (i+d) % S) for i in range(S)])
state = jnp.roll(state, +1, axis=0).at[0].set(sh(state[-1], +1))
if (i % K) == (-1 % K):
inputs = sh(inputs, +1)
if ((i-L+1) % K) == (-1 % K):
outputs = sh(outputs, +1)
return state, inputs, outputs
first_params, *inner_params, last_params = params
Ws, bs = zip(*inner_params)
params_stacked = jnp.stack(Ws), jnp.stack(bs)
first_params = jax.device_put(first_params, NamedSharding(mesh, P()))
params_stacked = jax.device_put(params_stacked, NamedSharding(mesh, P('stages')))
last_params = jax.device_put(last_params, NamedSharding(mesh, P()))
params_ = first_params, params_stacked, last_params
batch_ = jax.device_put(batch, NamedSharding(mesh, P('stages')))
print(jax.jit(loss)(params, batch))
print(jax.jit(loss_pp)(params_, batch_))
11.9203
11.920298
_ = jax.jit(jax.grad(loss_pp))(params_, batch_) # don't crash