Open inColab

分布式数组与自动并行化#

JAX 拥有三种多设备分布式并行处理风格,它们可以混合和组合使用。它们的区别在于编译器自动决策的程度,以及程序中显式控制的程度。

  • 基于编译器的自动分片:在这种模式下,您可以像使用单台“全局视图”机器一样进行编程,编译器会选择如何分片数据(通过 with_sharding_constraint 提供一些用户约束),以及如何通过集合通信将计算划分为针对每个设备的程序。

  • 显式分片与自动分区:在这种模式下,您仍然拥有全局视图,但数据分片在 JAX 类型中是显式的,可以使用 jax.typeof 进行检查。计算的分区工作仍然由编译器完成。

  • 手动逐设备编程:在这种模式下,您拥有针对每个设备的数据和计算视图,并编写显式的通信集合操作,如 jax.lax.psum

模式

视图?

显式分片?

显式集合操作?

自动 (Auto)

全局

显式

全局

手动 (Manual)

逐设备

在深入细节之前,这里有一个使用显式模式的简单示例。首先,我们创建一个跨多个设备分片的 jax.Array

from __future__ import annotations
import enum
import jax
import jax.numpy as jnp
jax.config.update('jax_num_cpu_devices', 8)
jax.set_mesh(jax.make_mesh((4, 2), ('X', 'Y')))  # explicit mode by default

x = jnp.arange(8 * 4.).reshape(8, 4)
x = jax.device_put(x, jax.P('X', 'Y'))
print(jax.typeof(x))  # f32[8@X, 4@Y]
float32[8@X,4@Y]
jax.debug.visualize_array_sharding(x)
                  
  CPU 0    CPU 1  
                  
                  
  CPU 2    CPU 3  
                  
                  
  CPU 4    CPU 5  
                  
                  
  CPU 6    CPU 7  
                  

接下来,我们将对它进行计算,并观察到结果值也存储在多个设备上

y = jnp.sin(x).T
print(jax.typeof(y))  # f32[4@Y, 8@X]
float32[4@Y,8@X]

jnp.sin 和转置计算被自动并行化分布在存储输入值(和输出值)的设备上。

为了理解这些模式以及如何在它们之间切换,我们首先需要理解网格(meshes)。

Mesh 是一个具有命名轴的设备网格#

为了描述数据和计算如何跨设备分布,我们首先将设备组织成一个称为 Mesh 的多维网格。由于通信沿着网格轴发生,网格形状和设备顺序会决定通信性能。网格应该反映设备之间的物理连接拓扑。

我们区分具体 (concrete) 网格和抽象 (abstract) 网格。抽象网格仅包含形状、轴名称以及反映每个轴模式的轴类型。

class AbstractMesh:
  axis_sizes: tuple[int, ...]
  axis_names: tuple[str, ...]
  axis_types: tuple[AxisType, ...]

class AxisType(enum.Enum):
  Auto = enum.auto()
  Explicit = enum.auto()
  Manual = enum.auto()

# A concrete mesh additionally includes physical device objects with e.g.
# precise coordinates:
import numpy as np

class Mesh:
  devices: np.ndarray[jax.Device]
  axis_names: tuple[str, ...]
  axis_types: tuple[AxisType, ...]

  @property
  def axis_sizes(self) -> tuple[int, ...]:
    return self.devices.shape

在程序的顶层(即不在 jit 内部),我们可以使用类构造函数直接创建具体的 Mesh,这允许我们指定确切的设备顺序;或者使用 jax.make_mesh 辅助函数,它会通过考虑底层硬件拓扑自动选择设备顺序。

mesh = jax.make_mesh((4, 2), ('X', 'Y'))
print(mesh)
Mesh('X': 4, 'Y': 2, axis_types=(Explicit, Explicit))

默认情况下,所有网格轴类型均为 AxisType.Explicit

为了避免在整个程序中传递 mesh,请使用 jax.set_mesh 全局设置一个具体网格。

