jax.experimental.pallas.mosaic_gpu 模块#
Pallas 针对 H100 的实验性 GPU 后端。
这些 API 非常不稳定,每周都可能发生变化。请自行承担使用风险。
类#
|
描述一个屏障引用。 |
|
GPU 特定的 |
|
Mosaic GPU 编译器参数。 |
|
|
|
|
|
|
|
表示内存引用的平铺转换。 |
|
转置一个平铺的 memref。 |
|
函数#
|
使 Mosaic GPU 内核能够使用 PyTorch 张量进行调用。 |
|
|
|
转换给定数组的布局。 |
|
设置 warp 所拥有的最大寄存器数。 |
|
将线性索引转换为 shape 中的索引,并尝试优化局部性。 |
类循环函数#
|
创建一个函数,用于在 Pallas 内核中发出手动流水线。 |
|
创建一个函数,用于发出 warp 特定的流水线。 |
|
沿给定轴分区的多维网格上的循环。 |
使用动态工作调度在程序实例上进行循环。 |
同步#
|
到达给定屏障。 |
|
等待给定屏障。 |
|
在不保证信号到达顺序的情况下,对多个信号量发出信号。 |
|
异步复制#
提交所有写入 SMEM 的操作,使其对 TMA 和 MMA 操作可见。 |
|
|
异步地将 GMEM 引用复制到 SMEM 引用。 |
|
异步地将 SMEM 引用复制到 GMEM 引用。 |
|
等待直到调用线程发出的 SMEM->GMEM 复制不再超过 |
Hopper 特定函数#
|
在给定的引用上执行异步 warp 组矩阵乘法累加。 |
|
等待直到正在进行的 WGMMA 操作不超过 |
Blackwell 特定函数#
|
TensorCore gen 5 (Blackwell) 的异步矩阵乘法累加。 |
|
跟踪先前 |
|
执行 TMEM 数组的异步加载。 |
|
将值存储到 TMEM。 |
等待调用线程发出的所有先前异步 TMEM 加载。 |
|
提交当前线程发出的所有写入 TMEM 的操作。 |
|
|
发起一个异步请求,以从网格中领取一个新的工作单元。 |
|
解码 |
多内存操作#
|
将值存储到 collective_axes 中存在的所有设备上的 ref。 |
|
从 collective_axes 中存在的所有设备上的 GMEM 引用加载并规约加载的值。 |
别名#
别名 |
|
别名 |