jax.experimental.pallas.triton 模块#

Triton 专用 Pallas API。

#

CompilerParams([num_warps, num_stages])

Triton 的编译器参数。

函数#

atomic_and(x_ref_or_view, idx, val, *[, mask])

原子地计算 x_ref_or_view[idx] &= val

atomic_add(x_ref_or_view, idx, val, *[, mask])

原子地计算 x_ref_or_view[idx] += val

atomic_cas(ref, cmp, val)

对引用中的值执行原子比较并交换操作。

atomic_max(x_ref_or_view, idx, val, *[, mask])

原子地计算 x_ref_or_view[idx] = max(x_ref_or_view[idx], val)

atomic_min(x_ref_or_view, idx, val, *[, mask])

原子地计算 x_ref_or_view[idx] = min(x_ref_or_view[idx], val)

atomic_or(x_ref_or_view, idx, val, *[, mask])

原子地计算 x_ref_or_view[idx] |= val

atomic_xchg(x_ref_or_view, idx, val, *[, mask])

原子地交换给定值与给定索引处的值。

atomic_xor(x_ref_or_view, idx, val, *[, mask])

原子地计算 x_ref_or_view[idx] ^= val

approx_tanh(x)

逐元素的近似双曲正切:\(\mathrm{tanh}(x)\)

debug_barrier()

同步网格中的所有内核执行。

elementwise_inline_asm(asm, *, args, ...)

应用逐元素操作的内联汇编。