jax.set_mesh(mesh)
<jax._src.sharding_impls.set_mesh at 0x7d4f1cca7040>

您也可以使用 with jax.set_mesh(mesh): ... 作为上下文管理器。仅在顶层,可以使用 jax.get_mesh() -> jax.sharding.Mesh 查询具体网格。

jit 内部,只能查询和更改抽象网格。使用 jax.sharding.get_abstract_mesh() -> jax.sharding.AbstractMesh 查询当前的抽象网格,并使用 with jax.sharding.use_abstract_mesh(m: AbstractMesh): ... 在上下文中更改抽象网格。轴大小、轴名称和轴类型可以更改,但网格的总大小(即轴大小的乘积)不得更改。

我们尚未解释分片,但这里有一个在 jax.jit 内部更改抽象网格的玩具示例。

@jax.jit
def f(x):
  abstract_mesh = jax.sharding.AbstractMesh((8,), ('A',), (jax.sharding.AxisType.Explicit,))
  with jax.sharding.use_abstract_mesh(abstract_mesh):
    y = jax.reshard(x, jax.P('A', None))
    return y * 2

z = f(x)
print(jax.typeof(z))  # f32[8@A, 4]
float32[8@A,4]

Sharding 描述了数组值如何在 Mesh 上分布#

jax.sharding.Sharding 描述了分布式内存布局。也就是说,它描述了数组的条目如何存储在不同设备的物理内存中,即它如何跨设备分片

在顶层,每个 jax.Array 都有一个关联的 Sharding,它由一个具体的 Mesh 和一个 jax.sharding.PartitionSpec(别名为 jax.P)组成。

print(x.sharding)
jax.debug.visualize_array_sharding(x)
NamedSharding(mesh=Mesh('X': 4, 'Y': 2, axis_types=(Explicit, Explicit)), spec=P('X', 'Y'), memory_kind=device)
                  
  CPU 0    CPU 1  
                  
                  
  CPU 2    CPU 3  
                  
                  
  CPU 4    CPU 5  
                  
                  
  CPU 6    CPU 7  
                  

这里,PartitionSpec('X', 'Y') 表示数组 x 的第一轴和第二轴分别在网格轴 ‘X’ 和 ‘Y’ 上进行分片。我们可以使用 addressable_shards 查看这如何转换为物理存储。

for s in x.addressable_shards:
  print(s.device, s.data, sep='\n', end='\n\n')
cpu:0
[[0. 1.]
 [4. 5.]]

cpu:1
[[2. 3.]
 [6. 7.]]

cpu:2
[[ 8.  9.]
 [12. 13.]]

cpu:3
[[10. 11.]
 [14. 15.]]

cpu:4
[[16. 17.]
 [20. 21.]]

cpu:5
[[18. 19.]
 [22. 23.]]

cpu:6
[[24. 25.]
 [28. 29.]]

cpu:7
[[26. 27.]
 [30. 31.]]

我们可以使用 jax.device_put(或 jax.reshard)来生成一个新的数组,该数组在相同的设备网格上分片,但具有由 jax.P 指定的不同布局。(jax.device_put 是一个运行时级 API,功能比 jax.reshard 更丰富。)由于我们通过上面的 jax.set_mesh 设置了网格上下文,我们可以将 jax.P 实例直接传递给 jax.device_put

y = jax.device_put(x, jax.P('Y', 'X'))
print(y.sharding)
jax.debug.visualize_array_sharding(y)
NamedSharding(mesh=Mesh('X': 4, 'Y': 2, axis_types=(Explicit, Explicit)), spec=P('Y', 'X'), memory_kind=device)
                                    
                                    
  CPU 0    CPU 2    CPU 4    CPU 6  
                                    
                                    
                                    
                                    
                                    
  CPU 1    CPU 3    CPU 5    CPU 7  
                                    
                                    
                                    
