jax.experimental.pallas.mosaic_gpu 模块#

Pallas 针对 H100 的实验性 GPU 后端。

这些 API 非常不稳定,每周都可能发生变化。请自行承担使用风险。

#

Barrier(*[, num_arrivals, num_barriers, ...])

描述一个屏障引用。

BlockSpec([block_shape, index_map, ...])

GPU 特定的 BlockSpec

CompilerParams(*[, approx_math, ...])

Mosaic GPU 编译器参数。

MemorySpace(value[, names, module, ...])

Layout(value[, names, module, qualname, ...])

SwizzleTransform(swizzle)

TilingTransform(tiling)

表示内存引用的平铺转换。

TransposeTransform(permutation)

转置一个平铺的 memref。

WGMMAAccumulatorRef(shape, dtype, _init)

函数#

as_torch_kernel(fn)

使 Mosaic GPU 内核能够使用 PyTorch 张量进行调用。

kernel(body, out_shape, *, scratch_shapes, ...)

layout_cast(x, new_layout)

转换给定数组的布局。

set_max_registers(n, *, action)

设置 warp 所拥有的最大寄存器数。

planar_snake(lin_idx, shape, minor_dim, ...)

将线性索引转换为 shape 中的索引,并尝试优化局部性。

类循环函数#

emit_pipeline(body, *, grid[, in_specs, ...])

创建一个函数,用于在 Pallas 内核中发出手动流水线。

emit_pipeline_warp_specialized(body, *, ...)

创建一个函数,用于发出 warp 特定的流水线。

nd_loop()

沿给定轴分区的多维网格上的循环。

dynamic_scheduling_loop()

使用动态工作调度在程序实例上进行循环。

同步#

barrier_arrive(barrier)

到达给定屏障。

barrier_wait(barrier)

等待给定屏障。

semaphore_signal_parallel(*signals)

在不保证信号到达顺序的情况下,对多个信号量发出信号。

SemaphoreSignal(ref, *, device_id[, inc])

异步复制#

commit_smem()

提交所有写入 SMEM 的操作,使其对 TMA 和 MMA 操作可见。

copy_gmem_to_smem(src, dst, barrier, *[, ...])

异步地将 GMEM 引用复制到 SMEM 引用。

copy_smem_to_gmem(src, dst[, predicate, ...])

异步地将 SMEM 引用复制到 GMEM 引用。

wait_smem_to_gmem(n[, wait_read_only])

等待直到调用线程发出的 SMEM->GMEM 复制不再超过 n 个正在进行中。

Hopper 特定函数#

wgmma(acc, a, b)

在给定的引用上执行异步 warp 组矩阵乘法累加。

wgmma_wait(n)

等待直到正在进行的 WGMMA 操作不超过 n 个。

Blackwell 特定函数#

tcgen05_mma(acc, a, b[, barrier, a_scale, ...])

TensorCore gen 5 (Blackwell) 的异步矩阵乘法累加。

tcgen05_commit_arrive(barrier[, collective_axis])

跟踪先前 tcgen05_mma 调用完成情况。

async_load_tmem(src, *, layout)

执行 TMEM 数组的异步加载。

async_store_tmem(ref, value)

将值存储到 TMEM。

wait_load_tmem()

等待调用线程发出的所有先前异步 TMEM 加载。

commit_tmem()

提交当前线程发出的所有写入 TMEM 的操作。

try_cluster_cancel(result_ref, barrier)

发起一个异步请求,以从网格中领取一个新的工作单元。

query_cluster_cancel(result_ref, grid_names)

解码 try_cluster_cancel 操作的结果。

多内存操作#

multimem_store(source, ref, collective_axes)

将值存储到 collective_axes 中存在的所有设备上的 ref。

multimem_load_reduce(ref, *, ...)

从 collective_axes 中存在的所有设备上的 GMEM 引用加载并规约加载的值。

别名#

ACC

别名 WGMMAAccumulatorRef

GMEM

别名 jax.experimental.pallas.mosaic_gpu.MemorySpace.GMEM

SMEM

别名 jax.experimental.pallas.mosaic_gpu.MemorySpace.SMEM