jax.experimental.pallas.mosaic_gpu.TransposeTransform# class jax.experimental.pallas.mosaic_gpu.TransposeTransform(permutation)[source]# 转置分块的 memref。 参数: permutation (tuple[int, ...]) __init__(permutation)# 参数: permutation (tuple[int, ...]) 返回类型: None 方法 __init__(permutation) batch(leading_rank) 返回一个转换,它接受带有额外 leading_rank 维度的 ref。 to_gpu_transform() undo(ref) 属性 permutation