jax.experimental.pallas.mosaic_gpu.async_load_tmem#
- jax.experimental.pallas.mosaic_gpu.async_load_tmem(src, *, layout=None)[source]#
从 TMEM 数组执行异步加载。
加载操作仅部分异步。返回的数组可以立即使用,无需任何额外的同步。然而,不能假定函数返回时 TMEM 的读取已完成。如果您尝试覆盖已读取的区域,则应确保在此操作发生之前已调用
wait_load_tmem
。否则可能导致不确定的数据竞争。例如,即使 TMEM 加载从未被等待,以下在内核末尾的操作序列也是有效的
smem_ref[...] = plgpu.async_load_tmem(tmem_ref) plgpu.commit_smem() plgpu.copy_smem_to_gmem(smem_ref, gmem_ref) plgpu.wait_smem_to_gmem(0)
然而,如果内核是持久的并且可能再次重用 TMEM,则应通过调用
wait_load_tmem
来扩展该序列。- 参数:
src (_Ref) – 要从中加载的 TMEM 引用。
layout (SomeLayout | None) – 用于结果数组的可选布局提示。
- 返回类型: