jax.experimental.pallas.mosaic_gpu.multimem_load_reduce#

jax.experimental.pallas.mosaic_gpu.multimem_load_reduce(ref, *, collective_axes, reduction_op)[源代码]#

从 collective_axes 中所有设备的 GMEM 引用加载数据并对加载的值进行归约。

支持的数据类型为:jnp.float32jnp.float16jnp.bfloat16jnp.float8_e5m2jnp.float8_e4m3fnjnp.int32jnp.int64

8 位浮点数据类型仅在 Blackwell GPU 上支持。

参数:
  • ref (_Ref) – 要从中加载数据的 GMEM 引用。

  • collective_axes (Hashable | tuple[Hashable, ...]) – 指示要从中加载数据的设备的 JAX mesh 轴。

  • reduction_op (mgpu.MultimemReductionOp) – 对加载的值执行的归约操作。允许的值包括 add(所有数据类型)、min、max(所有数据类型,但 f32 除外),以及 and、or 和 xor(仅限整数类型)。

返回类型:

jax.Array