jax.test_util.check_vjp#

jax.test_util.check_vjp(f, f_vjp, args, atol=None, rtol=None, eps=0.0001, err_msg='')[source]#

对照有限差分检查来自自动微分的 VJP。

梯度仅在单个随机选择的方向上进行检查,这确保了即使对于大型输入/输出空间,有限差分计算也不会变得过于昂贵。

参数:
  • f – 要在 f(*args) 处检查的函数。

  • f_vjp – 计算应用于 fjax.vjp 的函数。通常这应该是 functools.partial(jax.jvp, f))

  • args – 参数值的元组。

  • atol – 梯度相等的绝对容差。

  • rtol – 梯度相等的相对容差。

  • eps – 用于有限差分的步长。

  • err_msg – 如果检查失败,则包含的其他错误消息。

引发:

AssertionError – 如果梯度不匹配。