jax.experimental.pallas.triton.TritonCompilerParams#

class jax.experimental.pallas.triton.TritonCompilerParams(num_warps=None, num_stages=None, serialized_metadata=None)[source]#

Triton 的编译器参数。

参数:
  • num_warps (int | None)

  • num_stages (int | None)

  • serialized_metadata (bytes | None)

num_warps#

内核使用的 warp 数量。每个 warp 由 32 个线程组成。

类型:

int | None

num_stages#

编译器用于软件流水线循环的阶段数。

类型:

int | None

serialized_metadata#

额外的编译器元数据。此字段不稳定,将来可能会被移除。

类型:

bytes | None

__init__(num_warps=None, num_stages=None, serialized_metadata=None)#
参数:
  • num_warps (int | None | None)

  • num_stages (int | None | None)

  • serialized_metadata (bytes | None | None)

返回类型:

None

方法

__init__([num_warps, num_stages, ...])

属性

PLATFORM

num_stages

num_warps

serialized_metadata