Grids 和 BlockSpecs#

grid,即循环中的内核#

在使用 jax.experimental.pallas.pallas_call() 时,内核函数会在不同的输入上执行多次,具体方式由传递给 pallas_callgrid 参数指定。从概念上讲:

pl.pallas_call(some_kernel, grid=(n,))(...)

映射到

for i in range(n):
  some_kernel(...)

Grids 可以推广到多维,对应于嵌套循环。例如:

pl.pallas_call(some_kernel, grid=(n, m))(...)

等价于

for i in range(n):
  for j in range(m):
    some_kernel(...)

这可以推广到任何整数元组(长度为 d 的 grid 将对应于 d 个嵌套循环)。内核执行的次数为 prod(grid) 次。默认的 grid 值 () 会导致内核调用一次。每一次这样的调用都被称为一个“程序”(program)。要访问当前内核正在执行的是哪个程序(即 grid 中的哪个元素),我们使用 jax.experimental.pallas.program_id()。例如,对于调用 (1, 2)program_id(axis=0) 返回 1,而 program_id(axis=1) 返回 2。你还可以使用 jax.experimental.pallas.num_programs() 来获取给定轴的 grid 大小。

请参阅 Grids 示例 以获取使用此 API 的简单内核。

BlockSpec,即如何对输入进行分块#

结合 grid 参数,我们需要向 Pallas 提供关于如何为每次调用切分输入的信息。具体来说,我们需要提供从“循环迭代”到“要操作的输入和输出块”之间的映射。这是通过 jax.experimental.pallas.BlockSpec 对象来提供的。

在深入研究 BlockSpec 的细节之前,建议先回顾一下 Pallas 快速入门中的 Block specs 示例

BlockSpec 通过 in_specsout_specs 提供给 pallas_call,每个输入和输出各对应一个。

首先,我们讨论 indexing_mode == pl.Blocked()BlockSpec 的语义。

非正式地讲,BlockSpecindex_map 将调用索引(数量与 grid 元组的长度相同)作为参数,并返回**块索引**(整体数组的每个轴对应一个块索引)。然后,将每个块索引乘以 block_shape 中相应的轴大小,以获得相应数组轴上的实际元素索引。

注意

并非所有块形状都受支持。

  • 在 TPU 上,仅支持秩至少为 1 的块。此外,块形状的最后两个维度必须等于整体数组的相应维度,或者分别能被 8 和 128 整除。对于秩为 1 的块,块维度必须等于数组维度,或者必须是 1024 的倍数,或者是 2 的幂且至少为 128 * (32 / bitwidth(dtype))

  • 在 GPU 上,使用 Mosaic GPU 后端时,块的大小不受限制。然而,由于硬件限制,最内层数组维度的大小必须是 16 字节的倍数。例如,如果输入是 jnp.float16,则它必须是 8 的倍数。

  • 在 GPU 上,使用 Triton 后端时,块的大小不受限制,但每个操作(包括加载或存储)必须针对大小为 2 的幂的数组进行操作。

如果块形状不能均匀整除整体形状,则每个轴上的最后一次迭代仍将接收对 block_shape 大小的块的引用,但越界的元素在输入时会被填充,在输出时会被丢弃。填充的值未指定,你应该假定它们是垃圾数据。在 interpret=True 模式下,我们会使用 NaN 进行填充,以便用户有机会发现越界访问元素的情况,但这种行为不应作为依赖项。请注意,每个块中至少有一个元素必须在边界内。

更准确地说,形状为 x_shape 的输入 x 的每个轴的切片计算方式如下面的 slice_for_invocation 函数所示:

>>> import jax
>>> from jax.experimental import pallas as pl
>>> def slices_for_invocation(x_shape: tuple[int, ...],
...                           x_spec: pl.BlockSpec,
...                           grid: tuple[int, ...],
...                           invocation_indices: tuple[int, ...]) -> tuple[slice, ...]:
...   assert len(invocation_indices) == len(grid)
...   assert all(0 <= i < grid_size for i, grid_size in zip(invocation_indices, grid))
...   block_indices = x_spec.index_map(*invocation_indices)
...   assert len(x_shape) == len(x_spec.block_shape) == len(block_indices)
...   elem_indices = []
...   for x_size, block_size, block_idx in zip(x_shape, x_spec.block_shape, block_indices):
...     start_idx = block_idx * block_size
...     # At least one element of the block must be within bounds
...     assert start_idx < x_size
...     elem_indices.append(slice(start_idx, start_idx + block_size))
...   return elem_indices

