jax.experimental.pallas.mosaic_gpu 模块#

Pallas 针对 H100 的实验性 GPU 后端。

这些 API 非常不稳定,可能每周都会更改。使用风险自负。

#

Barrier(num_arrivals[, num_barriers])

GPUBlockSpec([block_shape, index_map, ...])

GPUCompilerParams(*[, approx_math, ...])

Mosaic GPU 编译器参数。

GPUMemorySpace(value)

一个枚举。

Layout(value)

一个枚举。

SwizzleTransform(swizzle)

TilingTransform(tiling)

表示内存引用的平铺变换。

TransposeTransform(permutation)

转置平铺的内存引用。

WGMMAAccumulatorRef(shape, dtype, _init)

函数#

barrier_arrive(barrier)

到达给定的屏障。

barrier_wait(barrier)

等待给定的屏障。

commit_smem()

提交对 SMEM 的所有写入,使其对加载、TMA 和 WGMMA 可见。

copy_gmem_to_smem(src, dst, barrier, *[, ...])

异步地将 GMEM 引用复制到 SMEM 引用。

copy_smem_to_gmem(src, dst[, predicate, ...])

异步地将 SMEM 引用复制到 GMEM 引用。

emit_pipeline(body, *, grid[, in_specs, ...])

创建一个函数,用于在 Pallas 内核中发出手动流水线。

layout_cast(x, new_layout)

转换给定数组的布局。

set_max_registers(n, *, action)

设置一个 warp 拥有的最大寄存器数量。

wait_smem_to_gmem(n[, wait_read_only])

等待直到不超过 n 个 SMEM->GMEM 副本在飞行中。

wgmma(acc, a, b)

在给定的引用上执行异步 warp 组矩阵乘法-累加。

wgmma_wait(n)

等待直到不超过 n 个 WGMMA 操作在飞行中。

别名#

ACC

别名 WGMMAAccumulatorRef

GMEM

jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.GMEM 的别名。

SMEM

jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.SMEM 的别名。