jax.linearize#
- jax.linearize(fun: Callable, *primals, has_aux: Literal[False] = False) tuple[Any, Callable] [source]#
- jax.linearize(fun: Callable, *primals, has_aux: Literal[True]) tuple[Any, Callable, Any]
使用
jvp()
和部分求值,对fun
进行线性近似。- 参数:
fun – 要微分的函数。其参数应为数组、标量,或包含数组或标量的标准 Python 容器。它应返回数组、标量或包含数组或标量的标准 Python 容器。
primals – 评估
fun
雅可比矩阵的原始值。应为数组、标量或其标准 Python 容器的元组。元组的长度等于fun
的位置参数数量。has_aux – 可选,布尔值。指示
fun
是否返回一个对,其中第一个元素被视为要线性化的数学函数的输出,第二个元素是辅助数据。默认为 False。
- 返回:
如果
has_aux
为False
,则返回一个对,其中第一个元素是f(*primals)
的值,第二个元素是一个函数,用于评估在primals
处计算的fun
的(前向模式)雅可比向量积,而无需重新进行线性化工作。如果has_aux
为True
,则返回一个(primals_out, lin_fn, aux)
元组,其中aux
是fun
返回的辅助数据。
就计算值而言,
linearize()
的行为很像柯里化的jvp()
,以下两个代码块计算相同的值y, out_tangent = jax.jvp(f, (x,), (in_tangent,)) y, f_jvp = jax.linearize(f, x) out_tangent = f_jvp(in_tangent)
然而,不同之处在于
linearize()
使用部分求值,因此函数f
在调用f_jvp
时不会被重新线性化。通常这意味着内存使用量随计算规模而变化,与反向模式非常相似。(事实上,linearize()
与vjp()
具有相似的签名!)此函数主要在你希望多次应用
f_jvp
时有用,即在同一线性化点为许多不同的输入切向量评估推前。此外,如果所有输入切向量都同时已知,则使用vmap()
进行向量化可能更高效,如下所示:pushfwd = partial(jvp, f, (x,)) y, out_tangents = vmap(pushfwd, out_axes=(None, 0))((in_tangents,))
通过这样同时使用
vmap()
和jvp()
,我们避免了随计算深度变化的存储线性化内存成本,这种成本是linearize()
和vjp()
都会产生的。以下是使用
linearize()
的更完整示例>>> import jax >>> import jax.numpy as jnp >>> >>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.) ... >>> jax.jvp(f, (2.,), (3.,)) (Array(3.2681944, dtype=float32, weak_type=True), Array(-5.007528, dtype=float32, weak_type=True)) >>> y, f_jvp = jax.linearize(f, 2.) >>> print(y) 3.2681944 >>> print(f_jvp(3.)) -5.007528 >>> print(f_jvp(4.)) -6.676704