显式分片(又名“类型中的分片”)#
JAX 传统的自动分片将分片决策留给编译器。您可以使用 jax.lax.with_sharding_constraint
向编译器提供提示,但大多数情况下,您应该专注于数学计算,而编译器则负责分片。
但是,如果您对如何分片程序有强烈意见呢?通过足够多次调用 with_sharding_constraint
,您大概可以引导编译器按照您的意愿执行。但“编译器调优”是一个众所周知的令人不快的编程模型。应该在哪里放置分片约束?您可以将它们放在每个中间结果上,但这工作量很大,而且这样做很容易出错,因为无法检查分片是否协同工作。更常见的情况是,人们添加足够多的分片注释来约束编译器。但这是一个缓慢的迭代过程。很难提前知道 XLA 的 GSPMD 传递会做什么(这是一个全程序优化),所以您所能做的就是添加注释,检查 XLA 的分片选择,看看发生了什么,然后重复。
为了解决这个问题,我们提出了一种不同的分片编程风格,我们称之为“显式分片”或“类型中的分片”。其思想是,分片传播发生在 JAX 级别的跟踪时。每个 JAX 运算符都有一个分片规则,该规则接受运算符参数的分片并生成结果的分片。对于大多数运算符,这些规则简单明了,因为只有一种合理的选择。但对于某些运算符,如何分片结果尚不清楚。在这种情况下,我们要求程序员显式提供 out_sharding
参数,否则我们会抛出(跟踪时)错误。由于分片在跟踪时传播,因此它们也可以在跟踪时被*查询*。在本文档的其余部分,我们将描述如何使用显式分片模式。请注意,这是一个新功能,因此我们预计会有错误和未实现的情况。当您发现不起作用的东西时,请告知我们!另请参阅 The Training Cookbook,其中包含一个使用显式分片的真实机器学习训练示例。
import jax
import numpy as np
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P, AxisType, get_abstract_mesh, reshard
jax.config.update('jax_num_cpu_devices', 8)
设置显式网格#
显式分片的(又名类型中的分片)主要思想是,JAX 级别的*类型*包含一个关于值如何分片的描述。我们可以使用 jax.typeof
查询任何 JAX 值(或 Numpy 数组、Python 标量)的 JAX 级别类型。
some_array = np.arange(8)
print(f"JAX-level type of some_array: {jax.typeof(some_array)}")
JAX-level type of some_array: int32[8]
重要的是,我们可以在 jit
跟踪时查询类型(JAX 级别的类型几乎*定义*为“在 jit 中可访问的值的信息”)。
@jax.jit
def foo(x):
print(f"JAX-level type of x during tracing: {jax.typeof(x)}")
return x + x
foo(some_array)
JAX-level type of x during tracing: int32[8]
Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)
这些类型显示了数组的形状和 dtype,但它们似乎没有显示分片。(实际上,它们*确实*显示了分片,但分片是微不足道的。参见下面的“具体数组分片”)。要开始看到一些有趣的分片,我们需要设置一个显式分片网格。
jax.set_mesh
可以用作全局设置器或上下文管理器。在此笔记本中,我们将 jax.set_mesh
用作全局设置器。您可以通过 with jax.set_mesh(mesh)
将其用作作用域上下文管理器。
mesh = jax.make_mesh((2, 4), ("X", "Y"),
axis_types=(AxisType.Explicit, AxisType.Explicit))
jax.set_mesh(mesh)
print(f"Current mesh is: {get_abstract_mesh()}")
Current mesh is: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit), device_kind=cpu, num_cores=None)
现在我们可以使用 reshard
创建一些分片数组
replicated_array = np.arange(8).reshape(4, 2)
sharded_array = reshard(replicated_array, P("X", None))
print(f"replicated_array type: {jax.typeof(replicated_array)}")
print(f"sharded_array type: {jax.typeof(sharded_array)}")
replicated_array type: int32[4,2]
sharded_array type: int32[4@X,2]
我们应该将类型 f32[4@X, 2]
理解为“一个 4x2 的 32 位浮点数数组,其第一个维度沿网格轴‘X’分片。数组沿所有其他网格轴复制”。
这些与 JAX 级别类型相关联的分片会通过运算符传播。例如
arg0 = reshard(np.arange(4).reshape(4, 1), P("X", None))
arg1 = reshard(np.arange(8).reshape(1, 8), P(None, "Y"))
result = arg0 + arg1
print(f"arg0 sharding: {jax.typeof(arg0)}")
print(f"arg1 sharding: {jax.typeof(arg1)}")
print(f"result sharding: {jax.typeof(result)}")
arg0 sharding: int32[4@X,1]
arg1 sharding: int32[1,8@Y]
result sharding: int32[4@X,8@Y]
我们可以在 jit 下执行相同的类型查询
@jax.jit
def add_arrays(x, y):
ans = x + y
print(f"x sharding: {jax.typeof(x)}")
print(f"y sharding: {jax.typeof(y)}")
print(f"ans sharding: {jax.typeof(ans)}")
return ans
add_arrays(arg0, arg1)
x sharding: int32[4@X,1]
y sharding: int32[1,8@Y]
ans sharding: int32[4@X,8@Y]
Array([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 1, 2, 3, 4, 5, 6, 7, 8],
[ 2, 3, 4, 5, 6, 7, 8, 9],
[ 3, 4, 5, 6, 7, 8, 9, 10]], dtype=int32)
这就是要点。分片在跟踪时确定性地传播,我们可以在跟踪时查询它们。
JAX 转换和高阶函数#
JAX 程序的阶段输出表示是显式类型的。(我们称这些类型为“avals”,但这并不重要。)在显式分片模式下,分片是该类型的一部分。这意味着分片需要在类型匹配的地方匹配。例如,lax.cond
的两侧需要具有匹配分片的结果。并且 lax.scan
的 carry 在扫描体的输入和输出时需要具有相同分片。并且当您使用 make_jaxpr
在没有具体参数的情况下构造 jaxpr 时,您还需要提供分片。某些 JAX 转换执行类型级别的操作。自动微分会为原始计算中的每个原始类型构造一个切线类型(例如,TangentOf(float) == float
,TangentOf(int) == float0
)。由于类型中存在分片,这意味着切线值与其原始值以相同的方式分片。Vmap 和 scan 也执行类型级别的操作,它们将数组形状提升到该形状的秩增强版本。那个额外的数组轴需要一个分片。我们可以从 vmap/scan 的参数中推断出来,但它们都需要一致。并且空 vmap/scan 需要一个显式分片参数,就像它需要一个显式长度参数一样。
使用 auto_axes
解决未实现的分片规则问题#
显式分片的实现仍在进行中,并且有很多运算符缺少分片规则。例如,scatter
和 gather
(即索引运算符)。
通常我们不会建议使用有如此多未实现情况的功能,但在这种情况下,有一个合理的后备方案可以使用:auto_axes
。其思想是,您可以暂时进入一个网格轴为“auto”而不是“explicit”的上下文。您显式指定 auto_axes
的最终结果在返回到调用上下文时应该如何分片。
这适用于缺少分片规则的运算符的后备方案。当您想覆盖类型中的分片类型系统时,它也适用。例如,假设我们想将 f32[4@X, 4]
添加到 f32[4, 4@X]
。我们添加的规则将抛出错误:结果需要是 f32[4@X, 4@X]
,这试图重复使用网格轴,这是非法的。但假设您仍想执行该操作,并且您希望结果沿第一个轴分片,例如 f32[4@X, 4]
。您可以按如下方式执行此操作:
from jax.sharding import auto_axes, explicit_axes
some_x = reshard(np.arange(16).reshape(4, 4), P("X", None))
some_y = reshard(np.arange(16).reshape(4, 4), P(None, "X"))
try:
some_x + some_y
except Exception as e:
print("ERROR!")
print(e)
print("=== try again with auto_axes ===")
@auto_axes
def add_with_out_sharding_kwarg(x, y):
print(f"We're in auto-sharding mode here. This is the current mesh: {get_abstract_mesh()}")
return x + y
result = add_with_out_sharding_kwarg(some_x, some_y, out_sharding=P("X", None))
print(f"Result type: {jax.typeof(result)}")
ERROR!
add operation with inputs: i32[4@X,4], i32[4,4@X] produces an illegally sharded result: i32[4@X,4@X]
=== try again with auto_axes ===
We're in auto-sharding mode here. This is the current mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto), device_kind=cpu, num_cores=None)
Result type: int32[4@X,4]
使用混合分片模式#
JAX 现在有三种并行风格
自动分片 是您将所有设备视为单个逻辑机器,并为该机器编写“全局视图”数组程序。编译器决定如何在可用设备之间划分数据和计算。您可以使用
with_sharding_constraint
向编译器提供提示。显式分片(*新*)与自动分片类似,因为您正在编写全局视图程序。不同之处在于,每个数组的分片是数组 JAX 级别类型的一部分,使其成为编程模型的显式部分。这些分片在 JAX 级别传播,并在跟踪时可查询。编译器仍负责将全数组程序转换为每个设备程序(例如,将
jnp.sum
转换为psum
),但编译器受到用户提供的分片的大大约束。手动分片(
shard_map
)是您从单个设备的角度编写程序。设备之间的通信通过显式的集体操作(如 psum)进行。
总结表
模式 |
视图? |
显式分片? |
显式集体操作? |
---|---|---|---|
自动 |
全局 |
❌ |
❌ |
显式 |
全局 |
✅ |
❌ |
手动 |
每个设备 |
✅ |
✅ |
当前的网格告诉我们我们处于哪种分片模式。我们可以使用 get_abstract_mesh
查询它。
print(f"Current mesh is: {get_abstract_mesh()}")
Current mesh is: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit), device_kind=cpu, num_cores=None)
由于 axis_types=(Explicit, Explicit)
,这意味着我们处于完全显式模式。请注意,分片模式与网格*轴*相关联,而不是与整个网格相关联。我们实际上可以通过为每个网格轴设置不同的分片模式来混合分片模式。分片(在 JAX 级别类型上)只能提及*显式*网格轴,而集体操作(如 psum
)只能提及*手动*网格轴。
您可以使用 auto_axes
API 在某些网格轴上为 Auto
,而在其他轴上为 Explicit
。例如
import functools
@functools.partial(auto_axes, axes='X')
def g(y):
print(f'mesh inside g: {get_abstract_mesh()}')
print(f'y.sharding inside g: {jax.typeof(y) = }', end='\n\n')
return y * 2
@jax.jit
def f(arr1):
print(f'mesh inside f: {get_abstract_mesh()}')
x = jnp.sin(arr1)
print(f'x.sharding: {jax.typeof(x)}', end='\n\n')
z = g(x, out_sharding=P("X", "Y"))
print(f'z.sharding: {jax.typeof(z)}', end="\n\n")
return z + 1
some_x = reshard(np.arange(16).reshape(4, 4), P("X", "Y"))
f(some_x)
mesh inside f: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit), device_kind=cpu, num_cores=None)
x.sharding: float32[4@X,4@Y]
mesh inside g: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Explicit), device_kind=cpu, num_cores=None)
y.sharding inside g: jax.typeof(y) = ShapedArray(float32[4,4@Y])
z.sharding: float32[4@X,4@Y]
Array([[ 1. , 2.682942 , 2.818595 , 1.28224 ],
[-0.513605 , -0.9178486 , 0.44116902, 2.3139732 ],
[ 2.9787164 , 1.824237 , -0.08804226, -0.99998045],
[-0.07314587, 1.840334 , 2.9812148 , 2.3005757 ]], dtype=float32)
正如您所见,在 g
内部,arr1
的类型是 ShapedArray(float32[4,4@Y])
,这表示它在 Y
网格轴上是显式的,而在 X
上是自动的。
您还可以使用 explicit_axes
API 进入某些或所有网格轴的 Explicit
模式。
auto_mesh = jax.make_mesh((2, 4), ("X", "Y"),
axis_types=(AxisType.Auto, AxisType.Auto))
@functools.partial(explicit_axes, axes=('X', 'Y'))
def explicit_g(y):
print(f'mesh inside g: {get_abstract_mesh()}')
print(f'y.sharding inside g: {jax.typeof(y) = }')
z = y * 2
print(f'z.sharding inside g: {jax.typeof(z) = }', end='\n\n')
return z
@jax.jit
def f(arr1):
print(f'mesh inside f: {get_abstract_mesh()}', end='\n\n')
x = jnp.sin(arr1)
z = explicit_g(x, in_sharding=P("X", "Y"))
return z + 1
with jax.set_mesh(auto_mesh):
some_x = jax.device_put(np.arange(16).reshape(4, 4), P("X", "Y"))
f(some_x)
mesh inside f: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto), device_kind=cpu, num_cores=None)
mesh inside g: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit), device_kind=cpu, num_cores=None)
y.sharding inside g: jax.typeof(y) = ShapedArray(float32[4@X,4@Y])
z.sharding inside g: jax.typeof(z) = ShapedArray(float32[4@X,4@Y])
正如您所见,f
内部的所有网格轴都是 Auto
类型,而在 g
内部,它们是 Explicit
类型。因此,分片在 g
内部数组的类型上可见。
具体数组分片可以提及 Auto
网格轴#
您可以使用 x.sharding
查询具体数组 x
的分片。您可能期望结果与值类型 jax.typeof(x).sharding
相关联的分片相同。可能不是!具体数组分片 x.sharding
描述了 Explicit
和 Auto
网格轴上的分片。这是编译器最终选择的分片。而类型指定的分片 jax.typeof(x).sharding
仅描述*显式*网格轴上的分片。Auto
轴被故意隐藏在类型之外,因为它们属于编译器。我们可以认为具体数组分片与类型指定的分片一致,但更具体。例如
def compare_shardings(x):
print(f"=== with mesh: {get_abstract_mesh()} ===")
print(f"Concrete value sharding: {x.sharding.spec}")
print(f"Type-specified sharding: {jax.typeof(x).sharding.spec}")
my_array = jnp.sin(reshard(np.arange(8), P("X")))
compare_shardings(my_array)
@auto_axes
def check_in_auto_context(x):
compare_shardings(x)
return x
check_in_auto_context(my_array, out_sharding=P("X"))
=== with mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit), device_kind=cpu, num_cores=None) ===
Concrete value sharding: PartitionSpec('X',)
Type-specified sharding: PartitionSpec('X',)
=== with mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto), device_kind=cpu, num_cores=None) ===
Concrete value sharding: PartitionSpec('X',)
Type-specified sharding: PartitionSpec(None,)
Array([ 0. , 0.84147096, 0.9092974 , 0.14112 , -0.7568025 ,
-0.9589243 , -0.2794155 , 0.6569866 ], dtype=float32)
请注意,在顶层,当我们处于完全 Explicit
网格上下文时,具体数组分片和类型指定分片是一致的。但在 auto_axes
装饰器下,我们处于完全 Auto
网格上下文,并且这两个分片不一致:类型指定分片是 P(None)
,而具体数组分片是 P("X")
(尽管它可能是任何东西!这取决于编译器)。