jax.experimental.pallas.mosaic_gpu.emit_pipeline#

jax.experimental.pallas.mosaic_gpu.emit_pipeline(body, *, grid, in_specs=(), out_specs=(), max_concurrent_steps=1, init_carry=None)[源代码]#

创建一个函数,用于在 Pallas 内核中发出手动流水线。

参数:
  • body (Callable[..., T]) –

    流水线主体函数,该函数将使用以下参数调用:

    • indices: 当前循环索引的元组。

    • *input_refs: 输入的 SMEM 引用。

    • *output_refs: 输出的 SMEM 引用。

    如果提供了 init_carry,则 body 将接收一个额外的参数 carry – 来自前一次迭代的 carry。然后它必须返回下一个 carry 值。

  • grid (pallas_core.TupleGrid) – 流水线的网格尺寸。

  • in_specs (Sequence[pallas_core.BlockSpec]) – 输入的 BlockSpec 序列。

  • out_specs (Sequence[pallas_core.BlockSpec]) – 输出的 BlockSpec 序列。

  • max_concurrent_steps (int) – 最大并发活动流水线阶段数。

  • init_carry (T | None) – 可选的初始 carry。如果提供,body 将处理迭代之间的 carry 状态,流水线将返回最终的 carry。

返回:

一个函数,当使用 GMEM 输入和输出引用调用时,它将执行流水线并返回最终的 carry 值(如果使用了 init_carry),否则返回 None。