jax.experimental.pallas.triton.CompilerParams# class jax.experimental.pallas.triton.CompilerParams(num_warps=None, num_stages=None)[源代码]# Triton 的编译器参数。 参数: num_warps (int | None) num_stages (int | None) num_warps# 用于核函数的 warp 数量。每个 warp 由 32 个线程组成。 类型: int | None num_stages# 编译器在软件流水线循环中应使用的阶段数。 类型: int | None __init__(num_warps=None, num_stages=None)# 参数: num_warps (int | None) num_stages (int | None) 返回类型: 无 方法 __init__([num_warps, num_stages]) 属性 BACKEND num_stages num_warps