jax.nn.get_scaled_dot_general_config#

jax.nn.get_scaled_dot_general_config(mode, global_scale=None)[source]#

获取scaled_dot_general的量化配置。

jax.nn.scaled_dot_general创建量化配置。

另请参阅

参数:
  • mode (Literal['nvfp4', 'mxfp8'])

  • global_scale (Array | None)