jax.scipy.linalg.svd#
- jax.scipy.linalg.svd(a: ArrayLike, full_matrices: bool = True, compute_uv: Literal[True] = True, overwrite_a: bool = False, check_finite: bool = True, lapack_driver: str = 'gesdd') tuple[Array, Array, Array][源代码]#
- jax.scipy.linalg.svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[False], overwrite_a: bool = False, check_finite: bool = True, lapack_driver: str = 'gesdd') Array
- jax.scipy.linalg.svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[False], overwrite_a: bool = False, check_finite: bool = True, lapack_driver: str = 'gesdd') Array
- jax.scipy.linalg.svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, overwrite_a: bool = False, check_finite: bool = True, lapack_driver: str = 'gesdd') Array | tuple[Array, Array, Array]
计算奇异值分解。
JAX 对
scipy.linalg.svd()的实现。矩阵 A 的 SVD 给出为
\[A = U\Sigma V^H\]\(U\) 包含左奇异向量,并满足 \(U^HU=I\)
\(V\) 包含右奇异向量,并满足 \(V^HV=I\)
\(\Sigma\) 是奇异值的对角矩阵。
- 参数:
a – 输入数组,形状为
(..., N, M)full_matrices – 如果为 True (默认),则计算完整的矩阵;即
u和vh的形状分别为(..., N, N)和(..., M, M)。如果为 False,则形状为(..., N, K)和(..., K, M),其中K = min(N, M)。compute_uv – 如果为 True (默认),则返回完整的 SVD
(u, s, vh)。如果为 False,则仅返回奇异值s。overwrite_a – JAX 未使用
check_finite – JAX 未使用
lapack_driver – JAX 未使用。如果您想选择一个非默认的 SVD 驱动程序,请查看
jax.lax.linalg.svd(),它提供了此功能。
- 返回:
如果
compute_uv为 True,则返回一个数组元组(u, s, vh),否则返回数组s。u: 左奇异向量,如果full_matrices为 True,则形状为(..., N, N),否则为(..., N, K)。s: 奇异值,形状为(..., K)vh: 共轭转置的右奇异向量,如果full_matrices为 True,则形状为(..., M, M),否则为(..., K, M)。
其中
K = min(N, M)。
另请参阅
jax.numpy.linalg.svd(): NumPy 风格的 SVD APIjax.lax.linalg.svd(): XLA 风格的 SVD API
示例
考虑一个小实值数组的 SVD
>>> x = jnp.array([[1., 2., 3.], ... [6., 5., 4.]]) >>> u, s, vt = jax.scipy.linalg.svd(x, full_matrices=False) >>> s Array([9.361919 , 1.8315067], dtype=float32)
u和v = vt.T的列中包含奇异向量。这些向量是正交的,可以通过将矩阵乘积与单位矩阵进行比较来证明>>> jnp.allclose(u.T @ u, jnp.eye(2), atol=1E-5) Array(True, dtype=bool) >>> v = vt.T >>> jnp.allclose(v.T @ v, jnp.eye(2), atol=1E-5) Array(True, dtype=bool)
给定 SVD,
x可以通过矩阵乘法重构>>> x_reconstructed = u @ jnp.diag(s) @ vt >>> jnp.allclose(x_reconstructed, x) Array(True, dtype=bool)