jax.numpy.linalg.norm#

jax.numpy.linalg.norm(x, ord=None, axis=None, keepdims=False)[source]#

计算矩阵或向量的范数。

JAX 实现的 numpy.linalg.norm()

参数:
  • x (ArrayLike) – 将要计算范数的 N 维数组。

  • ord (int | str | None) – 指定要采用的范数类型。默认为矩阵的 Frobenius 范数和向量的 2-范数。有关其他选项,请参见下面的注释。

  • axis (None | tuple[int, ...] | int) – 指定将要计算范数的轴的整数或整数序列。对于单个轴,计算向量范数。对于两个轴,计算矩阵范数。默认为 x 的所有轴。

  • keepdims (bool) – 如果为 True,则输出数组将具有与输入相同的维数,并且缩减轴的大小将替换为 1 (默认值: False)。

返回:

包含 x 指定范数的数组。

返回类型:

Array

注释

计算的范数类型取决于 ord 的值和缩减轴的数量。

对于向量范数 (即,单个轴缩减)

  • ord=None (默认) 计算 2-范数

  • ord=inf 计算 max(abs(x))

  • ord=-inf 计算 min(abs(x))``

  • ord=0 计算 sum(x!=0)

  • 对于其他数值,计算 sum(abs(x) ** ord)**(1/ord)

对于矩阵范数 (即,两个轴缩减)

  • ord='fro'ord=None (默认) 计算 Frobenius 范数

  • ord='nuc' 计算核范数,或奇异值之和

  • ord=1 计算 max(abs(x).sum(0))

  • ord=-1 计算 min(abs(x).sum(0))

  • ord=2 计算 2-范数,即最大奇异值

  • ord=-2 计算最小奇异值

ord=Noneaxis=None 的特殊情况下,此函数接受任何维度的数组,并计算扁平化数组的向量 2-范数。

示例

向量范数

>>> x = jnp.array([3., 4., 12.])
>>> jnp.linalg.norm(x)
Array(13., dtype=float32)
>>> jnp.linalg.norm(x, ord=1)
Array(19., dtype=float32)
>>> jnp.linalg.norm(x, ord=0)
Array(3., dtype=float32)

矩阵范数

>>> x = jnp.array([[1., 2., 3.],
...                [4., 5., 7.]])
>>> jnp.linalg.norm(x)  # Frobenius norm
Array(10.198039, dtype=float32)
>>> jnp.linalg.norm(x, ord='nuc')  # nuclear norm
Array(10.762535, dtype=float32)
>>> jnp.linalg.norm(x, ord=1)  # 1-norm
Array(10., dtype=float32)

批量向量范数

>>> jnp.linalg.norm(x, axis=1)
Array([3.7416575, 9.486833 ], dtype=float32)