jax.experimental.pallas.mosaic_gpu.emit_pipeline#
- jax.experimental.pallas.mosaic_gpu.emit_pipeline(body, *, grid, in_specs=(), out_specs=(), max_concurrent_steps=1, init_carry=None)[源代码]#
创建一个函数,用于在 Pallas 内核中发出手动流水线。
- 参数:
body (Callable[..., T]) –
流水线主体函数,该函数将使用以下参数调用:
indices: 当前循环索引的元组。*input_refs: 输入的 SMEM 引用。*output_refs: 输出的 SMEM 引用。
如果提供了
init_carry,则body将接收一个额外的参数carry– 来自前一次迭代的 carry。然后它必须返回下一个 carry 值。grid (pallas_core.TupleGrid) – 流水线的网格尺寸。
in_specs (Sequence[pallas_core.BlockSpec]) – 输入的
BlockSpec序列。out_specs (Sequence[pallas_core.BlockSpec]) – 输出的
BlockSpec序列。max_concurrent_steps (int) – 最大并发活动流水线阶段数。
init_carry (T | None) – 可选的初始 carry。如果提供,
body将处理迭代之间的 carry 状态,流水线将返回最终的 carry。
- 返回:
一个函数,当使用 GMEM 输入和输出引用调用时,它将执行流水线并返回最终的 carry 值(如果使用了
init_carry),否则返回 None。