jax.experimental.pallas.mosaic_gpu.dynamic_scheduling_loop#
- jax.experimental.pallas.mosaic_gpu.dynamic_scheduling_loop(grid_names: Sequence[Hashable], *, thread_axis: Hashable | None = None, init_carry: None = None) Callable[[Callable[[NDLoopInfo], None]], None] [源码]#
- jax.experimental.pallas.mosaic_gpu.dynamic_scheduling_loop(grid_names: Sequence[Hashable], *, thread_axis: Hashable | None = None, init_carry: _T) Callable[[Callable[[NDLoopInfo, _T], _T]], _T]
使用动态工作调度遍历程序实例的循环。
此循环将遍历可用的程序实例,直到所有工作都已调度。内核应使用等于要完成的逻辑工作量的网格进行实例化(而不是使用设置为核心数量的持久性内核)。运行此循环的每个核心将连续查询下一个可用工作块,并且当整个网格都已调度时,循环将终止。
示例用法
@plgpu.dynamic_scheduling_loop(grid_names) def body(loop_info): work(loop_info.index) # do work...
- 参数:
grid_names – 网格中轴的名称。
thread_axis – 线程轴的名称。如果内核使用多个线程,则必须传递此参数。
init_carry – 循环的可选初始载体。如果传递,则 body 函数应期望一个
carry
关键字参数并返回下一个载体值。