jax.experimental.pallas.mosaic_gpu.emit_pipeline#
- jax.experimental.pallas.mosaic_gpu.emit_pipeline(body, *, grid, in_specs=(), out_specs=(), max_concurrent_steps=1, delay_release=0)[源代码]#
创建一个函数,在 Pallas 内核中发出手动流水线。
- 参数:
body (Callable[..., None]) – 流水线主体。
grid (pallas_core.StaticGrid) – 用于流水线的网格。
in_specs (Sequence[pallas_core.BlockSpec]) – 输入的块规范。
out_specs (Sequence[pallas_core.BlockSpec]) – 输出的块规范。
max_concurrent_steps (int) – 同时处于活动状态的最大顺序阶段数。默认为 1。
delay_release (int) – 在重用输入/输出引用之前等待的步数。默认为 0,且必须严格小于
max_concurrent_steps
。通常,如果你不在主体中等待 WGMMA,则应将其设置为 1。