JEP 18137: JAX NumPy & SciPy 封装器的范围#

Jake VanderPlas

2023年10月

到目前为止,jax.numpyjax.scipy 的预期范围一直相对定义不清。本文提出为这些包设定明确定义的范围,以更好地指导和评估未来的贡献,并推动移除一些超出范围的代码。

背景#

从一开始,JAX 的目标就是提供一个类似 NumPy 的 API,用于在 XLA 中执行代码。该项目开发的一个重要部分是构建 jax.numpyjax.scipy 命名空间,作为 NumPy 和 SciPy API 的基于 JAX 的实现。一直以来,人们都隐含地认为 numpyscipy 的某些部分超出了 JAX 的范围,但这个范围尚未明确定义。这可能导致贡献者感到困惑和沮丧,因为对于潜在的 jax.numpyjax.scipy 贡献是否会被 JAX 接受,没有明确的答案。

为何限制范围?#

为了避免不言而喻,我们应该明确指出:一个像 JAX 这样的项目所包含的任何代码,都会给开发者带来虽小但非零的持续维护负担。一个项目的长期成功与维护者持续维护项目所有部分的能力直接相关:包括编写功能文档、回答问题、修复错误等。对于任何软件工具的长期成功和可持续性而言,维护者仔细权衡任何特定贡献是否能根据其目标和资源为项目带来净积极影响至关重要。

评估准则#

本文提出了一个包含六个维度的评估准则,可据此判断任何特定的 numpyscipy API 是否应纳入 JAX。如果一个 API 在所有维度上都很强大,那么它就是纳入 JAX 包的优秀候选;如果它在六个维度中任何一个存在明显弱点,则是有力论据反对将其纳入 JAX。

轴 1: XLA 对齐#

我们考虑的第一个维度是所提议的 API 与原生 XLA 操作的对齐程度。例如,jax.numpy.exp() 是一个几乎直接反映 jax.lax.exp 的函数。numpyscipy.specialnumpy.linalgscipy.linalg 以及其他许多函数都符合此标准:在考虑将其纳入 JAX 时,这些函数通过了 XLA 对齐检查。

另一方面,有些函数如 numpy.unique(),它们不直接对应任何 XLA 操作,在某些情况下,它们与 JAX 当前的计算模型根本不兼容,因为 JAX 需要静态形状数组(例如 unique 返回一个值依赖的动态数组形状)。在考虑将其纳入 JAX 时,这些函数未能通过 XLA 对齐检查。

我们还将纯函数语义的需求视为此轴的一部分。例如,numpy.random 构建在基于隐式更新状态的 RNG 上,这与 JAX 基于 XLA 构建的计算模型根本不兼容。

轴 2: Array API 对齐#

我们考虑的第二个维度侧重于 Python Array API 标准:这在某种意义上是一个社区驱动的概要,列出了在各种用户社区中哪些数组操作是面向数组编程的核心。如果在 numpyscipy 中的 API 被列入 Array API 标准,这是一个强烈信号,表明 JAX 应该包含它。以上述示例为例,Array API 标准包括 numpy.unique() 的几种变体(unique_allunique_countsunique_inverseunique_values),这表明,尽管该函数与 XLA 未能精确对齐,但它对 Python 用户社区足够重要,JAX 也许应该实现它。

轴 3: 下游实现的存在性#

对于与轴 1 或轴 2 不对齐的功能,纳入 JAX 的一个重要考量是是否存在提供该功能的、得到良好支持的下游包。一个很好的例子是 scipy.optimize:虽然 JAX 确实包含了一套最小的 scipy.optimize 功能封装器,但在 JAXopt 包中存在更完整的处理方式,该包由 JAX 协作者积极维护。在这种情况下,我们应该倾向于引导用户和贡献者使用这些专用包,而不是在 JAX 自身中重新实现这些 API。

轴 4: 实现的复杂性与健壮性#

对于与 XLA 不对齐的功能,一个考量是提议实现的复杂程度。这在某种程度上与轴 1 保持一致,但仍然很重要,值得指出。JAX 中有许多贡献的函数具有相对复杂的实现,这些实现难以验证,并带来巨大的维护负担;一个例子是 jax.scipy.special.bessel_jn():截至本 JEP 撰写之时,其当前实现是一种非直观的迭代近似,在某些领域存在收敛问题,并且提议的修复引入了进一步的复杂性。如果我们在接受贡献时更仔细地权衡了实现的复杂性和健壮性,我们可能就不会接受这项贡献。

