jax.experimental.pallas.triton 模块#

Triton 特定的 Pallas API。

#

TritonCompilerParams([num_warps, ...])

Triton 的编译器参数。

函数#

approx_tanh(x)

逐元素近似双曲正切:\(\mathrm{tanh}(x)\)

debug_barrier()

同步网格中的所有内核执行。

elementwise_inline_asm(asm, *, args, ...)

应用逐元素操作的内联汇编。