基于 JAX 构建#
学习 JAX 高级用法的一个好方法是了解其他库如何使用 JAX,包括它们如何将 JAX 集成到其 API 中,它在数学上增加了哪些功能,以及它在其他库中如何用于计算加速。
以下是 JAX 功能如何用于在众多领域和软件包中定义加速计算的示例。
梯度计算#
简单的梯度计算是 JAX 的一个关键特性。在 JaxOpt 库中,值和梯度在 其源代码中的多个优化算法中直接供用户使用。
类似地,上述 Dynamax Optax 配对是梯度使历史上具有挑战性的估计方法成为可能的一个示例,例如 使用 Optax 进行最大似然期望。
跨多个设备的单核计算加速#
在 JAX 中定义的模型可以通过 JIT 编译来加速单次计算。相同的编译代码可以发送到 CPU 设备、GPU 或 TPU 设备以获得额外加速,通常无需进行额外更改。这使得从开发到生产的工作流程变得顺畅。在 Dynamax 中,线性状态空间模型求解器中计算开销大的部分已被 JIT 编译。一个更复杂的例子来自 PyTensor,它动态编译 JAX 函数,然后 对构建的函数进行 JIT 编译。
使用并行化实现单机和多机加速#
JAX 的另一个优点是使用 pmap
和 vmap
函数调用或装饰器进行并行计算的简便性。在 Dynamax 中,状态空间模型通过 VMAP 装饰器实现并行化,多目标跟踪是此用例的一个实际示例。
将 JAX 代码整合到您或您用户的工作流中#
JAX 具有高度可组合性,可以通过多种方式使用。JAX 可以以独立模式使用,即用户自行定义所有计算。然而,也有其他模式,例如使用基于 JAX 构建的、提供特定功能的库。这些库可以定义特定类型的模型,例如神经网络或状态空间模型等,或者提供特定功能,例如优化。以下是每种模式的更具体示例。
直接使用#
JAX 可以直接导入并用于“从零开始”构建模型,如本网站所示,例如在 JAX 教程或 使用 JAX 构建神经网络中。如果您无法为您的特定挑战找到预构建代码,或者如果您希望减少代码库中的依赖项数量,这可能是最佳选择。
可组合的特定领域库(JAX 可见)#
另一种常见方法是使用提供预构建功能的软件包,无论是模型定义还是某种类型的计算。这些软件包的组合可以混合搭配,以实现一个完整的端到端工作流,其中定义模型并估计其参数。
一个例子是 Flax,它简化了神经网络的构建。Flax 通常与 Optax 配对使用,其中 Flax 定义神经网络架构,而 Optax 提供优化和模型拟合功能。
另一个是 Dynamax,它允许轻松定义状态空间模型。使用 Dynamax,参数可以通过 使用 Optax 进行最大似然估计来估计,或者使用 Blackjax 的 MCMC 来估计完整的贝叶斯后验。