jax.scipy.linalg.qr#
- jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal['full', 'economic'], pivoting: Literal[False] = False, check_finite: bool = True) tuple[Array, Array] [源代码]#
- jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal['full', 'economic'], pivoting: Literal[True] = True, check_finite: bool = True) tuple[Array, Array, Array]
- jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal['full', 'economic'], pivoting: bool = False, check_finite: bool = True) tuple[Array, Array] | tuple[Array, Array, Array]
- jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal['r'], pivoting: Literal[False] = False, check_finite: bool = True) tuple[Array]
- jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal['r'], pivoting: Literal[True] = True, check_finite: bool = True) tuple[Array, Array]
- jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal['r'], pivoting: bool = False, check_finite: bool = True) tuple[Array] | tuple[Array, Array]
- jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = 'full', pivoting: bool = False, check_finite: bool = True) tuple[Array] | tuple[Array, Array] | tuple[Array, Array, Array]
计算数组的 QR 分解
scipy.linalg.qr()
的 JAX 实现。矩阵 A 的 QR 分解由下式给出
\[A = QR\]其中 Q 是酉矩阵(即 \(Q^HQ=I\)),R 是上三角矩阵。
- 参数:
a – 形状为 (…, M, N) 的数组
mode –
计算模式。支持的值有
"full"
(默认):返回形状为(M, M)
的 Q 和形状为(M, N)
的 R。"r"
:仅返回 R"economic"
:返回形状为(M, K)
的 Q 和形状为(K, N)
的 R,其中 K = min(M, N)。
pivoting – 允许 QR 分解揭示秩。如果
True
,则计算列主元分解A[:, P] = Q @ R
,其中选择P
使得R
的对角线非递增。overwrite_a – 在 JAX 中未使用
lwork – 在 JAX 中未使用
check_finite – 在 JAX 中未使用
- 返回:
一个元组
(Q, R)
或(Q, R, P)
,如果mode
不是"r"
且pivoting
分别为False
或True
,否则是一个数组R
或元组(R, P)
(如果 mode 为"r"
且pivoting
分别为False
或True
),其中Q
是形状为(..., M, M)
(如果mode
为"full"
)或(..., M, K)
(如果mode
为"economic"
)的正交矩阵,R
是形状为(..., M, N)
(如果mode
为"r"
或"full"
)或(..., K, N)
(如果mode
为"economic"
)的上三角矩阵,P
是形状为(..., N)
的索引向量。
其中
K = min(M, N)
。
注意事项
目前,主元仅在 CPU 和 GPU 后端上实现。有关 GPU 实现的更多详细信息,请参阅
jax.lax.linalg.qr()
的文档。
另请参阅
jax.numpy.linalg.qr()
: NumPy 风格的 QR 分解 APIjax.lax.linalg.qr()
: XLA 风格的 QR 分解 API
示例
计算矩阵的 QR 分解
>>> a = jnp.array([[1., 2., 3., 4.], ... [5., 4., 2., 1.], ... [6., 3., 1., 5.]]) >>> Q, R = jax.scipy.linalg.qr(a) >>> Q Array([[-0.12700021, -0.7581426 , -0.6396022 ], [-0.63500065, -0.43322435, 0.63960224], [-0.7620008 , 0.48737738, -0.42640156]], dtype=float32) >>> R Array([[-7.8740077, -5.080005 , -2.4130025, -4.953006 ], [ 0. , -1.7870499, -2.6534991, -1.028908 ], [ 0. , 0. , -1.0660033, -4.050814 ]], dtype=float32)
检查
Q
是否是正交的>>> jnp.allclose(Q.T @ Q, jnp.eye(3), atol=1E-5) Array(True, dtype=bool)
重建输入
>>> jnp.allclose(Q @ R, a) Array(True, dtype=bool)