jax.experimental.pallas.mosaic_gpu.SwizzleTransform#

class jax.experimental.pallas.mosaic_gpu.SwizzleTransform(swizzle: 'int')[源代码]#
参数

swizzle (int)

__init__(swizzle)#
参数

swizzle (int)

返回类型

None

方法

__init__(swizzle)

batch(leading_rank)

返回一个转换,它接受一个带有额外的leading_rank维度的引用。

to_gpu_transform()

undo(ref)

undo_to_gpu_transform()

属性

swizzle