jax.linearize#
- jax.linearize(fun: Callable, *primals, has_aux: Literal[False] = False) tuple[Any, Callable] [source]#
- jax.linearize(fun: Callable, *primals, has_aux: Literal[True]) tuple[Any, Callable, Any]
使用
jvp()
和 partial eval 对fun
进行线性近似。- 参数:
fun – 要微分的函数。其参数应为数组、标量或数组或标量的标准 Python 容器。它应该返回数组、标量或数组或标量的标准 python 容器。
primals – 原始值,用于评估
fun
的 Jacobian 矩阵。应为数组、标量或其标准 Python 容器的元组。元组的长度等于fun
的位置参数的数量。has_aux – 可选,布尔值。指示
fun
是否返回一个对,其中第一个元素被认为是线性化的数学函数的输出,第二个元素是辅助数据。默认为 False。
- 返回:
如果
has_aux
为False
,则返回一个对,其中第一个元素是f(*primals)
的值,第二个元素是一个函数,该函数评估在primals
处评估的fun
的(前向模式) Jacobian 向量积,而无需重新进行线性化工作。如果has_aux
为True
,则返回一个(primals_out, lin_fn, aux)
元组,其中aux
是fun
返回的辅助数据。
就计算值而言,
linearize()
的行为很像柯里化的jvp()
,这两个代码块计算相同的值y, out_tangent = jax.jvp(f, (x,), (in_tangent,)) y, f_jvp = jax.linearize(f, x) out_tangent = f_jvp(in_tangent)
然而,不同之处在于
linearize()
使用了 partial evaluation,因此在调用f_jvp
时不会重新线性化函数f
。一般来说,这意味着内存使用量随计算规模而扩展,很像反向模式。(实际上,linearize()
具有与vjp()
类似的签名!)如果你想多次应用
f_jvp
,即在相同的线性化点评估许多不同输入切向量的前推,则此函数主要有用。此外,如果所有输入切向量都是一次性已知的,则使用vmap()
进行向量化可能更有效,例如pushfwd = partial(jvp, f, (x,)) y, out_tangents = vmap(pushfwd, out_axes=(None, 0))((in_tangents,))
通过像这样一起使用
vmap()
和jvp()
,我们避免了与计算深度成比例的存储线性化内存成本,linearize()
和vjp()
都会产生这种成本。这是一个更完整的
linearize()
用法示例>>> import jax >>> import jax.numpy as jnp >>> >>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.) ... >>> jax.jvp(f, (2.,), (3.,)) (Array(3.2681944, dtype=float32, weak_type=True), Array(-5.007528, dtype=float32, weak_type=True)) >>> y, f_jvp = jax.linearize(f, 2.) >>> print(y) 3.2681944 >>> print(f_jvp(3.)) -5.007528 >>> print(f_jvp(4.)) -6.676704