jax.experimental.pallas.triton.TritonCompilerParams#
- class jax.experimental.pallas.triton.TritonCompilerParams(num_warps=None, num_stages=None, serialized_metadata=None)[source]#
Triton 的编译器参数。
- 参数:
num_warps (int | None)
num_stages (int | None)
serialized_metadata (bytes | None)
- num_warps#
内核使用的 warp 数量。每个 warp 由 32 个线程组成。
- 类型:
int | None
- num_stages#
编译器用于软件流水线循环的阶段数。
- 类型:
int | None
- serialized_metadata#
额外的编译器元数据。此字段不稳定,将来可能会被移除。
- 类型:
bytes | None
- __init__(num_warps=None, num_stages=None, serialized_metadata=None)#
- 参数:
num_warps (int | None | None)
num_stages (int | None | None)
serialized_metadata (bytes | None | None)
- 返回类型:
None
方法
__init__
([num_warps, num_stages, ...])属性
PLATFORM