关于项目#
JAX 项目由 JAX 核心团队领导。我们以开放的方式进行开发,并欢迎来自社区的开源贡献。我们经常看到来自 Google DeepMind、更广泛的 Alphabet、NVIDIA 以及其他地方的贡献。
该项目的核心是 JAX 核心库,它专注于大规模机器学习和数值计算的基础知识。
在开发核心时,我们希望保持敏捷和专注的范围,因此我们严重依赖周围的模块化技术堆栈。首先,我们将 jax
模块设计为可组合的和可扩展的,以便各种特定领域的库可以在分散的方式下在其外部蓬勃发展。其次,我们严重依赖模块化的后端堆栈(编译器和运行时)来针对不同的加速器。无论您是编写使用 JAX 构建的新的特定领域库,还是希望支持新的硬件,您通常都可以通过极少甚至无需修改 JAX 核心代码库来贡献这些内容。
许多 JAX 的核心贡献者都扎根于开源软件和研究,领域涵盖计算机科学和自然科学。我们努力不断推动机器学习和数值计算的前沿——跨所有计算平台和加速器——并发现大规模数组编程的真谛。
开放开发#
JAX 的日常开发在 GitHub 上公开进行,使用拉取请求、问题跟踪器、讨论和JAX 增强提案 (JEP)。阅读并参与这些是参与其中的好方法。我们还维护开发人员笔记,其中涵盖了 JAX 的内部设计。
JAX 核心团队决定是否接受更改和增强。维护一个简单的决策结构目前可以帮助我们以研究前沿的速度进行开发。开放开发是我们的核心价值,如果/当它变得有用时,我们可能会随着时间的推移适应更复杂的决策结构(例如,指定区域负责人)。
有关更多信息,请参阅贡献 JAX。
模块化堆栈#
为了实现 (a) 跨数值领域的不断增长的用户社区,以及 (b) 不断发展的硬件环境,我们严重依赖模块化。
构建于 JAX 之上的库#
虽然 JAX 核心库专注于基础知识,但我们希望鼓励在 JAX 之上构建特定领域的库和工具。事实上,许多库已经围绕 JAX 出现,以提供更高级别的功能和扩展。
我们如何鼓励这种分散式开发?我们通过几个技术选择来指导它。首先,JAX 的主要 API 专注于基本构建块(例如,数值原语、NumPy 操作、数组和转换),鼓励辅助库根据其领域的需要开发实用程序。此外,JAX 还公开了一些更高级的 API,用于自定义和可扩展性。库可以依赖这些 API,以便将 JAX 用作内部实现手段,与自动微分等转换进行更多集成等等。
JAX 生态系统中的项目以分布式且通常开放的方式进行开发。它们不受 JAX 核心团队的管辖,即使有时团队成员会为它们做出贡献或与它们的开发人员保持联系。
可插拔后端#
我们希望 JAX 能够在 CPU、GPU、TPU 以及其他新兴的硬件平台上运行。为了鼓励在新的平台上无障碍地支持 JAX,JAX 核心也强调其后端中的模块化。
为了管理硬件设备和内存,以及针对此类设备的编译,JAX 调用开放的XLA 编译器和PJRT 运行时。这两者都是 JAX 核心外部的项目,由 OpenXLA 管理和维护(同样,JAX 核心开发人员经常做出贡献和讨论)。
XLA 旨在实现跨加速器的互操作性(例如,通过摄取StableHLO 作为输入),PJRT 通过插件设备 API 提供可扩展性。添加对新设备的支持是通过实现 XLA 的后端降低以及实现 PJRT 定义的插件设备 API 来完成的。如果您希望为编译做出贡献,或支持新的硬件,我们鼓励您在 XLA 和 PJRT 层做出贡献。
这些开放的系统组件允许第三方在新的加速器平台上支持 JAX,而无需更改 JAX 核心。目前有几个插件正在开发中。例如,苹果公司的一个团队正在开发一个 PJRT 插件,以使 JAX 在 Apple Metal 上运行。