jax.experimental.pallas.mosaic_gpu.wgmma_wait#

jax.experimental.pallas.mosaic_gpu.wgmma_wait(n)[源代码]#

等待直到飞行中的 WGMMA 操作不超过 n 个。

参数:

n (int)