jax.nn.scaled_matmul#
- jax.nn.scaled_matmul(lhs, rhs, lhs_scales, rhs_scales, preferred_element_type=<class 'jax.numpy.float32'>)[来源]#
缩放矩阵乘法函数。
使用 a_scales 和 b_scales 执行 a 和 b 的块缩放矩阵乘法。最后一个维度是收缩维度,块大小是推断出来的。
从数学上讲,此操作等效于
a_block_size = a.shape[-1] // a_scales.shape[-1] b_block_size = b.shape[-1] // b_scales.shape[-1] a_scaled = a * jnp.repeat(a_scales, a_block_size, axis=-1) b_scaled = b * jnp.repeat(b_scales, b_block_size, axis=-1) jnp.einsum('BMK,BNK->BMN', a_scaled, b_scaled)
- 参数:
- 返回:
形状为 (B, M, N) 的 Array。
- 返回类型:
注释
我们目前不支持用户定义的 precision 来定制计算数据类型。它固定为 jnp.float32。
块大小被推断为 a 的 K // K_a 和 b 的 K // K_b。
要将 cuDNN 与 Nvidia Blackwell GPU 一起使用,输入必须匹配
# mxfp8 a, b: jnp.float8_e4m3fn | jnp.float8_e5m2 a_scales, b_scales: jnp.float8_e8m0fnu block_size: 32 # nvfp4 a, b: jnp.float4_e2m1fn a_scales, b_scales: jnp.float8_e4m3fn block_size: 16
示例
基本情况
>>> a = jnp.array([1, 2, 3]).reshape((1, 1, 3)) >>> b = jnp.array([4, 5, 6]).reshape((1, 1, 3)) >>> a_scales = jnp.array([0.5]).reshape((1, 1, 1)) >>> b_scales = jnp.array([0.5]).reshape((1, 1, 1)) >>> scaled_matmul(a, b, a_scales, b_scales) Array([[[8.]]], dtype=float32)
在 Blackwell GPU 上使用融合的 cuDNN 调用
>>> dtype = jnp.float8_e4m3fn >>> a = jax.random.normal(jax.random.PRNGKey(1), (3, 128, 64), dtype=dtype) >>> b = jax.random.normal(jax.random.PRNGKey(2), (3, 128, 64), dtype=dtype) >>> a_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu) >>> b_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu) >>> scaled_matmul(a, b, a_scales, b_scales)