jax.test_util 模块#

函数列表#

check_grads(f, args, order[, modes, atol, ...])

检查自动微分的梯度与有限差分是否一致。

check_jvp(f, f_jvp, args[, atol, rtol, eps, ...])

检查自动微分的 JVP 与有限差分是否一致。

check_vjp(f, f_vjp, args[, atol, rtol, eps, ...])

检查自动微分的 VJP 与有限差分是否一致。