jax.experimental.pallas.mosaic_gpu.copy_gmem_to_smem#

jax.experimental.pallas.mosaic_gpu.copy_gmem_to_smem(src, dst, barrier, *, collective_axes=None, partitioned_axis=None)[源文件]#

将GMEM引用异步复制到SMEM引用。

如果指定了 collective_axes,这将执行多播复制,其中所有沿 collective axis 共享相同索引的CUDA块都将接收从 dst 加载到 src 的相同数据块的副本。

如果同时指定了 collective_axes 和 partitioned_axis,这将执行分区集体复制,其中集群中的每个块都将从 src Ref 接收 transfer_size // cluster_size 大小的数据块。例如,如果 src 的形状为 (256, 256),并且沿轴 0 执行集群大小为 2 的分区复制,则第一个块将接收 src[0:128, :],第二个块将接收 src[128:256, :]。注意:只有集群中的第一个块会到达屏障,并且需要额外的集群屏障来确保集群中的所有块都已完成复制。

参数:
  • src (_Ref) – 源引用。必须位于GMEM中。

  • dst (_Ref) – 目标引用。必须位于SMEM中。

  • barrier (_Ref) – 用于跟踪复制完成情况的屏障。

  • collective_axes (str | tuple[str, ...] | None) – 用于复制的集体轴。

  • partitioned_axis (int | None) – 指示在分区集体复制期间沿 src/dst Refs 分区的数组轴。需要同时指定 collective_axes。

返回类型:

另请参阅

jax.experimental.mosaic.gpu.barrier_arrive() jax.experimental.mosaic.gpu.barrier_wait()