jax.value_and_grad#
- jax.value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[源代码]#
创建一个同时计算
fun和fun的梯度的函数。- 参数:
fun (Callable) – 要微分的函数。其在
argnums指定位置的参数应该是数组、标量或标准的 Python 容器。它应该返回一个标量(包括形状为()的数组,但不包括形状为(1,)等的数组)。argnums (int | Sequence[int]) – 可选,整数或整数序列。指定要微分的位置参数(默认为 0)。
has_aux (bool) – 可选,布尔值。指示
fun是否返回一个对,其中第一个元素被视为要微分的数学函数的输出,第二个元素是辅助数据。默认为 False。holomorphic (bool) – 可选,布尔值。指示
fun是否保证是全纯的。如果为 True,输入和输出必须是复数。默认为 False。allow_int (bool) – 可选,布尔值。是否允许对整数值输入进行微分。整数输入的梯度将具有微弱的向量空间数据类型(float0)。默认为 False。
reduce_axes (Sequence[AxisName])
- 返回:
一个与
fun具有相同参数的函数,该函数计算fun和fun的梯度,并将它们作为一对(一个二元组)返回。如果argnums是一个整数,则梯度与由该整数指定的位置参数具有相同的形状和类型。如果 argnums 是一个整数序列,则梯度是与相应参数具有相同形状和类型的值的元组。如果has_aux为 True,则返回一个元组 ((value, auxiliary_data), gradient)。- 返回类型:
Callable[…, tuple[Any, Any]]