关键概念#
本节简要介绍 JAX 软件包的一些关键概念。
JAX 数组 (jax.Array
)#
JAX 中的默认数组实现是 jax.Array
。在许多方面,它类似于您可能熟悉的 NumPy 软件包中的 numpy.ndarray
类型,但它有一些重要的区别。
数组创建#
我们通常不直接调用 jax.Array
构造函数,而是通过 JAX API 函数创建数组。例如,jax.numpy
提供了熟悉的 NumPy 风格的数组构造功能,例如 jax.numpy.zeros()
、jax.numpy.linspace()
、jax.numpy.arange()
等。
import jax
import jax.numpy as jnp
x = jnp.arange(5)
isinstance(x, jax.Array)
True
如果您在代码中使用 Python 类型注解,则 jax.Array
是 jax 数组对象的适当注解(有关更多讨论,请参阅 jax.typing
)。
数组设备和分片#
JAX 数组对象有一个 devices
方法,可让您检查数组内容存储在何处。在最简单的情况下,这将是单个 CPU 设备
x.devices()
{CpuDevice(id=0)}
一般来说,一个数组可能会 *分片* 到多个设备,其方式可以通过 sharding
属性进行检查
x.sharding
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)
这里数组在单个设备上,但一般来说,JAX 数组可以分片到多个设备,甚至多个主机。要阅读有关分片数组和并行计算的更多信息,请参阅并行编程入门
转换#
除了操作数组的函数外,JAX 还包括许多对 JAX 函数进行操作的 转换。这些包括
jax.vmap()
:向量化转换;请参阅自动向量化jax.grad()
:梯度转换;请参阅自动微分
以及其他一些。转换接受一个函数作为参数,并返回一个新的转换后的函数。例如,以下是如何 JIT 编译一个简单的 SELU 函数的方法
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
selu_jit = jax.jit(selu)
print(selu_jit(1.0))
1.05
通常,为了方便起见,您会看到使用 Python 的装饰器语法应用转换
@jax.jit
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
追踪#
转换背后的魔力是 Tracer 的概念。Tracer 是数组对象的抽象替代品,并传递给 JAX 函数,以便提取函数编码的操作序列。
您可以通过打印转换后的 JAX 代码中的任何数组值来看到这一点;例如
@jax.jit
def f(x):
print(x)
return x + 1
x = jnp.arange(5)
result = f(x)
Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace>
打印的值不是数组 x
,而是一个 Tracer
实例,它表示 x
的基本属性,例如其 shape
和 dtype
。通过使用追踪值执行函数,JAX 可以在实际执行这些操作之前确定函数编码的操作序列:像 jit()
、vmap()
和 grad()
这样的转换可以将此输入操作序列映射到转换后的操作序列。
Jaxpr#
JAX 有自己的操作序列中间表示,称为 jaxpr。jaxpr(*JAX exPRession* 的缩写)是函数式程序的简单表示,包括一系列 原语 操作。
例如,考虑我们上面定义的 selu
函数
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
我们可以使用 jax.make_jaxpr()
实用程序将此函数转换为给定特定输入的 jaxpr
x = jnp.arange(5.0)
jax.make_jaxpr(selu)(x)
{ lambda ; a:f32[5]. let
b:bool[5] = gt a 0.0
c:f32[5] = exp a
d:f32[5] = mul 1.6699999570846558 c
e:f32[5] = sub d 1.6699999570846558
f:f32[5] = pjit[
name=_where
jaxpr={ lambda ; b:bool[5] a:f32[5] e:f32[5]. let
f:f32[5] = select_n b e a
in (f,) }
] b a e
g:f32[5] = mul 1.0499999523162842 f
in (g,) }
将其与 Python 函数定义进行比较,我们看到它编码了函数表示的精确操作序列。我们将在后面的JAX 内部原理:jaxpr 语言中更深入地介绍 jaxpr。
Pytree#
JAX 函数和转换从根本上对数组进行操作,但在实践中,编写处理数组集合的代码很方便:例如,神经网络可能会将其参数组织成带有有意义键的数组字典。JAX 不是逐个处理此类结构,而是依赖 pytree 抽象以统一的方式处理此类集合。
以下是一些可以作为 pytree 处理的对象的示例
# (nested) list of parameters
params = [1, 2, (jnp.arange(3), jnp.ones(2))]
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef([*, *, (*, *)])
[1, 2, Array([0, 1, 2], dtype=int32), Array([1., 1.], dtype=float32)]
# Dictionary of parameters
params = {'n': 5, 'W': jnp.ones((2, 2)), 'b': jnp.zeros(2)}
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef({'W': *, 'b': *, 'n': *})
[Array([[1., 1.],
[1., 1.]], dtype=float32), Array([0., 0.], dtype=float32), 5]
# Named tuple of parameters
from typing import NamedTuple
class Params(NamedTuple):
a: int
b: float
params = Params(1, 5.0)
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef(CustomNode(namedtuple[Params], [*, *]))
[1, 5.0]
JAX 有许多用于处理 PyTree 的通用实用程序;例如,函数 jax.tree.map()
可用于将函数映射到树中的每个叶子,而 jax.tree.reduce()
可用于在树的叶子上应用归约。
您可以在“使用 pytree”教程中了解更多信息。
伪随机数#
一般来说,JAX 力求与 NumPy 兼容,但伪随机数生成是一个明显的例外。NumPy 支持一种基于全局 state
的伪随机数生成方法,可以使用 numpy.random.seed()
设置。全局随机状态与 JAX 的计算模型交互不佳,并且难以在不同的线程、进程和设备之间强制执行可重复性。JAX 而是通过随机 key
显式跟踪状态
from jax import random
key = random.key(43)
print(key)
Array((), dtype=key<fry>) overlaying:
[ 0 43]
该密钥实际上是 NumPy 隐藏状态对象的替代品,但我们将其显式传递给 jax.random()
函数。重要的是,随机函数会消耗密钥,但不会修改它:将相同的密钥对象馈送到随机函数将始终导致生成相同的样本。
print(random.normal(key))
print(random.normal(key))
0.07520543
0.07520543
经验法则是:永远不要重复使用密钥(除非您想要相同的输出)。
为了生成不同且独立的样本,您必须在将密钥传递给随机函数之前显式地 split()
密钥
for i in range(3):
new_key, subkey = random.split(key)
del key # The old key is consumed by split() -- we must never use it again.
val = random.normal(subkey)
del subkey # The subkey is consumed by normal().
print(f"draw {i}: {val}")
key = new_key # new_key is safe to use in the next iteration.
draw 0: -1.9133632183074951
draw 1: -1.4749839305877686
draw 2: -0.36703771352767944
请注意,此代码是线程安全的,因为本地随机状态消除了可能涉及全局状态的竞争条件。jax.random.split()
是一个确定性函数,它将一个密钥转换为多个独立的(在伪随机意义上)密钥。
有关 JAX 中伪随机数的更多信息,请参阅“伪随机数”教程。