jax.value_and_grad#

jax.value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[源]#

创建一个同时评估 fun 及其梯度的函数。

参数:
  • fun (Callable) – 待求导函数。其由 argnums 指定位置的参数应为数组、标量或标准 Python 容器。它应返回一个标量(包括形状为 () 的数组,但不包括形状为 (1,) 等的数组)。

  • argnums (int | Sequence[int]) – 可选,整数或整数序列。指定对哪些位置参数求导(默认值为 0)。

  • has_aux (bool) – 可选,布尔值。指示 fun 是否返回一个对,其中第一个元素被视为待求导数学函数的输出,第二个元素为辅助数据。默认为 False。

  • holomorphic (bool) – 可选,布尔值。指示 fun 是否保证为全纯函数。如果为 True,则输入和输出必须是复数。默认为 False。

  • allow_int (bool) – 可选,布尔值。是否允许对整数值输入求导。整数输入的梯度将具有微不足道的向量空间数据类型 (float0)。默认为 False。

  • reduce_axes (Sequence[AxisName])

返回:

一个与 fun 具有相同参数的函数,它同时评估 fun 及其梯度,并将它们作为一对(一个两元素元组)返回。如果 argnums 是一个整数,则梯度具有与该整数指示的位置参数相同的形状和类型。如果 argnums 是一个整数序列,则梯度是一个值元组,具有与相应参数相同的形状和类型。如果 has_aux 为 True,则返回一个 ((值, 辅助数据), 梯度) 的元组。

返回类型:

Callable[…, tuple[Any, Any]]