jax.nn.get_scaled_dot_general_config#

jax.nn.get_scaled_dot_general_config(mode, global_scale=None)[源代码]#

为 scaled_dot_general 获取量化配置。

jax.nn.scaled_dot_general 创建量化配置。

另请参阅

参数:
  • mode (Literal['nvfp4', 'mxfp8'])

  • global_scale (Array | None)