JEP 18137: JAX NumPy 和 SciPy 封装器的范围#

Jake VanderPlas

2023 年 10 月

到目前为止,jax.numpyjax.scipy 的预期范围相对而言定义不清。本文档为这些包提出了明确定义的范围,以便更好地指导和评估未来的贡献,并推动移除一些超出范围的代码。

背景#

从一开始,JAX 就旨在为在 XLA 中执行代码提供类似 NumPy 的 API,并且该项目开发的一大重心是构建 jax.numpyjax.scipy 命名空间,作为 NumPy 和 SciPy API 基于 JAX 的实现。一直以来都存在一种隐式的理解,即 numpyscipy 的某些部分超出了 JAX 的范围,但这个范围并没有明确定义。这可能会给贡献者带来困惑和挫败感,因为对于潜在的 jax.numpyjax.scipy 贡献是否会被 JAX 接受,并没有明确的答案。

为何限制范围?#

为了避免不言而喻,我们应该明确指出:事实是,任何包含在像 JAX 这样的项目中的代码都会给开发者带来虽小但非零的持续维护负担。一个项目随着时间的推移取得成功,直接关系到维护者为项目所有部分持续进行维护的能力:编写功能文档、回复问题、修复错误等等。为了任何软件工具的长期成功和可持续性,至关重要的是,维护者要仔细权衡任何特定的贡献是否会对项目的目标和资源产生净正面影响。

评估标准#

本文档提出了一个包含六个轴的标准,可以根据这些轴来判断任何特定的 numpyscipy API 是否应包含在 JAX 中。在所有轴上都很强的 API 是包含在 JAX 包中的绝佳候选;在六个轴中的任何一个轴上存在明显的弱点都是反对包含在 JAX 中的有力论据。

轴 1:XLA 对齐#

我们考虑的第一个轴是提议的 API 与原生 XLA 操作对齐的程度。例如,jax.numpy.exp() 是一个函数,它或多或少直接镜像了 jax.lax.expnumpyscipy.specialnumpy.linalgscipy.linalg 和其他模块中的大量函数都符合此标准:当考虑将这些函数包含到 JAX 中时,它们通过了 XLA 对齐检查。

另一方面,存在像 numpy.unique() 这样的函数,它们不直接对应于任何 XLA 操作,并且在某些情况下与 JAX 当前的计算模型(需要静态形状数组)从根本上不兼容(例如,unique 返回一个值相关的动态数组形状)。当考虑将这些函数包含到 JAX 中时,它们未通过 XLA 对齐检查。

作为此轴的一部分,我们还考虑了对纯函数语义的需求。例如,numpy.random 构建于隐式更新的基于状态的 RNG 之上,这与基于 XLA 构建的 JAX 计算模型从根本上不兼容。

轴 2:Array API 对齐#

我们考虑的第二个轴侧重于 Python Array API 标准:从某种意义上说,这是一个社区驱动的概述,其中列出了跨广泛用户社区的面向数组编程的核心数组操作。numpyscipy 中的 API 如果列在 Array API 标准中,则强烈表明 JAX 应该包含它。以上面的示例为例,Array API 标准包含 numpy.unique() 的几种变体(unique_allunique_countsunique_inverseunique_values),这表明,尽管该函数与 XLA 的精确对齐性不高,但它对于 Python 用户社区来说足够重要,JAX 或许应该实现它。

轴 3:下游实现的存在性#

对于不符合轴 1 或轴 2 的功能,包含到 JAX 中的一个重要考虑因素是,是否存在良好支持的下游软件包提供所讨论的功能。scipy.optimize 就是一个很好的例子:虽然 JAX 确实包含 scipy.optimize 功能的最小包装器集,但在 JAXopt 包中存在更完整的处理方案,该包由 JAX 协作者积极维护。在这种情况下,我们应该倾向于引导用户和贡献者使用这些专门的软件包,而不是在 JAX 本身中重新实现这些 API。

轴 4:实现的复杂性和鲁棒性#

对于不符合 XLA 的功能,一个考虑因素是提议实现的复杂程度。这在一定程度上与轴 1 对齐,但仍然需要强调出来。许多已贡献给 JAX 的函数具有相对复杂的实现,这些实现难以验证并带来了过大的维护负担;一个例子是 jax.scipy.special.bessel_jn():在撰写此 JEP 时,其当前的实现是一种非直接的迭代近似,在某些领域存在收敛问题,并且提议的修复引入了进一步的复杂性。如果我们在接受贡献时更仔细地权衡实现的复杂性和鲁棒性,我们可能会选择不接受此贡献到包中。

