JAX:高性能数组计算#

高性能数组计算

JAX 是一个 Python 库,用于面向加速器的数组计算和程序转换,专为高性能数值计算和大规模机器学习而设计。

熟悉的 API

JAX 提供了熟悉的 NumPy 风格 API,方便研究人员和工程师采用。

转换

JAX 包括可组合的函数转换,用于编译、批处理、自动微分和并行化。

随处运行

相同的代码可以在多个后端执行,包括 CPU、GPU 和 TPU

安装
安装
入门指南
JAX 入门
用户指南
用户指南

如果您希望训练神经网络,请使用 Flax 并从其教程开始。对于基于 JAX 构建的端到端 Transformer 库,请参阅 MaxText

生态系统#

JAX 本身范围狭窄,专注于高效的数组操作和程序转换。围绕 JAX 构建的是一个不断发展的机器学习和数值计算工具生态系统;以下只是其中一小部分:

神经网络

优化器和求解器

杂项工具

概率编程

物理和模拟

已经开发了更多基于 JAX 的库;社区运行的 Awesome JAX 页面维护了一个最新的列表。