jax.experimental.pallas.mosaic_gpu.multimem_store#

jax.experimental.pallas.mosaic_gpu.multimem_store(source, ref, collective_axes)[源]#

将值存储到 collective_axes 中所有设备上的 ref。

存储操作使用 multimem 指令完成,这意味着数据仅传输一次到交换机,然后在交换机上广播到所有其他设备。

参数:
  • source (jax.Array) – 要存储的值。

  • ref (_Ref) – 要存储值的 GMEM 引用。

  • collective_axes (Hashable | tuple[Hashable, ...]) – 指示要存储的设备的 JAX mesh 轴。