轴 5:函数式与面向对象 API#

JAX 最适合函数式 API,而不是面向对象 API。面向对象 API 通常会隐藏不纯粹的语义,这使得它们通常难以良好地实现。NumPy 和 SciPy 通常坚持使用函数式 API,但有时会提供面向对象的便捷包装器。

这方面的示例包括 numpy.polynomial.Polynomial,它包装了较低级别的操作,例如 numpy.polyadd()numpy.polydiv() 等。一般来说,当函数式和面向对象 API 都可用时,JAX 应避免为面向对象 API 提供包装器,而应为函数式 API 提供包装器。

在仅存在面向对象 API 的情况下,除非在其他轴上表现强势,否则 JAX 应避免提供包装器。

轴 6:JAX 用户和利益相关者的总体“重要性”#

在 JAX 中包含 NumPy/SciPy API 的决定也应考虑到该算法对一般用户社区的重要性。诚然,量化谁是“利益相关者”以及如何衡量这种重要性是困难的;但我们包含这一点是为了明确,关于在 JAX 的 NumPy 和 SciPy 封装器中包含哪些内容的任何决定都将涉及一定程度的酌情权,而这种酌情权无法轻易量化。

对于现有的 API,在 github 中搜索使用情况可能有助于确定其重要性或缺乏重要性;例如,我们可以回到上面讨论的 jax.scipy.special.bessel_jn():搜索显示该函数在 github 上只有少数几个用途,这可能部分与之前提到的准确性问题有关。

评估:哪些在范围内?#

在本节中,我们将尝试根据上述标准评估 NumPy 和 SciPy API,包括当前 JAX API 中的一些示例。这不会是所有现有函数和类的完整列表,而是在子模块和主题方面进行更一般的讨论,并附带相关示例。

NumPy API#

numpy 命名空间#

我们认为,由于其与 XLA(轴 1)和 Python Array API(轴 2)的总体对齐,以及其对 JAX 用户社区的总体重要性(轴 6),主 numpy 命名空间中的函数基本上都在 JAX 的范围内。某些函数可能处于边缘(像 numpy.intersect1d()np.setdiff1d()np.union1d() 这样的函数可以说在某些方面不符合标准),但为了简单起见,我们声明主 numpy 命名空间中的所有数组函数都在 JAX 的范围内。

numpy.linalgnumpy.fft#

numpy.linalgnumpy.fft 子模块包含许多与 XLA 提供的功能广泛对齐的函数。其他函数具有复杂的设备特定降低,但代表了利益相关者的重要性(轴 6)超过复杂性的情况。因此,我们认为这两个子模块都在 JAX 的范围内。

numpy.random#

numpy.random 超出了 JAX 的范围,因为基于状态的 RNG 与 JAX 的计算模型从根本上不兼容。我们转而关注 jax.random,它使用基于计数器的 PRNG 提供类似的功能。

numpy.manumpy.polynomial#

numpy.manumpy.polynomial 子模块主要关注为可以通过其他函数方式表达的计算提供面向对象的接口(轴 5);因此,我们认为它们超出了 JAX 的范围。

numpy.testing#

NumPy 的测试功能实际上只对主机端计算有意义,因此我们在 JAX 中不包含任何用于它的包装器。也就是说,JAX 数组与 numpy.testing 兼容,并且 JAX 在整个 JAX 测试套件中频繁使用它。

SciPy API#

SciPy 在顶级命名空间中没有函数,但包含许多子模块。我们在下面分别考虑每个模块,并省略已弃用的模块。

scipy.cluster#

scipy.cluster 模块包括用于层次聚类、k 均值和相关算法的工具。这些在多个轴上都很弱,最好由下游软件包提供服务。JAX 中已经存在一个函数 (jax.scipy.cluster.vq.vq()),但在 github 上没有明显的用法:这表明聚类对于 JAX 用户来说并不普遍重要。

建议:弃用并删除 jax.scipy.cluster.vq()

scipy.constants#

scipy.constants 模块包含数学和物理常数。这些常数可以直接与 JAX 一起使用,因此没有理由在 JAX 中重新实现它。

scipy.datasets#

scipy.datasets 模块包含用于获取和加载数据集的工具。这些获取的数据集可以直接与 JAX 一起使用,因此没有理由在 JAX 中重新实现它。

scipy.fft#

