jax.experimental.pallas.mosaic_gpu.emit_pipeline#

jax.experimental.pallas.mosaic_gpu.emit_pipeline(body, *, grid, in_specs=(), out_specs=(), max_concurrent_steps=1, delay_release=0, init_carry=None)[源]#

创建一个函数,用于在 Pallas 内核中发出手动流水线。

参数:
  • body (Callable[..., T]) ——

    流水线主体函数,其调用时带以下参数:

    • indices: 当前循环索引的元组。

    • *input_refs: 输入数据的 SMEM 引用。

    • *output_refs: 输出数据的 SMEM 引用。

    如果提供了 init_carry,则 body 会接收一个额外的参数 carry —— 来自前一次迭代的进位值。然后它必须返回下一个进位值。

  • grid (pallas_core.TupleGrid) —— 流水线的网格维度。

  • in_specs (Sequence[pallas_core.BlockSpec]) —— 输入数据的 BlockSpec 序列。

  • out_specs (Sequence[pallas_core.BlockSpec]) —— 输出数据的 BlockSpec 序列。

  • max_concurrent_steps (int) —— 最大并发活跃流水线阶段数。

  • delay_release (int) —— 在重用输入/输出引用之前延迟的步数。必须小于 max_concurrent_steps。有助于隐藏 WGMMA 延迟(通常设置为 1)。

  • init_carry (T | None) —— 可选的初始进位值。如果提供,body 会处理迭代间的进位状态,并且流水线会返回最终进位值。

返回:

一个函数,当用 GMEM 输入和输出引用调用时,执行流水线并返回最终进位值(如果使用了 init_carry),否则返回 None。