关于本项目#
JAX 项目由 JAX 核心团队领导。我们以开放方式进行开发,并欢迎来自社区各界的开源贡献。我们经常看到来自 Google DeepMind、更广泛的 Alphabet、NVIDIA 以及其他地方的贡献。
该项目的核心是 JAX 核心库,它专注于大规模机器学习和数值计算的基础。
在开发核心时,我们希望保持敏捷性和集中的范围,因此我们严重依赖周围的模块化技术堆栈。首先,我们将 jax
模块设计为可组合和可扩展的,以便各种特定领域的库可以在分散的方式下在其外部蓬勃发展。其次,我们严重依赖模块化后端堆栈(编译器和运行时)来定位不同的加速器。无论您是编写使用 JAX 构建的新领域特定库,还是希望支持新的硬件,您通常都可以通过最少或无需修改 JAX 核心代码库来贡献这些内容。
JAX 的许多核心贡献者都扎根于开源软件和研究领域,领域涵盖计算机科学和自然科学。我们努力不断推动机器学习和数值计算的前沿——跨所有计算平台和加速器——并发现大规模数组编程的真谛。
开放式开发#
JAX 的日常开发在 GitHub 上公开进行,使用拉取请求、问题跟踪器、讨论和JAX 增强提案 (JEP)。阅读和参与这些是参与其中的好方法。我们还维护了涵盖 JAX 内部设计的开发者笔记。
JAX 核心团队决定是否接受更改和增强功能。维护一个简单的决策结构目前有助于我们以研究前沿的速度进行开发。开放式开发是我们的核心价值观,如果/当它变得有用时,我们可能会随着时间的推移适应更复杂的决策结构(例如,指定区域负责人)。
更多信息,请参阅为 JAX 做贡献。
模块化堆栈#
为了实现 (a) 跨数值领域的不断增长的用户社区,以及 (b) 不断发展的硬件环境,我们严重依赖模块化。
基于 JAX 构建的库#
虽然 JAX 核心库专注于基础知识,但我们希望鼓励在 JAX 之上构建特定领域的库和工具。实际上,许多库已经围绕 JAX 出现,以提供更高级别的功能和扩展。
我们如何鼓励这种分散式开发?我们通过几个技术选择来指导它。首先,JAX 的主要 API 专注于基本构建块(例如,数值原语、NumPy 操作、数组和转换),鼓励辅助库根据其领域的需求开发实用程序。此外,JAX 公开了一些更高级的 API,用于自定义和可扩展性。库可以依赖这些 API,以便将 JAX 用作内部实现手段,以更深入地集成其转换(如自动微分)等等。
JAX 生态系统中的项目以分布式且通常开放的方式开发。它们不受 JAX 核心团队的管辖,即使有时团队成员会为它们做出贡献或与它们的开发人员保持联系。
可插拔后端#
我们希望 JAX 能够在 CPU、GPU、TPU 和其他新兴硬件平台上运行。为了鼓励在新的平台上不受阻碍地支持 JAX,JAX 核心在其后端也强调模块化。
为了管理硬件设备和内存,以及对此类设备进行编译,JAX 调用开放的 XLA 编译器和 PJRT 运行时。这两者都是 JAX 核心之外的项目,由 OpenXLA 管理和维护(同样,JAX 核心开发人员经常做出贡献并进行讨论)。
XLA 旨在实现跨加速器的互操作性(例如,通过摄取 StableHLO 作为输入),而 PJRT 通过插件设备 API 提供可扩展性。添加对新设备的支持是通过为 XLA 实现后端降低,并实现由 PJRT 定义的插件设备 API 来完成的。如果您希望为编译做出贡献,或支持新的硬件,我们鼓励您在 XLA 和 PJRT 层做出贡献。
这些开放系统组件允许第三方在新加速器平台上支持 JAX,而无需更改 JAX 核心。如今有几个插件正在开发中。例如,Apple 的一个团队正在开发 PJRT 插件,以使 JAX 在 Apple Metal 上运行。