jax.lax.custom_root#

jax.lax.custom_root(f, initial_guess, solve, tangent_solve, has_aux=False)[源代码]#

可微地求解函数的根。

这是一个底层例程,主要用于 JAX 内部使用。 custom_root() 的梯度是相对于提供的函数 f 中的封闭变量,通过隐函数定理定义的: https://en.wikipedia.org/wiki/Implicit_function_theorem

参数:
  • f (Callable) – 用于查找根的函数。 应接受单个参数,并返回一个数组树,其结构与其输入相同。

  • initial_guess (Any) – f 的零点的初始猜测。

  • solve (Callable[[Callable, Any], Any]) –

    用于求解 f 的根的函数。 应接受两个位置参数,f 和 initial_guess,并返回一个与 initial_guess 具有相同结构的解,使得 func(solution) = 0。 换句话说,假设以下为真(但未检查)

    solution = solve(f, initial_guess)
    error = f(solution)
    assert all(error == 0)
    

  • tangent_solve (Callable[[Callable, Any], Any]) –

    用于求解切线系统的函数。 应接受两个位置参数,一个线性函数 g (函数 f 在其根处线性化)和一个数组树 y,其结构与 initial_guess 相同,并返回一个解 x,使得 g(x)=y

    • 对于标量 y,使用 lambda g, y: y / g(1.0)

    • 对于向量 y,如果 y 的维度不太大,可以使用具有 Jacobian 的线性求解: lambda g, y: np.linalg.solve(jacobian(g)(y), y)

  • has_aux – 布尔值,指示 solve 函数是否将辅助数据(如求解器诊断)作为第二个参数返回。

返回:

调用 solve(f, initial_guess) 的结果,其梯度是通过隐式微分定义的,假设 f(solve(f, initial_guess)) == 0