jax.experimental.pallas.mosaic_gpu.multimem_store# jax.experimental.pallas.mosaic_gpu.multimem_store(source, ref, collective_axes)[源]# 将值存储到 collective_axes 中所有设备上的 ref。 存储操作使用 multimem 指令完成,这意味着数据仅传输一次到交换机,然后在交换机上广播到所有其他设备。 参数: source (jax.Array) – 要存储的值。 ref (_Ref) – 要存储值的 GMEM 引用。 collective_axes (Hashable | tuple[Hashable, ...]) – 指示要存储的设备的 JAX mesh 轴。