jax.scipy.sparse.linalg.cg#
- jax.scipy.sparse.linalg.cg(A, b, x0=None, *, tol=1e-05, atol=0.0, maxiter=None, M=None)[源码]#
使用共轭梯度迭代法求解
Ax = b。JAX 的
cg的数值计算应该与 SciPy 的cg精确匹配(在数值精度范围内),但请注意,接口略有不同:您需要将线性算子A作为函数提供,而不是作为稀疏矩阵或LinearOperator。对
cg的导数是通过另一个cg求解进行隐式微分来实现的,而不是通过微分穿过求解器来实现的。只有当两个求解都收敛时,它们才会准确。- 参数:
A (ndarray, function, or matmul-compatible object) – 二维数组或函数,用于计算线性映射(矩阵向量乘积)
Ax,调用方式为A(x)或A @ x。A必须表示一个厄米正定矩阵,并且必须返回与参数具有相同结构和形状的数组。b (数组 或 数组树) – 表示单个向量的线性系统的右侧。 可以存储为数组或具有任何形状的数组的 Python 容器。
x0 (array or tree of arrays) – 解的起始猜测值。必须与
b具有相同的结构。tol (float, optional) – 收敛的容差,
norm(residual) <= max(tol*norm(b), atol)。我们不实现 SciPy 的“旧版”行为,因此 JAX 的容差将与 SciPy 不同,除非您在 SciPy 的cg中显式传递atol。atol (float, optional) – 收敛的容差,
norm(residual) <= max(tol*norm(b), atol)。我们不实现 SciPy 的“旧版”行为,因此 JAX 的容差将与 SciPy 不同,除非您在 SciPy 的cg中显式传递atol。maxiter (integer) – 最大迭代次数。即使未达到指定的容差,迭代也将会在 maxiter 步后停止。
M (ndarray, function, or matmul-compatible object) – A 的预条件子。预条件子应近似 A 的逆。有效的预条件化会显著提高收敛速度,这意味着需要更少的迭代才能达到给定的误差容差。
- 返回:
x (数组或数组树) – 收敛的解。 与
b具有相同的结构。info (None) – 收敛信息的占位符。 未来,JAX 将报告未实现收敛时的迭代次数,就像 SciPy 一样。
另请参阅
scipy.sparse.linalg.cg,jax.lax.custom_linear_solve