从 collective_axes 中所有设备的 GMEM 引用加载数据并对加载的值进行归约。
支持的数据类型为:jnp.float32
、jnp.float16
、jnp.bfloat16
、jnp.float8_e5m2
、jnp.float8_e4m3fn
、jnp.int32
和 jnp.int64
。
8 位浮点数据类型仅在 Blackwell GPU 上支持。
- 参数:
ref (_Ref) – 要从中加载数据的 GMEM 引用。
collective_axes (Hashable | tuple[Hashable, ...]) – 指示要从中加载数据的设备的 JAX mesh 轴。
reduction_op (mgpu.MultimemReductionOp) – 对加载的值执行的归约操作。允许的值包括 add(所有数据类型)、min、max(所有数据类型,但 f32 除外),以及 and、or 和 xor(仅限整数类型)。
- 返回类型:
jax.Array