y = jax.device_put(x, jax.P('X', None))
print(y.sharding)
jax.debug.visualize_array_sharding(y)
NamedSharding(mesh=Mesh('X': 4, 'Y': 2, axis_types=(Explicit, Explicit)), spec=P('X', None), memory_kind=device)
            
  CPU 0,1   
            
            
  CPU 2,3   
            
            
  CPU 4,5   
            
            
  CPU 6,7   
            

这里,因为网格轴名称 ‘Y’ 没有在 jax.P('X', None) 中提及,所以数组在网格轴 ‘Y’ 上进行了复制。(作为简写,末尾的 None 占位符可以省略,因此这里的 P(‘X’, None) 与 P(‘X’) 含义相同。但显式写出也无妨!)

for s in y.addressable_shards:
  print(s.device, s.data, sep='\n', end='\n\n')
cpu:0
[[0. 1. 2. 3.]
 [4. 5. 6. 7.]]

cpu:1
[[0. 1. 2. 3.]
 [4. 5. 6. 7.]]

cpu:2
[[ 8.  9. 10. 11.]
 [12. 13. 14. 15.]]

cpu:3
[[ 8.  9. 10. 11.]
 [12. 13. 14. 15.]]

cpu:4
[[16. 17. 18. 19.]
 [20. 21. 22. 23.]]

cpu:5
[[16. 17. 18. 19.]
 [20. 21. 22. 23.]]

cpu:6
[[24. 25. 26. 27.]
 [28. 29. 30. 31.]]

cpu:7
[[24. 25. 26. 27.]
 [28. 29. 30. 31.]]

通过在 PartitionSpec 中使用轴名称元组,我们可以将一个数组轴在多个网格轴上进行分片。

y = jax.device_put(x, jax.P(('X', 'Y')))
print(y.sharding)
jax.debug.visualize_array_sharding(y)
NamedSharding(mesh=Mesh('X': 4, 'Y': 2, axis_types=(Explicit, Explicit)), spec=P(('X', 'Y'),), memory_kind=device)
   CPU 0    
            
   CPU 1    
            
   CPU 2    
            
   CPU 3    
            
   CPU 4    
            
   CPU 5    
            
   CPU 6    
            
   CPU 7    
            

因此,数组数据可以在网格轴上进行复制,或者数组的一个轴可以在该网格轴上进行分片,但还有另一种可能性:数组可以在网格轴上保持未归约 (unreduced) 状态。

y = jax.device_put(x, jax.P('X', None, unreduced={'Y'}))
print(y.sharding)
NamedSharding(mesh=Mesh('X': 4, 'Y': 2, axis_types=(Explicit, Explicit)), spec=P('X', None, unreduced={'Y'}), memory_kind=device)

未归约意味着逻辑值等于物理分片值在该轴上的分布式求和。

for s in y.addressable_shards:
  print(s.device, s.data, sep='\n', end='\n\n')
cpu:0
[[0. 1. 0. 0.]
 [4. 5. 0. 0.]]

cpu:1
[[0. 0. 2. 3.]
 [0. 0. 6. 7.]]

cpu:2
[[ 8.  9.  0.  0.]
 [12. 13.  0.  0.]]

cpu:3
[[ 0.  0. 10. 11.]
 [ 0.  0. 14. 15.]]

cpu:4
[[16. 17.  0.  0.]
 [20. 21.  0.  0.]]

cpu:5
[[ 0.  0. 18. 19.]
 [ 0.  0. 22. 23.]]

cpu:6
[[24. 25.  0.  0.]
 [28. 29.  0.  0.]]

cpu:7
[[ 0.  0. 26. 27.]
 [ 0.  0. 30. 31.]]

未归约对于延迟分布式归约(reduction)很有用,特别是在自动微分的上下文中。稍后会详细介绍。

请注意,由于每个数组都有自己的 Sharding 实例,并且每个 Sharding 实例都有自己的 Mesh 实例,作用域内的数组可以关联到不同的网格。为了说明这一点,我们可以使用带有完整 jax.NamedSharding 实例参数的 jax.device_put,而不是使用上下文中的网格。