例如

>>> slices_for_invocation(x_shape=(100, 100),
...                       x_spec = pl.BlockSpec((10, 20), lambda i, j: (i, j)),
...                       grid = (10, 5),
...                       invocation_indices = (2, 4))
[slice(20, 30, None), slice(80, 100, None)]

>>> # Same shape of the array and blocks, but we iterate over each block 4 times
>>> slices_for_invocation(x_shape=(100, 100),
...                       x_spec = pl.BlockSpec((10, 20), lambda i, j, k: (i, j)),
...                       grid = (10, 5, 4),
...                       invocation_indices = (2, 4, 0))
[slice(20, 30, None), slice(80, 100, None)]

>>> # An example when the block is partially out-of-bounds in the 2nd axis.
>>> slices_for_invocation(x_shape=(100, 90),
...                       x_spec = pl.BlockSpec((10, 20), lambda i, j: (i, j)),
...                       grid = (10, 5),
...                       invocation_indices = (2, 4))
[slice(20, 30, None), slice(80, 100, None)]

下面定义的 show_program_ids 函数使用 Pallas 来显示调用索引。iota_2D_kernel 将用十进制数字填充每个输出块,其中第一位数字代表第一个轴上的调用索引,第二位数字代表第二个轴上的调用索引:

>>> def show_program_ids(x_shape, block_shape, grid,
...                      index_map=lambda i, j: (i, j)):
...   def program_ids_kernel(o_ref):  # Fill the output block with 10*program_id(1) + program_id(0)
...     axes = 0
...     for axis in range(len(grid)):
...       axes += pl.program_id(axis) * 10**(len(grid) - 1 - axis)
...     o_ref[...] = jnp.full(o_ref.shape, axes)
...   res = pl.pallas_call(program_ids_kernel,
...                        out_shape=jax.ShapeDtypeStruct(x_shape, dtype=np.int32),
...                        grid=grid,
...                        in_specs=[],
...                        out_specs=pl.BlockSpec(block_shape, index_map),
...                        interpret=True)()
...   print(res)

例如

>>> show_program_ids(x_shape=(8, 6), block_shape=(2, 3), grid=(4, 2),
...                  index_map=lambda i, j: (i, j))
[[ 0  0  0  1  1  1]
 [ 0  0  0  1  1  1]
 [10 10 10 11 11 11]
 [10 10 10 11 11 11]
 [20 20 20 21 21 21]
 [20 20 20 21 21 21]
 [30 30 30 31 31 31]
 [30 30 30 31 31 31]]

>>> # An example with out-of-bounds accesses
>>> show_program_ids(x_shape=(7, 5), block_shape=(2, 3), grid=(4, 2),
...                  index_map=lambda i, j: (i, j))
[[ 0  0  0  1  1]
 [ 0  0  0  1  1]
 [10 10 10 11 11]
 [10 10 10 11 11]
 [20 20 20 21 21]
 [20 20 20 21 21]
 [30 30 30 31 31]]

>>> # It is allowed for the shape to be smaller than block_shape
>>> show_program_ids(x_shape=(1, 2), block_shape=(2, 3), grid=(1, 1),
...                  index_map=lambda i, j: (i, j))
[[0 0]]

当多个调用写入输出数组的相同元素时,结果取决于平台。

在下面的示例中,我们有一个 3D grid,其中最后一个 grid 维度未在块选择中使用(index_map=lambda i, j, k: (i, j))。因此,我们对同一个输出块进行了 10 次迭代。下面显示的输出是在 CPU 上使用 interpret=True 模式生成的,该模式目前按顺序执行调用。在 TPU 上,程序以并行和顺序的组合方式执行,并且此函数会生成所示的输出。请参阅 值得注意的特性和限制

