jax.experimental.pallas.mosaic_gpu 模块#

针对 H100 的 Pallas 实验性 GPU 后端。

这些 API 极不稳定,每周都可能更改。使用风险自负。

#

Barrier(*[, num_arrivals, num_barriers, ...])

描述一个屏障引用。

BlockSpec([block_shape, index_map, ...])

CompilerParams(*[, approx_math, ...])

Mosaic GPU 编译器参数。

MemorySpace(value[, names, module, ...])

Layout(value[, names, module, qualname, ...])

SwizzleTransform(swizzle)

TilingTransform(tiling)

表示用于内存引用的平铺变换。

TransposeTransform(permutation)

转置一个平铺的 memref。

WGMMAAccumulatorRef(shape, dtype, _init)

函数#

barrier_arrive(barrier)

到达给定屏障。

barrier_wait(barrier)

等待给定屏障。

commit_smem()

提交所有对 SMEM 的写入,使其对 TMA 和 MMA 操作可见。

copy_gmem_to_smem(src, dst, barrier, *[, ...])

将 GMEM 引用异步复制到 SMEM 引用。

copy_smem_to_gmem(src, dst[, predicate, ...])

将 SMEM 引用异步复制到 GMEM 引用。

emit_pipeline(body, *, grid[, in_specs, ...])

创建一个函数,用于在 Pallas 内核中发出手动管道。

layout_cast(x, new_layout)

转换给定数组的布局。

set_max_registers(n, *, action)

设置 warp 拥有的最大寄存器数量。

wait_smem_to_gmem(n[, wait_read_only])

等待直到飞行中的 SMEM->GMEM 复制数量不多于 n

wgmma(acc, a, b)

对给定引用执行异步 warp 组矩阵乘法累加。

wgmma_wait(n)

等待直到飞行中的 WGMMA 操作数量不多于 n

别名#

ACC

别名 WGMMAAccumulatorRef

GMEM

别名 jax.experimental.pallas.mosaic_gpu.MemorySpace.GMEM

SMEM

别名 jax.experimental.pallas.mosaic_gpu.MemorySpace.SMEM