jax.numpy.linalg.norm#

jax.numpy.linalg.norm(x, ord=None, axis=None, keepdims=False)[源代码]#

计算矩阵或向量的范数。

JAX 对 numpy.linalg.norm() 的实现。

参数:
  • x (ArrayLike) – 将计算其范数的 N 维数组。

  • ord (int | str | None) – 指定要计算的范数类型。对于矩阵,默认为 Frobenius 范数;对于向量,默认为 2-范数。其他选项请参阅下方的“备注”。

  • axis (None | tuple[int, ...] | int) – 指定将计算范数的轴的整数或整数序列。对于单个轴,计算向量范数。对于两个轴,计算矩阵范数。默认为 x 的所有轴。

  • keepdims (bool) – 如果为 True,则输出数组将与输入数组具有相同的维度数,缩减轴的大小将被替换为 1(默认为 False)。

返回:

包含 x 指定范数的数组。

返回类型:

Array

注意事项

计算的范数类型取决于 ord 的值以及要缩减的轴的数量。

对于向量范数(即单个轴的缩减)

  • ord=None(默认)计算 2-范数

  • ord=inf 计算 max(abs(x))

  • ord=-inf 计算 min(abs(x))``

  • ord=0 计算 sum(x!=0)

  • 对于其他数值,计算 sum(abs(x) ** ord)**(1/ord)

对于矩阵范数(即两个轴的缩减)

  • ord='fro'ord=None(默认)计算 Frobenius 范数

  • ord='nuc' 计算核范数,即奇异值之和

  • ord=1 计算 max(abs(x).sum(0))

  • ord=-1 计算 min(abs(x).sum(0))

  • ord=2 计算 2-范数,即最大的奇异值

  • ord=-2 计算最小的奇异值

ord=Noneaxis=None 的特殊情况下,此函数接受任何维度的数组并计算展平数组的向量 2-范数。

示例

向量范数

>>> x = jnp.array([3., 4., 12.])
>>> jnp.linalg.norm(x)
Array(13., dtype=float32)
>>> jnp.linalg.norm(x, ord=1)
Array(19., dtype=float32)
>>> jnp.linalg.norm(x, ord=0)
Array(3., dtype=float32)

矩阵范数

>>> x = jnp.array([[1., 2., 3.],
...                [4., 5., 7.]])
>>> jnp.linalg.norm(x)  # Frobenius norm
Array(10.198039, dtype=float32)
>>> jnp.linalg.norm(x, ord='nuc')  # nuclear norm
Array(10.762535, dtype=float32)
>>> jnp.linalg.norm(x, ord=1)  # 1-norm
Array(10., dtype=float32)

批量向量范数

>>> jnp.linalg.norm(x, axis=1)
Array([3.7416575, 9.486833 ], dtype=float32)