JEP 18137: JAX NumPy 和 SciPy 包装器的范围#

Jake VanderPlas

2023 年 10 月

到目前为止,jax.numpyjax.scipy 的预期范围相对来说定义不明确。本文档为这些包提出了一个明确定义的范围,以便更好地指导和评估未来的贡献,并推动删除一些超出范围的代码。

背景#

从一开始,JAX 的目标就是为在 XLA 中执行代码提供类似 NumPy 的 API,并且该项目开发的一个重要部分是构建 jax.numpyjax.scipy 命名空间,作为基于 JAX 的 NumPy 和 SciPy API 的实现。一直以来,人们都隐含地理解到 numpyscipy 的某些部分超出了 JAX 的范围,但这个范围没有得到很好的定义。这可能会给贡献者带来困惑和沮丧,因为对于潜在的 jax.numpyjax.scipy 贡献是否会被 JAX 接受,没有明确的答案。

为什么要限制范围?#

为了避免对此保持沉默,我们应该明确指出:将任何代码包含到像 JAX 这样的项目中,都会给开发人员带来少量但非零的持续维护负担。一个项目随着时间的推移取得成功,直接关系到维护人员是否能够继续为项目的所有部分进行维护:编写功能文档、回复问题、修复错误等等。对于任何软件工具的长期成功和可持续性而言,维护人员仔细权衡任何特定贡献是否会根据其目标和资源为项目带来净收益至关重要。

评估标准#

本文档提出了六个维度的评估标准,用于判断任何特定的 numpyscipy API 是否应包含到 JAX 中。在所有维度上都很强的 API 是包含在 JAX 包中的绝佳候选者;在六个维度的任何一个上存在明显的弱点都是反对包含在 JAX 中的一个很好的理由。

维度 1:XLA 对齐#

我们考虑的第一个维度是提议的 API 与原生 XLA 操作的对齐程度。例如,jax.numpy.exp() 是一个或多或少直接镜像 jax.lax.exp 的函数。numpyscipy.specialnumpy.linalgscipy.linalg 和其他中的大量函数都满足此标准:在考虑将其包含到 JAX 中时,此类函数通过了 XLA 对齐检查。

另一方面,存在像 numpy.unique() 这样的函数,它们不直接对应于任何 XLA 操作,并且在某些情况下与 JAX 当前的计算模型(需要静态形状的数组)根本不兼容(例如,unique 返回一个值相关的动态数组形状)。在考虑将其包含到 JAX 中时,此类函数未通过 XLA 对齐检查。

我们还将纯函数语义的需求视为此维度的一部分。例如,numpy.random 构建在隐式更新的基于状态的 RNG 上,这与基于 XLA 构建的 JAX 计算模型根本不兼容。

维度 2:Array API 对齐#

我们考虑的第二个维度侧重于 Python Array API 标准:在某种意义上,这是一个由社区驱动的概述,其中列出了对广泛用户社区中的面向数组的编程至关重要的数组操作。如果 numpyscipy 中的 API 列在 Array API 标准中,则强烈表明 JAX 应将其包括在内。使用上面的示例,Array API 标准包括 numpy.unique() 的几种变体(unique_allunique_countsunique_inverseunique_values),这表明尽管该函数与 XLA 不完全一致,但它对 Python 用户社区来说足够重要,JAX 可能应该实现它。

维度 3:下游实现的存在#

对于不符合维度 1 或 2 的功能,是否支持良好的下游包提供所讨论的功能是包含到 JAX 中的一个重要考虑因素。一个很好的例子是 scipy.optimize:尽管 JAX 确实包含了一组 scipy.optimize 功能的最小包装器,但 JAXopt 包中存在更完整的处理方式,该包由 JAX 合作者积极维护。在这种情况下,我们应该倾向于引导用户和贡献者使用这些专门的包,而不是在 JAX 本身中重新实现这些 API。

维度 4:实现的复杂性和稳健性#

对于不符合 XLA 的功能,一个考虑因素是建议实现的复杂程度。这在某种程度上与维度 1 一致,但仍然值得一提。许多函数被贡献给了 JAX,它们的实现相对复杂,难以验证,并会引入过多的维护负担;一个例子是 jax.scipy.special.bessel_jn():在编写此 JEP 时,它当前的实现是一个不直接的迭代近似,在某些域中存在 收敛问题,并且提议的修复引入了进一步的复杂性。如果我们在接受贡献时更仔细地权衡实现的复杂性和稳健性,我们可能选择不接受此贡献到该包中。

维度 5:函数式与面向对象的 API#

JAX 最适合函数式 API,而不是面向对象的 API。面向对象的 API 通常会隐藏不纯粹的语义,使得它们通常难以很好地实现。NumPy 和 SciPy 通常坚持使用函数式 API,但有时会提供面向对象的便利包装器。