轴 5: 函数式与面向对象 API#

JAX 最适合与函数式 API 配合使用,而不是面向对象 API。面向对象 API 通常可以隐藏不纯语义,使得它们难以良好实现。NumPy 和 SciPy 通常坚持使用函数式 API,但有时会提供面向对象的便捷封装器。

这方面的例子有 numpy.polynomial.Polynomial,它封装了 numpy.polyadd()numpy.polydiv() 等底层操作。一般来说,当函数式和面向对象 API 都可用时,JAX 应避免为面向对象 API 提供封装器,而应为函数式 API 提供封装器。

在仅存在面向对象 API 的情况下,除非在其他轴上存在强烈理由,JAX 均应避免提供封装器。

轴 6: 对 JAX 用户和利益相关者的普遍“重要性”#

将 NumPy/SciPy API 纳入 JAX 的决定,还应考虑该算法对广大用户社区的重要性。诚然,量化谁是“利益相关者”以及如何衡量这种重要性是困难的;但我们将其纳入其中,以明确指出,关于 JAX 的 NumPy 和 SciPy 封装器中应包含哪些内容的任何决定,都将涉及一定程度的酌情判断,而这种判断无法轻易量化。

对于现有 API,在 GitHub 上搜索使用情况可能有助于确定其重要性或缺乏重要性;例如,我们可以回到上面讨论的 jax.scipy.special.bessel_jn():搜索显示该函数在 GitHub 上只有少数用途,这可能部分与之前提到的准确性问题有关。

评估:哪些在范围内?#

在本节中,我们将根据上述评估准则,尝试评估 NumPy 和 SciPy API,包括当前 JAX API 的一些示例。这不会是所有现有函数和类的全面列表,而是按子模块和主题进行的更一般性讨论,并附有相关示例。

NumPy API#

numpy 命名空间#

我们认为主 numpy 命名空间中的函数基本都在 JAX 的范围内,因为它与 XLA(轴 1)和 Python Array API(轴 2)普遍对齐,并且对 JAX 用户社区普遍重要(轴 6)。有些函数可能处于临界状态(例如 numpy.intersect1d()np.setdiff1d()np.union1d() 等函数可以说未能通过部分准则),但为简单起见,我们声明主 numpy 命名空间中的所有数组函数都包含在 JAX 的范围内。

numpy.linalg & numpy.fft#

numpy.linalgnumpy.fft 子模块包含许多与 XLA 提供功能广泛对齐的函数。其他一些函数具有复杂的设备特定下层实现,但这代表了一个对利益相关者重要性(轴 6)超过复杂性的案例。因此,我们认为这两个子模块都在 JAX 的范围内。

numpy.random#

numpy.random 不在 JAX 的范围内,因为基于状态的 RNG 与 JAX 的计算模型根本不兼容。我们转而关注 jax.random,它使用基于计数器的 PRNG 提供类似的功能。

numpy.ma & numpy.polynomial#

numpy.manumpy.polynomial 子模块主要关注为可以通过其他函数式方式表达的计算提供面向对象的接口(轴 5);因此,我们认为它们超出了 JAX 的范围。

numpy.testing#

NumPy 的测试功能只对主机端计算有意义,因此我们不将任何相关封装器包含在 JAX 中。尽管如此,JAX 数组与 numpy.testing 兼容,JAX 在其整个测试套件中也频繁使用它。

SciPy API#

SciPy 在顶级命名空间中没有函数,但包含许多子模块。我们将在下面逐一考虑,不包括已弃用的模块。

scipy.cluster#

scipy.cluster 模块包括用于层次聚类、k-均值和相关算法的工具。这些工具在多个轴上表现不佳,更适合由下游包提供服务。JAX 中已经存在一个函数(jax.scipy.cluster.vq.vq()),但在 GitHub 上没有明显的使用情况:这表明聚类对 JAX 用户来说并不普遍重要。

建议:弃用并移除 jax.scipy.cluster.vq()

scipy.constants#

scipy.constants 模块包含数学和物理常数。这些常数可以直接与 JAX 配合使用,因此没有理由在 JAX 中重新实现它。

scipy.datasets#

scipy.datasets 模块包含用于获取和加载数据集的工具。这些获取的数据集可以直接与 JAX 配合使用,因此没有理由在 JAX 中重新实现它。

