JEP 18137:JAX NumPy 和 SciPy 包装器的范围#

Jake VanderPlas

2023 年 10 月

到目前为止,jax.numpyjax.scipy 的预期范围一直比较模糊。本文档为这些包提出了一个明确的范围,以更好地指导和评估未来的贡献,并促使移除一些超出范围的代码。

背景#

从一开始,JAX 就致力于提供一个类似 NumPy 的 API 来执行 XLA 代码,而该项目开发的重要组成部分就是构建 jax.numpyjax.scipy 命名空间,作为 JAX 版的 NumPy 和 SciPy API 实现。一直以来,人们都隐约明白 numpyscipy 的某些部分不在 JAX 的范围之内,但这个范围从未被明确定义。这可能会导致贡献者感到困惑和沮丧,因为他们无法清楚地知道潜在的 jax.numpyjax.scipy 贡献是否会被 JAX 接受。

为何限制范围?#

为了避免留下未说明之处,我们应该明确:像 JAX 这样的项目中所包含的任何代码,都会给开发人员带来少量但非零的持续维护负担。一个项目能否长期成功,直接取决于维护者能否继续为项目的所有组成部分承担这些维护工作:记录功能、回答问题、修复错误等。为了任何软件工具的长期成功和可持续性,维护者必须仔细权衡特定贡献是否会根据项目的目标和资源,为项目带来净收益。

评估标准#

本文档提出了一个包含六个轴的标准,用于评估任何特定的 numpyscipy API 是否适合纳入 JAX。一个在所有轴上都表现出色的 API 是纳入 JAX 包的绝佳候选;而在**任何**一个轴上存在明显弱点,都是反对将其纳入 JAX 的有力论据。

轴 1:XLA 对齐#

我们考虑的第一个轴是拟议 API 与原生 XLA 操作的对齐程度。例如,jax.numpy.exp() 函数与 jax.lax.exp 或多或少直接对应。许多在 numpyscipy.specialnumpy.linalgscipy.linalg 等模块中的函数都符合这一标准:这类函数在考虑纳入 JAX 时,会通过 XLA 对齐检查。

另一方面,也存在像 numpy.unique() 这样的函数,它不直接对应任何 XLA 操作,并且在某些情况下与 JAX 当前的计算模型(需要静态形状的数组)根本不兼容(例如,unique 返回一个依赖于值的动态数组形状)。这类函数在考虑纳入 JAX 时,不会通过 XLA 对齐检查。

我们还将纯函数语义的需求纳入此轴的考虑范围。例如,numpy.random 是基于一个隐式更新的状态化伪随机数生成器构建的,这与 JAX 基于 XLA 的计算模型根本不兼容。

轴 2:数组 API 对齐#

我们考虑的第二个轴侧重于 Python 数组 API 标准:这在某种意义上是社区驱动的,概述了哪些数组操作对于跨广泛用户社区的面向数组的编程至关重要。如果 numpyscipy 中的某个 API 列在数组 API 标准中,则强烈表明 JAX 应该包含它。以上述示例为例,数组 API 标准包含 numpy.unique() 的几个变体(unique_allunique_countsunique_inverseunique_values),这表明,尽管该函数与 XLA 的对齐程度不精确,但它对 Python 用户社区足够重要,JAX 应该考虑实现它。

轴 3:下游实现的存在性#

对于不符合轴 1 或轴 2 功能的 API,纳入 JAX 的一个重要考虑因素是是否存在维护良好的下游包来提供所需的功能。scipy.optimize 是一个很好的例子:虽然 JAX 确实包含了一组最小化的 scipy.optimize 功能的包装器,但在 JAXopt 包中存在更全面的处理,该包由 JAX 合作者积极维护。在这种情况下,我们应该倾向于引导用户和贡献者使用这些专门的包,而不是在 JAX 本身中重新实现这些 API。

轴 4:实现的复杂性和鲁棒性#

对于不符合 XLA 的功能,一个考虑因素是拟议实现的复杂程度。这在一定程度上与轴 1 相关,但仍然很重要。JAX 已贡献了许多实现相对复杂的函数,这些函数难以验证并带来了不成比例的维护负担;一个例子是 jax.scipy.special.bessel_jn():截至撰写本文时,其当前实现是一种非直观的迭代近似,在某些域中存在收敛问题,而建议的修复又引入了进一步的复杂性。如果我们当初在接受贡献时更仔细地权衡实现的复杂性和鲁棒性,我们可能会选择不接受该贡献。