这方面的例子包括 numpy.polynomial.Polynomial,它封装了较低级别的操作,如 numpy.polyadd(), numpy.polydiv() 等。一般来说,当同时存在函数式和面向对象 API 时,JAX 应避免为面向对象 API 提供包装器,而应为函数式 API 提供包装器。

在仅存在面向对象 API 的情况下,除非在其他方面有很强的理由,否则 JAX 应避免提供包装器。

轴 6:对 JAX 用户和利益相关者的总体“重要性”#

在 JAX 中包含 NumPy/SciPy API 的决定还应考虑到该算法对一般用户社区的重要性。诚然,量化谁是“利益相关者”以及如何衡量这种重要性是困难的;但我们加入这一条是为了明确,关于在 JAX 的 NumPy 和 SciPy 包装器中包含什么内容的任何决定都将涉及一定程度的自由裁量权,而这无法轻易量化。

对于现有的 API,在 GitHub 上搜索使用情况可能有助于确定其重要性或缺乏重要性;例如,我们可以回到上面讨论的 jax.scipy.special.bessel_jn():搜索显示此函数在 GitHub 上只有少数几个用途,这可能部分与之前提到的精度问题有关。

评估:什么在范围内?#

在本节中,我们将尝试根据上述规则评估 NumPy 和 SciPy API,包括当前 JAX API 中的一些示例。这不会是所有现有函数和类的完整列表,而是按子模块和主题进行的更一般的讨论,并附有相关示例。

NumPy API#

numpy 命名空间#

我们认为主 numpy 命名空间中的函数基本上都在 JAX 的范围内,因为它们与 XLA(轴 1)和 Python 数组 API(轴 2)总体一致,并且对 JAX 用户社区也具有普遍重要性(轴 6)。某些函数可能处于边缘(例如 numpy.intersect1d(), np.setdiff1d(), np.union1d() 等函数,它们在某些方面不符合规则),但为了简单起见,我们声明主 numpy 命名空间中的所有数组函数都在 JAX 的范围内。

numpy.linalgnumpy.fft#

numpy.linalgnumpy.fft 子模块包含许多与 XLA 提供的功能广泛一致的函数。其他函数具有复杂的特定于设备的降级,但代表了利益相关者的重要性(轴 6)超过复杂性的情况。因此,我们认为这两个子模块都在 JAX 的范围内。

numpy.random#

numpy.random 不在 JAX 的范围内,因为基于状态的 RNG 从根本上与 JAX 的计算模型不兼容。我们转而关注 jax.random,它使用基于计数器的 PRNG 提供类似的功能。

numpy.manumpy.polynomial#

numpy.manumpy.polynomial 子模块主要关注为可以通过其他功能方式表达的计算提供面向对象的接口(轴 5);因此,我们认为它们不在 JAX 的范围内。

numpy.testing#

NumPy 的测试功能仅对主机端计算有意义,因此我们不在 JAX 中包含任何包装器。也就是说,JAX 数组与 numpy.testing 兼容,并且 JAX 在整个 JAX 测试套件中频繁使用它。

SciPy API#

SciPy 在顶层命名空间中没有函数,但包含许多子模块。我们在下面考虑每个子模块,忽略已弃用的模块。

scipy.cluster#

scipy.cluster 模块包含用于层次聚类、k 均值和相关算法的工具。这些工具在多个轴上都很弱,最好由下游软件包提供。JAX 中已经存在一个函数 (jax.scipy.cluster.vq.vq()) 但在 GitHub 上没有明显的用途:这表明聚类对 JAX 用户来说并不重要。

建议:弃用并删除 jax.scipy.cluster.vq()

scipy.constants#

scipy.constants 模块包含数学和物理常数。这些常数可以直接与 JAX 一起使用,因此没有理由在 JAX 中重新实现它。

scipy.datasets#

scipy.datasets 模块包含用于获取和加载数据集的工具。这些获取的数据集可以直接与 JAX 一起使用,因此没有理由在 JAX 中重新实现它。

scipy.fft#

scipy.fft 模块包含与 XLA 提供的功能广泛一致的函数,并且在其他方面也表现良好。因此,我们认为它们在 JAX 的范围内。

scipy.integrate#

scipy.integrate 模块包含用于数值积分的函数。其中更复杂的函数(quad, dblquad, ode)因轴 1 和 4 而不在 JAX 的范围内,因为它们往往是基于动态评估次数的循环算法。jax.experimental.ode.odeint() 是相关的,但相当有限且没有积极的开发。

