jax.nn.scaled_matmul#

jax.nn.scaled_matmul(lhs, rhs, lhs_scales, rhs_scales, preferred_element_type=<class 'jax.numpy.float32'>)[来源]#

缩放矩阵乘法函数。

使用 a_scalesb_scales 执行 ab 的块缩放矩阵乘法。最后一个维度是收缩维度,块大小是推断出来的。

从数学上讲,此操作等效于

a_block_size = a.shape[-1] // a_scales.shape[-1]
b_block_size = b.shape[-1] // b_scales.shape[-1]
a_scaled = a * jnp.repeat(a_scales, a_block_size, axis=-1)
b_scaled = b * jnp.repeat(b_scales, b_block_size, axis=-1)
jnp.einsum('BMK,BNK->BMN', a_scaled, b_scaled)
参数:
  • lhs (Array) – 操作数 a,形状 (B, M, K)。

  • rhs (Array) – 操作数 b,形状 (B, N, K)。

  • lhs_scales (Array) – 形状 (B, M, K_a),其中 K % K_a == 0

  • rhs_scales (Array) – 形状 (B, N, K_b),其中 K % K_b == 0

  • preferred_element_type (DTypeLike, 可选) – 默认为 jnp.float32

返回:

形状为 (B, M, N) 的 Array。

返回类型:

Array

注释

  • 我们目前不支持用户定义的 precision 来定制计算数据类型。它固定为 jnp.float32

  • 块大小被推断为 aK // K_abK // K_b

  • 要将 cuDNN 与 Nvidia Blackwell GPU 一起使用,输入必须匹配

    # mxfp8
    a, b: jnp.float8_e4m3fn | jnp.float8_e5m2
    a_scales, b_scales: jnp.float8_e8m0fnu
    block_size: 32
    # nvfp4
    a, b: jnp.float4_e2m1fn
    a_scales, b_scales: jnp.float8_e4m3fn
    block_size: 16
    

示例

基本情况

>>> a = jnp.array([1, 2, 3]).reshape((1, 1, 3))
>>> b = jnp.array([4, 5, 6]).reshape((1, 1, 3))
>>> a_scales = jnp.array([0.5]).reshape((1, 1, 1))
>>> b_scales = jnp.array([0.5]).reshape((1, 1, 1))
>>> scaled_matmul(a, b, a_scales, b_scales)  
Array([[[8.]]], dtype=float32)

在 Blackwell GPU 上使用融合的 cuDNN 调用

>>> dtype = jnp.float8_e4m3fn
>>> a = jax.random.normal(jax.random.PRNGKey(1), (3, 128, 64), dtype=dtype)
>>> b = jax.random.normal(jax.random.PRNGKey(2), (3, 128, 64), dtype=dtype)
>>> a_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu)
>>> b_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu)
>>> scaled_matmul(a, b, a_scales, b_scales)