jax.custom_vjp.defvjp#
- custom_vjp.defvjp(fwd, bwd, symbolic_zeros=False, optimize_remat=False)[source]#
为此实例表示的函数定义自定义 VJP 规则。
- 参数:
fwd (Callable[..., tuple[ReturnValue, Any]]) – 一个 Python 可调用对象,表示自定义 VJP 规则的前向传递。当没有
nondiff_argnums
时,fwd
函数具有与底层原始函数相同的输入签名。它应该返回一个输出对,其中第一个元素表示原始输出,第二个元素表示要从前向传递存储的任何“残差”值,供函数bwd
在后向传递中使用。输入参数和输出对的元素可以是数组或嵌套的元组/列表/字典。bwd (Callable[..., tuple[Any, ...]]) – 一个 Python 可调用对象,表示自定义 VJP 规则的后向传递。当没有
nondiff_argnums
时,bwd
函数接受两个参数,其中第一个参数是由fwd
在前向传递中生成的“残差”值,第二个参数是与原始函数输出具有相同结构的输出余切。bwd
的输出必须是一个元组,其长度等于原始函数的参数数量,并且元组元素可以是数组或嵌套的元组/列表/字典,以便与原始输入参数的结构匹配。symbolic_zeros (bool) –
布尔值,确定是否向
fwd
和bwd
规则指示符号零。启用此选项允许自定义导数规则检测何时某些输入以及何时某些输出余切不参与微分。如果为True
fwd
必须接受,代替 pytree 中包含原始函数参数的每个叶值x
,一个具有两个属性的对象(类型为jax.custom_derivatives.CustomVJPPrimal
):value
和perturbed
。value
字段是原始原始参数,perturbed
是一个布尔值。perturbed
位指示参数是否参与微分(即,如果为False
,则对应的 Jacobian “列”为零)。bwd
将在其余切参数中传递表示静态符号零的对象,以对应于未扰动的值;否则,仅传递标准 JAX 类型(例如,类数组)。
将此选项设置为
True
允许这些规则检测某些输入和输出是否不参与微分,但代价是需要特殊处理。例如fwd
的签名会更改,并且传递给它的对象不能直接从规则输出。bwd
规则传递的对象并非完全是类数组的,并且不能传递给大多数jax.numpy
函数。原始函数参数中涉及的任何自定义 pytree 节点都必须在其解展平函数中接受作为输入叶子提供给
fwd
规则的双字段记录对象。
默认为
False
。optimize_remat (bool) – 布尔值,一个实验性标志,用于在此函数在
jax.remat()
下使用时启用自动优化。当fwd
规则是不透明调用(例如 Pallas 内核或自定义调用)时,这将最有用。默认为False
。
- 返回:
无。
- 返回类型:
None
示例
>>> @jax.custom_vjp ... def f(x, y): ... return jnp.sin(x) * y ... >>> def f_fwd(x, y): ... return f(x, y), (jnp.cos(x), jnp.sin(x), y) ... >>> def f_bwd(res, g): ... cos_x, sin_x, y = res ... return (cos_x * g * y, sin_x * g) ... >>> f.defvjp(f_fwd, f_bwd)
>>> x = jnp.float32(1.0) >>> y = jnp.float32(2.0) >>> with jnp.printoptions(precision=2): ... print(jax.value_and_grad(f)(x, y)) (Array(1.68, dtype=float32), Array(1.08, dtype=float32))