jax.test_util 模块# 函数列表# check_grads(f, args, order[, modes, atol, ...]) 检查自动微分的梯度与有限差分是否一致。 check_jvp(f, f_jvp, args[, atol, rtol, eps, ...]) 检查自动微分的 JVP 与有限差分是否一致。 check_vjp(f, f_vjp, args[, atol, rtol, eps, ...]) 检查自动微分的 VJP 与有限差分是否一致。