scipy.fft 模块包含的函数与 XLA 提供的功能广泛一致,并且在其他方面也表现良好。 因此,我们认为它们在 JAX 的范围内。

scipy.integrate#

scipy.integrate 模块包含用于数值积分的函数。 其中更复杂的一些函数(quaddblquadode)因轴 1 和 4 的原因不在 JAX 的范围内,因为它们往往是基于动态评估次数的循环算法。jax.experimental.ode.odeint() 与之相关,但相当有限,并且没有任何积极的开发。

JAX 目前确实包含 jax.scipy.integrate.trapezoid(),但这仅仅是因为最近 numpy.trapz() 已被弃用,转而支持它。 对于任何特定的输入,它的实现都可以用一行 jax.numpy 表达式替换,因此它不是一个特别有用的 API。

基于轴 1、2、4 和 6,scipy.integrate 应被视为超出 JAX 的范围。

建议:删除 JAX 0.4.14 版本中添加的 jax.scipy.integrate.trapezoid()

scipy.interpolate#

scipy.interpolate 模块为一维或多维插值提供了底层和面向对象的例程。 这些 API 在上述许多轴上的评分都很差:它们是基于类的而不是底层的,并且除了最简单的方法外,没有一种方法可以用 XLA 操作有效地表达。

JAX 目前确实有 scipy.interpolate.RegularGridInterpolator 的包装器。 如果我们今天考虑这个贡献,我们可能会根据上述标准拒绝它。 但是这段代码相当稳定,因此继续维护它并没有太大的缺点。

展望未来,我们应将 scipy.interpolate 的其他成员视为超出 JAX 的范围。

scipy.io#

scipy.io 子模块与文件输入/输出有关。 没有理由在 JAX 中重新实现它。

scipy.linalg#

scipy.linalg 子模块包含的函数与 XLA 提供的功能广泛一致,并且快速线性代数对于 JAX 用户社区非常重要。 因此,我们认为它在 JAX 的范围内。

scipy.ndimage#

scipy.ndimage 子模块包含一组用于处理图像数据的工具。 其中许多工具与 scipy.signal 中的工具重叠(例如,卷积和滤波)。 JAX 目前在 jax.scipy.ndimage.map_coordinates() 中提供了一个 scipy.ndimage API。 此外,JAX 在 jax.image 模块中提供了一些与图像相关的工具。 DeepMind 生态系统包括 dm-pix,这是一套更完善的 JAX 图像操作工具。 考虑到所有这些因素,我建议 scipy.ndimage 应被视为超出 JAX 核心的范围; 我们可以将感兴趣的用户和贡献者指向 dm-pix。 我们可以考虑将 map_coordinates 移至 dm-pix 或另一个合适的软件包。

scipy.odr#

scipy.odr 模块围绕 ODRPACK 提供了一个面向对象的包装器,用于执行正交距离回归。 目前尚不清楚这是否可以使用现有的 JAX 原语干净地表达,因此我们认为它超出了 JAX 本身的范围。

scipy.optimize#

scipy.optimize 模块为优化提供了高级和低级接口。 这种功能对于许多 JAX 用户来说非常重要,并且 JAX 很早就在 jax.scipy.optimize 中创建了包装器。 然而,这些例程的开发人员很快意识到 scipy.optimize API 过于受限,不同的团队开始开发 JAXopt 软件包和 Optimistix 软件包,每个软件包都包含一套更全面且经过更好测试的 JAX 优化例程。

由于这些得到良好支持的外部软件包,我们现在认为 scipy.optimize 超出了 JAX 的范围。

建议:弃用 jax.scipy.optimize 和/或使其成为 JAXopt 或 Optimistix 可选依赖项的轻量级包装器。

🟡 scipy.signal#

scipy.signal 模块是混合的:某些函数完全在 JAX 的范围内(例如 correlateconvolve,它们是 lax.conv_general_dilated 的更用户友好的包装器),而许多其他函数则完全超出范围(特定领域的工具,没有可行的降级到 XLA 的路径)。 对 jax.scipy.signal 的潜在贡献将必须根据具体情况进行权衡。

🟡 scipy.sparse#

scipy.sparse 子模块主要包含用于以各种格式存储和操作稀疏矩阵和数组的数据结构。 此外,scipy.sparse.linalg 包含许多无矩阵求解器,适用于稀疏矩阵、稠密矩阵和线性算子。

