分布式数组与自动并行化#
JAX 拥有三种多设备分布式并行处理风格,它们可以混合和组合使用。它们的区别在于编译器自动决策的程度,以及程序中显式控制的程度。
基于编译器的自动分片:在这种模式下,您可以像使用单台“全局视图”机器一样进行编程,编译器会选择如何分片数据(通过
with_sharding_constraint提供一些用户约束),以及如何通过集合通信将计算划分为针对每个设备的程序。显式分片与自动分区:在这种模式下,您仍然拥有全局视图,但数据分片在 JAX 类型中是显式的,可以使用
jax.typeof进行检查。计算的分区工作仍然由编译器完成。手动逐设备编程:在这种模式下,您拥有针对每个设备的数据和计算视图,并编写显式的通信集合操作,如
jax.lax.psum。
模式 |
视图? |
显式分片? |
显式集合操作? |
|---|---|---|---|
自动 (Auto) |
全局 |
❌ |
❌ |
显式 |
全局 |
✅ |
❌ |
手动 (Manual) |
逐设备 |
✅ |
✅ |
在深入细节之前,这里有一个使用显式模式的简单示例。首先,我们创建一个跨多个设备分片的 jax.Array
from __future__ import annotations
import enum
import jax
import jax.numpy as jnp
jax.config.update('jax_num_cpu_devices', 8)
jax.set_mesh(jax.make_mesh((4, 2), ('X', 'Y'))) # explicit mode by default
x = jnp.arange(8 * 4.).reshape(8, 4)
x = jax.device_put(x, jax.P('X', 'Y'))
print(jax.typeof(x)) # f32[8@X, 4@Y]
float32[8@X,4@Y]
jax.debug.visualize_array_sharding(x)
CPU 0 CPU 1 CPU 2 CPU 3 CPU 4 CPU 5 CPU 6 CPU 7
接下来,我们将对它进行计算,并观察到结果值也存储在多个设备上
y = jnp.sin(x).T
print(jax.typeof(y)) # f32[4@Y, 8@X]
float32[4@Y,8@X]
jnp.sin 和转置计算被自动并行化分布在存储输入值(和输出值)的设备上。
为了理解这些模式以及如何在它们之间切换,我们首先需要理解网格(meshes)。
Mesh 是一个具有命名轴的设备网格#
为了描述数据和计算如何跨设备分布,我们首先将设备组织成一个称为 Mesh 的多维网格。由于通信沿着网格轴发生,网格形状和设备顺序会决定通信性能。网格应该反映设备之间的物理连接拓扑。
我们区分具体 (concrete) 网格和抽象 (abstract) 网格。抽象网格仅包含形状、轴名称以及反映每个轴模式的轴类型。
class AbstractMesh:
axis_sizes: tuple[int, ...]
axis_names: tuple[str, ...]
axis_types: tuple[AxisType, ...]
class AxisType(enum.Enum):
Auto = enum.auto()
Explicit = enum.auto()
Manual = enum.auto()
# A concrete mesh additionally includes physical device objects with e.g.
# precise coordinates:
import numpy as np
class Mesh:
devices: np.ndarray[jax.Device]
axis_names: tuple[str, ...]
axis_types: tuple[AxisType, ...]
@property
def axis_sizes(self) -> tuple[int, ...]:
return self.devices.shape
在程序的顶层(即不在 jit 内部),我们可以使用类构造函数直接创建具体的 Mesh,这允许我们指定确切的设备顺序;或者使用 jax.make_mesh 辅助函数,它会通过考虑底层硬件拓扑自动选择设备顺序。
mesh = jax.make_mesh((4, 2), ('X', 'Y'))
print(mesh)
Mesh('X': 4, 'Y': 2, axis_types=(Explicit, Explicit))
默认情况下,所有网格轴类型均为 AxisType.Explicit。
为了避免在整个程序中传递 mesh,请使用 jax.set_mesh 全局设置一个具体网格。
jax.set_mesh(mesh)
<jax._src.sharding_impls.set_mesh at 0x7d4f1cca7040>
您也可以使用 with jax.set_mesh(mesh): ... 作为上下文管理器。仅在顶层,可以使用 jax.get_mesh() -> jax.sharding.Mesh 查询具体网格。
在 jit 内部,只能查询和更改抽象网格。使用 jax.sharding.get_abstract_mesh() -> jax.sharding.AbstractMesh 查询当前的抽象网格,并使用 with jax.sharding.use_abstract_mesh(m: AbstractMesh): ... 在上下文中更改抽象网格。轴大小、轴名称和轴类型可以更改,但网格的总大小(即轴大小的乘积)不得更改。
我们尚未解释分片,但这里有一个在 jax.jit 内部更改抽象网格的玩具示例。
@jax.jit
def f(x):
abstract_mesh = jax.sharding.AbstractMesh((8,), ('A',), (jax.sharding.AxisType.Explicit,))
with jax.sharding.use_abstract_mesh(abstract_mesh):
y = jax.reshard(x, jax.P('A', None))
return y * 2
z = f(x)
print(jax.typeof(z)) # f32[8@A, 4]
float32[8@A,4]
Sharding 描述了数组值如何在 Mesh 上分布#
jax.sharding.Sharding 描述了分布式内存布局。也就是说,它描述了数组的条目如何存储在不同设备的物理内存中,即它如何跨设备分片。
在顶层,每个 jax.Array 都有一个关联的 Sharding,它由一个具体的 Mesh 和一个 jax.sharding.PartitionSpec(别名为 jax.P)组成。
print(x.sharding)
jax.debug.visualize_array_sharding(x)
NamedSharding(mesh=Mesh('X': 4, 'Y': 2, axis_types=(Explicit, Explicit)), spec=P('X', 'Y'), memory_kind=device)
CPU 0 CPU 1 CPU 2 CPU 3 CPU 4 CPU 5 CPU 6 CPU 7
这里,PartitionSpec('X', 'Y') 表示数组 x 的第一轴和第二轴分别在网格轴 ‘X’ 和 ‘Y’ 上进行分片。我们可以使用 addressable_shards 查看这如何转换为物理存储。
for s in x.addressable_shards:
print(s.device, s.data, sep='\n', end='\n\n')
cpu:0
[[0. 1.]
[4. 5.]]
cpu:1
[[2. 3.]
[6. 7.]]
cpu:2
[[ 8. 9.]
[12. 13.]]
cpu:3
[[10. 11.]
[14. 15.]]
cpu:4
[[16. 17.]
[20. 21.]]
cpu:5
[[18. 19.]
[22. 23.]]
cpu:6
[[24. 25.]
[28. 29.]]
cpu:7
[[26. 27.]
[30. 31.]]
我们可以使用 jax.device_put(或 jax.reshard)来生成一个新的数组,该数组在相同的设备网格上分片,但具有由 jax.P 指定的不同布局。(jax.device_put 是一个运行时级 API,功能比 jax.reshard 更丰富。)由于我们通过上面的 jax.set_mesh 设置了网格上下文,我们可以将 jax.P 实例直接传递给 jax.device_put。
y = jax.device_put(x, jax.P('Y', 'X'))
print(y.sharding)
jax.debug.visualize_array_sharding(y)
NamedSharding(mesh=Mesh('X': 4, 'Y': 2, axis_types=(Explicit, Explicit)), spec=P('Y', 'X'), memory_kind=device)
CPU 0 CPU 2 CPU 4 CPU 6 CPU 1 CPU 3 CPU 5 CPU 7
y = jax.device_put(x, jax.P('X', None))
print(y.sharding)
jax.debug.visualize_array_sharding(y)
NamedSharding(mesh=Mesh('X': 4, 'Y': 2, axis_types=(Explicit, Explicit)), spec=P('X', None), memory_kind=device)
CPU 0,1 CPU 2,3 CPU 4,5 CPU 6,7
这里,因为网格轴名称 ‘Y’ 没有在 jax.P('X', None) 中提及,所以数组在网格轴 ‘Y’ 上进行了复制。(作为简写,末尾的 None 占位符可以省略,因此这里的 P(‘X’, None) 与 P(‘X’) 含义相同。但显式写出也无妨!)
for s in y.addressable_shards:
print(s.device, s.data, sep='\n', end='\n\n')
cpu:0
[[0. 1. 2. 3.]
[4. 5. 6. 7.]]
cpu:1
[[0. 1. 2. 3.]
[4. 5. 6. 7.]]
cpu:2
[[ 8. 9. 10. 11.]
[12. 13. 14. 15.]]
cpu:3
[[ 8. 9. 10. 11.]
[12. 13. 14. 15.]]
cpu:4
[[16. 17. 18. 19.]
[20. 21. 22. 23.]]
cpu:5
[[16. 17. 18. 19.]
[20. 21. 22. 23.]]
cpu:6
[[24. 25. 26. 27.]
[28. 29. 30. 31.]]
cpu:7
[[24. 25. 26. 27.]
[28. 29. 30. 31.]]
通过在 PartitionSpec 中使用轴名称元组,我们可以将一个数组轴在多个网格轴上进行分片。
y = jax.device_put(x, jax.P(('X', 'Y')))
print(y.sharding)
jax.debug.visualize_array_sharding(y)
NamedSharding(mesh=Mesh('X': 4, 'Y': 2, axis_types=(Explicit, Explicit)), spec=P(('X', 'Y'),), memory_kind=device)
CPU 0 CPU 1 CPU 2 CPU 3 CPU 4 CPU 5 CPU 6 CPU 7
因此,数组数据可以在网格轴上进行复制,或者数组的一个轴可以在该网格轴上进行分片,但还有另一种可能性:数组可以在网格轴上保持未归约 (unreduced) 状态。
y = jax.device_put(x, jax.P('X', None, unreduced={'Y'}))
print(y.sharding)
NamedSharding(mesh=Mesh('X': 4, 'Y': 2, axis_types=(Explicit, Explicit)), spec=P('X', None, unreduced={'Y'}), memory_kind=device)
未归约意味着逻辑值等于物理分片值在该轴上的分布式求和。
for s in y.addressable_shards:
print(s.device, s.data, sep='\n', end='\n\n')
cpu:0
[[0. 1. 0. 0.]
[4. 5. 0. 0.]]
cpu:1
[[0. 0. 2. 3.]
[0. 0. 6. 7.]]
cpu:2
[[ 8. 9. 0. 0.]
[12. 13. 0. 0.]]
cpu:3
[[ 0. 0. 10. 11.]
[ 0. 0. 14. 15.]]
cpu:4
[[16. 17. 0. 0.]
[20. 21. 0. 0.]]
cpu:5
[[ 0. 0. 18. 19.]
[ 0. 0. 22. 23.]]
cpu:6
[[24. 25. 0. 0.]
[28. 29. 0. 0.]]
cpu:7
[[ 0. 0. 26. 27.]
[ 0. 0. 30. 31.]]
未归约对于延迟分布式归约(reduction)很有用,特别是在自动微分的上下文中。稍后会详细介绍。
请注意,由于每个数组都有自己的 Sharding 实例,并且每个 Sharding 实例都有自己的 Mesh 实例,作用域内的数组可以关联到不同的网格。为了说明这一点,我们可以使用带有完整 jax.NamedSharding 实例参数的 jax.device_put,而不是使用上下文中的网格。
mesh2 = jax.make_mesh((8,), ('A',))
z = jax.device_put(x, jax.NamedSharding(mesh2, jax.P('A', None)))
print(z.sharding)
print(y.sharding)
NamedSharding(mesh=Mesh('A': 8, axis_types=(Explicit,)), spec=P('A', None), memory_kind=device)
NamedSharding(mesh=Mesh('X': 4, 'Y': 2, axis_types=(Explicit, Explicit)), spec=P('X', None, unreduced={'Y'}), memory_kind=device)
现在我们了解了顶层的网格形状、轴名称和分片,接下来我们可以深入探讨网格轴类型,以及显式模式和自动模式的区别。
显式分片模式使分片在追踪期间可查询#
在显式分片模式下,分片始终可以通过 jax.typeof 查询,即使是在 jax.jit 内部也是如此。
print(jax.typeof(x).sharding)
NamedSharding(mesh=AbstractMesh('X': 4, 'Y': 2, axis_types=(Explicit, Explicit), device_kind=cpu, num_cores=None), spec=P('X', 'Y'))
jax.jit(lambda x: print(jax.typeof(x).sharding))(x)
NamedSharding(mesh=AbstractMesh('X': 4, 'Y': 2, axis_types=(Explicit, Explicit), device_kind=cpu, num_cores=None), spec=P('X', 'Y'))
我们也称这种模式为“类型中的分片”。
就打印表示而言,类型语言大致如下:
<array_type> ::= <dtype>[<size_and_sharding>, ...]
<size_and_sharding> ::= <size> | <size>@<MeshAxisName>
其中:
作用域内的 MeshAxisName 是来自
jax.typeof(x).sharding.mesh的那些名称。每个 MeshAxisName 必须是 Explicit 轴类型。
每个 MeshAxisName 在数组类型中最多只能提及一次。
这些与 JAX 级类型关联的分片会通过操作进行传播。例如:
arg0 = jax.device_put(np.arange(4).reshape(4, 1), jax.P("X", None))
arg1 = jax.device_put(np.arange(8).reshape(1, 8), jax.P(None, "Y"))
result = arg0 + arg1
print(f"{jax.typeof(arg0)=!s}")
print(f"{jax.typeof(arg1)=!s}")
print(f"{jax.typeof(result)=!s}")
jax.typeof(arg0)=int32[4@X,1]
jax.typeof(arg1)=int32[1,8@Y]
jax.typeof(result)=int32[4@X,8@Y]
我们可以在 jit 内部进行相同的类型查询。
@jax.jit
def add_arrays(x, y):
ans = x + y
print(f"{jax.typeof(arg0)=!s}")
print(f"{jax.typeof(arg1)=!s}")
print(f"{jax.typeof(result)=!s}")
return ans
add_arrays(arg0, arg1)
jax.typeof(arg0)=int32[4@X,1]
jax.typeof(arg1)=int32[1,8@Y]
jax.typeof(result)=int32[4@X,8@Y]
Array([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 1, 2, 3, 4, 5, 6, 7, 8],
[ 2, 3, 4, 5, 6, 7, 8, 9],
[ 3, 4, 5, 6, 7, 8, 9, 10]], dtype=int32)
给定输入和输出分片,计算本身会自动在设备上进行分区。编译器会根据需要插入通信操作。例如:
x = jax.random.normal(jax.random.key(0), (8, 4),
out_sharding=jax.P('X', 'Y'))
print(jax.typeof(x))
float32[8@X,4@Y]
y = x.sum(0)
print(jax.typeof(y))
float32[4@Y]
在这里,当对计算进行分区时,编译器会自动插入通信集合操作以执行归约。
compile_txt = jax.jit(lambda x: x.sum(0)).lower(x).compile().as_text()
print('all-reduce(' in compile_txt)
True
结果分片遵循简单规则,否则会报错并要求进行标注#
每个原始操作都有一个分片传播规则,用于根据输入分片确定结果的分片。如果没有明显输出分片,则会引发错误。其目标是让重要的并行化决策显现出来,而不是隐藏它们,以免您意外错过。换句话说,分片传播规则倾向于报错并要求标注,而不是回退到随意选择的默认值。
每个操作都能实现自己的分片传播规则,但通常的模式是:
对于每个输出数组轴,将其与零个或多个对应的输入数组轴进行标识。
如果所有这些输入轴的分片方式相同,则以相同方式分片输出轴;否则,报错(并要求显式的
out_sharding参数)。在确定所有输出数组轴后,如果输出数组分片多次提及同一个网格轴,则报错(并要求显式的
out_sharding)。
以下是一些示例规则:
空元操作,如
jnp.zeros、jnp.arange:这些操作凭空创建数组,因此没有可传播的输入分片。除非被out_sharding关键字参数覆盖,否则其输出默认是不分片的。一元逐元素操作,如
sin、exp:输出的分片方式与输入相同。二元操作(
+、-、*等):被“压缩”维度的轴分片必须匹配(或为 None)。“外积”维度(仅出现在一个参数中的维度)的分片方式与其在输入中相同。如果结果最终提及同一个网格轴多次,则为错误。
收缩操作(如 jnp.dot 和 jnp.einsum)也有一些有趣的情况。例如,jnp.dot(x: f32[8,4@X], y:f32[4@X,16]) 的结果(其中共享的收缩轴分片方式相同)合理的结果可能是:
f32[8,16](执行 all-reduce)f32[8@X,16](在第一轴上进行 reduce-scatter)f32[8,16@X](在第二轴上进行 reduce-scatter)f32[8,16]{U:X}(无通信)
JAX 在这种情况下不会自动选择一个,而是报错并要求提供 out_sharding,例如 jnp.dot(x, y, out_sharding=jax.P('X', None))。
x = jax.device_put(jnp.arange(8 * 4.).reshape(8, 4), jax.P(None, 'X'))
y = jax.device_put(jnp.arange(4 * 16.).reshape(4, 16), jax.P('X', None))
try:
jnp.dot(x, y)
except Exception as e:
print("ERROR!")
print(e)
ERROR!
Contracting dimensions are sharded and it is ambiguous how the output should be sharded. Please specify the output sharding via the `out_sharding` parameter. Got lhs_contracting_spec=('X',) and rhs_contracting_spec=('X',)
z = jnp.dot(x, y, out_sharding=jax.P('X', None))
print(jax.typeof(z))
float32[8@X,16]
但也存在其他导致通信的 jnp.dot 情况,JAX 会自动执行这些通信,例如 jnp.dot(x:f32[8,4], y:f32[4@x,16]) 会导致 f32[8,16],这很可能是通过在 y 上执行 all-gather(类似于 FSDP)来实现的。
使用 @auto_axes,编译器会在被装饰的函数内选择分片#
如果您不想指定某些中间变量的分片,而是希望编译器自动选择,可以使用 @auto_axes 装饰器。在此装饰器下,分片无法使用 jax.typeof 进行查询。更具体地说,auto_axes 将部分或全部网格轴类型切换为 Auto,而 Auto 网格轴不能出现在数组类型中。
用 @auto_axes 装饰函数会向函数的签名中添加一个 out_sharding 参数,因此最终的输出分片可以由调用者设置。或者,使用 @auto_axes(out_sharding=...) 在函数定义处指定最终输出分片。
例如,当我们的网格轴为 Explicit 时,我们无法将两个分片方式不同的数组相加。
from jax.sharding import auto_axes, explicit_axes
x = jax.device_put(np.arange(16).reshape(4, 4), jax.P("X", None))
y = jax.device_put(np.arange(16).reshape(4, 4), jax.P(None, "X"))
try:
x + y
except Exception as e:
print("ERROR!")
print(e)
ERROR!
add operation with inputs: i32[4@X,4], i32[4,4@X] produces an illegally sharded result: i32[4@X,4@X]
如果我们只想指定结果的分片,并让编译器处理其余部分,我们可以使用 auto_axes。
@auto_axes
def add2(x, y):
print("We're in auto-sharding mode here. This is the current mesh:\n"
f"{jax.sharding.get_abstract_mesh()}")
return x + y
result = add2(x, y, out_sharding=jax.P("X", None))
print(f"Result type: {jax.typeof(result)}")
We're in auto-sharding mode here. This is the current mesh:
AbstractMesh('X': 4, 'Y': 2, axis_types=(Auto, Auto), device_kind=cpu, num_cores=None)
Result type: int32[4@X,4]
因此 auto_axes 允许您向任何操作组合添加 out_sharding 参数。
当上下文网格的轴类型为 Explicit 或 Auto 时,可以调用 auto_axes 装饰的函数,但不能处于 Manual 模式。默认情况下,它将所有网格轴类型切换为 Auto;使用 axes=... 可仅切换子集。
自动分片模式在编译期间自动决定分片#
虽然 auto_axes 装饰器对于将网格轴类型从 Explicit 临时切换到 Auto 很有用,但您也可以在顶层构造一个具有 Auto 轴类型的 Mesh。
Auto = jax.sharding.AxisType.Auto
auto_mesh = jax.make_mesh((4, 2), ('X', 'Y'), (Auto, Auto))
jax.set_mesh(auto_mesh)
x = jax.device_put(jnp.arange(8 * 4. ).reshape(8, 4 ), jax.P(None, 'X'))
y = jax.device_put(jnp.arange(4 * 16.).reshape(4, 16), jax.P('X', None))
z = jnp.dot(x, y) # not an error!
编译器没有报错,而是自动决定了结果的分片!
print(z.sharding) # works at the top-level only (i.e. outside `jit`)
NamedSharding(mesh=Mesh('X': 4, 'Y': 2, axis_types=(Auto, Auto)), spec=P(), memory_kind=device)
无论使用顶层 Auto 网格轴,还是使用 auto_axes 装饰器,您都可以使用 jax.lax.with_sharding_constraint 为编译器提供关于中间变量应如何分片的提示。
@jax.jit
def f(x, y):
z = jnp.dot(x, y)
z = jax.lax.with_sharding_constraint(z, jax.P('X', None))
return z
z = f(x, y)
print(z.sharding)
NamedSharding(mesh=Mesh('X': 4, 'Y': 2, axis_types=(Auto, Auto)), spec=P('X',), memory_kind=device)
使用 Explicit 模式轴调用 jax.lax.with_sharding_constraint 也是合法的;对于任何 Explicit 网格轴,它的作用相当于断言参数的分片与指定的分片匹配。
您可以使用 @explicit_axes 装饰器在局部将网格轴类型切换为 Explicit。
@explicit_axes
def explicit_g(y):
print(f'mesh inside g: {jax.sharding.get_abstract_mesh()}')
print(f'y.sharding inside g: {jax.typeof(y) = }')
z = y * 2
print(f'z.sharding inside g: {jax.typeof(z) = }', end='\n\n')
return z
@jax.jit
def f(arr1):
print(f'mesh inside f: {jax.sharding.get_abstract_mesh()}', end='\n\n')
x = jnp.sin(arr1)
z = explicit_g(x, in_sharding=jax.P("X", "Y"))
return z + 1
x = jax.device_put(np.arange(16).reshape(4, 4), jax.P("X", "Y"))
f(x)
mesh inside f: AbstractMesh('X': 4, 'Y': 2, axis_types=(Auto, Auto), device_kind=cpu, num_cores=None)
mesh inside g: AbstractMesh('X': 4, 'Y': 2, axis_types=(Explicit, Explicit), device_kind=cpu, num_cores=None)
y.sharding inside g: jax.typeof(y) = ShapedArray(float32[4@X,4@Y])
z.sharding inside g: jax.typeof(z) = ShapedArray(float32[4@X,4@Y])
Array([[ 1. , 2.682942 , 2.818595 , 1.28224 ],
[-0.513605 , -0.9178486 , 0.44116902, 2.3139732 ],
[ 2.9787164 , 1.824237 , -0.08804226, -0.99998045],
[-0.07314587, 1.840334 , 2.9812148 , 2.3005757 ]], dtype=float32)
它是 auto_axes 的一种对偶,您可以指定 in_shardings 而不是 out_shardings。
具体数组分片可以提及 Auto 网格轴#
具体 jax.Array 的分片可以通过 x.sharding 查询。这只能在顶层完成。您可能预期结果与与值类型关联的分片 jax.typeof(x).sharding 相同。事实可能并非如此!具体数组分片 x.sharding 描述了沿 Explicit 和 Auto 网格轴的分片。这是编译器最终选择的分片。而类型指定的分片 jax.typeof(x).sharding 仅描述了沿 Explicit 网格轴的分片。Auto 轴被刻意从类型中隐藏,因为它们是编译器的权限范围。我们可以认为具体数组分片与类型指定的分片是一致的,但更具体。例如:
def compare_shardings(x):
print(f"=== with mesh: {jax.sharding.get_abstract_mesh()} ===")
print(f"Concrete value sharding: {x.sharding.spec}")
print(f"Type-specified sharding: {jax.typeof(x).sharding.spec}\n")
my_array = jnp.sin(jax.device_put(np.arange(8), jax.P("X")))
compare_shardings(my_array)
@auto_axes
def check_in_auto_context(x):
compare_shardings(x)
return x
check_in_auto_context(my_array, out_sharding=jax.P("X"))
=== with mesh: AbstractMesh('X': 4, 'Y': 2, axis_types=(Auto, Auto), device_kind=cpu, num_cores=None) ===
Concrete value sharding: P('X',)
Type-specified sharding: P(None,)
=== with mesh: AbstractMesh('X': 4, 'Y': 2, axis_types=(Auto, Auto), device_kind=cpu, num_cores=None) ===
Concrete value sharding: P('X',)
Type-specified sharding: P(None,)
Array([ 0. , 0.84147096, 0.9092974 , 0.14112 , -0.7568025 ,
-0.9589243 , -0.2794155 , 0.6569866 ], dtype=float32)
请注意,在顶层,当我们处于完全 Explicit 的网格上下文中时,具体数组分片和类型指定分片是一致的。
但在 auto_axes 装饰器下,我们处于完全 Auto 的网格上下文中,两者的分片不一致:类型指定的分片是 P(None),而具体数组分片是 P("X")(尽管它可以是任何值!这取决于编译器)。
手动模式允许您编写显式的集合操作,并提供每个设备的数据视图#
使用 jax.shard_map 将网格轴类型设置为 Manual。
mesh = jax.make_mesh((4, 2), ('X', 'Y'))
jax.set_mesh(mesh)
x = jax.device_put(jnp.arange(8 * 4. ).reshape(8, 4 ), jax.P(None, 'X'))
y = jax.device_put(jnp.arange(4 * 16.).reshape(4, 16), jax.P('X', None))
@jax.shard_map(out_specs=jax.P('X', None))
def matmul(x_shard, y_shard):
z_summand = jnp.dot(x_shard, y_shard)
return jax.lax.psum_scatter(z_summand, 'X', tiled=True)
z = matmul(x, y)
print(jax.typeof(z))
z_ref = jnp.dot(x, y, out_sharding=jax.P('X', None))
print(jnp.allclose(z_ref, z))
float32[8@X,16]
True
有关详细信息,请参阅 shard_map 教程。