jax.experimental.pallas.mosaic_gpu.copy_gmem_to_smem#
- jax.experimental.pallas.mosaic_gpu.copy_gmem_to_smem(src, dst, barrier, *, collective_axes=None, partitioned_axis=None)[源文件]#
将GMEM引用异步复制到SMEM引用。
如果指定了 collective_axes,这将执行多播复制,其中所有沿 collective axis 共享相同索引的CUDA块都将接收从 dst 加载到 src 的相同数据块的副本。
如果同时指定了 collective_axes 和 partitioned_axis,这将执行分区集体复制,其中集群中的每个块都将从 src Ref 接收 transfer_size // cluster_size 大小的数据块。例如,如果 src 的形状为 (256, 256),并且沿轴 0 执行集群大小为 2 的分区复制,则第一个块将接收 src[0:128, :],第二个块将接收 src[128:256, :]。注意:只有集群中的第一个块会到达屏障,并且需要额外的集群屏障来确保集群中的所有块都已完成复制。
- 参数:
- 返回类型:
无
另请参阅
jax.experimental.mosaic.gpu.barrier_arrive()
jax.experimental.mosaic.gpu.barrier_wait()