jax.scipy.linalg.solve#

jax.scipy.linalg.solve(a, b, lower=False, overwrite_a=False, overwrite_b=False, debug=False, check_finite=True, assume_a='gen')[source]#

求解线性方程组。

JAX 实现的 scipy.linalg.solve()

这求解 (批量) 线性方程组 a @ x = b,对于给定的 ab 求解 x

如果 a 是奇异的,这将返回 naninf 值。

参数:
  • a (ArrayLike) – 形状为 (..., N, N) 的数组。

  • b (ArrayLike) – 形状为 (..., N)(..., N, M) 的数组

  • lower (bool) – 仅当 assume_a != 'gen' 时引用。如果为 True,则仅使用输入的下三角部分,如果为 False(默认),则仅使用上三角部分。

  • assume_a (str) –

    指定可以假定 a 的哪些属性。选项有

    • "gen":通用矩阵(默认)

    • "sym":对称矩阵

    • "her":埃尔米特矩阵

    • "pos":正定矩阵

  • overwrite_a (bool) – JAX 未使用

  • overwrite_b (bool) – JAX 未使用

  • debug (bool) – JAX 未使用

  • check_finite (bool) – JAX 未使用

返回:

b 形状相同的数组,如果 a 非奇异,则包含线性系统的解。如果 a 是奇异的,则结果包含 naninf 值。

返回类型:

Array

另请参阅

示例

一个简单的 3x3 线性系统

>>> A = jnp.array([[1., 2., 3.],
...                [2., 4., 2.],
...                [3., 2., 1.]])
>>> b = jnp.array([14., 16., 10.])
>>> x = jax.scipy.linalg.solve(A, b)
>>> x
Array([1., 2., 3.], dtype=float32)

确认结果求解了系统

>>> jnp.allclose(A @ x, b)
Array(True, dtype=bool)