Pytrees#
JAX 内置支持处理类似数组字典(dicts)、嵌套数组列表等结构的对象——在 JAX 中,这些被称为 pytrees。本节将介绍如何使用它们,提供有用的代码示例,并指出常见的“陷阱”和模式。
关于如何创建自定义 pytree 的解释,请参阅 自定义 pytree 节点。
什么是 Pytree?#
Pytree 是一种由类容器 Python 对象构成的容器状结构,它可以包含“叶子”pytree 或更多的 pytree。Pytree 可以包含列表、元组和字典。叶子是指任何非 pytree 的对象,例如数组,但单个叶子本身也可以被视为一个 pytree。
在机器学习(ML)上下文中,pytree 可以包含:
模型参数
数据集条目
强化学习智能体的观测值
在处理数据集时,你经常会遇到 pytree(例如列表的列表的字典)。
以下是一个简单的 pytree 示例。在 JAX 中,你可以使用 jax.tree.leaves() 从树中提取展平后的叶子,如下所示:
import jax
import jax.numpy as jnp
example_trees = [
[1, 'a', object()],
(1, (2, 3), ()),
[1, {'k1': 2, 'k2': (3, 4)}, 5],
{'a': 2, 'b': (2, 3)},
jnp.array([1, 2, 3]),
]
# Print how many leaves the pytrees have.
for pytree in example_trees:
# This `jax.tree.leaves()` method extracts the flattened leaves from the pytrees.
leaves = jax.tree.leaves(pytree)
print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")
[1, 'a', <object object at 0x703c2c668580>] has 3 leaves: [1, 'a', <object object at 0x703c2c668580>]
(1, (2, 3), ()) has 3 leaves: [1, 2, 3]
[1, {'k1': 2, 'k2': (3, 4)}, 5] has 5 leaves: [1, 2, 3, 4, 5]
{'a': 2, 'b': (2, 3)} has 3 leaves: [2, 2, 3]
Array([1, 2, 3], dtype=int32) has 1 leaves: [Array([1, 2, 3], dtype=int32)]
任何由类容器 Python 对象构建的树状结构在 JAX 中都可以被视为 pytree。如果类位于 pytree 注册表中(默认包括列表、元组和字典),它们就被认为是类容器的。任何类型不在 pytree 容器注册表中的对象都将被视为树中的叶子节点。
可以通过向注册表添加函数来扩展 pytree 注册表,以包含用户定义的容器类,这些函数指定了如何展平树结构;详见下方的 自定义 pytree 节点。
常用 Pytree 函数#
JAX 提供了许多用于操作 pytree 的实用程序。这些可以在 jax.tree_util 子包中找到;为了方便起见,其中许多在 jax.tree 模块中都有别名。
常用函数:jax.tree.map#
最常用的 pytree 函数是 jax.tree.map()。它的工作方式类似于 Python 原生的 map,但能透明地操作整个 pytree。
这里是一个示例:
list_of_lists = [
[1, 2, 3],
[1, 2],
[1, 2, 3, 4]
]
jax.tree.map(lambda x: x*2, list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]
jax.tree.map() 也允许将一个 N 元函数映射到多个参数上。例如:
another_list_of_lists = list_of_lists
jax.tree.map(lambda x, y: x+y, list_of_lists, another_list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]
当使用多个参数配合 jax.tree.map() 时,输入的结构必须完全匹配。也就是说,列表必须具有相同数量的元素,字典必须具有相同的键等。
使用 jax.tree.map 处理机器学习模型参数的示例#
此示例演示了在训练简单的 多层感知机 (MLP) 时,pytree 操作的实用性。
首先定义初始模型参数:
import numpy as np
def init_mlp_params(layer_widths):
params = []
for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
params.append(
dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),
biases=np.ones(shape=(n_out,))
)
)
return params
params = init_mlp_params([1, 128, 128, 1])
使用 jax.tree.map() 检查初始参数的形状:
jax.tree.map(lambda x: x.shape, params)
[{'biases': (128,), 'weights': (1, 128)},
{'biases': (128,), 'weights': (128, 128)},
{'biases': (1,), 'weights': (128, 1)}]
接下来,定义用于训练 MLP 模型的函数:
# Define the forward pass.
def forward(params, x):
*hidden, last = params
for layer in hidden:
x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
return x @ last['weights'] + last['biases']
# Define the loss function.
def loss_fn(params, x, y):
return jnp.mean((forward(params, x) - y) ** 2)
# Set the learning rate.
LEARNING_RATE = 0.0001
# Using the stochastic gradient descent, define the parameter update function.
# Apply `@jax.jit` for JIT compilation (speed).
@jax.jit
def update(params, x, y):
# Calculate the gradients with `jax.grad`.
grads = jax.grad(loss_fn)(params, x, y)
# Note that `grads` is a pytree with the same structure as `params`.
# `jax.grad` is one of many JAX functions that has
# built-in support for pytrees.
# This is useful - you can apply the SGD update using JAX pytree utilities.
return jax.tree.map(
lambda p, g: p - LEARNING_RATE * g, params, grads
)
查看对象的 Pytree 定义#
为了调试目的查看任意 object 的 pytree 定义,你可以使用:
from jax.tree_util import tree_structure
print(tree_structure(object))
PyTreeDef(*)
Pytree 与 JAX 转换#
许多 JAX 函数(如 jax.lax.scan())都在数组 pytree 上运行。此外,所有的 JAX 函数转换都可以应用于以数组 pytree 作为输入并产生数组 pytree 作为输出的函数。
一些 JAX 函数转换接受可选参数,用于指定应如何处理特定的输入或输出值(例如 jax.vmap() 的 in_axes 和 out_axes 参数)。这些参数也可以是 pytree,其结构必须与对应参数的 pytree 结构一致。特别是,为了能够将这些参数 pytree 中的叶子与参数 pytree 中的值“匹配”,参数 pytree 通常被限制为参数 pytree 的树前缀。
例如,如果你将以下输入传递给 jax.vmap()(注意函数的输入参数被视为一个元组):
vmap(f, in_axes=(a1, {"k1": a2, "k2": a3}))
那么你可以使用以下 in_axes pytree 来指定只有 k2 参数被映射(axis=0),其余参数不进行映射(axis=None):
vmap(f, in_axes=(None, {"k1": None, "k2": 0}))
可选参数 pytree 的结构必须与主输入 pytree 的结构匹配。但是,可选参数也可以选择性地指定为“前缀”pytree,这意味着单个叶子值可以应用于整个子 pytree。
例如,如果你有与上述相同的 jax.vmap() 输入,但只想映射字典参数,你可以使用:
vmap(f, in_axes=(None, 0)) # equivalent to (None, {"k1": 0, "k2": 0})
或者,如果你希望映射每一个参数,你可以写一个应用于整个参数元组 pytree 的单一叶子值:
vmap(f, in_axes=0) # equivalent to (0, {"k1": 0, "k2": 0})
这恰好是 jax.vmap() 的默认 in_axes 值。
同样的逻辑也适用于引用转换函数特定输入或输出值的其他可选参数,例如 jax.vmap() 中的 out_axes。
显式键路径#
在 pytree 中,每个叶子都有一个键路径 (key path)。叶子的键路径是一个键的 list,其中列表的长度等于该叶子在 pytree 中的深度。每个键都是一个 可哈希对象,表示对相应 pytree 节点类型的索引。键的类型取决于 pytree 节点的类型;例如,dict 的键类型与 tuple 的键类型不同。
对于内置的 pytree 节点类型,任何 pytree 实例的键集都是唯一的。对于由具有此属性的节点组成的 pytree,每个叶子的键路径都是唯一的。
JAX 拥有以下用于处理键路径的 jax.tree_util.* 方法:
jax.tree_util.tree_flatten_with_path():工作方式类似于jax.tree.flatten(),但会返回键路径。jax.tree_util.tree_map_with_path():工作方式类似于jax.tree.map(),但该函数还将键路径作为参数。jax.tree_util.keystr():给定一个通用键路径,返回一个易于阅读的字符串表达式。
例如,一个用例是打印与特定叶子值相关的调试信息:
import collections
ATuple = collections.namedtuple("ATuple", ('name'))
tree = [1, {'k1': 2, 'k2': (3, 4)}, ATuple('foo')]
flattened, _ = jax.tree_util.tree_flatten_with_path(tree)
for key_path, value in flattened:
print(f'Value of tree{jax.tree_util.keystr(key_path)}: {value}')
Value of tree[0]: 1
Value of tree[1]['k1']: 2
Value of tree[1]['k2'][0]: 3
Value of tree[1]['k2'][1]: 4
Value of tree[2].name: foo
为了表示键路径,JAX 为内置 pytree 节点类型提供了一些默认键类型,即:
SequenceKey(idx: int):用于列表和元组。DictKey(key: Hashable):用于字典。GetAttrKey(name: str):用于namedtuple以及(最好是)自定义 pytree 节点(下一节中有更多介绍)。
你可以自由定义自己的自定义节点键类型。只要它们的 __str__() 方法也被重写为易于阅读的表达式,它们就能与 jax.tree_util.keystr() 一起使用。
for key_path, _ in flattened:
print(f'Key path of tree{jax.tree_util.keystr(key_path)}: {repr(key_path)}')
Key path of tree[0]: (SequenceKey(idx=0),)
Key path of tree[1]['k1']: (SequenceKey(idx=1), DictKey(key='k1'))
Key path of tree[1]['k2'][0]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=0))
Key path of tree[1]['k2'][1]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=1))
Key path of tree[2].name: (SequenceKey(idx=2), GetAttrKey(name='name'))
Pytree 常见陷阱#
本节介绍了在使用 JAX pytree 时遇到的一些最常见的问题(“陷阱”)。
误将 Pytree 节点当作叶子节点#
一个需要注意的常见陷阱是无意中引入了树节点而不是叶子:
a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))]
# Try to make another pytree with ones instead of zeros.
shapes = jax.tree.map(lambda x: x.shape, a_tree)
jax.tree.map(jnp.ones, shapes)
[(Array([1., 1.], dtype=float32), Array([1., 1., 1.], dtype=float32)),
(Array([1., 1., 1.], dtype=float32), Array([1., 1., 1., 1.], dtype=float32))]
这里发生的情况是,数组的 shape 是一个元组,它是一个 pytree 节点,其元素是叶子。因此,在 map 中,调用 jnp.ones 时,传入的不是例如 (2, 3),而是分别传入了 2 和 3。
解决方案取决于具体情况,但有两个广泛适用的选项:
重写代码以避免中间的
jax.tree.map()。将元组转换为 NumPy 数组(
np.array)或 JAX NumPy 数组(jnp.array),这会使整个序列成为一个叶子。
jax.tree_util 对 None 的处理#
jax.tree_util 函数将 None 视为 pytree 节点的缺失,而不是叶子:
jax.tree.leaves([None, None, None])
[]
要将 None 视为叶子,可以使用 is_leaf 参数:
jax.tree.leaves([None, None, None], is_leaf=lambda x: x is None)
[None, None, None]
常用 Pytree 模式#
本节涵盖了 JAX pytree 中一些最常见的模式。
使用 jax.tree.map 和 jax.tree.transpose 转置 Pytree#
为了转置 pytree(将树的列表转换为列表的树),JAX 有两个函数:jax.tree.map()(更基础)和 jax.tree.transpose()(更灵活、复杂且冗长)。
选项 1: 使用 jax.tree.map()。这是一个示例:
def tree_transpose(list_of_trees):
"""
Converts a list of trees of identical structure into a single tree of lists.
"""
return jax.tree.map(lambda *xs: list(xs), *list_of_trees)
# Convert a dataset from row-major to column-major.
episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)]
tree_transpose(episode_steps)
{'obs': [3, 4], 't': [1, 2]}
选项 2: 对于更复杂的转置,使用 jax.tree.transpose(),它虽然更冗长,但允许你指定内部和外部 pytree 的结构以获得更大的灵活性。例如:
jax.tree.transpose(
outer_treedef = jax.tree.structure([0 for e in episode_steps]),
inner_treedef = jax.tree.structure(episode_steps[0]),
pytree_to_transpose = episode_steps
)
{'obs': [3, 4], 't': [1, 2]}
扩展 Pytree#
有关扩展 pytree 的材料已移至 自定义 pytree 节点。