mesh2 = jax.make_mesh((8,), ('A',))
z = jax.device_put(x, jax.NamedSharding(mesh2, jax.P('A', None)))
print(z.sharding)
print(y.sharding)
NamedSharding(mesh=Mesh('A': 8, axis_types=(Explicit,)), spec=P('A', None), memory_kind=device)
NamedSharding(mesh=Mesh('X': 4, 'Y': 2, axis_types=(Explicit, Explicit)), spec=P('X', None, unreduced={'Y'}), memory_kind=device)

现在我们了解了顶层的网格形状、轴名称和分片,接下来我们可以深入探讨网格轴类型,以及显式模式和自动模式的区别。

显式分片模式使分片在追踪期间可查询#

在显式分片模式下,分片始终可以通过 jax.typeof 查询,即使是在 jax.jit 内部也是如此。

print(jax.typeof(x).sharding)
NamedSharding(mesh=AbstractMesh('X': 4, 'Y': 2, axis_types=(Explicit, Explicit), device_kind=cpu, num_cores=None), spec=P('X', 'Y'))
jax.jit(lambda x: print(jax.typeof(x).sharding))(x)
NamedSharding(mesh=AbstractMesh('X': 4, 'Y': 2, axis_types=(Explicit, Explicit), device_kind=cpu, num_cores=None), spec=P('X', 'Y'))

我们也称这种模式为“类型中的分片”。

就打印表示而言,类型语言大致如下:

 <array_type> ::= <dtype>[<size_and_sharding>, ...]
 <size_and_sharding> ::= <size> | <size>@<MeshAxisName>

其中:

  • 作用域内的 MeshAxisName 是来自 jax.typeof(x).sharding.mesh 的那些名称。

  • 每个 MeshAxisName 必须是 Explicit 轴类型。

  • 每个 MeshAxisName 在数组类型中最多只能提及一次。

这些与 JAX 级类型关联的分片会通过操作进行传播。例如:

arg0 = jax.device_put(np.arange(4).reshape(4, 1), jax.P("X", None))
arg1 = jax.device_put(np.arange(8).reshape(1, 8), jax.P(None, "Y"))

result = arg0 + arg1

print(f"{jax.typeof(arg0)=!s}")
print(f"{jax.typeof(arg1)=!s}")
print(f"{jax.typeof(result)=!s}")
jax.typeof(arg0)=int32[4@X,1]
jax.typeof(arg1)=int32[1,8@Y]
jax.typeof(result)=int32[4@X,8@Y]

我们可以在 jit 内部进行相同的类型查询。

@jax.jit
def add_arrays(x, y):
  ans = x + y
  print(f"{jax.typeof(arg0)=!s}")
  print(f"{jax.typeof(arg1)=!s}")
  print(f"{jax.typeof(result)=!s}")
  return ans

add_arrays(arg0, arg1)
jax.typeof(arg0)=int32[4@X,1]
jax.typeof(arg1)=int32[1,8@Y]
jax.typeof(result)=int32[4@X,8@Y]
Array([[ 0,  1,  2,  3,  4,  5,  6,  7],
       [ 1,  2,  3,  4,  5,  6,  7,  8],
       [ 2,  3,  4,  5,  6,  7,  8,  9],
       [ 3,  4,  5,  6,  7,  8,  9, 10]], dtype=int32)

给定输入和输出分片,计算本身会自动在设备上进行分区。编译器会根据需要插入通信操作。例如:

x = jax.random.normal(jax.random.key(0), (8, 4),
                      out_sharding=jax.P('X', 'Y'))
print(jax.typeof(x))
float32[8@X,4@Y]
y = x.sum(0)
print(jax.typeof(y))
float32[4@Y]

在这里,当对计算进行分区时,编译器会自动插入通信集合操作以执行归约。

compile_txt = jax.jit(lambda x: x.sum(0)).lower(x).compile().as_text()
print('all-reduce(' in compile_txt)
True

