显式分片(也称为“类型分片”)#

JAX 传统的自动分片将分片决策留给编译器。您可以使用 jax.lax.with_sharding_constraint 向编译器提供提示,但在大多数情况下,您应该专注于数学,而编译器负责分片。

但是,如果您对程序的切分方式有强烈的看法,该怎么办?通过足够多次调用 with_sharding_constraint,您或许可以引导编译器按照您的意愿执行。但是,“挠编译器痒痒”是出了名的不好玩的编程模型。您应该在哪里放置分片约束?您可以将它们放在每个中间步骤上,但这会花费大量精力,而且也很容易犯错,因为无法检查分片是否合理地结合在一起。更常见的情况是,人们只添加足够的分片注释来约束编译器。但这是一个缓慢的迭代过程。很难预先知道 XLA 的 GSPMD 传递会做什么(这是一个整体程序优化),所以您所能做的就是添加注释,检查 XLA 的分片选择以查看发生了什么,然后重复。

为了解决这个问题,我们提出了一种不同的分片编程风格,我们称之为“显式分片”或“类型分片”。其思想是,分片传播发生在 JAX 级别的跟踪时。每个 JAX 操作都有一个分片规则,该规则接受操作参数的分片,并为操作结果生成一个分片。对于大多数操作,这些规则都很简单明了,因为只有一个合理的选择。但对于某些操作,如何对结果进行分片尚不清楚。在这种情况下,我们要求程序员显式提供 out_sharding 参数,否则我们会抛出一个(跟踪时)错误。由于分片是在跟踪时传播的,因此也可以在跟踪时查询分片。在本文档的其余部分,我们将描述如何使用显式分片模式。请注意,这是一个新功能,因此我们预计会出现错误和未实现的情况。当您发现任何无法工作的情况时,请告诉我们!

import jax
import numpy as np
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P, AxisType, set_mesh, get_abstract_mesh
from jax.experimental.shard import reshard, auto_axes, explicit_axes

jax.config.update('jax_num_cpu_devices', 8)

设置显式网格#

显式分片(也称为类型分片)背后的主要思想是,JAX 级别的类型值包含值如何分片的描述。我们可以使用 jax.typeof 查询任何 JAX 值(或 Numpy 数组,或 Python 标量)的 JAX 级别类型

some_array = np.arange(8)
print(f"JAX-level type of some_array: {jax.typeof(some_array)}")
JAX-level type of some_array: ShapedArray(int32[8])

重要的是,即使在 jit 下跟踪时,我们也可以查询类型(JAX 级别类型几乎定义为“我们在 jit 下可以访问的关于值的信息)。

@jax.jit
def foo(x):
  print(f"JAX-level type of x during tracing: {jax.typeof(x)}")
  return x + x

foo(some_array)
JAX-level type of x during tracing: ShapedArray(int32[8])
Array([ 0,  2,  4,  6,  8, 10, 12, 14], dtype=int32)

这些类型显示了数组的形状和 dtype,但它们似乎没有显示分片。(实际上,它们确实显示了分片,但分片是微不足道的。请参阅下面的“具体数组分片”。)要开始看到一些有趣的分片,我们需要设置一个显式分片网格。我们使用 set_mesh 将其设置为笔记本剩余部分的当前网格。(如果您只想为某个特定范围设置网格,并在之后返回到之前的网格,则可以使用上下文管理器 jax.sharding.use_mesh 代替。)

mesh = jax.make_mesh((2, 4), ("X", "Y"),
                     axis_types=(AxisType.Explicit, AxisType.Explicit))
set_mesh(mesh)

print(f"Current mesh is: {get_abstract_mesh()}")
Current mesh is: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))

现在我们可以使用 reshard 创建一些分片数组

replicated_array = np.arange(8).reshape(4, 2)
sharded_array = reshard(replicated_array, P("X", None))

print(f"replicated_array type: {jax.typeof(replicated_array)}")
print(f"sharded_array type: {jax.typeof(sharded_array)}")
replicated_array type: ShapedArray(int32[4,2])
sharded_array type: ShapedArray(int32[4@X,2])

我们应该将类型 f32[4@X, 2] 读取为“一个 4x2 的 32 位浮点数数组,其第一个维度沿网格轴 ‘X’ 分片。该数组沿所有其他网格轴复制”

这些与 JAX 级别类型关联的分片通过操作传播。例如