scipy.sparse 数组和矩阵数据结构超出了 JAX 的范围,因为它们与 JAX 的计算模型不一致(例如,许多操作依赖于动态大小的缓冲区)。 JAX 开发了 jax.experimental.sparse 模块,作为一组替代数据结构,这些数据结构更符合 JAX 的计算约束。 出于这些原因,我们认为 scipy.sparse 中的数据结构超出了 JAX 的范围。

另一方面,scipy.sparse.linalg 已被证明是一个有趣的领域,并且 jax.scipy.sparse.linalg 包括 bicgstabcggmres 求解器。 这些对 JAX 用户社区(轴 6)很有用,但除此之外,在其他轴上表现不佳。 它们非常适合移动到下游库中; 一个可能的选择可能是 Lineax,它具有许多基于 JAX 构建的线性求解器。

建议:探索将稀疏求解器移动到 Lineax 中,否则将 `scipy.sparse` 视为超出 JAX 的范围。

scipy.spatial#

scipy.spatial 模块主要包含用于空间/距离计算和最近邻搜索的面向对象的接口。 它在很大程度上超出了 JAX 的范围

scipy.spatial.transform 子模块提供了用于操作三维空间旋转的工具。 这是一个相对复杂的面向对象的接口,也许可以通过下游项目更好地服务。 JAX 目前在 jax.scipy.spatial.transform 中包含 RotationSlerp 的部分实现; 这些是基本函数的面向对象包装器,它们引入了非常大的 API 表面,并且用户非常少。 我们判断它们超出了 JAX 本身的范围,用户最好通过假设的下游项目获得服务。

scipy.spatial.distance 子模块包含有用的距离度量集合,并且可能很想为这些提供 JAX 包装器。 也就是说,使用 jit 和 vmap,用户可以轻松地从头开始定义大多数这些度量的高效版本(如果需要),因此将它们添加到 JAX 并不是特别有利。

建议:考虑弃用和删除 RotationSlerp API,并将 scipy.spatial 整体视为超出未来 JAX 贡献的范围。

scipy.special#

scipy.special 模块包括许多更专业函数的实现。 在许多情况下,这些函数完全在范围内:例如,gammalnbetaincdigamma 等许多函数直接对应于可用的 XLA 原语,并且根据轴 1 和其他轴,它们显然在范围内。

其他函数需要更复杂的实现; 上面提到的一个例子是 bessel_jn。 尽管与轴 1 和 2 不一致,但这些函数在轴 6 上往往非常强大:scipy.special 提供了在各种领域中进行计算所需的基本函数,因此即使是实现复杂的函数也应倾向于在范围内,只要实现设计良好且稳健。

我们应该仔细研究一些现有的函数包装器; 例如

  • jax.scipy.special.lpmn():这通过复杂的 fori_loop 生成勒让德多项式,其方式与 scipy API 不匹配(例如,对于 scipyz 必须是标量,而对于 JAX,z 必须是一维数组)。 该函数的可发现用途很少,使其成为轴 1、2、4 和 6 的薄弱候选者。

  • jax.scipy.special.lpmn_values():这与上面的 lmpn 有类似的缺点。

  • jax.scipy.special.sph_harm():这是基于 lpmn 构建的,并且类似地具有与相应的 scipy 函数不同的 API。

  • jax.scipy.special.bessel_jn():正如上面轴 4 中讨论的那样,这在实现鲁棒性方面存在弱点,并且使用率很低。 我们可能会考虑用新的、更稳健的实现来替换它(例如 #17038)。

建议:重构并提高 bessel_jn 的鲁棒性和测试覆盖率。 如果无法修改 lpmnlpmn_valuessph_harm 以更紧密地匹配 scipy API,请考虑弃用它们。

scipy.stats#

scipy.stats 模块包含范围广泛的统计函数,包括离散和连续分布、摘要统计和假设检验。 JAX 目前在 jax.scipy.stats 中包装了许多这些函数,主要包括大约 20 个统计分布,以及其他一些函数(moderankdatagaussian_kde)。 总的来说,这些与 JAX 非常一致:分布通常可以用高效的 XLA 操作来表达,并且 API 清晰且功能齐全。

我们目前没有任何假设检验函数的包装器,可能是因为这些函数对于 JAX 的主要用户群不太有用。

关于分布,在某些情况下,tensorflow_probability 提供了类似的功能,将来我们可能会考虑是否弃用 scipy.stats 分布以支持该实现。

建议:展望未来,我们应将统计分布和摘要统计视为在范围内,并将假设检验和相关功能通常视为超出范围。