jax.experimental.pallas.triton.CompilerParams#

class jax.experimental.pallas.triton.CompilerParams(num_warps=None, num_stages=None)[源代码]#

Triton 的编译器参数。

参数:
  • num_warps (int | None)

  • num_stages (int | None)

num_warps#

用于核函数的 warp 数量。每个 warp 由 32 个线程组成。

类型:

int | None

num_stages#

编译器在软件流水线循环中应使用的阶段数。

类型:

int | None

__init__(num_warps=None, num_stages=None)#
参数:
  • num_warps (int | None)

  • num_stages (int | None)

返回类型:

方法

__init__([num_warps, num_stages])

属性

BACKEND

num_stages

num_warps