arg0 = reshard(np.arange(4).reshape(4, 1), P("X", None))
arg1 = reshard(np.arange(8).reshape(1, 8), P(None, "Y"))

result = arg0 + arg1

print(f"arg0 sharding: {jax.typeof(arg0)}")
print(f"arg1 sharding: {jax.typeof(arg1)}")
print(f"result sharding: {jax.typeof(result)}")
arg0 sharding: ShapedArray(int32[4@X,1])
arg1 sharding: ShapedArray(int32[1,8@Y])
result sharding: ShapedArray(int32[4@X,8@Y])

我们可以在 jit 下进行相同的类型查询

@jax.jit
def add_arrays(x, y):
  ans = x + y
  print(f"x sharding: {jax.typeof(x)}")
  print(f"y sharding: {jax.typeof(y)}")
  print(f"ans sharding: {jax.typeof(ans)}")
  return ans

add_arrays(arg0, arg1)
x sharding: ShapedArray(int32[4@X,1])
y sharding: ShapedArray(int32[1,8@Y])
ans sharding: ShapedArray(int32[4@X,8@Y])
Array([[ 0,  1,  2,  3,  4,  5,  6,  7],
       [ 1,  2,  3,  4,  5,  6,  7,  8],
       [ 2,  3,  4,  5,  6,  7,  8,  9],
       [ 3,  4,  5,  6,  7,  8,  9, 10]], dtype=int32)

这就是它的要点。分片在跟踪时确定性地传播,我们可以在跟踪时查询它们。

分片规则和具有歧义分片的操作#

每个操作都有一个分片规则,该规则指定其输出分片(给定其输入分片)。分片规则也可能抛出一个(跟踪时)错误。每个操作都可以自由地实现它喜欢的任何分片规则,但通常的模式如下:对于每个输出轴,我们识别零个或多个对应的输入轴。然后,根据对应输入轴的“共识”分片对输出轴进行分片。即,如果输入分片都为 None,则为 None,如果只有一个非 None 输入分片,则为公共非 None 输入分片,否则为错误(需要显式的 out_sharding=… kwarg)。

此过程在逐轴的基础上完成。完成后,我们最终可能会得到一个多次提及网格轴的数组分片,这是非法的。在这种情况下,我们会引发一个(跟踪时)分片错误,并要求提供显式的 out_sharding。

以下是一些示例分片规则

  • 空操作,如 jnp.zerosjnp.arange:这些操作从整体上创建数组,因此它们没有要传播的输入分片。它们的输出默认情况下是不分片的,除非被 out_sharding kwarg 覆盖。

  • 一元元素级操作,如 sinexp:输出与输入分片相同。

  • 二元操作(+-* 等): “压缩”维度的轴分片必须匹配(或为 None)。“外积”维度(仅在一个参数中出现的维度)与其在输入中一样分片。如果结果最终多次提及网格轴,则会出错。

  • reshape. Reshape 是一个特别棘手的操作。输出轴可以映射到多个输入轴(当 reshape 用于合并轴时),或者仅映射到输入轴的一部分(当 reshape 用于拆分轴时)。我们通常的规则不适用。相反,我们按以下方式处理 reshape。我们剥离单例轴(无论如何这些轴都不能分片)。然后,我们决定 reshape 是“拆分”(将单个轴拆分为两个或多个相邻轴)、“合并”(将两个或多个相邻轴合并为一个轴)还是其他情况。如果我们有一个拆分/合并情况,其中拆分/合并的轴分片为 None,那么我们将生成的拆分/合并轴分片为 None,其他轴根据其对应的输入轴分片进行分片。在所有其他情况下,我们会抛出一个错误,并要求用户提供 out_shardings 参数。

JAX 转换和高阶函数#

JAX 程序的暂存输出表示是显式类型的。(我们称类型为“avals”,但这并不重要。)在显式分片模式下,分片是该类型的一部分。这意味着分片需要在类型需要匹配的任何地方匹配。例如,lax.cond 的两侧需要具有分片匹配的结果。lax.scan 的 carry 需要在扫描主体的输入和输出处具有相同的分片。当您使用 make_jaxpr 构建没有具体参数的 jaxpr 时,您也需要提供分片。某些 JAX 转换执行类型级操作。自动微分为原始计算中的每个原始类型构造一个切线类型(例如,TangentOf(float) == floatTangentOf(int) == float0)。使用类型中的分片,这意味着切线值以与其原始值相同的方式分片。Vmap 和 scan 也执行类型级操作,它们将数组形状提升为该形状的秩增强版本。额外的数组轴需要一个分片。我们可以从 vmap/scan 的参数中推断出来,但它们都需要一致。空 vmap/scan 需要显式的分片参数,就像它需要显式的长度参数一样。