>>> show_program_ids(x_shape=(8, 6), block_shape=(2, 3), grid=(4, 2, 10),
...                  index_map=lambda i, j, k: (i, j))
[[  9   9   9  19  19  19]
 [  9   9   9  19  19  19]
 [109 109 109 119 119 119]
 [109 109 109 119 119 119]
 [209 209 209 219 219 219]
 [209 209 209 219 219 219]
 [309 309 309 319 319 319]
 [309 309 309 319 319 319]]

block_shape 中作为维度值出现的 None 值的行为类似于值 1,区别在于相应的块轴会被压缩(你也可以传入 pl.Squeezed() 来代替 None)。在下面的示例中,观察到当块形状指定为 (None, 2) 时,o_ref 的形状是 (2,)(前导维度被压缩了)。

>>> def kernel(o_ref):
...   assert o_ref.shape == (2,)
...   o_ref[...] = jnp.full((2,), 10 * pl.program_id(1) + pl.program_id(0))
>>> pl.pallas_call(kernel,
...                jax.ShapeDtypeStruct((3, 4), dtype=np.int32),
...                out_specs=pl.BlockSpec((None, 2), lambda i, j: (i, j)),
...                grid=(3, 2), interpret=True)()
Array([[ 0,  0, 10, 10],
       [ 1,  1, 11, 11],
       [ 2,  2, 12, 12]], dtype=int32)

当我们构建 BlockSpec 时,我们可以在 block_shape 参数中使用值 None,在这种情况下,整体数组的形状将被用作 block_shape。如果我们对 index_map 参数使用值 None,则会使用返回零元组的默认索引映射函数:index_map=lambda *invocation_indices: (0,) * len(block_shape)

>>> show_program_ids(x_shape=(4, 4), block_shape=None, grid=(2, 3),
...                  index_map=None)
[[12 12 12 12]
 [12 12 12 12]
 [12 12 12 12]
 [12 12 12 12]]

>>> show_program_ids(x_shape=(4, 4), block_shape=(4, 4), grid=(2, 3),
...                  index_map=None)
[[12 12 12 12]
 [12 12 12 12]
 [12 12 12 12]
 [12 12 12 12]]

“element”(元素)索引模式#

上述行为适用于默认的“blocked”(块)索引模式。当 block_shape 元组中使用整数时(例如 (4, 8)),它等同于传入 pl.Blocked(block_size) 对象(例如 (pl.Blocked(4), pl.Blocked(8)))。Blocked 索引模式意味着 index_map 返回的索引是*块索引*。我们可以传入 pl.Blocked 以外的对象来改变 index_map 的语义,最显著的是 pl.Element(block_size)。当使用 pl.Element 索引模式时,索引映射函数返回的值直接用作数组索引,而无需先按块大小进行缩放。使用 pl.Element 模式时,你可以将数组的虚拟填充指定为维度上低-高填充的元组:其行为就像整体数组在输入时进行了填充。在 element 模式下,填充值不作保证,类似于块形状不能整除整体数组形状时 blocked 索引模式下的填充值。

Element 模式目前仅在 TPU 上受支持。

>>> # element without padding
>>> show_program_ids(x_shape=(8, 6), block_shape=(pl.Element(2), pl.Element(3)),
...                  grid=(4, 2),
...                  index_map=lambda i, j: (2*i, 3*j))
    [[ 0  0  0  1  1  1]
     [ 0  0  0  1  1  1]
     [10 10 10 11 11 11]
     [10 10 10 11 11 11]
     [20 20 20 21 21 21]
     [20 20 20 21 21 21]
     [30 30 30 31 31 31]
     [30 30 30 31 31 31]]

>>> # element, first pad the array with 1 row and 2 columns.
>>> show_program_ids(x_shape=(7, 7),
...                  block_shape=(pl.Element(2, (1, 0)),
...                               pl.Element(3, (2, 0))),
...                  grid=(4, 3),
...                  index_map=lambda i, j: (2*i, 3*j))
    [[ 0  1  1  1  2  2  2]
     [10 11 11 11 12 12 12]
     [10 11 11 11 12 12 12]
     [20 21 21 21 22 22 22]
     [20 21 21 21 22 22 22]
     [30 31 31 31 32 32 32]
     [30 31 31 31 32 32 32]]