基于 JAX 构建#

学习高级 JAX 用法的一个好方法是了解其他库如何使用 JAX,包括它们如何将库集成到它们的 API 中、它们在数学上增加了哪些功能,以及它们如何在其他库中用于计算加速。

以下是 JAX 的功能如何在众多领域和软件包中用于定义加速计算的示例。

梯度计算#

JaxOpt 库中,值和 grad 直接用于用户在其中多个优化算法的 源代码 中,这使得梯度计算变得容易。这是 JAX 的一个关键特性。

同样,上面提到的 Dynamax Optax 组合也是梯度支持估计方法的示例,这些方法在历史上一直具有挑战性 使用 Optax 进行最大似然期望

跨多个设备在单个核心上的计算加速#

用 JAX 定义的模型随后可以进行编译,通过 JIT 编译实现单次计算加速。然后可以将相同的编译代码发送到 CPU 设备、GPU 或 TPU 设备以获得额外的加速,通常不需要任何其他更改。这使得从开发到生产的工作流程非常顺畅。在 Dynamax 中,线性状态空间模型求解器的计算密集型部分已被 jitted。一个更复杂的例子来自 PyTensor,它动态编译一个 JAX 函数,然后 jits 构造的函数

使用并行化的单机和多机加速#

JAX 的另一个好处是可以使用 pmapvmap 函数调用或装饰器轻松地并行化计算。在 Dynamax 中,状态空间模型使用 VMAP 装饰器 进行并行化,该用例的一个实际示例是多目标跟踪。

将 JAX 代码集成到您或您的用户的工作流程中#

JAX 非常可组合,并且可以用多种方式使用。JAX 可以与独立模式一起使用,用户自己定义所有计算。然而,其他模式,例如使用基于 JAX 构建的库,这些库提供特定的功能。这些库可以是定义特定类型模型的库,例如神经网络或状态空间模型或其他模型,或者提供特定功能的库,例如优化。以下是每种模式的更具体的示例。

直接使用#

可以像本网站上的示例一样,直接导入和使用 JAX 来“从头开始”构建模型,例如在 JAX 101 教程使用 JAX 构建神经网络 中。如果您找不到针对特定挑战的预构建代码,或者您想减少代码库中的依赖项数量,这可能是最佳选择。

具有 JAX 暴露的可组合的领域特定库#

另一种常见方法是使用提供预构建功能的包,无论是模型定义还是某种类型的计算。然后可以将这些包的组合混合搭配,以实现端到端的完整工作流程,在该工作流程中定义模型并估计其参数。

一个例子是 Flax,它简化了神经网络的构建。Flax 通常与 Optax 配对,其中 Flax 定义了神经网络架构,Optax 提供了优化和模型拟合功能。

另一个例子是 Dynamax,它允许轻松定义状态空间模型。使用 Dynamax,可以使用 Optax 进行最大似然估计 来估计参数,或者使用 Blackjax 的 MCMC 来估计完整的贝叶斯后验。

JAX 对用户完全隐藏#

其他库选择完全在它们的模型特定 API 中包装 JAX。一个例子是 PyMC 和 Pytensor,其中用户可能永远不会直接“看到”JAX,而是使用 PyMC 特定 API 包装 JAX 函数