Building on JAX#
学习高级 JAX 用法的一个好方法是了解其他库如何使用 JAX,包括它们如何将库集成到其 API 中、它在数学上添加了哪些功能,以及它如何在其他库中用于计算加速。
以下是 JAX 的功能如何用于定义跨多个领域和软件包的加速计算的示例。
梯度计算#
简易的梯度计算是 JAX 的一个关键特性。在 JaxOpt 库中,value 和 grad 直接用于 其源代码中多种优化算法的用户。
同样,上面提到的 Dynamax Optax 配对是梯度能够实现历史上具有挑战性的估计方法的一个例子:使用 Optax 的最大似然期望。
跨多个设备的单核计算加速#
在 JAX 中定义的模型可以被编译,通过 JIT 编译实现单次计算加速。然后,相同的编译代码可以发送到 CPU 设备、GPU 或 TPU 设备以实现额外的加速,通常无需额外的更改。这实现了从开发到生产的平滑 workflow。在 Dynamax 中,线性状态空间模型求解器中计算量大的部分已经进行了 jitted。一个更复杂的例子来自 PyTensor,它动态编译一个 JAX 函数,然后 jits 构造的函数。
使用并行化的单机和多机加速#
JAX 的另一个好处是使用 pmap
和 vmap
函数调用或装饰器并行化计算的简单性。在 Dynamax 中,状态空间模型使用 VMAP 装饰器进行并行化,一个实际用例是多对象跟踪。
将 JAX 代码整合到您或用户的 workflow 中#
JAX 非常可组合,可以以多种方式使用。JAX 可以与独立模式一起使用,用户在其中定义所有自己的计算。然而,其他模式,例如使用建立在 jax 之上的库来提供特定功能。这些可以是定义特定类型模型的库,例如神经网络或状态空间模型或其他模型,或提供特定功能的库,例如优化。以下是每种模式的更具体示例。
直接使用#
Jax 可以直接导入和使用,以“从头开始”构建模型,如本网站所示,例如在 JAX 教程或 使用 JAX 的神经网络中。如果您无法为您的特定挑战找到预构建的代码,或者如果您希望减少代码库中的依赖项数量,这可能是最佳选择。
具有暴露 JAX 的可组合领域特定库#
另一种常见方法是提供预构建功能的软件包,无论是模型定义还是某种类型的计算。这些软件包的组合可以混合和匹配,以实现完整的端到端 workflow,在其中定义模型并估计其参数。
一个例子是 Flax,它简化了神经网络的构建。然后,Flax 通常与 Optax 配对,其中 Flax 定义神经网络架构,而 Optax 提供优化和模型拟合功能。
另一个是 Dynamax,它允许轻松定义状态空间模型。使用 Dynamax,可以使用 使用 Optax 的最大似然法估计参数,或者使用 来自 Blackjax 的 MCMC 估计完整的贝叶斯后验。