jax.custom_jvp#

class jax.custom_jvp(fun, nondiff_argnums=(), nondiff_argnames=())[源代码]#

设置一个可由 JAX 转换的函数,用于自定义 JVP 规则定义。

此类旨在用作函数装饰器。实例是可调用对象,其行为类似于装饰器所应用的底层函数,除了应用微分变换(例如 jax.jvp()jax.grad())时,在后一种情况下,将使用用户提供的自定义 JVP 规则函数,而不是跟踪并对底层函数的实现执行自动微分。

有两种实例方法可用于定义自定义 JVP 规则:defjvp() 用于为函数的所有输入定义单个自定义 JVP 规则,以及为方便起见提供的 defjvps(),它封装了 defjvp(),并允许您为函数关于其每个参数的偏导数提供单独的定义。

例如

@jax.custom_jvp
def f(x, y):
  return jnp.sin(x) * y

@f.defjvp
def f_jvp(primals, tangents):
  x, y = primals
  x_dot, y_dot = tangents
  primal_out = f(x, y)
  tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
  return primal_out, tangent_out

有关更详细的介绍,请参阅教程

参数:
  • fun (Callable[..., ReturnValue])

  • nondiff_argnums (Sequence[int])

  • nondiff_argnames (Sequence[str])

__init__(fun, nondiff_argnums=(), nondiff_argnames=())[源代码]#
参数:
  • fun (Callable[..., ReturnValue])

  • nondiff_argnums (Sequence[int])

  • nondiff_argnames (Sequence[str])

方法

__init__(fun[, nondiff_argnums, ...])

defjvp(jvp[, symbolic_zeros])

为此实例表示的函数定义自定义 JVP 规则。

defjvps(*jvps)

方便地为每个参数分别定义 JVP 的封装器。

属性

jvp

symbolic_zeros

fun

nondiff_argnums

nondiff_argnames