jax.experimental.pallas.mosaic_gpu.dynamic_scheduling_loop#

jax.experimental.pallas.mosaic_gpu.dynamic_scheduling_loop(grid_names: Sequence[Hashable], *, thread_axis: Hashable | None = None, init_carry: None = None) Callable[[Callable[[NDLoopInfo], None]], None][源码]#
jax.experimental.pallas.mosaic_gpu.dynamic_scheduling_loop(grid_names: Sequence[Hashable], *, thread_axis: Hashable | None = None, init_carry: _T) Callable[[Callable[[NDLoopInfo, _T], _T]], _T]

使用动态工作调度遍历程序实例的循环。

此循环将遍历可用的程序实例,直到所有工作都已调度。内核应使用等于要完成的逻辑工作量的网格进行实例化(而不是使用设置为核心数量的持久性内核)。运行此循环的每个核心将连续查询下一个可用工作块,并且当整个网格都已调度时,循环将终止。

示例用法

@plgpu.dynamic_scheduling_loop(grid_names)
def body(loop_info):
  work(loop_info.index)  # do work...
参数:
  • grid_names – 网格中轴的名称。

  • thread_axis – 线程轴的名称。如果内核使用多个线程,则必须传递此参数。

  • init_carry – 循环的可选初始载体。如果传递,则 body 函数应期望一个 carry 关键字参数并返回下一个载体值。