jax.experimental.pallas.mosaic_gpu.tcgen05_mma#
- jax.experimental.pallas.mosaic_gpu.tcgen05_mma(acc, a, b, barrier=None, *, a_scale=None, b_scale=None, a_sparse_metadata=None, accumulate=True, collective_axis=None)[源代码]#
TensorCore gen 5(Blackwell)的异步矩阵乘累加。
如果在集体模式下运行,则
acc
、a
(LHS)和b
(RHS)应对应于 MMA 输入总数的一半,其中acc
和a
(LHS)沿行分割成两半,而b
(RHS)沿列分割,如下所示:----------- ----------- ----------- | ACC1 | | LHS1 | | | | ----------- += ----------- @ |RHS1|RHS2| | ACC2 | | LHS2 | | | | ----------- ----------- -----------
要使用块缩放矩阵乘法,请提供
a_scale
和b_scale
操作数(必须同时存在或同时未指定)。- 参数:
acc (_Ref) – 累加器。必须是 TMEM Ref。
a (_Ref) – 左侧操作数。必须是 TMEM/SMEM Ref。
b (_Ref) – 右侧操作数。必须是 SMEM Ref。
barrier (_Ref | None) – 用于与 TensorCore 同步的可选 barrier Ref。必须将 orders_tensor_core 设置为 True。如果未指定,则应通过调用
jax.experimental.pallas.mosaic_gpu.tcgen05_commit_arrive()
来显式观察 MMA 的完成。a_scale (_Ref | None) –
a
操作数的可选缩放。如果存在,则必须是 TMEM Ref。b_scale (_Ref | None) –
b
操作数的可选缩放。如果存在,则必须是 TMEM Ref。a_sparse_metadata (_Ref | None) –
a
操作数的可选稀疏元数据。如果存在,则必须是 TMEM Ref。collective_axis (str | None) – 要执行集体 MMA 的集群轴的名称。集群轴的大小应正好为 2,并且必须位于最次的集群轴上。