jax.experimental.custom_partitioning 模块#
API#
- jax.experimental.custom_partitioning.custom_partitioning(fun, static_argnums=())[source]#
将自定义 SPMD 降低规则的 CustomCallOp 插入 XLA 图。
@custom_partitioning def f(*args): return ... def propagate_user_sharding(mesh, user_shape): '''Update the sharding of the op from a user's shape.sharding.''' user_sharding = jax.tree.map(lambda x: x.sharding, user_shape) def partition(mesh, arg_shapes, result_shape): def lower_fn(*args): ... builds computation on per-device shapes ... result_shardings = jax.tree.map(lambda x: x.sharding, result_shape) arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) # result_sharding and arg_shardings may optionally be modified and the # partitioner will insert collectives to reshape. return mesh, lower_fn, result_sharding, arg_shardings def infer_sharding_from_operands(mesh, arg_shapes, shape): '''Compute the result sharding from the sharding of the operands.''' arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) f.def_partition(partition, propagate_user_sharding, infer_sharding_from_operands=infer_sharding_from_operands, sharding_rule='i j -> 'i j')
def_partition的参数如下propagate_user_sharding: 可调用对象,它接收 DAG 中用户的分片,并返回一个新的 NamedSharding 的建议。默认值为 None。一个简单的实现只是返回输入分片。partition: 可调用对象,它接收 SPMD 建议的分片形状和分片规范,并返回 mesh、每个分片的降低函数以及最终的输入和输出分片规范(SPMD 分片器会重新分片输入以匹配)。mesh 返回以允许在未提供 mesh 时为集合配置 axis_names。infer_sharding_from_operands: 可调用对象,它根据为每个参数选择的NamedSharding计算输出NamedSharding。decode_shardings: 如果设置为 True,则将输入的GSPMDSharding转换为NamedSharding(如果可能)。如果用户未提供上下文 mesh,则可能无法执行此操作。sharding_rule: 一个 SdyShardingRule 对象、一个描述分片规则的类似 Einsum 的表示法字符串,或一个生成其中之一的可调用对象。我们将 Einsum 表示法中的索引标签称为分片规则中的因子。我们借鉴了 einops.rearrange 字符串的思想,使用空格分隔因子,并允许使用多个字母作为因子名称。默认情况下,因子对应于 passthrough/elementwise 维度。其他维度的因子可以通过下面描述的关键字参数指定。有关更多详细信息和示例,请参阅 jax-shardy-guide。reduction_factors: 一个字符串元组,指定字符串 sharding_rule 的归约因子。归约因子对应于出现在操作数中但未出现在结果中的维度,例如矩阵乘法运算中的收缩维度。如果归约因子被分片,则结果需要沿着相同的轴进行 all-reduce。need_replication_factors: 一个字符串元组,指定字符串 sharding_rule 的 need_replication 因子。need_replication 因子对应于不支持实现而不能被分片的维度。permutation_factors: 一个字符串元组,指定字符串 sharding_rule 的置换因子。置换因子对应于如果被分片则会触发 collective permute 的维度。factor_sizes: 一个可变关键字参数字典,指定字符串 sharding_rule 中仅用于复合因子的因子的尺寸。
当 config.use_shardy_partitioner.value 为 True 时,使用 sharding_rule;否则,使用 propagate_user_sharding 和 infer_sharding_from_operands。
位置参数可以使用 static_argnums 指定为静态。JAX 使用
inspect.signature(fun)来解析这些位置参数。示例
例如,假设我们想增强现有的
jax.numpy.fft.fft。此函数沿最后一个维度计算 N 维输入的离散傅里叶变换,并沿前 N-1 个维度进行批处理。但是,默认情况下,它会忽略输入的 sharding 并将输入收集到所有设备上。然而,由于jax.numpy.fft.fft是沿前 N-1 个维度批处理的,因此这是不必要的。我们将创建一个新的my_fftop,它反而不会改变沿前 N-1 个维度的 sharding,并且仅在需要时沿最后一个维度收集输入。import jax from jax.sharding import NamedSharding from jax.experimental.custom_partitioning import custom_partitioning from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as P from jax.sharding import Mesh from jax.numpy.fft import fft import regex as re import numpy as np # Pattern to detect all-gather or dynamic-slice in the generated HLO _PATTERN = '(dynamic-slice|all-gather)' # For an N-D input, keeps sharding along the first N-1 dimensions # but replicate along the last dimension def supported_sharding(sharding, shape): rank = len(shape.shape) max_shared_dims = min(len(sharding.spec), rank-1) names = tuple(sharding.spec[:max_shared_dims]) + tuple(None for _ in range(rank - max_shared_dims)) return NamedSharding(sharding.mesh, P(*names)) def partition(mesh, arg_shapes, result_shape): result_shardings = jax.tree.map(lambda x: x.sharding, result_shape) arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) return mesh, fft, supported_sharding(arg_shardings[0], arg_shapes[0]), (supported_sharding(arg_shardings[0], arg_shapes[0]),) def infer_sharding_from_operands(mesh, arg_shapes, result_shape): arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) return supported_sharding(arg_shardings[0], arg_shapes[0]) @custom_partitioning def my_fft(x): return fft(x) # Use Einsum-like notation to specify the sharding rule. my_fft.def_partition( infer_sharding_from_operands=infer_sharding_from_operands, partition=partition, sharding_rule='...i -> ...i') # Use SdyShardingRule object to specify the sharding rule. my_fft.def_partition( infer_sharding_from_operands=infer_sharding_from_operands, partition=partition, sharding_rule=SdyShardingRule(operand_mappings=((BATCHING, 'i'),), result_mappings=((BATCHING, 'i'),))))
现在创建一个沿第一个轴分片的 2D 数组,将其通过
my_fft,并注意它仍然如预期般分片,并且与fft的输出相同。然而,检查 HLO(使用lower(x).compile().runtime_executable().hlo_modules())会发现my_fft没有创建任何 all-gather 或 dynamic-slice,而fft则有。with Mesh(np.array(jax.devices()), ('x',)): x = np.asarray(np.random.randn(32*1024, 1024), dtype=np.complex64) y = pjit(lambda x: x, in_shardings=None, out_shardings=P('x'))(x) pjit_my_fft = pjit(my_fft, in_shardings=P('x'), out_shardings=P('x')) pjit_fft = pjit(fft, in_shardings=P('x'), out_shardings=P('x')) print(pjit_my_fft(y)) print(pjit_fft(y)) # dynamic-slice or all-gather are not present in the HLO for my_fft, because x is a 2D array assert(re.search(_PATTERN, pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is None) # dynamic-slice or all-gather are present in the HLO for fft assert(re.search(_PATTERN, pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is not None)
# my_fft [[-38.840824 +0.j -40.649452 +11.845365j ... -1.6937828 +0.8402481j 15.999859 -4.0156755j]] # jax.numpy.fft.fft [[-38.840824 +0.j -40.649452 +11.845365j ... -1.6937828 +0.8402481j 15.999859 -4.0156755j]]
由于
supported_sharding中的逻辑,my_fft也适用于 1 维数组。但是,在这种情况下,my_fft的 HLO 显示了一个 dynamic-slice,因为最后一个维度是计算 FFT 的维度,并且需要在计算之前将其复制到所有设备上。with Mesh(np.array(jax.devices()), ('x',)): x = np.asarray(np.random.randn(32*1024*1024), dtype=np.complex64) y = pjit(lambda x: x, in_shardings=None, out_shardings=P('x'))(x) pjit_my_fft = pjit(my_fft, in_shardings=P('x'), out_shardings=P('x')) pjit_fft = pjit(fft, in_shardings=P('x'), out_shardings=P('x')) print(pjit_my_fft(y)) print(pjit_fft(y)) # dynamic-slice or all-gather are present in the HLO for my_fft, because x is a 1D array assert(re.search(_PATTERN, pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is None) # dynamic-slice or all-gather are present in the HLO for fft assert(re.search(_PATTERN, pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is not None)
# my_fft [ 7.217285 +0.j -3012.4937 +4287.635j -405.83594 +3042.984j ... 1422.4502 +7271.4297j -405.84033 -3042.983j -3012.4963 -4287.6343j] # jax.numpy.fft.fft [ 7.217285 +0.j -3012.4937 +4287.635j -405.83594 +3042.984j ... 1422.4502 +7271.4297j -405.84033 -3042.983j -3012.4963 -4287.6343j]