轴 5:函数式 API 与面向对象 API#

JAX 最适合使用函数式 API,而不是面向对象 API。面向对象 API 常常会隐藏不纯的语义,使其难以很好地实现。NumPy 和 SciPy 通常遵循函数式 API,但有时也提供面向对象的便捷包装器。

这方面的一个例子是 numpy.polynomial.Polynomial,它包装了低级操作,如 numpy.polyadd()numpy.polydiv() 等。总的来说,当存在函数式和面向对象 API 时,JAX 应该避免提供面向对象 API 的包装器,而是提供函数式 API 的包装器。

如果只存在面向对象 API,JAX 应该避免提供包装器,除非在其他轴上情况非常有利。

轴 6:对 JAX 用户和利益相关者的普遍“重要性”#

将 NumPy/SciPy API 纳入 JAX 的决定,还应考虑该算法对广大用户社区的重要性。不可否认,量化“利益相关者”以及如何衡量这种重要性是困难的;但我们将其包含在内,是为了明确 JAX 的 NumPy 和 SciPy 包装器中包含哪些内容的任何决定,都将涉及一定程度的酌情决定权,这种决定很难量化。

对于现有 API,在 GitHub 上搜索使用情况可能有助于确定其重要性或缺乏重要性;例如,我们可以回顾上面轴 4 讨论的 jax.scipy.special.bessel_jn():搜索显示,该函数在 GitHub 上的使用次数寥寥无几,这很可能与之前提到的准确性问题有关。

评估:在范围内?#

在本节中,我们将尝试根据上述标准评估 NumPy 和 SciPy API,包括当前 JAX API 中的一些示例。这不会是所有现有函数和类的详尽列表,而是一个更普遍的按子模块和主题进行的讨论,并附有相关示例。

NumPy API#

numpy 命名空间#

我们认为主 numpy 命名空间中的函数基本都在 JAX 的范围内,因为它们总体上与 XLA(轴 1)和 Python 数组 API(轴 2)对齐,并且对 JAX 用户社区(轴 6)普遍重要。有些函数可能处于边缘地带(例如 numpy.intersect1d()np.setdiff1d()np.union1d() 等函数,它们可以被认为在评估标准的部分不合格),但为了简单起见,我们宣布主 NumPy 命名空间中的所有数组函数都属于 JAX 的范围。

numpy.linalgnumpy.fft#

numpy.linalgnumpy.fft 子模块包含许多与 XLA 提供的功能广泛对齐的函数。其他函数可能具有复杂的设备特定低级实现,但它们代表了重要性(轴 6)超过复杂性的情况。因此,我们认为这两个子模块都属于 JAX 的范围。

numpy.random#

numpy.random 不属于 JAX 的范围,因为基于状态的伪随机数生成器与 JAX 的计算模型根本不兼容。我们转而关注 jax.random,它提供使用基于计数器的 PRNG 的类似功能。

numpy.manumpy.polynomial#

numpy.manumpy.polynomial 子模块主要提供对可以通过其他函数式方法表达的计算的面向对象接口(轴 5);因此,我们认为它们不属于 JAX 的范围。

numpy.testing#

NumPy 的测试功能主要适用于主机端计算,因此我们在 JAX 中不包含任何包装器。尽管如此,JAX 数组与 numpy.testing 兼容,并且 JAX 在整个 JAX 测试套件中频繁使用它。

SciPy API#

SciPy 在顶级命名空间中没有函数,但包含许多子模块。我们下面逐一进行考虑,省略已弃用的模块。

scipy.cluster#

scipy.cluster 模块包含用于层次聚类、k-means 和相关算法的工具。这些在多个轴上都表现不佳,最好由下游包来提供。JAX 中已经存在一个函数(jax.scipy.cluster.vq.vq()),但它在 GitHub 上的使用次数不多:这表明聚类对 JAX 用户普遍不重要。

建议:弃用并移除 jax.scipy.cluster.vq()

scipy.constants#

scipy.constants 模块包含数学和物理常数。这些常数可以直接与 JAX 一起使用,因此没有理由在 JAX 中重新实现。

scipy.datasets#

scipy.datasets 模块包含用于获取和加载数据集的工具。这些获取的数据集可以直接与 JAX 一起使用,因此没有理由在 JAX 中重新实现。

