jax.numpy.linalg.lstsq#

jax.numpy.linalg.lstsq(a, b, rcond=None, *, numpy_resid=False)[源代码]#

返回线性方程的最小二乘解。

JAX 实现的 numpy.linalg.lstsq()

参数:
  • a (ArrayLike) – 形状为 (M, N) 的数组,表示系数矩阵。

  • b (ArrayLike) – 形状为 (M,)(M, K) 的数组,表示等式右侧。

  • rcond (float | None) – 小奇异值的截止比率。小于 rcond * largest_singular_value 的奇异值被视为零。如果为 None (默认),将使用最佳值来减少浮点错误。

  • numpy_resid (bool) – 如果为 True,则以与 NumPy 的 linalg.lstsq 相同的方式计算并返回残差。如果您想精确复制 NumPy 的行为,这是必要的。如果为 False (默认),则使用更有效的方法来计算残差。

返回:

数组的元组 (x, resid, rank, s) 其中

  • x 是包含最小二乘解的形状为 (N,)(N, K) 的数组。

  • resid 是形状为 ()(K,) 的平方残差之和。

  • rank 是矩阵 a 的秩。

  • s 是矩阵 a 的奇异值。

返回类型:

tuple[Array, Array, Array, Array]

示例

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> b = jnp.array([5, 6])
>>> x, _, _, _ = jnp.linalg.lstsq(a, b)
>>> with jnp.printoptions(precision=3):
...   print(x)
[-4.   4.5]