jax.nn.get_scaled_dot_general_config#

jax.nn.get_scaled_dot_general_config(mode, global_scale=None)[源代码]#

获取 scaled_dot_general 的量化配置。

jax.nn.scaled_dot_general 创建量化配置。

另请参阅

参数:
  • mode (Literal['nvfp4', 'mxfp8'])

  • global_scale (Array | None | None)