jax.experimental.pallas.mosaic_gpu.Layout#

class jax.experimental.pallas.mosaic_gpu.Layout(value)[源代码]#

一个枚举。

__init__()#

方法

to_mgpu(*args, **kwargs)

属性

WGMMA

[m, n] 矩阵,其中 m % 64 == 0 == n % 8。

WGMMA_ROW

[m] 矩阵,其中 m % 64 == 0。

WGMMA_COL

[n] 矩阵,其中 n % 8 == 0。

WGMMA_TRANSPOSED

WG_SPLAT

WG_STRIDED