jax.experimental.pallas.mosaic_gpu.tcgen05_mma#

jax.experimental.pallas.mosaic_gpu.tcgen05_mma(acc, a, b, barrier=None, *, a_scale=None, b_scale=None, a_sparse_metadata=None, accumulate=True, collective_axis=None)[源代码]#

TensorCore gen 5(Blackwell)的异步矩阵乘累加。

如果在集体模式下运行,则 acca(LHS)和 b(RHS)应对应于 MMA 输入总数的一半,其中 acca(LHS)沿行分割成两半,而 b(RHS)沿列分割,如下所示:

-----------    -----------   -----------
|  ACC1   |    |  LHS1   |   |    |    |
----------- += ----------- @ |RHS1|RHS2|
|  ACC2   |    |  LHS2   |   |    |    |
-----------    -----------   -----------

要使用块缩放矩阵乘法,请提供 a_scaleb_scale 操作数(必须同时存在或同时未指定)。

参数:
  • acc (_Ref) – 累加器。必须是 TMEM Ref。

  • a (_Ref) – 左侧操作数。必须是 TMEM/SMEM Ref。

  • b (_Ref) – 右侧操作数。必须是 SMEM Ref。

  • barrier (_Ref | None) – 用于与 TensorCore 同步的可选 barrier Ref。必须将 orders_tensor_core 设置为 True。如果未指定,则应通过调用 jax.experimental.pallas.mosaic_gpu.tcgen05_commit_arrive() 来显式观察 MMA 的完成。

  • a_scale (_Ref | None) – a 操作数的可选缩放。如果存在,则必须是 TMEM Ref。

  • b_scale (_Ref | None) – b 操作数的可选缩放。如果存在,则必须是 TMEM Ref。

  • a_sparse_metadata (_Ref | None) – a 操作数的可选稀疏元数据。如果存在,则必须是 TMEM Ref。

  • accumulate (bool | jax.Array) – 是累加到 acc 还是覆盖它。

  • collective_axis (str | None) – 要执行集体 MMA 的集群轴的名称。集群轴的大小应正好为 2,并且必须位于最次的集群轴上。