jax.experimental.pallas.triton.CompilerParams#

jax.experimental.pallas.triton.CompilerParams(num_warps=None, num_stages=None)[源]#

Triton 的编译器参数。

参数:
  • num_warps (int | None)

  • num_stages (int | None)

num_warps#

用于内核的 warp 数量。每个 warp 包含 32 个线程。

类型:

int | None

num_stages#

编译器用于软件流水线循环的阶段数量。

类型:

int | None

__init__(num_warps=None, num_stages=None)#
参数:
  • num_warps (int | None)

  • num_stages (int | None)

返回类型:

方法

__init__([num_warps, num_stages])

属性

后端

num_stages

num_warps