custom_vjp 和 nondiff_argnums 更新指南#
mattjj@ 2020 年 10 月 14 日
本文档假定您已熟悉 jax.custom_vjp,该内容在 JAX 可转换 Python 函数的自定义导数规则 notebook 中进行了描述。
需要更新的内容#
在 JAX PR #4008 之后,传递给 custom_vjp 函数的 nondiff_argnums 的参数不能是 Tracer(或 Tracer 的容器),这基本上意味着为了允许任意转换的代码,nondiff_argnums 不应该用于数组值参数。相反,nondiff_argnums 应该只用于非数组值,例如 Python 可调用对象、形状元组或字符串。
以前我们对数组值使用 nondiff_argnums 的地方,现在应该将它们作为常规参数传递。在 bwd 规则中,我们需要为它们生成值,但我们可以只生成 None 值,以表示没有相应的梯度值。
例如,这是编写 clip_gradient 的旧方法,当 hi 和/或 lo 是来自某些 JAX 转换的 Tracer 时,这种方法将不起作用。
from functools import partial
import jax
@partial(jax.custom_vjp, nondiff_argnums=(0, 1))
def clip_gradient(lo, hi, x):
return x # identity function
def clip_gradient_fwd(lo, hi, x):
return x, None # no residual values to save
def clip_gradient_bwd(lo, hi, _, g):
return (jnp.clip(g, lo, hi),)
clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
这是新的、出色的方法,它支持任意转换
import jax
@jax.custom_vjp # no nondiff_argnums!
def clip_gradient(lo, hi, x):
return x # identity function
def clip_gradient_fwd(lo, hi, x):
return x, (lo, hi) # save lo and hi values as residuals
def clip_gradient_bwd(res, g):
lo, hi = res
return (None, None, jnp.clip(g, lo, hi)) # return None for lo and hi
clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
如果您使用旧方法而不是新方法,在任何可能出错的情况下(即当 Tracer 被传递到 nondiff_argnums 参数中时),您将收到一个明确的错误。
这是一个我们需要使用 nondiff_argnums 与 custom_vjp 的情况
from functools import partial
import jax
@partial(jax.custom_vjp, nondiff_argnums=(0,))
def skip_app(f, x):
return f(x)
def skip_app_fwd(f, x):
return skip_app(f, x), None
def skip_app_bwd(f, _, g):
return (g,)
skip_app.defvjp(skip_app_fwd, skip_app_bwd)
说明#
将 Tracer 传递到 nondiff_argnums 参数中一直存在 bug。虽然有些情况可以正确工作,但其他情况会导致复杂且令人困惑的错误消息。
bug 的本质是 nondiff_argnums 的实现方式非常类似于词法闭包。但当时词法闭包对 Tracer 的支持并不是为了与 custom_jvp/custom_vjp 一起使用。以这种方式实现 nondiff_argnums 是一个错误!
PR #4008 修复了 custom_jvp 和 custom_vjp 的所有词法闭包问题。 太棒了!也就是说,现在 custom_jvp 和 custom_vjp 函数和规则可以随意闭包 Tracer。对于所有非自动微分转换,一切都会正常工作。对于自动微分转换,我们将收到一条明确的错误消息,说明为什么我们无法对 custom_jvp 或 custom_vjp 闭包的值进行微分。
检测到对 custom_jvp 函数相对于闭包值进行微分。这不受支持,因为 custom JVP 规则仅指定如何相对于显式输入参数微分 custom_jvp 函数。
尝试将闭包值作为参数传递给 custom_jvp 函数,并调整 custom_jvp 规则。
通过这种方式来完善和加固 custom_jvp 和 custom_vjp,我们发现允许 custom_vjp 接受其 nondiff_argnums 中的 Tracer 需要大量的簿记工作:我们需要重写用户的 fwd 函数以将值作为残差返回,并重写用户的 bwd 函数以接受它们作为正常残差(而不是像 nondiff_argnums 那样作为特殊的前导参数接受)。这似乎是可行的,直到您考虑到我们必须如何处理任意的 Pytrees!此外,这种复杂性并非必需:如果用户代码将类数组的非微分参数视为常规参数和残差,那么一切都会正常工作。(在 #4039 之前,JAX 可能会抱怨将整数值输入和输出涉及自动微分,但在 #4039 之后,这些将正常工作!)
与 custom_vjp 不同,让 custom_jvp 与是 Tracer 的 nondiff_argnums 参数一起工作很容易。因此,这些更新只需要在 custom_vjp 上进行。