jax.experimental.pallas.mosaic_gpu.nd_loop#
- jax.experimental.pallas.mosaic_gpu.nd_loop(grid: Sequence[int], *, collective_axes: Sequence[Hashable] | Hashable, tiling: Sequence[int] | None = None, init_carry: None = None) Callable[[Callable[[NDLoopInfo], None]], None] [源代码]#
- jax.experimental.pallas.mosaic_gpu.nd_loop(grid: Sequence[int], *, collective_axes: Sequence[Hashable] | Hashable, tiling: Sequence[int] | None = None, init_carry: _T) Callable[[Callable[[NDLoopInfo, _T], _T]], _T]
沿给定轴分区的多维网格上的循环。
循环体接受一个名为
loop_info
的参数,该参数是一个包含索引和迭代信息的 NDLoopInfo 对象。但是,如果指定了 carry,则循环体将期望一个名为 carry 的第二个关键字参数,其中包含循环 carry。例如,如果
collective_axes
是"x"
,且lax.axis_size()
等于 4,网格为 (2, 3),则实现将产生以下迭代顺序循环步
索引
轴索引
0
(0, 0)
0
1
(0, 1)
1
2
(0, 2)
2
3
(1, 0)
3
4
(1, 1)
0
5
(1, 2)
1
这是通过根据
"x"
轴索引以交错方式将平坦的迭代空间划分为块来实现的。请注意,在此示例中,总循环步数不能被
"x"
的轴大小整除,因此对于某些"x"
轴索引,循环将少迭代一次。轴索引
indices
0
(0, 0), (1, 1)
1
(0, 1), (1, 2)
2
(0, 2)
3
(1, 0)
如果传递了
init_carry
,则nd_loop()
将期望循环体接收并返回 carry。如果它是None
,则不期望 carry 参数。另请参阅
jax.experimental.pallas.loop()
:单维循环。