JAX 目前确实包含 jax.scipy.integrate.trapezoid(),但这仅仅是因为最近 numpy.trapz() 被弃用而支持它。对于任何特定的输入,它的实现都可以用一行 jax.numpy 表达式替换,因此它不是一个特别有用的 API。

基于轴 1、2、4 和 6,scipy.integrate 应该被认为超出 JAX 的范围。

建议:移除在 JAX 0.4.14 中添加的 jax.scipy.integrate.trapezoid()

scipy.interpolate#

scipy.interpolate 模块提供用于在一维或多维中进行插值的低级和面向对象的例程。这些 API 在上述多个轴上的评分都很差:它们是基于类的而不是低级的,并且除了最简单的方法之外,没有其他方法可以有效地用 XLA 操作表示。

JAX 目前确实有 scipy.interpolate.RegularGridInterpolator 的包装器。如果我们今天考虑这个贡献,我们可能会根据上述标准拒绝它。但是这段代码相当稳定,因此继续维护它并没有太大的缺点。

展望未来,我们应该认为 scipy.interpolate 的其他成员超出 JAX 的范围。

scipy.io#

scipy.io 子模块与文件输入/输出有关。没有理由在 JAX 中重新实现它。

scipy.linalg#

scipy.linalg 子模块包含与 XLA 提供的功能广泛一致的函数,并且快速线性代数对 JAX 用户社区来说非常重要。因此,我们认为它在 JAX 的范围内。

scipy.ndimage#

scipy.ndimage 子模块包含一组用于处理图像数据的工具。其中许多与 scipy.signal 中的工具重叠(例如卷积和滤波)。 JAX 目前在 jax.scipy.ndimage.map_coordinates() 中提供了一个 scipy.ndimage API。此外,JAX 在 jax.image 模块中提供了一些与图像相关的工具。 deepmind 生态系统包括 dm-pix,这是一套更完善的用于在 JAX 中进行图像处理的工具。考虑到所有这些因素,我建议 scipy.ndimage 应该被认为超出 JAX 核心的范围;我们可以将感兴趣的用户和贡献者指向 dm-pix。我们可以考虑将 map_coordinates 移动到 dm-pix 或另一个合适的软件包。

scipy.odr#

scipy.odr 模块为执行正交距离回归提供了围绕 ODRPACK 的面向对象的包装器。目前尚不清楚是否可以使用现有的 JAX 原语干净地表达这一点,因此我们认为它超出 JAX 本身的范围。

scipy.optimize#

scipy.optimize 模块提供了用于优化的低级和高级接口。此类功能对许多 JAX 用户来说非常重要,并且 JAX 很早就创建了 jax.scipy.optimize 包装器。然而,这些例程的开发人员很快意识到 scipy.optimize API 的约束太多,不同的团队开始研究 JAXopt 包和 Optimistix 包,每个包都包含一套更全面且经过更好测试的 JAX 优化例程。

由于这些有良好支持的外部软件包,我们现在认为 scipy.optimize 超出 JAX 的范围。

建议:弃用 jax.scipy.optimize 和/或使其成为围绕可选的 JAXopt 或 Optimistix 依赖项的轻量级包装器。

🟡 scipy.signal#

scipy.signal 模块是混合的:一些函数完全在 JAX 的范围内(例如 correlateconvolve,它们是 lax.conv_general_dilated 的更用户友好的包装器),而其他许多函数完全超出范围(具有没有可行的降低到 XLA 路径的特定领域工具)。对 jax.scipy.signal 的潜在贡献必须逐个案例进行权衡。

🟡 scipy.sparse#

scipy.sparse 子模块主要包含用于存储和操作各种格式的稀疏矩阵和数组的数据结构。此外,scipy.sparse.linalg 包含许多无矩阵求解器,适用于稀疏矩阵、稠密矩阵和线性运算符。

scipy.sparse 数组和矩阵数据结构超出 JAX 的范围,因为它们与 JAX 的计算模型不一致(例如,许多操作依赖于动态大小的缓冲区)。 JAX 开发了 jax.experimental.sparse 模块作为一组与 JAX 的计算约束更一致的替代数据结构。由于这些原因,我们认为 scipy.sparse 中的数据结构超出 JAX 的范围。

另一方面,scipy.sparse.linalg 已被证明是一个有趣的领域,并且 jax.scipy.sparse.linalg 包括 bicgstabcggmres 求解器。这些对 JAX 用户社区很有用(轴 6),但除此之外,在其他轴上表现不佳。它们非常适合移动到下游库中;一个潜在的选择可能是 Lineax,它具有许多基于 JAX 构建的线性求解器。

建议:探索将稀疏求解器移动到 Lineax,否则将 `scipy.sparse` 视为超出 JAX 的范围。

scipy.spatial#