scipy.fft#

scipy.fft 模块包含的函数与 XLA 提供功能广泛对齐,并且在其他轴上也表现良好。因此,我们认为它们在 JAX 的范围内。

scipy.integrate#

scipy.integrate 模块包含用于数值积分的函数。其中更复杂的函数(quaddblquadode)根据轴 1 和 4 超出了 JAX 的范围,因为它们倾向于基于动态评估次数的循环算法。jax.experimental.ode.odeint() 与之相关,但相当有限,且未在积极开发中。

JAX 目前确实包含 jax.scipy.integrate.trapezoid(),但这仅仅是因为 numpy.trapz() 最近被弃用,转而支持此函数。对于任何特定输入,其实现可以用一行 jax.numpy 表达式替换,因此它不是一个特别有用的 API。

基于轴 1、2、4 和 6,scipy.integrate 应被视为超出了 JAX 的范围。

建议:移除 JAX 0.4.14 中添加的 jax.scipy.integrate.trapezoid()

scipy.interpolate#

scipy.interpolate 模块提供了一维或多维插值的低层级和面向对象例程。这些 API 在上述多个轴上评分不佳:它们是基于类的而不是低层级的,并且除了最简单的方法外,都无法有效地用 XLA 操作表达。

JAX 目前确实包含了 scipy.interpolate.RegularGridInterpolator 的封装器。如果今天我们考虑这项贡献,我们可能会根据上述标准拒绝它。但是这部分代码相当稳定,所以继续维护它并没有太大弊端。

展望未来,我们应将 scipy.interpolate 的其他成员视为超出 JAX 范围。

scipy.io#

scipy.io 子模块与文件输入/输出有关。没有理由在 JAX 中重新实现它。

scipy.linalg#

scipy.linalg 子模块包含的函数与 XLA 提供功能广泛对齐,并且快速线性代数对 JAX 用户社区普遍重要。因此,我们认为它在 JAX 的范围内。

scipy.ndimage#

scipy.ndimage 子模块包含一组用于处理图像数据的工具。其中许多与 scipy.signal 中的工具重叠(例如卷积和滤波)。JAX 目前提供一个 scipy.ndimage API,即 jax.scipy.ndimage.map_coordinates()。此外,JAX 在 jax.image 模块中提供了一些图像相关工具。DeepMind 生态系统包括 dm-pix,这是一套更全面的 JAX 图像处理工具。考虑到所有这些因素,我建议将 scipy.ndimage 视为超出 JAX 核心范围;我们可以将感兴趣的用户和贡献者引导至 dm-pix。我们可以考虑将 map_coordinates 移动到 dm-pix 或其他合适的包中。

scipy.odr#

scipy.odr 模块提供了围绕 ODRPACK 的面向对象封装器,用于执行正交距离回归。尚不清楚这是否能用现有 JAX 原语清晰表达,因此我们认为它超出了 JAX 自身的范围。

scipy.optimize#

scipy.optimize 模块提供了高级和低级优化接口。这种功能对许多 JAX 用户来说很重要,很早以前 JAX 就创建了 jax.scipy.optimize 封装器。然而,这些例程的开发者很快意识到 scipy.optimize API 限制性太强,于是不同的团队开始开发 JAXopt 包和 Optimistix 包,这两个包都包含了更全面、经过更好测试的 JAX 优化例程。

由于这些得到良好支持的外部包,我们现在认为 scipy.optimize 超出了 JAX 的范围。

建议:弃用 jax.scipy.optimize 和/或使其成为可选的 JAXopt 或 Optimistix 依赖项的轻量级封装器。

🟡 scipy.signal#

scipy.signal 模块是混杂的:有些函数完全在 JAX 的范围内(例如 correlateconvolve,它们是 lax.conv_general_dilated 的更用户友好型封装器),而另一些则完全超出范围(领域特定工具,没有可行的 XLA 下层路径)。对 jax.scipy.signal 的潜在贡献将需要逐案权衡。

🟡 scipy.sparse#

scipy.sparse 子模块主要包含用于以各种格式存储和操作稀疏矩阵和数组的数据结构。此外,scipy.sparse.linalg 包含许多无矩阵求解器,适用于稀疏矩阵、稠密矩阵和线性算子。

