jax.experimental.pallas.mosaic_gpu
模块
针对 H100 的 Pallas 实验性 GPU 后端。
这些 API 极不稳定,每周都可能更改。使用风险自负。
类
Barrier (*[, num_arrivals, num_barriers, ...])
|
描述一个屏障引用。 |
BlockSpec ([block_shape, index_map, ...])
|
|
CompilerParams (*[, approx_math, ...])
|
Mosaic GPU 编译器参数。 |
MemorySpace (value[, names, module, ...])
|
|
Layout (value[, names, module, qualname, ...])
|
|
SwizzleTransform (swizzle)
|
|
TilingTransform (tiling)
|
表示用于内存引用的平铺变换。 |
TransposeTransform (permutation)
|
转置一个平铺的 memref。 |
WGMMAAccumulatorRef (shape, dtype, _init)
|
|
函数
barrier_arrive (barrier)
|
到达给定屏障。 |
barrier_wait (barrier)
|
等待给定屏障。 |
commit_smem ()
|
提交所有对 SMEM 的写入,使其对 TMA 和 MMA 操作可见。 |
copy_gmem_to_smem (src, dst, barrier, *[, ...])
|
将 GMEM 引用异步复制到 SMEM 引用。 |
copy_smem_to_gmem (src, dst[, predicate, ...])
|
将 SMEM 引用异步复制到 GMEM 引用。 |
emit_pipeline (body, *, grid[, in_specs, ...])
|
创建一个函数,用于在 Pallas 内核中发出手动管道。 |
layout_cast (x, new_layout)
|
转换给定数组的布局。 |
set_max_registers (n, *, action)
|
设置 warp 拥有的最大寄存器数量。 |
wait_smem_to_gmem (n[, wait_read_only])
|
等待直到飞行中的 SMEM->GMEM 复制数量不多于 n 。 |
wgmma (acc, a, b)
|
对给定引用执行异步 warp 组矩阵乘法累加。 |
wgmma_wait (n)
|
等待直到飞行中的 WGMMA 操作数量不多于 n 。 |