术语表#

Array#

JAX 中与 numpy.ndarray 对应的概念。请参阅 jax.Array

CPU#

中央处理器 (Central Processing Unit) 的缩写。CPU 是大多数计算机中标准的计算架构。JAX 可以在 CPU 上运行计算,但通常在 GPUTPU 上获得更好的性能。

Device#

JAX 用于执行计算的 CPUGPUTPU 的通用名称。

forward-mode autodiff#

请参阅 JVP

functional programming#

一种编程范式,其中程序通过应用和组合 纯函数 来定义。JAX 设计用于与函数式程序配合使用。

GPU#

图形处理单元 (Graphical Processing Unit) 的缩写。GPU 最初专门用于屏幕图像渲染相关的操作,但现在已变得更加通用。JAX 能够针对 GPU 进行快速数组操作(另请参阅 CPUTPU)。

jaxpr#

JAX expression 的缩写,jaxpr 是 JAX 生成的计算的中间表示,它被转发给 XLA 进行编译和执行。有关更多讨论和示例,请参阅 JAX 内部:jaxpr 语言

JIT#

即时 (Just In Time) 编译的缩写。JAX 中的 JIT 通常指的是将数组操作编译到 XLA,这通常通过 jax.jit() 来实现。

JVP#

Jacobian Vector Product 的缩写,有时也称为前向模式自动微分。更多详细信息,请参阅 Jacobian-Vector products (JVPs, aka forward-mode autodiff)。在 JAX 中,JVP 是一个通过 jax.jvp() 实现的转换。另请参阅 VJP

primitive#

primitive 是 JAX 程序中使用的基本计算单元。 jax.lax 中的大多数函数代表单个 primitive。在 jaxpr 中表示计算时,jaxpr 中的每个操作都是一个 primitive。

pure function#

纯函数是其输出仅取决于其输入且没有任何副作用的函数。JAX 的 转换 模型旨在与纯函数配合使用。另请参阅 函数式编程

pytree#

pytree 是一种抽象,它允许 JAX 以统一的方式处理元组、列表、字典和其他更通用的数组值容器。有关更详细的讨论,请参阅 使用 pytrees

reverse-mode autodiff#

请参阅 VJP

SPMD#

单程序多数据 (Single Program Multi Data) 的缩写,它指的是一种并行计算技术,在该技术中,相同的计算(例如,神经网络的前向传播)在不同的数据(例如,批次中的不同输入)上并行运行在不同的设备(例如,多个 TPU)上。 jax.pmap() 是实现 SPMD 并行化的 JAX 转换

static#

JIT 编译中,一个未被追踪的值(请参阅 Tracer)。也可能指在编译时对静态值进行的计算。

TPU#

张量处理单元 (Tensor Processing Unit) 的缩写。TPU 是专门为深度学习应用中 N 维张量的快速操作而设计的芯片。JAX 能够针对 TPU 进行快速数组操作(另请参阅 CPUGPU)。

Tracer#

一个对象,用作 JAX Array 的占位符,以便确定 Python 函数执行的操作序列。在内部,JAX 通过 `jax.core.Tracer` 类实现此功能。

transformation#

高阶函数,即接受函数作为输入并输出转换后函数的函数。JAX 中的示例包括 jax.jit()jax.vmap()jax.grad()

VJP#

Vector Jacobian Product 的缩写,有时也称为反向模式自动微分。更多详细信息,请参阅 Vector-Jacobian products (VJPs, aka reverse-mode autodiff)。在 JAX 中,VJP 是一个通过 jax.vjp() 实现的转换。另请参阅 JVP

XLA#

加速线性代数 (Accelerated Linear Algebra) 的缩写。XLA 是一个用于线性代数运算的领域特定编译器,是 JIT 编译的 JAX 代码的主要后端。请参阅 https://www.openxla.org/xla/

weak type#

一种 JAX 数据类型,其类型提升语义与 Python 标量相同;请参阅 JAX 中的弱类型值