scipy.sparse 的数组和矩阵数据结构超出了 JAX 的范围,因为它们与 JAX 的计算模型不符(例如,许多操作依赖于动态大小的缓冲区)。JAX 已经开发了 jax.experimental.sparse 模块作为一组替代数据结构,它们更符合 JAX 的计算约束。鉴于这些原因,我们认为 scipy.sparse 中的数据结构超出了 JAX 的范围。

另一方面,scipy.sparse.linalg 已被证明是一个有趣的领域,并且 jax.scipy.sparse.linalg 包含了 bicgstabcggmres 求解器。这些对 JAX 用户社区很有用(轴 6),但除此之外,在其他轴上表现不佳。它们非常适合迁移到下游库;一个潜在的选择可能是 Lineax,它包含了许多基于 JAX 构建的线性求解器。

建议:探讨将稀疏求解器迁移到 Lineax,否则将 scipy.sparse 视为超出 JAX 范围。

scipy.spatial#

scipy.spatial 模块主要包含空间/距离计算和最近邻搜索的面向对象接口。它大部分超出了 JAX 的范围。

scipy.spatial.transform 子模块提供了操作三维空间旋转的工具。它是一个相对复杂的面向对象接口,也许更适合由下游项目提供服务。JAX 目前在 jax.scipy.spatial.transform 中包含了 RotationSlerp 的部分实现;这些是其他基本函数的面向对象封装器,它们引入了非常大的 API 表面,且用户很少。我们的判断是,它们超出了 JAX 自身的范围,用户更适合由假设的下游项目提供服务。

scipy.spatial.distance 子模块包含了一组有用的距离度量,提供 JAX 封装器可能很有诱惑力。但话虽如此,有了 jit 和 vmap,用户可以轻松地从头开始定义其中大多数的高效版本(如果需要),因此将其添加到 JAX 中并没有特别的好处。

建议:考虑弃用并移除 RotationSlerp API,并考虑将 scipy.spatial 作为一个整体视为未来贡献的范围之外。

scipy.special#

scipy.special 模块包含了许多更专业函数的实现。在许多情况下,这些函数完全在范围内:例如,gammalnbetaincdigamma 等函数直接对应于可用的 XLA 原语,并根据轴 1 和其他轴明确地在范围内。

其他函数需要更复杂的实现;上面提到的一个例子是 bessel_jn。尽管在轴 1 和 2 上不对齐,但这些函数在轴 6 上往往非常强大:scipy.special 提供了各种领域计算所需的基本函数,因此即使是实现复杂的函数,也应倾向于将其纳入范围,只要实现设计良好且健壮。

目前存在一些函数封装器,我们应该仔细审查一下;例如:

  • jax.scipy.special.lpmn():它通过复杂的 fori_loop 生成勒让德多项式,其方式与 scipy API 不匹配(例如,对于 scipyz 必须是标量,而对于 JAX,z 必须是 1D 数组)。该函数很少有可发现的用途,使其在轴 1、2、4 和 6 上都是一个弱候选。

  • jax.scipy.special.lpmn_values():这与上面的 lmpn 存在类似的弱点。

  • jax.scipy.special.sph_harm():它基于 lpmn 构建,同样其 API 与相应的 scipy 函数有所分歧。

  • jax.scipy.special.bessel_jn():如上文轴 4 所述,该函数在实现健壮性方面存在弱点,且使用率较低。我们或许可以考虑用新的、更健壮的实现来替代它(例如 #17038)。

建议:重构并提高 bessel_jn 的健壮性和测试覆盖率。如果 lpmnlpmn_valuessph_harm 无法修改以更接近 scipy API,则考虑弃用它们。

scipy.stats#

scipy.stats 模块包含广泛的统计函数,包括离散和连续分布、汇总统计以及假设检验。JAX 目前在 jax.scipy.stats 中封装了其中许多函数,主要包括约 20 种统计分布,以及一些其他函数(moderankdatagaussian_kde)。总的来说,这些函数与 JAX 兼容性良好:分布通常可以用高效的 XLA 操作表达,并且 API 清晰且函数化。

我们目前没有任何假设检验函数的封装器,这可能是因为它们对 JAX 的主要用户群来说用处较小。

关于分布,在某些情况下,tensorflow_probability 提供了类似的功能,未来我们可能会考虑是否弃用 scipy.stats 分布,转而支持该实现。

建议:展望未来,我们应将统计分布和汇总统计视为在范围内,并普遍将假设检验及相关功能视为超出范围。