jax.experimental.custom_partitioning 模块

jax.experimental.custom_partitioning 模块#

API#

jax.experimental.custom_partitioning.custom_partitioning(fun, static_argnums=())[source]#

在 XLA 图中插入一个带有自定义 SPMD 降低(lowering)规则的 CustomCallOp。

@custom_partitioning
def f(*args):
  return ...

def propagate_user_sharding(mesh, user_shape):
  '''Update the sharding of the op from a user's shape.sharding.'''
  user_sharding = jax.tree.map(lambda x: x.sharding, user_shape)

def partition(mesh, arg_shapes, result_shape):
  def lower_fn(*args):
    ... builds computation on per-device shapes ...
  result_shardings = jax.tree.map(lambda x: x.sharding, result_shape)
  arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
  # result_sharding and arg_shardings may optionally be modified and the
  # partitioner will insert collectives to reshape.
  return mesh, lower_fn, result_sharding, arg_shardings

def infer_sharding_from_operands(mesh, arg_shapes, shape):
  '''Compute the result sharding from the sharding of the operands.'''
  arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)


f.def_partition(partition, propagate_user_sharding,
                infer_sharding_from_operands=infer_sharding_from_operands,
                sharding_rule='i j -> 'i j')

def_partition 的参数如下:

  • propagate_user_sharding:一个可调用对象,它接收用户(在 DAG 中)的分片信息,并返回对新的 NamedSharding 的建议。默认值为 None。一个简单的实现就是直接返回输入的切分。

  • partition:一个可调用对象,它接收 SPMD 建议的分区形状和分区规格(partition specs),并返回网格(mesh)、每个分片的降低函数,以及最终的输入和输出分片规格(SPMD 分区器将重新分区输入以进行匹配)。返回网格是为了在未提供网格时允许配置集合通信操作(collectives)的 axis_names。

  • infer_sharding_from_operands:一个可调用对象,它根据为每个参数选择的 NamedSharding 计算出输出的 NamedSharding

  • decode_shardings:当设置为 True 时,尽可能将输入的 GSPMDSharding 转换为 NamedSharding。如果用户未提供上下文网格,则可能无法实现。

  • sharding_rule:一个 SdyShardingRule 对象、一个描述分片规则的类 Einsum 符号字符串,或者一个可以生成上述任一内容的可调用对象。我们将分片规则中 Einsum 符号的索引标签称为“因子”(factors)。我们借鉴了 einops.rearrange 字符串的概念,在因子之间使用空格分隔,并允许使用多字母因子名称。默认情况下,因子对应于直通/元素级维度。对应于其他维度的因子可以通过下述关键字参数指定。详情和示例请参阅 jax-shardy-guide

  • reduction_factors:一个字符串元组,指定字符串 sharding_rule 的缩减因子(reduction factors)。缩减因子对应于出现在操作数中但不出现在结果中的维度,例如矩阵乘法中的收缩维度。如果缩减因子被分片,结果将需要在相同的轴上进行 all-reduce。

  • need_replication_factors:一个字符串元组,指定字符串 sharding_rule 的 need_replication 因子。need_replication 因子对应于为了支持实现而不应被分片的维度。

  • permutation_factors:一个字符串元组,指定字符串 sharding_rule 的置换因子。置换因子对应于如果被分片则会触发集合置换(collective permute)的维度。

  • factor_sizes:一个变量关键字参数字典,指定仅在字符串 sharding_rule 的复合因子中使用的因子大小。

当 config.use_shardy_partitioner.value 为 True 时,使用 sharding_rule;否则使用 propagate_user_shardinginfer_sharding_from_operands

可以使用 static_argnums 将位置参数指定为静态参数。JAX 使用 inspect.signature(fun) 来解析这些位置参数。

示例

例如,假设我们要增强现有的 jax.numpy.fft.fft。该函数计算 N 维输入沿最后一个维度的离散傅里叶变换,并沿前 N-1 个维度进行批处理。然而,默认情况下,它会忽略输入的分片并将输入收集到所有设备上。由于 jax.numpy.fft.fft 是沿前 N-1 个维度进行批处理的,这样做是不必要的。我们将创建一个新的 my_fft 操作,它不会改变沿前 N-1 个维度的分片,并且仅在必要时沿最后一个维度收集输入。

import jax
from jax.sharding import NamedSharding
from jax.experimental.custom_partitioning import custom_partitioning
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
from jax.sharding import Mesh
from jax.numpy.fft import fft
import regex as re
import numpy as np

# Pattern to detect all-gather or dynamic-slice in the generated HLO
_PATTERN = '(dynamic-slice|all-gather)'

# For an N-D input, keeps sharding along the first N-1 dimensions
# but replicate along the last dimension
def supported_sharding(sharding, shape):
    rank = len(shape.shape)
    max_shared_dims = min(len(sharding.spec), rank-1)
    names = tuple(sharding.spec[:max_shared_dims]) + tuple(None for _ in range(rank - max_shared_dims))
    return NamedSharding(sharding.mesh, P(*names))

def partition(mesh, arg_shapes, result_shape):
    result_shardings = jax.tree.map(lambda x: x.sharding, result_shape)
    arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
    return mesh, fft,               supported_sharding(arg_shardings[0], arg_shapes[0]),               (supported_sharding(arg_shardings[0], arg_shapes[0]),)

def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
    arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
    return supported_sharding(arg_shardings[0], arg_shapes[0])

@custom_partitioning
def my_fft(x):
    return fft(x)

# Use Einsum-like notation to specify the sharding rule.
my_fft.def_partition(
  infer_sharding_from_operands=infer_sharding_from_operands,
  partition=partition,
  sharding_rule='...i -> ...i')
# Use SdyShardingRule object to specify the sharding rule.
my_fft.def_partition(
  infer_sharding_from_operands=infer_sharding_from_operands,
  partition=partition,
  sharding_rule=SdyShardingRule(operand_mappings=((BATCHING, 'i'),), result_mappings=((BATCHING, 'i'),))))

