设备本地数组布局控制#
jax.experimental.layout 包提供了控制 JAX 数组在设备本地内存中布局方式的方法。
术语#
数组布局与数组 分片 (sharding) 紧密耦合。布局和分片共同完整地描述了数组的值如何在(分布式)内存中分布。基于此,我们使用以下术语:
布局 (Layout):指数组值在其所驻留的每个内存(例如,单个设备内存)中的排列方式。典型的布局规范是数组维度的从次要到主要 (minor-to-major) 的排序列表。
分片 (Sharding):指数组的值如何跨不同的内存空间分布,例如跨多个设备内存(例如通过对某些维度进行分片并复制其他维度来描述)。
格式 (Format):布局和分片的配对,提供了数组内存放置的完整图景。
类型#
在控制数组布局时会用到两种 Python 类型:Layout 和 Format。
Layout类用于定义数组的内存中布局。它具有以下关键属性:major_to_minor:一个整数元组,用于指定内存中的维度排序。例如,对于二维数组,(0, 1)表示行优先布局,(1, 0)表示列优先布局。_tiling:一个刻意隐藏的、高度实验性的可选属性,用于指定平铺 (tiled) 布局。AUTO:一个特殊的静态哨兵对象,可与jax.jit一起使用,请求编译器自动为编译函数的输入或输出数组确定最佳布局。
Format类同时携带Layout和Sharding,当未指定时,两者均采用默认值。当显式指定布局时,也必须同时指定分片。
JAX API 函数(如 jax.jit 和 jax.device_put)接受用于分片控制的 Sharding,或用于额外布局控制的 Format。它们通常不直接接受 Layout 实例。
指定和读取布局#
通过将 Format 对象传递给 jax.jit 来代替分片(在 in_shardings 和 out_shardings 参数中),您可以指导编译器的布局决策。同样,您可以将 Format 代替 Sharding 传递给 jax.device_put,以控制结果数组的布局。
让我们看一个同时使用显式布局和自动布局(如 Layout.AUTO)的示例。假设我们有两个编译函数 init_fn 和 apply_fn。假设 init_fn 大约只被调用一次,而 apply_fn 会在 init_fn 的输出上多次调用,因此我们更关心 apply_fn 的性能。我们可能希望让编译器为 apply_fn 选择一个良好的布局,并约束 init_fn 生成具有该布局的数组。我们可以按照以下方式实现:
import jax, jax.numpy as jnp
from jax.experimental.layout import Layout, Format
from jax.sharding import SingleDeviceSharding
import numpy as np
def init_fn(x, y):
return x * 2, y * 3
def apply_fn(x, y):
return x[0, :], y[:, 0]
由于 apply_fn 读取其第二个参数 y 的连续列,因此将其布局为列优先(列连续存储)是有意义的。使用 Layout.AUTO,我们可以请求编译器推断出良好的输入布局,并观察到它确实请求了列优先布局的第二个参数。
shape = (4 * 128, 8 * 128)
duck = jax.ShapeDtypeStruct(shape, jnp.float32)
# Compile the `apply` function with layouts inferred automatically
apply_exe = jax.jit(
apply_fn,
in_shardings=Format(Layout.AUTO),
out_shardings=Format(Layout.AUTO),
).trace(duck, duck).lower().compile()
# Read back the inferred input layout
arg_formats, kwarg_formats = apply_exe.input_formats
assert len(kwarg_formats) == 0
assert arg_formats[0].layout.major_to_minor == (0, 1)
assert arg_formats[1].layout.major_to_minor == (1, 0)
然后,我们可以编译 init_fn,使其输出显式匹配该布局。
init_exe = jax.jit(init_fn, out_shardings=arg_formats).trace(
duck, duck).lower().compile()
assert init_exe.output_formats == arg_formats
最后,我们可以观察编译后的 apply_fn 在使用不同布局的输入数组调用时的行为。其行为取决于输入是否已 提交 (committed)。正如以下测试所示,如果参数数组已提交,则预编译的 apply_fn 要求它们匹配上述由编译器确定的布局。同时,它接受任何布局的未提交数组(当然包括推断出的布局)。在这种情况下,数组可以在调用编译计算之前进行重布局。
def test(x, y, msg):
print(f'-- {msg}:')
print('x major_to_minor =', x.format.layout.major_to_minor)
print('y major_to_minor =', y.format.layout.major_to_minor)
try:
apply_exe(x, y)
print('-> `apply` called successfully')
except ValueError as e:
assert 'does not match' in str(e)
print('-> error: mismatched input layouts')
print()
dev = jax.devices()[0]
x1 = y1 = jnp.ones(shape)
test(x1, y1, 'uncommitted with mismatched layout')
x2, y2 = init_exe(x1, y1)
test(x2, y2, 'uncommitted with matching layout')
x3 = jnp.ones(shape)
y3 = jax.device_put(np.ones(shape), Format(Layout(major_to_minor=(1, 0)),
SingleDeviceSharding(dev)))
test(x3, y3, 'committed with matching layout')
x4 = jnp.ones(shape)
y4 = jax.device_put(np.ones(shape), Format(Layout(major_to_minor=(0, 1)),
SingleDeviceSharding(dev)))
test(x4, y4, 'committed with mismatched layout')
-- uncommitted with mismatched layout:
x major_to_minor = (0, 1)
y major_to_minor = (0, 1)
-> `apply` called successfully
-- uncommitted with matching layout:
x major_to_minor = (0, 1)
y major_to_minor = (1, 0)
-> `apply` called successfully
-- committed with matching layout:
x major_to_minor = (0, 1)
y major_to_minor = (1, 0)
-> `apply` called successfully
-- committed with mismatched layout:
x major_to_minor = (0, 1)
y major_to_minor = (0, 1)
-> error: mismatched input layouts
约束中间布局#
我们还可以使用 with_layout_constraint 在 JIT 编译函数内强制执行特定的布局:
from jax.experimental.layout import with_layout_constraint
@jax.jit
def f(x):
y = x.T
# Enforce a specific layout on `y`
y = with_layout_constraint(y, Layout(major_to_minor=(0, 1)))
return y * 2
这类似于 jax.lax.with_sharding_constraint,但用于约束布局而非分片。