jax.nn.scaled_dot_general#
- jax.nn.scaled_dot_general(lhs, rhs, dimension_numbers, preferred_element_type=<class 'numpy.float32'>, configs=None, implementation=None)[源代码]#
缩放点积通用运算。
在 lhs 和 rhs 输入上执行具有块缩放量化的通用点积。 此操作扩展了 lax.dot_general 以支持用户定义的缩放配置。
本质上,该操作遵循
a, a_scales = quantize(lhs, configs[0]) b, b_scales = quantize(rhs, configs[1]) c = jax.nn.scaled_matmul(a, b, a_scales, b_scales)
- 参数:
lhs (ArrayLike) – 输入数组。
rhs (ArrayLike) – 输入数组。
dimension_numbers (DotDimensionNumbers) – 一个由两个元组组成的元组,指定收缩维度和批处理维度:((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))。
preferred_element_type (DTypeLike, 可选) – 点积的输出数据类型。 默认为 jnp.float32。 其他有效类型包括 jnp.bfloat16 和 jnp.float16。
configs (BlockScaleConfig 的 list,可选) – lhs、rhs 和梯度的缩放配置。 用户可以通过 jax.nn.get_scaled_dot_general_config 获取有效配置。 目前,支持 nvfp4 和 mxfp8。 如果为 None,则回退到 lax.dot_general。
implementation (Literal['cudnn'] | None) – str (已弃用) 后端选择器,现在被忽略。 系统会自动选择后端。 计划在未来版本中删除。
- 返回:
生成的张量,首先是批处理维度,然后是 lhs 的非收缩/非批处理维度,然后是 rhs 的非收缩/非批处理维度。
- 返回类型:
另请参阅
jax.nn.scaled_matmul()
: 缩放的 matmul 函数。jax.lax.dot_general()
: 通用点积运算符。
注意事项
与 nn.scaled_matmul 不同,后者假定量化的低精度输入具有显式缩放因子,此运算符采用高精度输入,在内部应用量化,并处理反向传递。
示例
为 mxfp8 创建配置
>>> configs = [jax.nn.get_scaled_dot_general_config('mxfp8')] * 3
为 nvfp4 创建配置
>>> global_scale = jnp.array([0.5], jnp.float32) >>> configs = [jax.nn.get_scaled_dot_general_config('nvfp4', global_scale)] * 3
使用带有配置的 scaled_dot_general
>>> import functools >>> scaled_dot_general_fn = functools.partial(jax.nn.scaled_dot_general, configs=configs) >>> lhs = jax.random.normal(jax.random.PRNGKey(1), (3, 128, 64)) >>> rhs = jax.random.normal(jax.random.PRNGKey(2), (3, 128, 64)) >>> out = scaled_dot_general_fn(lhs, rhs, (((2,), (2,)), ((0,), (0,))))