jax.numpy.linalg.solve#
- jax.numpy.linalg.solve(a, b)[source]#
求解线性方程组。
numpy.linalg.solve()
的 JAX 实现。这求解了给定
a
和b
的(批量)线性方程组a @ x = b
的x
。如果
a
是奇异的,这将返回nan
或inf
值。- 参数:
a (ArrayLike) – 形状为
(..., N, N)
的数组。b (ArrayLike) – 形状为
(N,)
(对于 1 维右侧)或(..., N, M)
(对于批量 2 维右侧)的数组。
- 返回:
如果
a
是非奇异的,则返回包含线性求解结果的数组。如果b
的形状为(N,)
,则结果的形状为(..., N)
,否则形状为(..., N, M)
。如果a
是奇异的,则结果包含nan
或inf
值。- 返回类型:
另请参阅
jax.scipy.linalg.solve()
:用于求解线性系统的 SciPy 风格 API。jax.lax.custom_linear_solve()
:无矩阵线性求解器。
示例
一个简单的 3x3 线性系统
>>> A = jnp.array([[1., 2., 3.], ... [2., 4., 2.], ... [3., 2., 1.]]) >>> b = jnp.array([14., 16., 10.]) >>> x = jnp.linalg.solve(A, b) >>> x Array([1., 2., 3.], dtype=float32)
确认结果求解了系统
>>> jnp.allclose(A @ x, b) Array(True, dtype=bool)