jax.experimental.pallas.mosaic_gpu.emit_pipeline#
- jax.experimental.pallas.mosaic_gpu.emit_pipeline(body, *, grid, in_specs=(), out_specs=(), max_concurrent_steps=1, delay_release=0, init_carry=None)[源]#
创建一个函数,用于在 Pallas 内核中发出手动流水线。
- 参数:
body (Callable[..., T]) ——
流水线主体函数,其调用时带以下参数:
indices
: 当前循环索引的元组。*input_refs
: 输入数据的 SMEM 引用。*output_refs
: 输出数据的 SMEM 引用。
如果提供了
init_carry
,则body
会接收一个额外的参数carry
—— 来自前一次迭代的进位值。然后它必须返回下一个进位值。grid (pallas_core.TupleGrid) —— 流水线的网格维度。
in_specs (Sequence[pallas_core.BlockSpec]) —— 输入数据的
BlockSpec
序列。out_specs (Sequence[pallas_core.BlockSpec]) —— 输出数据的
BlockSpec
序列。max_concurrent_steps (int) —— 最大并发活跃流水线阶段数。
delay_release (int) —— 在重用输入/输出引用之前延迟的步数。必须小于
max_concurrent_steps
。有助于隐藏 WGMMA 延迟(通常设置为 1)。init_carry (T | None) —— 可选的初始进位值。如果提供,
body
会处理迭代间的进位状态,并且流水线会返回最终进位值。
- 返回:
一个函数,当用 GMEM 输入和输出引用调用时,执行流水线并返回最终进位值(如果使用了
init_carry
),否则返回 None。