jax.experimental.pallas.mosaic_gpu.emit_pipeline_warp_specialized#

jax.experimental.pallas.mosaic_gpu.emit_pipeline_warp_specialized(body, *, grid, memory_registers, in_specs=(), out_specs=(), max_concurrent_steps=2, wg_axis, num_compute_wgs, pipeline_state=None, manual_consumed_barriers=False, compute_context=None, memory_thread_idx=None)[来源]#

创建一个函数来发射一个 warp 特定的流水线。

body 函数应该具有以下签名(不带 carry)。如果 `manual_consumed_barriers` 参数为 True,则会传递一个可选参数 `consumed_barriers`。

def body(indices, *input_refs, *output_refs, *consumed_barriers) -> None:

或者启用 carry(通过 `compute_context` 参数启用),此时 body 返回下一个 carry。

def body(
    indices, *input_refs, *output_refs, *consumed_barriers, carry
) -> Carry:

当 `manual_consumed_barriers` 为 True 时,用户必须在每个流水线步骤中到达所有计算 warpgroup 的所有已消耗 barrier。

参数:
  • body (Callable[..., None]) – 流水线主体。

  • grid (pallas_core.TupleGrid) – 用于流水线的网格。

  • memory_registers (int) – 为内存线程预留的寄存器数量。对于 H100 GPU,40 是一个合理的值。

  • in_specs (BlockSpecPytree) – 输入的块规范。

  • out_specs (BlockSpecPytree) – 输出的块规范。

  • max_concurrent_steps (int) – 同时活动的顺序 stage 的最大数量。默认为 2。

  • wg_axis (str) – warp group 轴的轴名称。

  • num_compute_wgs (int) – 计算 warpgroup 的数量。

  • manual_consumed_barriers (bool) – 如果为 True,已消耗的 barrier 将在输出引用之后传递给 body 函数。每个输入会有一个 barrier,并且传递顺序与输入相同。

  • compute_context (ComputeContext | None) – 如果指定,则启用流水线中的 carry,并允许用户指定的 prologue/epilogue,该 prologue/epilogue 仅在计算线程中执行。流水线 body 函数的签名将被修改,最后一个参数是当前 carry,并且必须返回下一个 carry。compute_context 本身应遵循 ComputeContext 的签名,并将流水线函数作为其唯一参数。使用初始 carry 调用流水线将运行流水线并返回最终 carry。

  • memory_thread_idx (int | None) – 内存线程的索引。如果未指定,则默认为最后一个线程。

  • pipeline_state (jax.Array | PipelinePipeline | None) –

    如果多个具有几乎相同参数(只有 in/out_specs 和 body 可以不同)的流水线将按顺序进行评估,则可以使用此参数来避免它们调用之间的流水线气泡。序列中的第一个流水线应使用 START 状态,后跟任意数量的 STEADY 状态,后跟一个 STOP 状态。请注意,直到具有 STOP 的流水线完成之前,内存线程不会等待计算线程完成并完全消耗它们的工作。除了调用另一个流水线之外,不允许修改它们的任何操作数。

    重要提示:为了实现无气泡执行,还必须通过调用返回函数的 get_allocations 来使用手动分配模式,将结果传递给 pl.run_scoped,并将提供的结果作为 `allocations` 关键字参数传递给返回的函数。否则,流水线函数将自行执行作用域分配,这可能导致同步,从而仍然引起流水线气泡。

返回类型:

WarpSpecializedPipeline