jax.lax.custom_linear_solve#
- jax.lax.custom_linear_solve(matvec, b, solve, transpose_solve=None, symmetric=False, has_aux=False)[源代码]#
执行具有隐式定义梯度的无矩阵线性求解。
此函数允许通过在解处直接进行隐式微分来覆盖或定义线性求解的梯度,而不是通过对求解运算进行微分。这有时会更快或更数值稳定,或者甚至可能未实现通过求解运算进行微分(例如,如果
solve
使用lax.while_loop
)。必需的不变量
x = solve(matvec, b) # solve the linear equation assert matvec(x) == b # not checked
- 参数:
matvec (Callable) – 要反转的线性函数。必须是可微分的。
b (Any) – 方程的常数右侧句柄。可以是数组的任何嵌套结构。
solve (Callable[[Callable, Any], Any]) – 更高级别的函数,用于求解线性方程的解,即对于与
b
形式相同的所有x
,solve(matvec, x) == x
。此函数不需要是可微分的。transpose_solve (Callable[[Callable, Any], Any] | None) – 用于求解转置线性方程的更高级别的函数,即
transpose_solve(vecmat, x) == x
,其中vecmat
是线性映射matvec
的转置(使用自动微分自动计算)。反向模式自动微分是必需的,除非symmetric=True
,在这种情况下,solve
提供默认值。symmetric – 布尔值,指示是否可以安全地假定线性映射对应于对称矩阵,即
matvec == vecmat
。has_aux – 布尔值,指示
solve
和transpose_solve
函数是否将辅助数据(如求解器诊断)作为第二个参数返回。
- 返回:
solve(matvec, b)
的结果,梯度定义为假设解
x
满足线性方程matvec(x) == b
。