JAX 直接线性化#

发生了什么?#

我们正在改变 JAX 内部实现自动微分的方式。 以前,grad 是通过一个三阶段过程完成的:JVP、部分评估、转置。 通过此更改,我们将前两个步骤(JVP 和部分评估)捆绑在一起,形成一个新的转换:线性化。

这通常不会改变用户可见的行为。 一些例外情况:

  • 如果在自动微分期间打印出跟踪值,您将看到 LinearizeTracer 而不是 JVPTracer。

  • 可能某些数值会发生变化,原因与程序中的任何扰动都可能稍微改变数值结果的原因相同。

为什么?#

该升级解锁了几个新功能,例如:

  • 涉及 Pallas 风格的可变数组引用的微分;

  • 更简单、更灵活的用户自定义自动微分规则,例如 custom_vjp/jvp;

  • 控制用户定义类型上的自动微分行为。

这个更改破坏了我的代码!#

目前,您仍然可以通过取消设置 use_direct_linearize 配置选项来获得旧的行为:

  • 将 shell 环境变量设置为 falsey 值,例如 JAX_USE_DIRECT_LINEARIZE=0

  • 设置配置选项 jax.config.update('jax_use_direct_linearize', False)

  • 如果您使用 absl 解析标志,则可以传递命令行标志 –jax_use_direct_linearize=false

我们计划在 2025 年 8 月 16 日删除此配置选项。