术语表#

Array#

JAX 中与 numpy.ndarray 类似的概念。参见 jax.Array

CPU#

中央处理器 (Central Processing Unit) 的缩写,CPU 是大多数计算机中可用的标准计算架构。JAX 可以在 CPU 上运行计算,但通常可以在 GPUTPU 上获得更好的性能。

设备 (Device)#

用于指代 JAX 用于执行计算的 CPUGPUTPU 的通用名称。

前向模式自动微分 (forward-mode autodiff)#

参见 JVP

函数式编程 (functional programming)#

一种编程范式,其中程序通过应用和组合纯函数来定义。JAX 设计用于函数式程序。

GPU#

图形处理器 (Graphical Processing Unit) 的缩写,GPU 最初专门用于与屏幕图像渲染相关的操作,但现在用途更加广泛。JAX 能够以 GPU 为目标,以快速执行数组操作(另请参阅 CPUTPU)。

jaxpr#

JAX 表达式 (JAX expression) 的缩写,jaxpr 是 JAX 生成的计算的中间表示,并转发到 XLA 进行编译和执行。有关更多讨论和示例,请参阅 JAX 内部原理:jaxpr 语言

JIT#

即时 (Just In Time) 编译的缩写,JAX 中的 JIT 通常指将数组操作编译为 XLA,最常见的是使用 jax.jit() 完成。

JVP#

雅可比向量积 (Jacobian Vector Product) 的缩写,有时也称为前向模式自动微分。有关更多详细信息,请参阅 雅可比-向量积(JVP,又名前向模式自动微分)。在 JAX 中,JVP 是一种转换,通过 jax.jvp() 实现。另请参阅 VJP

原语 (primitive)#

原语是 JAX 程序中使用的基本计算单元。jax.lax 中的大多数函数代表单个原语。在 jaxpr 中表示计算时,jaxpr 中的每个操作都是一个原语。

纯函数 (pure function)#

纯函数是一种函数,其输出仅基于其输入,并且没有副作用。JAX 的转换模型旨在与纯函数一起使用。另请参阅函数式编程

pytree#

pytree 是一种抽象,允许 JAX 以统一的方式处理元组、列表、字典以及其他更通用的数组值容器。有关更详细的讨论,请参阅使用 pytrees

反向模式自动微分 (reverse-mode autodiff)#

参见 VJP

SPMD#

单程序多数据 (Single Program Multi Data) 的缩写,它指的是一种并行计算技术,其中相同的计算(例如,神经网络的前向传递)在不同的输入数据(例如,批处理中的不同输入)上并行运行在不同的设备上(例如,多个 TPU)。jax.pmap() 是实现 SPMD 并行性的 JAX 转换

静态 (static)#

JIT 编译中,未被跟踪的值(参见 Tracer)。有时也指对静态值的编译时计算。

TPU#

张量处理单元 (Tensor Processing Unit) 的缩写,TPU 是专门为深度学习应用中使用的 N 维张量快速运算而设计的芯片。JAX 能够以 TPU 为目标,以快速执行数组操作(另请参阅 CPUGPU)。

Tracer#

用作 JAX Array 的占位符的对象,以确定 Python 函数执行的操作序列。在内部,JAX 通过 jax.core.Tracer 类实现此功能。

转换 (transformation)#

高阶函数:即,将函数作为输入并输出转换后的函数的函数。JAX 中的示例包括 jax.jit()jax.vmap()jax.grad()

VJP#

向量雅可比积 (Vector Jacobian Product) 的缩写,有时也称为反向模式自动微分。有关更多详细信息,请参阅 向量-雅可比积(VJP,又名反向模式自动微分)。在 JAX 中,VJP 是一种转换,通过 jax.vjp() 实现。另请参阅 JVP

XLA#

加速线性代数 (Accelerated Linear Algebra) 的缩写,XLA 是线性代数运算的特定领域编译器,是 JIT 编译的 JAX 代码的主要后端。请参阅 https://tensorflowcn.cn/xla/

弱类型 (weak type)#

一种 JAX 数据类型,其类型提升语义与 Python 标量相同;请参阅 JAX 中的弱类型值