jax.scipy.optimize.minimize#

jax.scipy.optimize.minimize(fun, x0, args=(), *, method, tol=None, options=None)[source]#

最小化一个或多个变量的标量函数。

此函数 API 与 SciPy 匹配,但存在一些细微差别。

  • 当需要时,使用 JAX 的自动微分支持自动计算 fun 的梯度。

  • 必须提供 method 参数。您必须指定一个求解器。

  • SciPy 接口中的各种可选参数尚未实现。

  • 由于线搜索实现上的差异,优化结果可能与 SciPy 不同。

minimize 支持 jit() 编译。它尚不支持微分或多维数组形式的参数,但计划同时支持这两者。

参数:
  • fun (Callable) – 要最小化的目标函数,fun(x, *args) -> float,其中 x 是一个形状为 (n,) 的一维数组,args 是完全指定函数所需的固定参数的元组。fun 必须支持微分。

  • x0 (Array) – 初始猜测值。大小为 (n,) 的实数元素数组,其中 n 是独立变量的数量。

  • args (tuple) – 传递给目标函数的额外参数。

  • method (str) – 求解器类型。目前仅支持 "BFGS"

  • tol (float | None) – 终止容差。有关详细控制,请使用特定于求解器的选项。

  • options (Mapping[str, Any] | None) –

    求解器选项字典。所有方法都接受以下通用选项:

    • maxiter (int):要执行的最大迭代次数。根据方法不同,每次迭代可能使用多次函数评估。

返回:

一个 OptimizeResults 对象。

返回类型:

OptimizeResults