jax.experimental.pallas.mosaic_gpu.emit_pipeline_warp_specialized#
- jax.experimental.pallas.mosaic_gpu.emit_pipeline_warp_specialized(body, *, grid, memory_registers, in_specs=(), out_specs=(), max_concurrent_steps=2, wg_axis, num_compute_wgs, pipeline_state=None, manual_consumed_barriers=False, compute_context=None, memory_thread_idx=None)[来源]#
创建一个函数来发射一个 warp 特定的流水线。
body
函数应该具有以下签名(不带 carry)。如果 `manual_consumed_barriers` 参数为 True,则会传递一个可选参数 `consumed_barriers`。def body(indices, *input_refs, *output_refs, *consumed_barriers) -> None:
或者启用 carry(通过 `compute_context` 参数启用),此时 body 返回下一个 carry。
def body( indices, *input_refs, *output_refs, *consumed_barriers, carry ) -> Carry:
当 `manual_consumed_barriers` 为 True 时,用户必须在每个流水线步骤中到达所有计算 warpgroup 的所有已消耗 barrier。
- 参数:
body (Callable[..., None]) – 流水线主体。
grid (pallas_core.TupleGrid) – 用于流水线的网格。
memory_registers (int) – 为内存线程预留的寄存器数量。对于 H100 GPU,40 是一个合理的值。
in_specs (BlockSpecPytree) – 输入的块规范。
out_specs (BlockSpecPytree) – 输出的块规范。
max_concurrent_steps (int) – 同时活动的顺序 stage 的最大数量。默认为 2。
wg_axis (str) – warp group 轴的轴名称。
num_compute_wgs (int) – 计算 warpgroup 的数量。
manual_consumed_barriers (bool) – 如果为 True,已消耗的 barrier 将在输出引用之后传递给 body 函数。每个输入会有一个 barrier,并且传递顺序与输入相同。
compute_context (ComputeContext | None) – 如果指定,则启用流水线中的 carry,并允许用户指定的 prologue/epilogue,该 prologue/epilogue 仅在计算线程中执行。流水线 body 函数的签名将被修改,最后一个参数是当前 carry,并且必须返回下一个 carry。compute_context 本身应遵循 ComputeContext 的签名,并将流水线函数作为其唯一参数。使用初始 carry 调用流水线将运行流水线并返回最终 carry。
memory_thread_idx (int | None) – 内存线程的索引。如果未指定,则默认为最后一个线程。
pipeline_state (jax.Array | PipelinePipeline | None) –
如果多个具有几乎相同参数(只有 in/out_specs 和 body 可以不同)的流水线将按顺序进行评估,则可以使用此参数来避免它们调用之间的流水线气泡。序列中的第一个流水线应使用
START
状态,后跟任意数量的STEADY
状态,后跟一个STOP
状态。请注意,直到具有STOP
的流水线完成之前,内存线程不会等待计算线程完成并完全消耗它们的工作。除了调用另一个流水线之外,不允许修改它们的任何操作数。重要提示:为了实现无气泡执行,还必须通过调用返回函数的
get_allocations
来使用手动分配模式,将结果传递给pl.run_scoped
,并将提供的结果作为 `allocations` 关键字参数传递给返回的函数。否则,流水线函数将自行执行作用域分配,这可能导致同步,从而仍然引起流水线气泡。
- 返回类型:
WarpSpecializedPipeline