结果分片遵循简单规则,否则会报错并要求进行标注#

每个原始操作都有一个分片传播规则,用于根据输入分片确定结果的分片。如果没有明显输出分片,则会引发错误。其目标是让重要的并行化决策显现出来,而不是隐藏它们,以免您意外错过。换句话说,分片传播规则倾向于报错并要求标注,而不是回退到随意选择的默认值。

每个操作都能实现自己的分片传播规则,但通常的模式是:

  1. 对于每个输出数组轴,将其与零个或多个对应的输入数组轴进行标识。

  2. 如果所有这些输入轴的分片方式相同,则以相同方式分片输出轴;否则,报错(并要求显式的 out_sharding 参数)。

  3. 在确定所有输出数组轴后,如果输出数组分片多次提及同一个网格轴,则报错(并要求显式的 out_sharding)。

以下是一些示例规则:

  • 空元操作,如 jnp.zerosjnp.arange:这些操作凭空创建数组,因此没有可传播的输入分片。除非被 out_sharding 关键字参数覆盖,否则其输出默认是不分片的。

  • 一元逐元素操作,如 sinexp:输出的分片方式与输入相同。

  • 二元操作(+-* 等):被“压缩”维度的轴分片必须匹配(或为 None)。“外积”维度(仅出现在一个参数中的维度)的分片方式与其在输入中相同。如果结果最终提及同一个网格轴多次,则为错误。

收缩操作(如 jnp.dotjnp.einsum)也有一些有趣的情况。例如,jnp.dot(x: f32[8,4@X], y:f32[4@X,16]) 的结果(其中共享的收缩轴分片方式相同)合理的结果可能是:

  • f32[8,16](执行 all-reduce)

  • f32[8@X,16](在第一轴上进行 reduce-scatter)

  • f32[8,16@X](在第二轴上进行 reduce-scatter)

  • f32[8,16]{U:X}(无通信)

JAX 在这种情况下不会自动选择一个,而是报错并要求提供 out_sharding,例如 jnp.dot(x, y, out_sharding=jax.P('X', None))

x = jax.device_put(jnp.arange(8 * 4.).reshape(8, 4), jax.P(None, 'X'))
y = jax.device_put(jnp.arange(4 * 16.).reshape(4, 16), jax.P('X', None))

try:
  jnp.dot(x, y)
except Exception as e:
  print("ERROR!")
  print(e)
ERROR!
Contracting dimensions are sharded and it is ambiguous how the output should be sharded. Please specify the output sharding via the `out_sharding` parameter. Got lhs_contracting_spec=('X',) and rhs_contracting_spec=('X',)
z = jnp.dot(x, y, out_sharding=jax.P('X', None))

print(jax.typeof(z))
float32[8@X,16]

但也存在其他导致通信的 jnp.dot 情况,JAX 会自动执行这些通信,例如 jnp.dot(x:f32[8,4], y:f32[4@x,16]) 会导致 f32[8,16],这很可能是通过在 y 上执行 all-gather(类似于 FSDP)来实现的。

使用 @auto_axes,编译器会在被装饰的函数内选择分片#

如果您不想指定某些中间变量的分片,而是希望编译器自动选择,可以使用 @auto_axes 装饰器。在此装饰器下,分片无法使用 jax.typeof 进行查询。更具体地说,auto_axes 将部分或全部网格轴类型切换为 Auto,而 Auto 网格轴不能出现在数组类型中。

@auto_axes 装饰函数会向函数的签名中添加一个 out_sharding 参数,因此最终的输出分片可以由调用者设置。或者,使用 @auto_axes(out_sharding=...) 在函数定义处指定最终输出分片。

例如,当我们的网格轴为 Explicit 时,我们无法将两个分片方式不同的数组相加。

from jax.sharding import auto_axes, explicit_axes

x = jax.device_put(np.arange(16).reshape(4, 4), jax.P("X", None))
y = jax.device_put(np.arange(16).reshape(4, 4), jax.P(None, "X"))