使用 auto_sharding 解决未实现的分片规则#

显式分片的实现仍然是一个进行中的工作,并且有很多操作缺少分片规则。例如,scattergather(即索引操作)。

通常,我们不会建议使用具有如此多未实现情况的功能,但在这种情况下,有一个合理的后备方案可以使用:auto_axes。其思想是,您可以暂时进入一个上下文,其中网格轴是“auto”而不是“explicit”。您显式指定您希望 auto_axes 的最终结果如何分片,因为它会被返回到调用上下文。

这可以用作具有未实现分片规则的操作的后备方案。当您想要覆盖类型分片类型系统时,它也适用。例如,假设我们想要将 f32[4@X, 4] 添加到 f32[4, 4@X]。我们的加法分片规则会抛出一个错误:结果需要是 f32[4@X, 4@X],这将尝试两次使用网格轴,这是非法的。但是,假设您仍然想要执行该操作,并且您希望结果仅沿第一个轴分片,例如 f32[4@X, 4]。您可以按如下方式执行此操作

some_x = reshard(np.arange(16).reshape(4, 4), P("X", None))
some_y = reshard(np.arange(16).reshape(4, 4), P(None, "X"))

try:
  some_x + some_y
except Exception as e:
  print("ERROR!")
  print(e)

print("=== try again with auto_axes ===")

@auto_axes
def add_with_out_sharding_kwarg(x, y):
  print(f"We're in auto-sharding mode here. This is the current mesh: {get_abstract_mesh()}")
  return x + y

result = add_with_out_sharding_kwarg(some_x, some_y, out_shardings=P("X", None))
print(f"Result type: {jax.typeof(result)}")
ERROR!
add operation with inputs: i32[4@X,4], i32[4,4@X] produces an illegally sharded result: i32[4@X,4@X]
=== try again with auto_axes ===
We're in auto-sharding mode here. This is the current mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto))
Result type: ShapedArray(int32[4@X,4])

混合使用分片模式#

JAX 现在有三种并行风格

  • 自动分片是指您将所有设备视为单个逻辑机器,并为该机器编写“全局视图”数组程序。编译器决定如何跨可用设备对数据和计算进行分区。您可以使用 with_sharding_constraint 向编译器提供提示。

  • 显式分片(*新*)类似于自动分片,因为您正在编写全局视图程序。不同之处在于,每个数组的分片都是数组 JAX 级别类型的一部分,使其成为编程模型的显式部分。这些分片在 JAX 级别传播,并且在跟踪时可查询。将整体数组程序转换为每设备程序(例如,将 jnp.sum 转换为 psum)仍然是编译器的责任,但编译器受到用户提供的分片的严格约束。

  • 手动分片 (shard_map) 是指您从单个设备的角度编写程序。设备之间的通信通过显式集体操作(如 psum)进行。

摘要表

模式

显式分片?

显式集合?

自动

显式(新)

手动

当前网格告诉我们我们处于哪种分片模式。我们可以使用 get_abstract_mesh 查询它

print(f"Current mesh is: {get_abstract_mesh()}")
Current mesh is: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))

由于 axis_types=(Explicit, Explicit),这意味着我们处于完全显式模式。请注意,分片模式与网格相关联,而不是与整个网格相关联。我们实际上可以通过为每个网格轴设置不同的分片模式来混合分片模式。分片(在 JAX 级别类型上)只能提及显式网格轴,而集体操作(如 psum)只能提及手动网格轴。

您可以使用 auto_axes API 在某些网格轴上为 Auto,而在其他网格轴上为 Explicit。例如

import functools

@functools.partial(auto_axes, axes='X')
def g(y):
  print(f'mesh inside g: {get_abstract_mesh()}')
  print(f'y.sharding inside g: {jax.typeof(y) = }', end='\n\n')
  return y * 2

@jax.jit
def f(arr1):
  print(f'mesh inside f: {get_abstract_mesh()}')
  x = jnp.sin(arr1)
  print(f'x.sharding: {jax.typeof(x)}', end='\n\n')

  z = g(x, out_shardings=P("X", "Y"))

  print(f'z.sharding: {jax.typeof(z)}', end="\n\n")
  return z + 1

