jax.vjp#
- jax.vjp(fun: Callable[..., T], *primals: Any, has_aux: Literal[False] = False, reduce_axes: Sequence[AxisName] = ()) tuple[T, Callable] [源代码]#
- jax.vjp(fun: Callable[..., tuple[T, U]], *primals: Any, has_aux: Literal[True], reduce_axes: Sequence[AxisName] = ()) tuple[T, Callable, U]
计算函数
fun
的(反向模式)向量-雅可比积。- 参数:
fun – 待求导的函数。其参数应为数组、标量,或包含数组或标量的标准 Python 容器。它应返回一个数组、标量,或包含数组或标量的标准 Python 容器。
primals – 用于评估函数
fun
雅可比的原始值序列。primals
的数量应等于fun
的位置参数数量。每个原始值应为数组、标量或它们的 PyTree 结构(标准 Python 容器)。has_aux – 可选,布尔值。指示函数
fun
是否返回一个对,其中第一个元素被认为是待求导数学函数的输出,第二个元素是辅助数据。默认值为 `False`。
- 返回:
如果
has_aux
是False
,则返回一个(primals_out, vjpfun)
对,其中primals_out
是fun(*primals)
。如果has_aux
是True
,则返回一个(primals_out, vjpfun, aux)
元组,其中aux
是函数fun
返回的辅助数据。vjpfun
是一个函数,它将与primals_out
具有相同形状的余切向量映射为与primals
具有相同数量和形状的余切向量元组,表示在primals
处评估的函数fun
的向量-雅可比积。
>>> import jax >>> >>> def f(x, y): ... return jax.numpy.sin(x), jax.numpy.cos(y) ... >>> primals, f_vjp = jax.vjp(f, 0.5, 1.0) >>> xbar, ybar = f_vjp((-0.7, 0.3)) >>> print(xbar) -0.61430776 >>> print(ybar) -0.2524413