custom_vjp
和 nondiff_argnums
更新指南#
mattjj@ 2020年10月14日
本文档假定您熟悉 jax.custom_vjp
,如 JAX 可变换 Python 函数的自定义导数规则 笔记本中所述。
更新内容#
在 JAX PR #4008 之后,传递给 custom_vjp
函数的 nondiff_argnums
参数不能是 Tracer
(或 Tracer
的容器),这基本上意味着,为了支持任意可变换的代码,nondiff_argnums
不应用于数组值参数。相反,nondiff_argnums
仅应用于非数组值,例如 Python 可调用对象、形状元组或字符串。
过去我们使用 nondiff_argnums
处理数组值的地方,我们应该将它们作为常规参数传递。在 bwd
规则中,我们需要为它们生成值,但我们可以直接生成 None
值,以表明没有相应的梯度值。
例如,以下是编写 clip_gradient
的旧方式,当 hi
和/或 lo
是来自某个 JAX 变换的 Tracer
时,这种方式将无法工作。
from functools import partial
import jax
@partial(jax.custom_vjp, nondiff_argnums=(0, 1))
def clip_gradient(lo, hi, x):
return x # identity function
def clip_gradient_fwd(lo, hi, x):
return x, None # no residual values to save
def clip_gradient_bwd(lo, hi, _, g):
return (jnp.clip(g, lo, hi),)
clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
以下是支持任意变换的新方式,它非常出色
import jax
@jax.custom_vjp # no nondiff_argnums!
def clip_gradient(lo, hi, x):
return x # identity function
def clip_gradient_fwd(lo, hi, x):
return x, (lo, hi) # save lo and hi values as residuals
def clip_gradient_bwd(res, g):
lo, hi = res
return (None, None, jnp.clip(g, lo, hi)) # return None for lo and hi
clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
如果您使用旧方式而非新方式,在任何可能出错的情况下(即当 Tracer
被传递到 nondiff_argnums
参数中时),您都会收到一个响亮的错误提示。
这是一个我们确实需要将 nondiff_argnums
配合 custom_vjp
使用的例子
from functools import partial
import jax
@partial(jax.custom_vjp, nondiff_argnums=(0,))
def skip_app(f, x):
return f(x)
def skip_app_fwd(f, x):
return skip_app(f, x), None
def skip_app_bwd(f, _, g):
return (g,)
skip_app.defvjp(skip_app_fwd, skip_app_bwd)
说明#
将 Tracer
传递给 nondiff_argnums
参数一直存在错误。虽然有些情况能正常工作,但其他情况会导致复杂且令人困惑的错误消息。
这个 bug 的核心在于,nondiff_argnums
的实现方式非常类似于词法闭包。但在当时,对 Tracer
使用词法闭包并非旨在与 custom_jvp
/custom_vjp
协同工作。以这种方式实现 nondiff_argnums
是个错误!
PR #4008 修复了 custom_jvp
和 custom_vjp
的所有词法闭包问题。 太棒了!也就是说,现在 custom_jvp
和 custom_vjp
函数和规则可以随意地对 Tracer
进行闭包操作。对于所有非自动微分变换,一切都将正常工作。对于自动微分变换,我们将收到一个清晰的错误消息,说明为什么我们不能对 custom_jvp
或 custom_vjp
闭包的值进行微分。
检测到对 custom_jvp 函数相对于闭包值进行微分。这不支持,因为自定义 JVP 规则仅指定了如何对 custom_jvp 函数相对于显式输入参数进行微分。
尝试将闭包值作为参数传递给 custom_jvp 函数,并调整 custom_jvp 规则。
通过这种方式收紧并强化 custom_jvp
和 custom_vjp
时,我们发现,允许 custom_vjp
在其 nondiff_argnums
中接受 Tracer
将需要大量的簿记工作:我们需要重写用户的 fwd
函数以将值作为残差返回,并重写用户的 bwd
函数以将它们作为正常残差接受(而不是像 nondiff_argnums
那样将它们作为特殊的先行参数接受)。这看起来也许可以管理,直到你考虑到我们必须如何处理任意 pytrees!此外,这种复杂性没有必要:如果用户代码将类数组的不可微分参数像常规参数和残差一样处理,那么一切都已正常工作。(在 #4039 之前,JAX 可能会抱怨自动微分中涉及整数值输入和输出,但 #4039 之后,它们将正常工作!)
与 custom_vjp
不同,让 custom_jvp
处理作为 Tracer
的 nondiff_argnums
参数很容易。因此,这些更新只需针对 custom_vjp
进行。