JAX:高性能数组计算#
高性能数组计算
JAX 是一个用于面向加速器的数组计算和程序转换的 Python 库,专为高性能数值计算和大规模机器学习而设计。
熟悉的 API
JAX 提供了一个熟悉的 NumPy 风格 API,方便研究人员和工程师采用。
转换
JAX 包含可组合的函数转换,用于编译、批处理、自动微分和并行化。
随处运行
相同的代码可在多种后端上运行,包括 CPU、GPU 和 TPU。
安装
入门
JAX 101
如果您想使用 JAX 训练神经网络,请查看 JAX AI Stack!
生态系统#
JAX 本身范围较窄,专注于高效的数组操作和程序转换。围绕 JAX 构建的是一个不断发展的机器学习和数值计算工具生态系统;以下仅是部分示例:
已开发出许多基于 JAX 的库;社区维护的 Awesome JAX 页面维护着一个最新的列表。