jax.experimental.pallas.mosaic_gpu.nd_loop#

jax.experimental.pallas.mosaic_gpu.nd_loop(grid: Sequence[int], *, collective_axes: Sequence[Hashable] | Hashable, tiling: Sequence[int] | None = None, init_carry: None = None) Callable[[Callable[[NDLoopInfo], None]], None][源代码]#
jax.experimental.pallas.mosaic_gpu.nd_loop(grid: Sequence[int], *, collective_axes: Sequence[Hashable] | Hashable, tiling: Sequence[int] | None = None, init_carry: _T) Callable[[Callable[[NDLoopInfo, _T], _T]], _T]

沿给定轴分区的多维网格上的循环。

循环体接受一个名为 loop_info 的参数,该参数是一个包含索引和迭代信息的 NDLoopInfo 对象。但是,如果指定了 carry,则循环体将期望一个名为 carry 的第二个关键字参数,其中包含循环 carry。

例如,如果 collective_axes"x",且 lax.axis_size() 等于 4,网格为 (2, 3),则实现将产生以下迭代顺序

循环步

索引

轴索引

0

(0, 0)

0

1

(0, 1)

1

2

(0, 2)

2

3

(1, 0)

3

4

(1, 1)

0

5

(1, 2)

1

这是通过根据 "x" 轴索引以交错方式将平坦的迭代空间划分为块来实现的。

请注意,在此示例中,总循环步数不能被 "x" 的轴大小整除,因此对于某些 "x" 轴索引,循环将少迭代一次。

轴索引

indices

0

(0, 0), (1, 1)

1

(0, 1), (1, 2)

2

(0, 2)

3

(1, 0)

如果传递了 init_carry,则 nd_loop() 将期望循环体接收并返回 carry。如果它是 None,则不期望 carry 参数。

另请参阅

  • jax.experimental.pallas.loop():单维循环。