jax.test_util.check_grads#
- jax.test_util.check_grads(f, args, order, modes=('fwd', 'rev'), atol=None, rtol=None, eps=None)[source]#
对照有限差分检查来自自动微分的梯度。
梯度仅在单个随机选择的方向上进行检查,这确保了即使对于大型输入/输出空间,有限差分计算也不会变得过于昂贵。
- 参数:
f – 要在
f(*args)
处检查的函数。args – 参数值的元组。
order – 检查到此阶的前向和后向梯度。
modes – 要检查的梯度模式列表(‘fwd’ 和/或 ‘rev’)。
atol – 梯度相等的绝对容差。
rtol – 梯度相等的相对容差。
eps – 用于有限差分的步长。
- Raises:
AssertionError – 如果梯度不匹配。