jax.custom_vjp#
- 类 jax.custom_vjp(fun, nondiff_argnums=(), nondiff_argnames=())[源代码]#
设置一个可供 JAX 转换的函数,用于自定义 VJP 规则定义。
此类旨在用作函数装饰器。实例是可调用对象,其行为与应用装饰器的底层函数类似,但当应用反向模式微分变换(例如
jax.grad()
)时,将使用用户提供的自定义 VJP 规则函数,而不是跟踪并对底层函数的实现执行自动微分。有一个实例方法,defvjp()
,可用于定义自定义 VJP 规则。此装饰器排除了使用前向模式自动微分。
例如
@jax.custom_vjp def f(x, y): return jnp.sin(x) * y def f_fwd(x, y): return f(x, y), (jnp.cos(x), jnp.sin(x), y) def f_bwd(res, g): cos_x, sin_x, y = res return (cos_x * g * y, sin_x * g) f.defvjp(f_fwd, f_bwd)
有关更详细的介绍,请参阅教程。
方法
__init__
(fun[, nondiff_argnums, ...])defvjp
(fwd, bwd[, symbolic_zeros, ...])为此实例表示的函数定义自定义 VJP 规则。