jax.experimental.pallas.mosaic_gpu
模块#
Pallas 针对 H100 的实验性 GPU 后端。
这些 API 非常不稳定,可能每周都会更改。使用风险自负。
类#
|
|
|
|
|
Mosaic GPU 编译器参数。 |
|
一个枚举。 |
|
一个枚举。 |
|
|
|
表示内存引用的平铺变换。 |
|
转置平铺的内存引用。 |
|
函数#
|
到达给定的屏障。 |
|
等待给定的屏障。 |
提交对 SMEM 的所有写入,使其对加载、TMA 和 WGMMA 可见。 |
|
|
异步地将 GMEM 引用复制到 SMEM 引用。 |
|
异步地将 SMEM 引用复制到 GMEM 引用。 |
|
创建一个函数,用于在 Pallas 内核中发出手动流水线。 |
|
转换给定数组的布局。 |
|
设置一个 warp 拥有的最大寄存器数量。 |
|
等待直到不超过 |
|
在给定的引用上执行异步 warp 组矩阵乘法-累加。 |
|
等待直到不超过 |
别名#
|
|
|