术语表#
- 数组 (Array)#
JAX 中
numpy.ndarray
的模拟。请参阅jax.Array
。- CPU#
中央处理器 (Central Processing Unit) 的缩写,CPU 是大多数计算机中可用的标准计算架构。JAX 可以在 CPU 上运行计算,但通常可以在 GPU 和 TPU 上获得更好的性能。
- 设备 (Device)#
- 前向模式自动微分 (forward-mode autodiff)#
请参阅 JVP。
- 函数式编程 (functional programming)#
一种编程范式,其中程序通过应用和组合 纯函数 来定义。JAX 被设计用于函数式程序。
- GPU#
图形处理单元 (Graphical Processing Unit) 的缩写,GPU 最初专门用于屏幕上图像渲染相关的操作,但现在更加通用。JAX 可以针对 GPU 进行快速的数组操作(另请参阅 CPU 和 TPU)。
- jaxpr#
JAX 表达式 (JAX expression) 的缩写,jaxpr 是 JAX 生成的计算的中间表示形式,并被转发到 XLA 进行编译和执行。有关更多讨论和示例,请参阅 JAX 内部原理:jaxpr 语言。
- JIT#
即时 (Just In Time) 编译的缩写,JAX 中的 JIT 通常是指将数组操作编译到 XLA,最常见的是使用
jax.jit()
完成。- JVP#
雅可比向量积 (Jacobian Vector Product) 的缩写,有时也称为前向模式自动微分。有关更多详细信息,请参阅 雅可比-向量积(JVP,又名 前向模式自动微分)。在 JAX 中,JVP 是一种通过
jax.jvp()
实现的 变换。另请参阅 VJP。- 原语 (primitive)#
原语是 JAX 程序中使用的基本计算单元。
jax.lax
中的大多数函数代表单独的原语。当在 jaxpr 中表示计算时,jaxpr 中的每个操作都是一个原语。- 纯函数 (pure function)#
纯函数是指其输出仅基于其输入,且没有副作用的函数。JAX 的 变换 模型被设计为与纯函数一起工作。另请参阅 函数式编程。
- pytree#
pytree 是一种抽象,它允许 JAX 以统一的方式处理元组、列表、字典以及其他更通用的数组值容器。有关更详细的讨论,请参阅 使用 pytree。
- 反向模式自动微分 (reverse-mode autodiff)#
请参阅 VJP。
- SPMD#
单程序多数据 (Single Program Multi Data) 的缩写,它指的是一种并行计算技术,其中相同的计算(例如,神经网络的前向传递)在不同的设备(例如,多个 TPU)上并行地运行在不同的输入数据(例如,批次中的不同输入)上。
jax.pmap()
是一种实现 SPMD 并行的 JAX 变换。- 静态 (static)#
- TPU#
张量处理单元 (Tensor Processing Unit) 的缩写,TPU 是专门为深度学习应用中使用的 N 维张量快速操作而设计的芯片。JAX 能够针对 TPU 进行快速的数组操作(另请参阅 CPU 和 GPU)。
- Tracer#
用作 JAX Array 的占位符的对象,用于确定 Python 函数执行的操作序列。在内部,JAX 通过
jax.core.Tracer
类实现此功能。- 变换 (transformation)#
高阶函数:也就是说,一个将函数作为输入并输出转换后的函数的函数。JAX 中的示例包括
jax.jit()
、jax.vmap()
和jax.grad()
。- VJP#
向量雅可比积 (Vector Jacobian Product) 的缩写,有时也称为反向模式自动微分。有关更多详细信息,请参阅 向量-雅可比积(VJP,又名反向模式自动微分)。在 JAX 中,VJP 是一种通过
jax.vjp()
实现的 变换。另请参阅 JVP。- XLA#
加速线性代数 (Accelerated Linear Algebra) 的缩写,XLA 是一个用于线性代数操作的特定领域编译器,它是 JIT 编译的 JAX 代码的主要后端。请参阅 https://tensorflowcn.cn/xla/。
- 弱类型 (weak type)#
一种具有与 Python 标量相同类型提升语义的 JAX 数据类型;请参阅 JAX 中的弱类型值。