scipy.fft#

scipy.fft 模块包含与 XLA 提供的功能广泛对齐的函数,并且在其他轴上表现良好。因此,我们认为它们属于 JAX 的范围。

scipy.integrate#

scipy.integrate 模块包含数值积分函数。其中更复杂的函数(quaddblquadode)根据轴 1 和轴 4 的标准,不属于 JAX 的范围,因为它们往往是基于动态计算次数的循环算法。jax.experimental.ode.odeint() 相关,但功能有限且未积极开发。

JAX 目前确实包含 jax.scipy.integrate.trapezoid(),但这仅仅是因为 numpy.trapz() 最近被弃用并推荐使用此函数。对于任何特定输入,其实现可以用一行 jax.numpy 表达式替换,因此它不是一个特别有用的 API。

根据轴 1、2、4 和 6,scipy.integrate 应被视为不属于 JAX 的范围。

建议:移除 jax.scipy.integrate.trapezoid(),该函数在 JAX 0.4.14 中添加。

scipy.interpolate#

scipy.interpolate 模块提供了用于一维或多维插值的低级和面向对象例程。这些 API 在多个评估标准上得分较低:它们是基于类的而不是低级的,并且除了最简单的方法之外,没有一种可以有效地用 XLA 操作来表达。

JAX 目前确实有 scipy.interpolate.RegularGridInterpolator 的包装器。如果我们今天考虑这项贡献,我们可能会根据上述标准拒绝它。但这些代码一直相当稳定,因此继续维护它们并不会带来太大坏处。

今后,我们应认为 scipy.interpolate 的其他成员不属于 JAX 的范围。

scipy.io#

scipy.io 子模块与文件输入/输出相关。没有理由在 JAX 中重新实现这一点。

scipy.linalg#

scipy.linalg 子模块包含与 XLA 提供的功能广泛对齐的函数,并且快速线性代数对 JAX 用户社区(轴 6)普遍重要。因此,我们认为它属于 JAX 的范围。

scipy.ndimage#

scipy.ndimage 子模块包含一套用于处理图像数据的工具。其中许多与 scipy.signal 中的工具(例如卷积和滤波)重叠。JAX 目前提供一个 scipy.ndimage API,即 jax.scipy.ndimage.map_coordinates()。此外,JAX 在 jax.image 模块中提供了一些图像相关工具。deepmind 生态系统包含 dm-pix,这是一套更全面的 JAX 图像处理工具。考虑到所有这些因素,我建议 scipy.ndimage 应被视为 JAX 核心范围之外;我们可以将感兴趣的用户和贡献者引向 dm-pix。我们可以考虑将 map_coordinates 移动到 dm-pix 或其他合适的包。

scipy.odr#

scipy.odr 模块提供了一个面向对象的包装器 ODRPACK,用于执行正交距离回归。目前尚不清楚这是否能使用现有的 JAX 原始函数干净地表达出来,因此我们认为它不属于 JAX 本身的范围。

scipy.optimize#

scipy.optimize 模块提供高低级优化接口。此类功能对许多 JAX 用户很重要,并且 JAX 很早就创建了 jax.scipy.optimize 包装器。然而,这些例程的开发人员很快意识到 scipy.optimize API 过于受限,不同的团队开始开发 JAXopt 包和 Optimistix 包,它们都包含更全面、经过更好测试的 JAX 优化例程集。

由于这些维护良好的外部包,我们现在认为 scipy.optimize 不属于 JAX 的范围。

建议:弃用 jax.scipy.optimize,或将其设为可选的 JAXopt 或 Optimistix 依赖项的轻量级包装器。

🟡 scipy.signal#

scipy.signal 模块是混合的:一些函数属于 JAX 的范围(例如,correlateconvolve,它们是 lax.conv_general_dilated 的更用户友好的包装器),而许多其他函数则明确不属于范围(特定领域的工具,没有可行的 XLA 低级实现)。对 jax.scipy.signal 的潜在贡献必须逐案权衡。

🟡 scipy.sparse#

scipy.sparse 子模块主要包含用于以各种格式存储和操作稀疏矩阵和稀疏数组的数据结构。此外,scipy.sparse.linalg 包含许多无矩阵求解器,适用于稀疏矩阵、密集矩阵和线性算子。

