jax.experimental.pallas.mosaic_gpu.async_load_tmem#

jax.experimental.pallas.mosaic_gpu.async_load_tmem(src, *, layout=None)[source]#

从 TMEM 数组执行异步加载。

加载操作仅部分异步。返回的数组可以立即使用,无需任何额外的同步。然而,不能假定函数返回时 TMEM 的读取已完成。如果您尝试覆盖已读取的区域,则应确保在此操作发生之前已调用 wait_load_tmem。否则可能导致不确定的数据竞争。

例如,即使 TMEM 加载从未被等待,以下在内核末尾的操作序列也是有效的

smem_ref[...] = plgpu.async_load_tmem(tmem_ref)
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(smem_ref, gmem_ref)
plgpu.wait_smem_to_gmem(0)

然而,如果内核是持久的并且可能再次重用 TMEM,则应通过调用 wait_load_tmem 来扩展该序列。

参数:
  • src (_Ref) – 要从中加载的 TMEM 引用。

  • layout (SomeLayout | None) – 用于结果数组的可选布局提示。

返回类型:

jax.Array