some_x = reshard(np.arange(16).reshape(4, 4), P("X", "Y"))
f(some_x)
mesh inside f: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))
x.sharding: ShapedArray(float32[4@X,4@Y])

mesh inside g: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Explicit))
y.sharding inside g: jax.typeof(y) = ShapedArray(float32[4,4@Y])

z.sharding: ShapedArray(float32[4@X,4@Y])
Array([[ 1.        ,  2.682942  ,  2.818595  ,  1.28224   ],
       [-0.513605  , -0.9178486 ,  0.44116902,  2.3139732 ],
       [ 2.9787164 ,  1.824237  , -0.08804226, -0.99998045],
       [-0.07314587,  1.840334  ,  2.9812148 ,  2.3005757 ]],      dtype=float32)

如您所见,在 g 内部,arr1 的类型为 ShapedArray(float32[4,4@Y]),这表明它在 Y 网格轴上是 Explicit,而在 X 上是 auto。

您还可以使用 explicit_axes API 进入某些或所有网格轴上的 Explicit 模式。

auto_mesh = jax.make_mesh((2, 4), ("X", "Y"),
                           axis_types=(AxisType.Auto, AxisType.Auto))

@functools.partial(explicit_axes, axes=('X', 'Y'))
def explicit_g(y):
  print(f'mesh inside g: {get_abstract_mesh()}')
  print(f'y.sharding inside g: {jax.typeof(y) = }')
  z = y * 2
  print(f'z.sharding inside g: {jax.typeof(z) = }', end='\n\n')
  return z

@jax.jit
def f(arr1):
  print(f'mesh inside f: {get_abstract_mesh()}', end='\n\n')
  x = jnp.sin(arr1)

  z = explicit_g(x, in_shardings=P("X", "Y"))

  return z + 1

with jax.sharding.use_mesh(auto_mesh):
  some_x = jax.device_put(np.arange(16).reshape(4, 4), P("X", "Y"))
  f(some_x)
mesh inside f: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto))

mesh inside g: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))
y.sharding inside g: jax.typeof(y) = ShapedArray(float32[4@X,4@Y])
z.sharding inside g: jax.typeof(z) = ShapedArray(float32[4@X,4@Y])

如您所见,在 f 内部,网格的所有轴的类型均为 Auto,而在 g 内部,它们的类型均为 Explicit。因此,分片在 g 内部的数组类型上可见。

具体数组分片可以提及 Auto 网格轴#

您可以使用 x.sharding 查询具体数组 x 的分片。您可能会期望结果与值类型关联的分片 jax.typeof(x).sharding 相同。但可能并非如此!具体数组分片 x.sharding 描述了沿 ExplicitAuto 网格轴的分片。它是编译器最终选择的分片。而类型指定的分片 jax.typeof(x).sharding 仅描述了沿 Explicit 网格轴的分片。Auto 轴被有意地从类型中隐藏,因为它们是编译器的职权范围。我们可以将具体数组分片视为与类型指定的分片一致,但比类型指定的分片更具体。例如

def compare_shardings(x):
  print(f"=== with mesh: {get_abstract_mesh()} ===")
  print(f"Concrete value sharding: {x.sharding.spec}")
  print(f"Type-specified sharding: {jax.typeof(x).sharding.spec}")

my_array = jnp.sin(reshard(np.arange(8), P("X")))
compare_shardings(my_array)

@auto_axes
def check_in_auto_context(x):
  compare_shardings(x)
  return x

check_in_auto_context(my_array, out_shardings=P("X"))
=== with mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit)) ===
Concrete value sharding: PartitionSpec('X',)
Type-specified sharding: PartitionSpec('X',)
=== with mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto)) ===
Concrete value sharding: PartitionSpec('X',)
Type-specified sharding: PartitionSpec(None,)
Array([ 0.        ,  0.84147096,  0.9092974 ,  0.14112   , -0.7568025 ,
       -0.9589243 , -0.2794155 ,  0.6569866 ], dtype=float32)

请注意,在顶层,我们当前处于完全 Explicit 网格上下文中,具体数组分片和类型指定的分片一致。但在 auto_axes 装饰器下,我们处于完全 Auto 网格上下文中,并且两个分片不一致:类型指定的分片是 P(None),而具体数组分片是 P("X")(尽管它可以是任何东西!这取决于编译器)。