设备本地数组布局控制#
jax.experimental.layout 包提供了控制 JAX 数组如何在设备本地内存中布局的方法。
术语#
数组布局与数组的 分片 紧密耦合。布局和分片共同完整地描述了数组的值如何在(分布式)内存中布局。在此基础上,我们使用以下术语:
布局 (Layout):数组的值如何在它们驻留的每个内存中布局(例如,在单个设备内存的内存中)。典型的布局规范是数组维度的次序到主序的列表。
分片 (Sharding):数组的值如何在不同的内存空间(例如,多个设备内存)之间 分布(例如,通过分片某些维度并复制其他维度来描述)。
格式 (Format):布局 和 分片 的配对,提供了数组内存放置的完整图景。
类型#
控制数组布局时会遇到两种 Python 类型:Layout 和 Format。
Layout类用于定义数组的内存布局。它具有以下主要属性:major_to_minor:一个整数元组,指定内存中的维度顺序。例如,对于一个二维数组,(0, 1)表示行主序布局,(1, 0)表示列主序。_tiling:一个故意隐藏的、高度实验性的、可选的属性,用于指定分块布局。AUTO:一个特殊的、静态的哨兵对象,可以与jax.jit一起使用,以请求编译器自动确定已编译函数的输入或输出数组的良好布局。
Format类同时承载Layout和Sharding,其中一个在未指定时会采用默认值。当显式指定布局时,分片也必须被指定。
JAX API 函数,例如 jax.jit 和 jax.device_put,接受 Shardings 用于分片控制或 Formats 用于附加的布局控制。它们通常不直接接受 Layout 实例。
指定和读取布局#
通过将 Format 对象传递给 jax.jit 的 in_shardings 和 out_shardings 参数(代替 shardings),您可以指导编译器的布局决策。类似地,您也可以将 Formats 代替 Shardings 传递给 jax.device_put 来控制结果数组的布局。
让我们看一个同时使用显式布局和自动布局(如 Layout.AUTO)的例子。假设我们有两个已编译的函数,init_fn 和 apply_fn。我们预期 init_fn 只会被调用一次,而 apply_fn 将多次调用 init_fn 的输出,因此我们更关心 apply_fn 的性能。我们可以让编译器为 apply_fn 选择一个良好的布局,并约束 init_fn 以产生具有该布局的数组。我们可以这样做:
import jax, jax.numpy as jnp
from jax.experimental.layout import Layout, Format
from jax.sharding import SingleDeviceSharding
import numpy as np
def init_fn(x, y):
return x * 2, y * 3
def apply_fn(x, y):
return x[0, :], y[:, 0]
由于 apply_fn 读取其第二个参数 y 的连续列,因此将其布局为列主序(其中列是连续存储的)是有意义的。使用 Layout.AUTO,我们可以要求编译器推断良好的输入布局,并看到它确实选择请求第二个参数为列主序布局。
shape = (4 * 128, 8 * 128)
duck = jax.ShapeDtypeStruct(shape, jnp.float32)
# Compile the `apply` function with layouts inferred automatically
apply_exe = jax.jit(
apply_fn,
in_shardings=Format(Layout.AUTO),
out_shardings=Format(Layout.AUTO),
).trace(duck, duck).lower().compile()
# Read back the inferred input layout
arg_formats, kwarg_formats = apply_exe.input_formats
assert len(kwarg_formats) == 0
assert arg_formats[0].layout.major_to_minor == (0, 1)
assert arg_formats[1].layout.major_to_minor == (1, 0)
然后,我们可以编译 init_fn 以显式匹配其输出的此布局。
init_exe = jax.jit(init_fn, out_shardings=arg_formats).trace(
duck, duck).lower().compile()
assert init_exe.output_formats == arg_formats
最后,我们可以看到当用不同布局的输入数组调用已编译的 apply_fn 时其行为如何。行为取决于输入是否被 提交 (committed)。如下面的测试所示,如果参数数组被提交,则预编译的 apply_fn 要求它们匹配上面编译器确定的布局。同时,它接受任何布局的未提交数组(当然,包括推断的布局)。在这种情况下,数组可以在调用已编译的计算之前重新布局。
def test(x, y, msg):
print(f'-- {msg}:')
print('x major_to_minor =', x.format.layout.major_to_minor)
print('y major_to_minor =', y.format.layout.major_to_minor)
try:
apply_exe(x, y)
print('-> `apply` called successfully')
except ValueError as e:
assert 'does not match' in str(e)
print('-> error: mismatched input layouts')
print()
dev = jax.devices()[0]
x1 = y1 = jnp.ones(shape)
test(x1, y1, 'uncommitted with mismatched layout')
x2, y2 = init_exe(x1, y1)
test(x2, y2, 'uncommitted with matching layout')
x3 = jnp.ones(shape)
y3 = jax.device_put(np.ones(shape), Format(Layout(major_to_minor=(1, 0)),
SingleDeviceSharding(dev)))
test(x3, y3, 'committed with matching layout')
x4 = jnp.ones(shape)
y4 = jax.device_put(np.ones(shape), Format(Layout(major_to_minor=(0, 1)),
SingleDeviceSharding(dev)))
test(x4, y4, 'committed with mismatched layout')
-- uncommitted with mismatched layout:
x major_to_minor = (0, 1)
y major_to_minor = (0, 1)
-> `apply` called successfully
-- uncommitted with matching layout:
x major_to_minor = (0, 1)
y major_to_minor = (1, 0)
-> `apply` called successfully
-- committed with matching layout:
x major_to_minor = (0, 1)
y major_to_minor = (1, 0)
-> `apply` called successfully
-- committed with mismatched layout:
x major_to_minor = (0, 1)
y major_to_minor = (0, 1)
-> error: mismatched input layouts
约束中间布局#
我们还可以使用 with_layout_constraint 来强制 JIT 编译函数内的中间值具有特定的布局。
from jax.experimental.layout import with_layout_constraint
@jax.jit
def f(x):
y = x.T
# Enforce a specific layout on `y`
y = with_layout_constraint(y, Layout(major_to_minor=(0, 1)))
return y * 2
这类似于 jax.lax.with_sharding_constraint,用于约束布局而不是分片。