jax.experimental.pallas.mosaic_gpu.emit_pipeline#

jax.experimental.pallas.mosaic_gpu.emit_pipeline(body, *, grid, in_specs=(), out_specs=(), max_concurrent_steps=1, delay_release=0)[源代码]#

创建一个函数,在 Pallas 内核中发出手动流水线。

参数:
  • body (Callable[..., None]) – 流水线主体。

  • grid (pallas_core.StaticGrid) – 用于流水线的网格。

  • in_specs (Sequence[pallas_core.BlockSpec]) – 输入的块规范。

  • out_specs (Sequence[pallas_core.BlockSpec]) – 输出的块规范。

  • max_concurrent_steps (int) – 同时处于活动状态的最大顺序阶段数。默认为 1。

  • delay_release (int) – 在重用输入/输出引用之前等待的步数。默认为 0,且必须严格小于 max_concurrent_steps。通常,如果你不在主体中等待 WGMMA,则应将其设置为 1。