jax.nn.scaled_dot_general#

jax.nn.scaled_dot_general(lhs, rhs, dimension_numbers, preferred_element_type=<class 'jax.numpy.float32'>, configs=None, implementation=None)[source]#

缩放点积通用操作。

在 lhs 和 rhs 输入上执行具有块缩放量化的通用点积操作。此操作扩展了 lax.dot_general 以支持用户定义的缩放配置。

本质上,该操作遵循:

a, a_scales = quantize(lhs, configs[0])
b, b_scales = quantize(rhs, configs[1])
c = jax.nn.scaled_matmul(a, b, a_scales, b_scales)
参数:
  • lhs (ArrayLike) – 输入数组。

  • rhs (ArrayLike) – 输入数组。

  • dimension_numbers (DotDimensionNumbers) – 一个元组,包含两个元组,用于指定收缩维度和批次维度:((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))

  • preferred_element_type (DTypeLike, 可选) – 点积的输出数据类型。默认为 jnp.float32。其他有效类型包括 jnp.bfloat16jnp.float16

  • configs (BlockScaleConfig列表, 可选) – lhs、rhs 和梯度的缩放配置。用户可以通过 jax.nn.get_scaled_dot_general_config 获取有效配置。目前,支持 nvfp4mxfp8。如果为 None,则回退到 lax.dot_general

  • implementation (Literal['cudnn'] | None | None) – str(已弃用)后端选择器,现在已被忽略。系统自动选择后端。计划在未来版本中移除。

返回:

结果张量,首先是批次维度,然后是 lhs 的非收缩/非批次维度,最后是 rhs 的非收缩/非批次维度。

返回类型:

Array

另请参阅

笔记

  • nn.scaled_matmul 不同,后者假定具有显式缩放因子的量化低精度输入,此运算符接受高精度输入,在内部应用量化,并处理反向传播。

示例

为 mxfp8 创建配置

>>> configs = [jax.nn.get_scaled_dot_general_config('mxfp8')] * 3

为 nvfp4 创建配置

>>> global_scale = jnp.array([0.5], jnp.float32)
>>> configs = [jax.nn.get_scaled_dot_general_config('nvfp4', global_scale)] * 3

将 scaled_dot_general 与配置一起使用

>>> import functools
>>> scaled_dot_general_fn = functools.partial(jax.nn.scaled_dot_general, configs=configs)
>>> lhs = jax.random.normal(jax.random.PRNGKey(1), (3, 128, 64))
>>> rhs = jax.random.normal(jax.random.PRNGKey(2), (3, 128, 64))
>>> out = scaled_dot_general_fn(lhs, rhs, (((2,), (2,)), ((0,), (0,))))