jax.experimental.pallas.triton.CompilerParams# 类 jax.experimental.pallas.triton.CompilerParams(num_warps=None, num_stages=None)[源]# Triton 的编译器参数。 参数: num_warps (int | None) num_stages (int | None) num_warps# 用于内核的 warp 数量。每个 warp 包含 32 个线程。 类型: int | None num_stages# 编译器用于软件流水线循环的阶段数量。 类型: int | None __init__(num_warps=None, num_stages=None)# 参数: num_warps (int | None) num_stages (int | None) 返回类型: 无 方法 __init__([num_warps, num_stages]) 属性 后端 num_stages num_warps