设备本地数组布局控制#

jax.experimental.layout 包提供了控制 JAX 数组在设备本地内存中如何布局的方法。

术语#

数组布局与数组分片紧密耦合。布局和分片一起完整地描述了数组的值如何在(分布式)内存中布局。沿着这些思路,我们使用以下术语

  • 布局:数组的值如何在它们所在的每个内存中布局(例如,在单个设备内存的内存中)。典型的布局规范是从最小到最大的数组维度顺序列表。

  • 分片:数组的值如何不同的内存空间分布,例如多个设备内存(例如,通过分片某些维度和复制其他维度来描述)。

  • 格式布局分片的配对,提供数组内存放置的完整图像。

类型#

控制数组布局时,会出现两种 Python 类型:LayoutFormat

  • Layout 类用于定义数组的内存布局。它具有以下关键属性

    • major_to_minor:一个整数元组,指定内存中的维度顺序。例如,对于二维数组,(0, 1) 表示行主序布局,(1, 0) 表示列主序布局。

    • _tiling:一个有意隐藏的、高度实验性的可选属性,用于指定平铺布局。

    • AUTO:一个特殊的静态 sentinel 对象,可以与 jax.jit 一起使用,以请求编译器自动确定已编译函数的输入或输出数组的良好布局。

  • Format 类包含 LayoutSharding,当未指定时,两者都会采用默认值。显式指定布局时,也必须指定分片。

JAX API 函数(例如 jax.jitjax.device_put)接受用于分片控制的 Sharding 或用于附加布局控制的 Format。它们通常不直接接受 Layout 实例。

指定和读取布局#

通过将 Format 对象传递给 jax.jit 以代替分片(在 in_shardingsout_shardings 参数中),您可以指导编译器的布局决策。类似地,您可以将 Format 传递给 jax.device_put 而不是 Sharding 以控制结果数组的布局。

让我们看一个同时使用显式和自动布局的示例(如 Layout.AUTO 中)。假设我们有两个已编译的函数,init_fnapply_fn。假设我们期望 init_fn 大致调用一次,但 apply_fninit_fn 的输出上调用多次,因此我们更关心 apply_fn 的性能。我们可能希望编译器为 apply_fn 选择一个好的布局,并约束 init_fn 生成这种布局的数组。我们可以这样做

import jax, jax.numpy as jnp
from jax.experimental.layout import Layout, Format
from jax.sharding import SingleDeviceSharding
import numpy as np

def init_fn(x, y):
  return x * 2, y * 3

def apply_fn(x, y):
  return x[0, :], y[:, 0]

由于 apply_fn 读取其第二个参数 y 的连续列,因此以列主序布局它是有意义的(其中列是连续存储的)。使用 Layout.AUTO,我们可以要求编译器推断良好的输入布局,并看到它确实选择以列主序布局请求第二个参数。

shape = (4 * 128, 8 * 128)
duck = jax.ShapeDtypeStruct(shape, jnp.float32)

# Compile the `apply` function with layouts inferred automatically
apply_exe = jax.jit(
    apply_fn,
    in_shardings=Format(Layout.AUTO),
    out_shardings=Format(Layout.AUTO),
).trace(duck, duck).lower().compile()

# Read back the inferred input layout
arg_formats, kwarg_formats = apply_exe.input_formats
assert len(kwarg_formats) == 0
assert arg_formats[0].layout.major_to_minor == (0, 1)
assert arg_formats[1].layout.major_to_minor == (1, 0)

然后我们可以编译 init_fn 以显式匹配其输出中的此布局。

init_exe = jax.jit(init_fn, out_shardings=arg_formats).trace(
    duck, duck).lower().compile()

assert init_exe.output_formats == arg_formats

最后,我们可以看到使用不同布局的输入数组调用编译后的 apply_fn 时的行为。行为随输入是否已提交而变化。正如以下测试所证明的那样,如果参数数组已提交,则预编译的 apply_fn 要求它们与上述编译器确定的布局匹配。同时,它接受任何布局的未提交数组(当然,包括推断的布局)。在这种情况下,数组可能会在调用编译后的计算之前重新布局。

def test(x, y, msg):
  print(f'-- {msg}:')
  print('x major_to_minor =', x.format.layout.major_to_minor)
  print('y major_to_minor =', y.format.layout.major_to_minor)
  try:
    apply_exe(x, y)
    print('-> `apply` called successfully')
  except ValueError as e:
    assert 'does not match' in str(e)
    print('-> error: mismatched input layouts')
  print()

dev = jax.devices()[0]

x1 = y1 = jnp.ones(shape)
test(x1, y1, 'uncommitted with mismatched layout')

x2, y2 = init_exe(x1, y1)
test(x2, y2, 'uncommitted with matching layout')

x3 = jnp.ones(shape)
y3 = jax.device_put(np.ones(shape), Format(Layout(major_to_minor=(1, 0)),
                                           SingleDeviceSharding(dev)))
test(x3, y3, 'committed with matching layout')

x4 = jnp.ones(shape)
y4 = jax.device_put(np.ones(shape), Format(Layout(major_to_minor=(0, 1)),
                                           SingleDeviceSharding(dev)))
test(x4, y4, 'committed with mismatched layout')
-- uncommitted with mismatched layout:
x major_to_minor = (0, 1)
y major_to_minor = (0, 1)
-> `apply` called successfully

-- uncommitted with matching layout:
x major_to_minor = (0, 1)
y major_to_minor = (1, 0)
-> `apply` called successfully

-- committed with matching layout:
x major_to_minor = (0, 1)
y major_to_minor = (1, 0)
-> `apply` called successfully

-- committed with mismatched layout:
x major_to_minor = (0, 1)
y major_to_minor = (0, 1)
-> error: mismatched input layouts

约束中间布局#

我们还可以使用 with_layout_constraint 在 JIT 编译的函数中对中间值强制执行特定布局

from jax.experimental.layout import with_layout_constraint

@jax.jit
def f(x):
  y = x.T
  # Enforce a specific layout on `y`
  y = with_layout_constraint(y, Layout(major_to_minor=(0, 1)))
  return y * 2

这类似于 jax.lax.with_sharding_constraint,用于约束布局而不是分片。