jax.test_util.check_vjp#
- jax.test_util.check_vjp(f, f_vjp, args, atol=None, rtol=None, eps=0.0001, err_msg='')[源代码]#
使用有限差分法检查自动微分的 VJP。
梯度仅在单个随机选择的方向上检查,这确保了即使对于大型输入/输出空间,有限差分计算也不会变得过于昂贵。
- 参数:
f – 要在
f(*args)处检查的函数。f_vjp – 计算应用于
f的jax.vjp的函数。通常这应该是functools.partial(jax.jvp, f))。args – 参数值的元组。
atol – 梯度相等的绝对容差。
rtol – 梯度相等的相对容差。
eps – 用于有限差分的步长。
err_msg – 如果检查失败,则包含的附加错误消息。
- 引发:
AssertionError – 如果梯度不匹配。