jax.lax.stop_gradient#
- jax.lax.stop_gradient(x)[source]#
停止梯度计算。
在操作上,
stop_gradient
是恒等函数,也就是说,它返回参数 x 不变。但是,stop_gradient
阻止梯度在正向或反向模式自动微分期间流动。如果存在多个嵌套的梯度计算,stop_gradient
会停止所有这些计算的梯度。有关这在何处有用的更多讨论,请参阅 停止梯度。- 参数:
x (T) – 数组或数组的 pytree
- 返回:
输入值保持不变返回,但在自动微分中将被视为常量。
- 返回类型:
T
示例
考虑一个简单的函数,它返回输入值的平方
>>> def f1(x): ... return x ** 2 >>> x = jnp.float32(3.0) >>> f1(x) Array(9.0, dtype=float32) >>> jax.grad(f1)(x) Array(6.0, dtype=float32)
在
x
周围使用stop_gradient
的相同函数在正常求值下是等效的,但会返回零梯度,因为x
被有效视为常量>>> def f2(x): ... return jax.lax.stop_gradient(x) ** 2 >>> f2(x) Array(9.0, dtype=float32) >>> jax.grad(f2)(x) Array(0.0, dtype=float32)
这在 JAX 代码库中的许多地方使用;例如
jax.nn.softmax()
在内部通过其最大值对输入进行归一化,并且为了效率,此最大值被包装在stop_gradient
中。有关stop_gradient
适用性的更多讨论,请参阅 停止梯度。