scipy.sparse 数组和矩阵数据结构不属于 JAX 的范围,因为它们不符合 JAX 的计算模型(例如,许多操作依赖于动态大小的缓冲区)。JAX 开发了 jax.experimental.sparse 模块,作为一组更符合 JAX 计算约束的数据结构。出于这些原因,我们认为 scipy.sparse 中的数据结构不属于 JAX 的范围。

另一方面,scipy.sparse.linalg 已被证明是一个有趣的领域,而 jax.scipy.sparse.linalg 包含 bicgstabcggmres 求解器。这些对 JAX 用户社区(轴 6)很有用,但除此之外,它们在其他轴上表现不佳。它们非常适合迁移到下游库;一个潜在的选择可能是 Lineax,它包含许多基于 JAX 构建的线性求解器。

建议:考虑将稀疏求解器迁移到 Lineax,否则将 scipy.sparse 视为 JAX 范围之外。

scipy.spatial#

scipy.spatial 模块主要包含空间/距离计算和最近邻搜索的面向对象接口。它大部分不属于 JAX 的范围

scipy.spatial.transform 子模块提供用于处理三维空间旋转的工具。它是一个相对复杂的面向对象接口,可能最好由下游项目来提供。JAX 目前在 jax.scipy.spatial.transform 中包含 RotationSlerp 的部分实现;这些是基本函数的面向对象包装器,引入了非常大的 API 表面积且用户很少。我们判断它们不属于 JAX 本身的范围,用户最好由假想的下游项目来服务。

scipy.spatial.distance 子模块包含有用的距离度量集合,提供 JAX 包装器可能很诱人。然而,借助 jit 和 vmap,用户可以轻松地从头开始定义大多数这些距离度量的有效版本,因此将它们添加到 JAX 并不会带来特别的好处。

建议:考虑弃用并移除 Rotation 和 Slerp API,并将 scipy.spatial 整体视为未来贡献的范围之外。

scipy.special#

scipy.special 模块包含许多更专业函数的实现。在许多情况下,这些函数明确属于范围:例如,像 gammalnbetaincdigamma 等函数,以及许多其他函数,直接对应于可用的 XLA 原始函数,并且根据轴 1 和其他标准明确属于范围。

其他函数需要更复杂的实现;上面提到一个例子是 bessel_jn。尽管与轴 1 和轴 2 不符,但这些函数在轴 6 上往往表现出色:scipy.special 提供了计算各种领域所需的基本函数,因此即使是实现复杂的函数也应倾向于纳入范围,只要实现设计良好且鲁棒。

有几个现有的函数包装器需要我们仔细审查;例如

  • jax.scipy.special.lpmn():它通过一个复杂的 fori_loop 生成勒让德多项式,其方式与 scipy API 不匹配(例如,对于 scipy,z 必须是标量,而对于 JAX,z 必须是 1D 数组)。该函数可发现的用途很少,使其成为轴 1、2、4 和 6 的薄弱候选。

  • jax.scipy.special.lpmn_values():它具有与 lpmn 类似的弱点。

  • jax.scipy.special.sph_harm():它基于 lpmn,并且类似地,其 API 与相应的 scipy 函数有偏差。

  • jax.scipy.special.bessel_jn():如轴 4 中所述,它在实现鲁棒性方面存在弱点,并且使用量很少。我们可以考虑用新的、更鲁棒的实现替换它(例如,#17038)。

建议:重构并提高 bessel_jn 的鲁棒性和测试覆盖率。如果 lpmn、lpmn_values 和 sph_harm 无法修改以更接近 scipy API,则考虑弃用它们。

scipy.stats#

scipy.stats 模块包含广泛的统计函数,包括离散和连续分布、摘要统计量和假设检验。JAX 目前在 jax.scipy.stats 中包装了其中许多函数,主要包括约 20 种统计分布,以及一些其他函数(moderankdatagaussian_kde)。总体而言,它们与 JAX 很好地对齐:分布通常可以用高效的 XLA 操作来表达,并且 API 清晰且函数式。

我们目前没有假设检验函数的包装器,这可能是因为这些函数对 JAX 的主要用户群来说用处不大。

关于分布,在某些情况下,tensorflow_probability 提供了类似的功能,未来我们可以考虑是否弃用 scipy.stats 分布而采用该实现。

建议:今后,我们应将统计分布和摘要统计量视为在范围内,并将假设检验及其相关功能通常视为范围之外。