try:
  x + y
except Exception as e:
  print("ERROR!")
  print(e)
ERROR!
add operation with inputs: i32[4@X,4], i32[4,4@X] produces an illegally sharded result: i32[4@X,4@X]

如果我们只想指定结果的分片,并让编译器处理其余部分,我们可以使用 auto_axes

@auto_axes
def add2(x, y):
  print("We're in auto-sharding mode here. This is the current mesh:\n"
        f"{jax.sharding.get_abstract_mesh()}")
  return x + y

result = add2(x, y, out_sharding=jax.P("X", None))
print(f"Result type: {jax.typeof(result)}")
We're in auto-sharding mode here. This is the current mesh:
AbstractMesh('X': 4, 'Y': 2, axis_types=(Auto, Auto), device_kind=cpu, num_cores=None)
Result type: int32[4@X,4]

因此 auto_axes 允许您向任何操作组合添加 out_sharding 参数。

当上下文网格的轴类型为 ExplicitAuto 时,可以调用 auto_axes 装饰的函数,但不能处于 Manual 模式。默认情况下,它将所有网格轴类型切换为 Auto;使用 axes=... 可仅切换子集。

自动分片模式在编译期间自动决定分片#

虽然 auto_axes 装饰器对于将网格轴类型从 Explicit 临时切换到 Auto 很有用,但您也可以在顶层构造一个具有 Auto 轴类型的 Mesh

Auto = jax.sharding.AxisType.Auto
auto_mesh = jax.make_mesh((4, 2), ('X', 'Y'), (Auto, Auto))
jax.set_mesh(auto_mesh)

x = jax.device_put(jnp.arange(8 * 4. ).reshape(8, 4 ), jax.P(None, 'X'))
y = jax.device_put(jnp.arange(4 * 16.).reshape(4, 16), jax.P('X', None))

z = jnp.dot(x, y)  # not an error!

编译器没有报错,而是自动决定了结果的分片!

print(z.sharding)  # works at the top-level only (i.e. outside `jit`)
NamedSharding(mesh=Mesh('X': 4, 'Y': 2, axis_types=(Auto, Auto)), spec=P(), memory_kind=device)

无论使用顶层 Auto 网格轴,还是使用 auto_axes 装饰器,您都可以使用 jax.lax.with_sharding_constraint 为编译器提供关于中间变量应如何分片的提示。

@jax.jit
def f(x, y):
  z = jnp.dot(x, y)
  z = jax.lax.with_sharding_constraint(z, jax.P('X', None))
  return z

z = f(x, y)
print(z.sharding)
NamedSharding(mesh=Mesh('X': 4, 'Y': 2, axis_types=(Auto, Auto)), spec=P('X',), memory_kind=device)

使用 Explicit 模式轴调用 jax.lax.with_sharding_constraint 也是合法的;对于任何 Explicit 网格轴,它的作用相当于断言参数的分片与指定的分片匹配。

您可以使用 @explicit_axes 装饰器在局部将网格轴类型切换为 Explicit

@explicit_axes
def explicit_g(y):
  print(f'mesh inside g: {jax.sharding.get_abstract_mesh()}')
  print(f'y.sharding inside g: {jax.typeof(y) = }')
  z = y * 2
  print(f'z.sharding inside g: {jax.typeof(z) = }', end='\n\n')
  return z

@jax.jit
def f(arr1):
  print(f'mesh inside f: {jax.sharding.get_abstract_mesh()}', end='\n\n')
  x = jnp.sin(arr1)
  z = explicit_g(x, in_sharding=jax.P("X", "Y"))
  return z + 1

x = jax.device_put(np.arange(16).reshape(4, 4), jax.P("X", "Y"))
f(x)
mesh inside f: AbstractMesh('X': 4, 'Y': 2, axis_types=(Auto, Auto), device_kind=cpu, num_cores=None)

