jax.grad#
- jax.grad(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[source]#
创建一个用于计算
fun梯度的函数。- 参数:
fun (Callable) – 待求导函数。其位于
argnums指定位置的参数应为数组、标量或标准 Python 容器。位于argnums指定位置的参数数组必须为非精确类型(即浮点数或复数类型)。它应返回一个标量(包括形状为()的数组,但不包括形状为(1,)等的数组)。argnums (int | Sequence[int]) – 可选参数,整数或整数序列。指定对哪些位置参数求导(默认为 0)。
has_aux (bool) – 可选,布尔值。指示
fun是否返回一个对,其中第一个元素被视为要微分的数学函数的输出,第二个元素是辅助数据。默认为 False。holomorphic (bool) – 可选参数,布尔值。指示是否保证
fun是全纯函数(holomorphic)。如果为 True,则输入和输出必须为复数。默认为 False。allow_int (bool) – 可选,布尔值。是否允许对整数值输入进行微分。整数输入的梯度将具有微弱的向量空间数据类型(float0)。默认为 False。
reduce_axes (Sequence[AxisName])
- 返回:
返回一个与
fun具有相同参数的函数,用于计算fun的梯度。如果argnums是一个整数,则梯度具有与该整数所指位置参数相同的形状和类型。如果argnums是一个整数元组,则梯度是一个值元组,其形状和类型与对应的参数相同。如果has_aux为 True,则返回 (梯度, 辅助数据) 对。- 返回类型:
例如
>>> import jax >>> >>> grad_tanh = jax.grad(jax.numpy.tanh) >>> print(grad_tanh(0.2)) 0.961043