现在创建一个沿第一个轴分片的 2D 数组,将其通过 my_fft,观察它如何保持预期的分片方式,并且与 fft 的输出相同。然而,检查 HLO(使用 lower(x).compile().runtime_executable().hlo_modules())会发现,my_fft 不会创建任何 all-gather 或 dynamic-slice 操作,而 fft 会。

with Mesh(np.array(jax.devices()), ('x',)):
  x = np.asarray(np.random.randn(32*1024, 1024), dtype=np.complex64)
  y = pjit(lambda x: x, in_shardings=None, out_shardings=P('x'))(x)
  pjit_my_fft = pjit(my_fft, in_shardings=P('x'), out_shardings=P('x'))
  pjit_fft    = pjit(fft,    in_shardings=P('x'), out_shardings=P('x'))
  print(pjit_my_fft(y))
  print(pjit_fft(y))
  # dynamic-slice or all-gather are not present in the HLO for my_fft, because x is a 2D array
  assert(re.search(_PATTERN, pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is None)
  # dynamic-slice or all-gather are present in the HLO for fft
  assert(re.search(_PATTERN, pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string())    is not None)
# my_fft
[[-38.840824   +0.j        -40.649452  +11.845365j
...
  -1.6937828  +0.8402481j  15.999859   -4.0156755j]]

# jax.numpy.fft.fft
[[-38.840824   +0.j        -40.649452  +11.845365j
  ...
  -1.6937828  +0.8402481j  15.999859   -4.0156755j]]

由于 supported_sharding 中的逻辑,my_fft 也适用于一维数组。但在这种情况下,my_fft 的 HLO 确实显示了 dynamic-slice,因为最后一个维度是计算 FFT 的维度,在进行计算之前需要将其复制到所有设备上。

with Mesh(np.array(jax.devices()), ('x',)):
  x = np.asarray(np.random.randn(32*1024*1024), dtype=np.complex64)
  y = pjit(lambda x: x, in_shardings=None, out_shardings=P('x'))(x)
  pjit_my_fft = pjit(my_fft, in_shardings=P('x'), out_shardings=P('x'))
  pjit_fft    = pjit(fft,    in_shardings=P('x'), out_shardings=P('x'))
  print(pjit_my_fft(y))
  print(pjit_fft(y))
  # dynamic-slice or all-gather are present in the HLO for my_fft, because x is a 1D array
  assert(re.search(_PATTERN, pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is None)
  # dynamic-slice or all-gather are present in the HLO for fft
  assert(re.search(_PATTERN, pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string())    is not None)
# my_fft
[    7.217285   +0.j     -3012.4937  +4287.635j   -405.83594 +3042.984j
...  1422.4502  +7271.4297j  -405.84033 -3042.983j
-3012.4963  -4287.6343j]

# jax.numpy.fft.fft
[    7.217285   +0.j     -3012.4937  +4287.635j   -405.83594 +3042.984j
...  1422.4502  +7271.4297j  -405.84033 -3042.983j
-3012.4963  -4287.6343j]