mesh inside g: AbstractMesh('X': 4, 'Y': 2, axis_types=(Explicit, Explicit), device_kind=cpu, num_cores=None)
y.sharding inside g: jax.typeof(y) = ShapedArray(float32[4@X,4@Y])
z.sharding inside g: jax.typeof(z) = ShapedArray(float32[4@X,4@Y])
Array([[ 1.        ,  2.682942  ,  2.818595  ,  1.28224   ],
       [-0.513605  , -0.9178486 ,  0.44116902,  2.3139732 ],
       [ 2.9787164 ,  1.824237  , -0.08804226, -0.99998045],
       [-0.07314587,  1.840334  ,  2.9812148 ,  2.3005757 ]],      dtype=float32)

它是 auto_axes 的一种对偶,您可以指定 in_shardings 而不是 out_shardings

具体数组分片可以提及 Auto 网格轴#

具体 jax.Array 的分片可以通过 x.sharding 查询。这只能在顶层完成。您可能预期结果与与值类型关联的分片 jax.typeof(x).sharding 相同。事实可能并非如此!具体数组分片 x.sharding 描述了沿 ExplicitAuto 网格轴的分片。这是编译器最终选择的分片。而类型指定的分片 jax.typeof(x).sharding 仅描述了沿 Explicit 网格轴的分片。Auto 轴被刻意从类型中隐藏,因为它们是编译器的权限范围。我们可以认为具体数组分片与类型指定的分片是一致的,但更具体。例如:

def compare_shardings(x):
  print(f"=== with mesh: {jax.sharding.get_abstract_mesh()} ===")
  print(f"Concrete value sharding: {x.sharding.spec}")
  print(f"Type-specified sharding: {jax.typeof(x).sharding.spec}\n")

my_array = jnp.sin(jax.device_put(np.arange(8), jax.P("X")))
compare_shardings(my_array)

@auto_axes
def check_in_auto_context(x):
  compare_shardings(x)
  return x

check_in_auto_context(my_array, out_sharding=jax.P("X"))
=== with mesh: AbstractMesh('X': 4, 'Y': 2, axis_types=(Auto, Auto), device_kind=cpu, num_cores=None) ===
Concrete value sharding: P('X',)
Type-specified sharding: P(None,)

=== with mesh: AbstractMesh('X': 4, 'Y': 2, axis_types=(Auto, Auto), device_kind=cpu, num_cores=None) ===
Concrete value sharding: P('X',)
Type-specified sharding: P(None,)
Array([ 0.        ,  0.84147096,  0.9092974 ,  0.14112   , -0.7568025 ,
       -0.9589243 , -0.2794155 ,  0.6569866 ], dtype=float32)

请注意,在顶层,当我们处于完全 Explicit 的网格上下文中时,具体数组分片和类型指定分片是一致的。

但在 auto_axes 装饰器下,我们处于完全 Auto 的网格上下文中,两者的分片不一致:类型指定的分片是 P(None),而具体数组分片是 P("X")(尽管它可以是任何值!这取决于编译器)。

手动模式允许您编写显式的集合操作,并提供每个设备的数据视图#

使用 jax.shard_map 将网格轴类型设置为 Manual

mesh = jax.make_mesh((4, 2), ('X', 'Y'))
jax.set_mesh(mesh)

x = jax.device_put(jnp.arange(8 * 4. ).reshape(8, 4 ), jax.P(None, 'X'))
y = jax.device_put(jnp.arange(4 * 16.).reshape(4, 16), jax.P('X', None))

@jax.shard_map(out_specs=jax.P('X', None))
def matmul(x_shard, y_shard):
  z_summand = jnp.dot(x_shard, y_shard)
  return jax.lax.psum_scatter(z_summand, 'X', tiled=True)

z = matmul(x, y)
print(jax.typeof(z))

z_ref = jnp.dot(x, y, out_sharding=jax.P('X', None))
print(jnp.allclose(z_ref, z))
float32[8@X,16]
True

有关详细信息,请参阅 shard_map 教程