jax.jacrev#
- jax.jacrev(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False)[source]#
使用反向模式 AD 逐行计算
fun
的 Jacobian。- 参数:
fun (Callable) – 要计算 Jacobian 的函数。
argnums (int | Sequence[int]) – 可选,整数或整数序列。指定要对其求导的位置参数(默认为
0
)。has_aux (bool) – 可选,布尔值。指示
fun
是否返回一对值,其中第一个元素被认为是数学函数的输出,第二个元素是辅助数据。默认为 False。holomorphic (bool) – 可选,布尔值。指示
fun
是否保证是全纯的。默认为 False。allow_int (bool) – 可选,布尔值。是否允许对整数值输入进行微分。整数输入的梯度将具有一个微不足道的向量空间 dtype (float0)。默认为 False。
- 返回:
一个与
fun
具有相同参数的函数,它使用反向模式自动微分来评估fun
的 Jacobian。如果has_aux
为 True,则返回一对 (jacobian, auxiliary_data)。- 返回类型:
Callable
>>> import jax >>> import jax.numpy as jnp >>> >>> def f(x): ... return jnp.asarray( ... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])]) ... >>> print(jax.jacrev(f)(jnp.array([1., 2., 3.]))) [[ 1. 0. 0. ] [ 0. 0. 5. ] [ 0. 16. -2. ] [ 1.6209 0. 0.84147]]