jax.experimental.pallas.mosaic_gpu
模块#
Pallas 针对 H100 的实验性 GPU 后端。
这些 API 非常不稳定,每周都可能发生变化。使用风险自负。
类#
|
|
|
|
|
Mosaic GPU 编译器参数。 |
|
一个枚举。 |
|
一个枚举。 |
|
|
|
表示内存引用的平铺变换。 |
|
转置一个平铺的 memref。 |
|
函数#
|
到达给定的屏障。 |
|
等待给定的屏障。 |
提交对 SMEM 的所有写入,使其对加载、TMA 和 WGMMA 可见。 |
|
|
将 GMEM 引用异步复制到 SMEM 引用。 |
|
将 SMEM 引用异步复制到 GMEM 引用。 |
|
创建一个函数以在 Pallas 内核中发出手动流水线。 |
|
转换给定数组的布局。 |
|
设置一个 warp 拥有的最大寄存器数。 |
|
等待直到正在进行的 SMEM->GMEM 复制不超过 |
|
在给定引用上执行异步 warp 组矩阵乘法累加。 |
|
等待直到正在进行的 WGMMA 操作不超过 |
别名#
|
|
|