jax.experimental.pallas.mosaic_gpu.TilingTransform#

class jax.experimental.pallas.mosaic_gpu.TilingTransform(tiling)[源代码]#

表示内存引用的平铺变换。

在形状为 (M, N) 的数组上对 (X, Y) 进行平铺将导致变换后的形状为 (M // X, N // Y, X, Y)。 例如,一个 (256, 256) 的块,用 (64, 32) 的平铺进行平铺,将被平铺为 (4, 8, 64, 32)。

参数:

tiling (tuple[int, ...])

__init__(tiling)#
参数:

tiling (tuple[int, ...])

返回类型:

方法

__init__(tiling)

batch(leading_rank)

返回一个转换,它接受带有额外 leading_rank 维度的 ref。

to_gpu_transform()

to_gpu_transform_attr()

undo(ref)

属性

tiling