jax.numpy.linalg.norm#
- jax.numpy.linalg.norm(x, ord=None, axis=None, keepdims=False)[source]#
计算矩阵或向量的范数。
JAX 实现的
numpy.linalg.norm()
。- 参数:
- 返回:
包含 x 指定范数的数组。
- 返回类型:
注释
计算的范数类型取决于
ord
的值和缩减轴的数量。对于向量范数 (即,单个轴缩减)
ord=None
(默认) 计算 2-范数ord=inf
计算max(abs(x))
ord=-inf
计算 min(abs(x))``ord=0
计算sum(x!=0)
对于其他数值,计算
sum(abs(x) ** ord)**(1/ord)
对于矩阵范数 (即,两个轴缩减)
ord='fro'
或ord=None
(默认) 计算 Frobenius 范数ord='nuc'
计算核范数,或奇异值之和ord=1
计算max(abs(x).sum(0))
ord=-1
计算min(abs(x).sum(0))
ord=2
计算 2-范数,即最大奇异值ord=-2
计算最小奇异值
在
ord=None
和axis=None
的特殊情况下,此函数接受任何维度的数组,并计算扁平化数组的向量 2-范数。示例
向量范数
>>> x = jnp.array([3., 4., 12.]) >>> jnp.linalg.norm(x) Array(13., dtype=float32) >>> jnp.linalg.norm(x, ord=1) Array(19., dtype=float32) >>> jnp.linalg.norm(x, ord=0) Array(3., dtype=float32)
矩阵范数
>>> x = jnp.array([[1., 2., 3.], ... [4., 5., 7.]]) >>> jnp.linalg.norm(x) # Frobenius norm Array(10.198039, dtype=float32) >>> jnp.linalg.norm(x, ord='nuc') # nuclear norm Array(10.762535, dtype=float32) >>> jnp.linalg.norm(x, ord=1) # 1-norm Array(10., dtype=float32)
批量向量范数
>>> jnp.linalg.norm(x, axis=1) Array([3.7416575, 9.486833 ], dtype=float32)