jax.numpy.linalg.solve#

jax.numpy.linalg.solve(a, b)[source]#

求解线性方程组。

numpy.linalg.solve() 的 JAX 实现。

这求解了给定 ab 的(批量)线性方程组 a @ x = bx

如果 a 是奇异的,这将返回 naninf 值。

参数:
  • a (ArrayLike) – 形状为 (..., N, N) 的数组。

  • b (ArrayLike) – 形状为 (N,) (对于 1 维右侧)或 (..., N, M) (对于批量 2 维右侧)的数组。

返回:

如果 a 是非奇异的,则返回包含线性求解结果的数组。如果 b 的形状为 (N,),则结果的形状为 (..., N),否则形状为 (..., N, M)。如果 a 是奇异的,则结果包含 naninf 值。

返回类型:

Array

另请参阅

示例

一个简单的 3x3 线性系统

>>> A = jnp.array([[1., 2., 3.],
...                [2., 4., 2.],
...                [3., 2., 1.]])
>>> b = jnp.array([14., 16., 10.])
>>> x = jnp.linalg.solve(A, b)
>>> x
Array([1., 2., 3.], dtype=float32)

确认结果求解了系统

>>> jnp.allclose(A @ x, b)
Array(True, dtype=bool)