jax.experimental.pallas.mosaic_gpu.Layout#

class jax.experimental.pallas.mosaic_gpu.Layout(value, names=<未给定>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[源代码]#
__init__(*args, **kwds)#

方法

reduce(axes)

to_mgpu(*args, **kwargs)

属性

WGMMA

[m, n] 矩阵,其中 m % 64 == 0 == n % 8。

WGMMA_TRANSPOSED

WG_SPLAT

WG_STRIDED

TCGEN05

TCGEN05_TRANSPOSED

TCGEN05_M64_COLLECTIVE

TCGEN05_TMEM_NATIVE

WGMMA_ROW

WGMMA_COL

TCGEN05_ROW

TCGEN05_COL

TCGEN05_TMEM_NATIVE_ROW