jax.experimental.pallas.mosaic_gpu.set_max_registers#

jax.experimental.pallas.mosaic_gpu.set_max_registers(n, *, action)[源代码]#

设置一个 warp 所拥有的最大寄存器数量。

参数:
  • n (int)

  • action (Literal['increase', 'decrease'])