显式分片(也称“类型中的分片”)#
JAX 传统的自动分片将分片决策留给编译器。您可以使用 jax.lax.with_sharding_constraint
向编译器提供提示,但大多数情况下,您应该专注于数学计算,而由编译器来处理分片。
但是,如果您对程序如何分片有强烈的主见怎么办?通过足够多的 with_sharding_constraint
调用,您或许可以引导编译器执行您想要的操作。但“编译器挠痒痒”(compiler tickling)众所周知不是一种有趣的编程模型。您应该将分片约束放在哪里?您可以将它们放在每一个中间结果上,但这工作量很大,而且也很容易出错,因为无法检查这些分片是否协同工作。更常见的是,人们只添加足够的分片注解来约束编译器。但这是一个缓慢的迭代过程。很难提前知道 XLA 的 GSPMD pass 会做什么(它是一种全程序优化),所以您能做的就是添加注解,检查 XLA 的分片选择以查看发生了什么,然后重复此过程。
为了解决这个问题,我们提出了一种不同的分片编程风格,我们称之为“显式分片”或“类型中的分片”。其思想是分片传播在 JAX 层面于跟踪时(trace time)发生。每个 JAX 操作都有一条分片规则,它接收操作参数的分片并为操作结果生成一个分片。对于大多数操作,这些规则简单明了,因为只有一个合理的选择。但对于某些操作,如何对结果进行分片尚不清楚。在这种情况下,我们要求程序员显式地提供一个 out_sharding
参数,否则我们将抛出(跟踪时)错误。由于分片在跟踪时传播,因此它们也可以在跟踪时被查询。在本文档的其余部分,我们将介绍如何使用显式分片模式。请注意,这是一个新功能,因此我们预计会存在 bug 和未实现的情况。如果您发现任何不起作用的地方,请告诉我们!
import jax
import numpy as np
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P, AxisType, set_mesh, get_abstract_mesh
from jax.sharding import reshard, auto_axes, explicit_axes
jax.config.update('jax_num_cpu_devices', 8)
设置显式网格#
显式分片(也称类型中的分片)背后的主要思想是,值的 JAX 级别类型包含对该值如何分片的描述。我们可以使用 jax.typeof
查询任何 JAX 值(或 Numpy 数组,或 Python 标量)的 JAX 级别类型。
some_array = np.arange(8)
print(f"JAX-level type of some_array: {jax.typeof(some_array)}")
JAX-level type of some_array: int32[8]
重要的是,即使在 jit
下进行跟踪时,我们也可以查询类型(JAX 级别类型几乎被定义为“在 jit 下我们可以访问的值的信息”)。
@jax.jit
def foo(x):
print(f"JAX-level type of x during tracing: {jax.typeof(x)}")
return x + x
foo(some_array)
JAX-level type of x during tracing: int32[8]
Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)
这些类型显示了数组的形状和数据类型(dtype),但它们似乎没有显示分片。(实际上,它们确实显示了分片,但分片是微不足道的。请参阅下面的“具体数组分片”。)要开始看到一些有趣的分片,我们需要设置一个显式分片网格。我们使用 set_mesh
将其设置为本笔记本其余部分的当前网格。(如果您只想在特定范围内设置网格并在之后返回到之前的网格,则可以使用上下文管理器 jax.sharding.use_mesh
。)
mesh = jax.make_mesh((2, 4), ("X", "Y"),
axis_types=(AxisType.Explicit, AxisType.Explicit))
set_mesh(mesh)
print(f"Current mesh is: {get_abstract_mesh()}")
Current mesh is: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))
现在我们可以使用 reshard
创建一些分片数组。
replicated_array = np.arange(8).reshape(4, 2)
sharded_array = reshard(replicated_array, P("X", None))
print(f"replicated_array type: {jax.typeof(replicated_array)}")
print(f"sharded_array type: {jax.typeof(sharded_array)}")
replicated_array type: int32[4,2]
sharded_array type: int32[4@X,2]
我们应该将类型 f32[4@X, 2]
理解为“一个 4x2 的 32 位浮点数组,其第一维度沿网格轴‘X’分片。该数组沿所有其他网格轴复制。”
这些与 JAX 级别类型关联的分片会通过操作传播。例如:
arg0 = reshard(np.arange(4).reshape(4, 1), P("X", None))
arg1 = reshard(np.arange(8).reshape(1, 8), P(None, "Y"))
result = arg0 + arg1
print(f"arg0 sharding: {jax.typeof(arg0)}")
print(f"arg1 sharding: {jax.typeof(arg1)}")
print(f"result sharding: {jax.typeof(result)}")
arg0 sharding: int32[4@X,1]
arg1 sharding: int32[1,8@Y]
result sharding: int32[4@X,8@Y]
我们可以在 jit 下执行相同的类型查询。
@jax.jit
def add_arrays(x, y):
ans = x + y
print(f"x sharding: {jax.typeof(x)}")
print(f"y sharding: {jax.typeof(y)}")
print(f"ans sharding: {jax.typeof(ans)}")
return ans
add_arrays(arg0, arg1)
x sharding: int32[4@X,1]
y sharding: int32[1,8@Y]
ans sharding: int32[4@X,8@Y]
Array([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 1, 2, 3, 4, 5, 6, 7, 8],
[ 2, 3, 4, 5, 6, 7, 8, 9],
[ 3, 4, 5, 6, 7, 8, 9, 10]], dtype=int32)
这就是它的要点。分片在跟踪时确定性地传播,我们也可以在跟踪时查询它们。
JAX 转换和高阶函数#
JAX 程序的阶段化(staged-out)表示是显式类型化的。(我们称这些类型为“avals”,但这并不重要。)在显式分片模式下,分片是该类型的一部分。这意味着分片需要与类型匹配的地方一致。例如,lax.cond
的两端需要具有分片匹配的结果。并且 lax.scan
的 carry 在 scan 主体的输入和输出处需要具有相同的分片。当您使用 make_jaxpr
构造没有具体参数的 jaxpr 时,您也需要提供分片。某些 JAX 转换执行类型级别操作。自动微分为原始计算中的每个原始类型构造一个切线类型(例如,TangentOf(float) == float
,TangentOf(int) == float0
)。由于类型中包含分片,这意味着切线值与其原始值以相同的方式进行分片。Vmap 和 scan 也执行类型级别操作,它们将数组形状提升到该形状的秩增强版本。那个额外的数组轴需要一个分片。我们可以从 vmap/scan 的参数中推断出来,但它们都需要一致。并且零元 vmap/scan 需要一个显式分片参数,就像它需要一个显式长度参数一样。
使用 auto_axes
解决未实现的分片规则#
显式分片的实现仍在进行中,并且有许多操作缺少分片规则。例如,scatter
和 gather
(即索引操作)。
通常我们不建议使用有这么多未实现情况的功能,但在这种情况下,有一个合理的备用方案可以使用:auto_axes
。其思想是您可以暂时进入一个网格轴是“自动”而不是“显式”的上下文。您显式指定 auto_axes
的最终结果在返回到调用上下文时如何分片。
这可以作为未实现分片规则的操作的备用方案。当您想要覆盖“类型中的分片”类型系统时,它也适用。例如,假设我们想将 f32[4@X, 4]
添加到 f32[4, 4@X]
。我们的加法分片规则会抛出错误:结果需要是 f32[4@X, 4@X]
,这会尝试两次使用同一个网格轴,这是非法的。但是,假设您无论如何都想执行该操作,并且希望结果仅沿第一个轴分片,例如 f32[4@X, 4]
。您可以这样做:
some_x = reshard(np.arange(16).reshape(4, 4), P("X", None))
some_y = reshard(np.arange(16).reshape(4, 4), P(None, "X"))
try:
some_x + some_y
except Exception as e:
print("ERROR!")
print(e)
print("=== try again with auto_axes ===")
@auto_axes
def add_with_out_sharding_kwarg(x, y):
print(f"We're in auto-sharding mode here. This is the current mesh: {get_abstract_mesh()}")
return x + y
result = add_with_out_sharding_kwarg(some_x, some_y, out_sharding=P("X", None))
print(f"Result type: {jax.typeof(result)}")
ERROR!
add operation with inputs: i32[4@X,4], i32[4,4@X] produces an illegally sharded result: i32[4@X,4@X]
=== try again with auto_axes ===
We're in auto-sharding mode here. This is the current mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto))
Result type: int32[4@X,4]
混合使用分片模式#
JAX 现在有三种并行风格:
自动分片是指您将所有设备视为一台逻辑机器,并为该机器编写一个“全局视图”数组程序。编译器决定如何在可用设备之间划分数据和计算。您可以使用
with_sharding_constraint
向编译器提供提示。显式分片(*新*)与自动分片类似,都是编写全局视图程序。不同之处在于,每个数组的分片是该数组 JAX 级别类型的一部分,使其成为编程模型的显式部分。这些分片在 JAX 层面传播,并可在跟踪时查询。编译器仍然负责将整个数组程序转换为每个设备的程序(例如,将
jnp.sum
转换为psum
),但编译器受到用户提供的分片的严格约束。手动分片(
shard_map
)是指您从单个设备的视角编写程序。设备之间的通信通过显式集体操作(如 psum)进行。
摘要表
模式 |
视图? |
显式分片? |
显式集合操作? |
---|---|---|---|
自动 |
全局 |
❌ |
❌ |
显式 |
全局 |
✅ |
❌ |
手动 |
每设备 |
✅ |
✅ |
当前网格告诉我们处于哪种分片模式。我们可以使用 get_abstract_mesh
查询它。
print(f"Current mesh is: {get_abstract_mesh()}")
Current mesh is: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))
由于 axis_types=(Explicit, Explicit)
,这意味着我们处于完全显式模式。请注意,分片模式是与网格轴关联的,而不是整个网格。我们实际上可以通过为每个网格轴设置不同的分片模式来混合使用分片模式。分片(在 JAX 级别类型上)只能提及显式网格轴,而集体操作(如 psum
)只能提及手动网格轴。
您可以使用 auto_axes
API,在某些网格轴上采用 Auto
模式,而在其他网格轴上采用 Explicit
模式。例如:
import functools
@functools.partial(auto_axes, axes='X')
def g(y):
print(f'mesh inside g: {get_abstract_mesh()}')
print(f'y.sharding inside g: {jax.typeof(y) = }', end='\n\n')
return y * 2
@jax.jit
def f(arr1):
print(f'mesh inside f: {get_abstract_mesh()}')
x = jnp.sin(arr1)
print(f'x.sharding: {jax.typeof(x)}', end='\n\n')
z = g(x, out_sharding=P("X", "Y"))
print(f'z.sharding: {jax.typeof(z)}', end="\n\n")
return z + 1
some_x = reshard(np.arange(16).reshape(4, 4), P("X", "Y"))
f(some_x)
mesh inside f: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))
x.sharding: float32[4@X,4@Y]
mesh inside g: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Explicit))
y.sharding inside g: jax.typeof(y) = ShapedArray(float32[4,4@Y])
z.sharding: float32[4@X,4@Y]
Array([[ 1. , 2.682942 , 2.818595 , 1.28224 ],
[-0.513605 , -0.9178486 , 0.44116902, 2.3139732 ],
[ 2.9787164 , 1.824237 , -0.08804226, -0.99998045],
[-0.07314587, 1.840334 , 2.9812148 , 2.3005757 ]], dtype=float32)
如您所见,在 g
内部,arr1
的类型是 ShapedArray(float32[4,4@Y])
,这表明它在 Y
网格轴上是显式模式,而在 X
上是自动模式。
您还可以使用 explicit_axes
API,在部分或所有网格轴上进入 Explicit
模式。
auto_mesh = jax.make_mesh((2, 4), ("X", "Y"),
axis_types=(AxisType.Auto, AxisType.Auto))
@functools.partial(explicit_axes, axes=('X', 'Y'))
def explicit_g(y):
print(f'mesh inside g: {get_abstract_mesh()}')
print(f'y.sharding inside g: {jax.typeof(y) = }')
z = y * 2
print(f'z.sharding inside g: {jax.typeof(z) = }', end='\n\n')
return z
@jax.jit
def f(arr1):
print(f'mesh inside f: {get_abstract_mesh()}', end='\n\n')
x = jnp.sin(arr1)
z = explicit_g(x, in_sharding=P("X", "Y"))
return z + 1
with jax.sharding.use_mesh(auto_mesh):
some_x = jax.device_put(np.arange(16).reshape(4, 4), P("X", "Y"))
f(some_x)
mesh inside f: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto))
mesh inside g: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))
y.sharding inside g: jax.typeof(y) = ShapedArray(float32[4@X,4@Y])
z.sharding inside g: jax.typeof(z) = ShapedArray(float32[4@X,4@Y])
如您所见,在 f
内部,网格的所有轴都是 Auto
类型,而在 g
内部,它们是 Explicit
类型。因此,分片在 g
内部的数组类型上是可见的。
具体数组分片可以提及 Auto
网格轴#
您可以使用 x.sharding
查询具体数组 x
的分片。您可能期望结果与该值类型关联的分片 jax.typeof(x).sharding
相同。但可能不是!具体数组分片 x.sharding
描述了沿 Explicit
和 Auto
网格轴的分片。它是编译器最终选择的分片。而类型指定的分片 jax.typeof(x).sharding
仅描述沿 Explicit
网格轴的分片。Auto
轴被故意从类型中隐藏,因为它们是编译器的职权范围。我们可以认为具体数组分片与类型指定的分片一致,但更具体。例如:
def compare_shardings(x):
print(f"=== with mesh: {get_abstract_mesh()} ===")
print(f"Concrete value sharding: {x.sharding.spec}")
print(f"Type-specified sharding: {jax.typeof(x).sharding.spec}")
my_array = jnp.sin(reshard(np.arange(8), P("X")))
compare_shardings(my_array)
@auto_axes
def check_in_auto_context(x):
compare_shardings(x)
return x
check_in_auto_context(my_array, out_sharding=P("X"))
=== with mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit)) ===
Concrete value sharding: PartitionSpec('X',)
Type-specified sharding: PartitionSpec('X',)
=== with mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto)) ===
Concrete value sharding: PartitionSpec('X',)
Type-specified sharding: PartitionSpec(None,)
Array([ 0. , 0.84147096, 0.9092974 , 0.14112 , -0.7568025 ,
-0.9589243 , -0.2794155 , 0.6569866 ], dtype=float32)
请注意,在顶层(我们当前处于完全 Explicit
网格上下文中),具体数组分片和类型指定的分片是一致的。但在 auto_axes
装饰器下,我们处于完全 Auto
网格上下文中,两种分片不一致:类型指定的分片是 P(None)
,而具体数组分片是 P("X")
(尽管它可以是任何东西!这取决于编译器)。