JAX 直接线性化#
发生了什么?#
我们正在改变 JAX 内部实现自动微分的方式。 以前,grad 是通过一个三阶段过程完成的:JVP、部分评估、转置。 通过此更改,我们将前两个步骤(JVP 和部分评估)捆绑在一起,形成一个新的转换:线性化。
这通常不会改变用户可见的行为。 一些例外情况:
如果在自动微分期间打印出跟踪值,您将看到 LinearizeTracer 而不是 JVPTracer。
可能某些数值会发生变化,原因与程序中的任何扰动都可能稍微改变数值结果的原因相同。
为什么?#
该升级解锁了几个新功能,例如:
涉及 Pallas 风格的可变数组引用的微分;
更简单、更灵活的用户自定义自动微分规则,例如 custom_vjp/jvp;
控制用户定义类型上的自动微分行为。
这个更改破坏了我的代码!#
目前,您仍然可以通过取消设置 use_direct_linearize 配置选项来获得旧的行为:
将 shell 环境变量设置为 falsey 值,例如 JAX_USE_DIRECT_LINEARIZE=0
设置配置选项 jax.config.update('jax_use_direct_linearize', False)
如果您使用 absl 解析标志,则可以传递命令行标志 –jax_use_direct_linearize=false
我们计划在 2025 年 8 月 16 日删除此配置选项。