jax.experimental.jet 模块#

Jet 是一个实验性的高阶自动微分模块,它不依赖于重复的一阶自动微分。

如何实现?通过截断泰勒多项式的传播。考虑一个函数 \(f = g \circ h\)、一个点 \(x\) 和一个偏移量 \(v\)。一阶自动微分(例如 jax.jvp())从 \((h(x), \partial h(x)[v])\) 对计算出 \((f(x), \partial f(x)[v])\) 对。

jet() 实现高阶的类似功能:给定元组

\[(h_0, ... h_K) := (h(x), \partial h(x)[v], \partial^2 h(x)[v, v], ..., \partial^K h(x)[v,...,v]),\]

它代表了 \(h\)\(x\) 点的 \(K\) 阶泰勒近似,jet() 返回 \(f\)\(x\) 点的 \(K\) 阶泰勒近似

\[(f_0, ..., f_K) := (f(x), \partial f(x)[v], \partial^2 f(x)[v, v], ..., \partial^K f(x)[v,...,v]).\]

更具体地说,jet() 计算

\[f_0, (f_1, . . . , f_K) = \texttt{jet} (f, h_0, (h_1, . . . , h_K))\]

因此可以用于 \(f\) 的高阶自动微分。详细信息请参阅 这些笔记

注意

通过贡献 待处理的原生规则,帮助改进 jet()

API#

jax.experimental.jet.jet(fun, primals, series, factorial_scaled=True, **_)[source]#

泰勒模式高阶自动微分。

参数:
  • fun – 要微分的函数。其参数应为数组、标量或数组/标量的标准 Python 容器。它应返回一个数组、标量或数组/标量的标准 Python 容器。

  • primals – 应在其中评估 fun 的泰勒近似的原始值。应为参数的元组或列表,并且其长度应等于 fun 的位置参数的数量。

  • series – 高阶泰勒级数系数。 primalsseries 一起构成一个截断的泰勒多项式。应为元组或元组/列表的列表,并且其长度决定截断泰勒多项式的次数。

  • factorial_scaled – 如果为 True,则输入和输出级数中的每个项都乘以其阶数的阶乘,以便输入和输出级数为泰勒级数。这是默认行为,因此输入和输出级数中的 n 阶项是函数的 n 阶导数。如果为 False,则输入和输出级数为非阶乘缩放的泰勒系数(即,泰勒级数中每项的常数系数)。

返回:

一个 (primals_out, series_out) 对,其中 primals_outfun(*primals),并且 primals_outseries_out 一起构成 \(f(h(\cdot))\) 的截断泰勒多项式。primals_out 值具有与 primals 相同的 Python 树结构,而 series_out 值具有与 series 相同的 Python 树结构。

例如

>>> import jax
>>> import jax.numpy as np

考虑函数 \(h(z) = z^3\)\(x = 0.5\) 以及前几个泰勒系数 \(h_0=x^3\)\(h_1=3x^2\)\(h_2=6x\)。设 \(f(y) = \sin(y)\)

>>> h0, h1, h2 = 0.5**3., 3.*0.5**2., 6.*0.5
>>> f, df, ddf = np.sin, np.cos, lambda *args: -np.sin(*args)

jet() 根据 Faà di Bruno 公式返回 \(f(h(z)) = \sin(z^3)\) 的泰勒系数

>>> f0, (f1, f2) =  jet(f, (h0,), ((h1, h2),))
>>> print(f0,  f(h0))
0.12467473 0.12467473
>>> print(f1, df(h0) * h1)
0.74414825 0.74414825
>>> print(f2, ddf(h0) * h1 ** 2 + df(h0) * h2)
2.9064636 2.9064634