使用 Pytrees#
JAX 内置支持看起来像数组字典(dicts)、列表的列表、字典的列表或其他嵌套结构的类——在 JAX 中,这些被称为 Pytrees。本节将解释如何使用它们,提供有用的代码示例,并指出常见的“陷阱”和模式。
什么是 Pytree?#
Pytree 是一个由类容器 Python 对象构建的容器状结构——“叶子” Pytree 和/或更多的 Pytree。Pytree 可以包含列表、元组和字典。叶子是任何不是 Pytree 的东西,例如数组,但单个叶子也是一个 Pytree。
在机器学习(ML)的上下文中,Pytree 可以包含
模型参数
数据集条目
强化学习代理的观察
在使用数据集时,您经常会遇到 Pytrees(例如,字典的列表的列表)。
下面是一个简单的 Pytree 的示例。在 JAX 中,您可以使用 jax.tree.leaves() 来提取 Pytree 的扁平化叶子,如下所示。
import jax
import jax.numpy as jnp
example_trees = [
[1, 'a', object()],
(1, (2, 3), ()),
[1, {'k1': 2, 'k2': (3, 4)}, 5],
{'a': 2, 'b': (2, 3)},
jnp.array([1, 2, 3]),
]
# Print how many leaves the pytrees have.
for pytree in example_trees:
# This `jax.tree.leaves()` method extracts the flattened leaves from the pytrees.
leaves = jax.tree.leaves(pytree)
print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")
[1, 'a', <object object at 0x78d81854c0d0>] has 3 leaves: [1, 'a', <object object at 0x78d81854c0d0>]
(1, (2, 3), ()) has 3 leaves: [1, 2, 3]
[1, {'k1': 2, 'k2': (3, 4)}, 5] has 5 leaves: [1, 2, 3, 4, 5]
{'a': 2, 'b': (2, 3)} has 3 leaves: [2, 2, 3]
Array([1, 2, 3], dtype=int32) has 1 leaves: [Array([1, 2, 3], dtype=int32)]
任何由类容器 Python 对象构建的树形结构都可以被 JAX 视为 Pytree。类被认为是类容器,如果它们在 Pytree 注册表中,该注册表默认包含列表、元组和字典。任何类型不在 Pytree 容器注册表中的对象都将被视为树中的叶子节点。
可以通过注册类及其用于扁平化树的函数来扩展 Pytree 注册表以包含用户定义的容器类;请参阅下面的 自定义 Pytree 节点。
常用的 Pytree 函数#
JAX 提供了许多用于操作 Pytrees 的实用程序。这些可以在 jax.tree_util 子包中找到;为了方便起见,其中许多在 jax.tree 模块中都有别名。
常用函数:jax.tree.map#
最常用的 Pytree 函数是 jax.tree.map()。它的作用类似于 Python 原生的 map,但可以透明地操作整个 Pytrees。
这是一个例子
list_of_lists = [
[1, 2, 3],
[1, 2],
[1, 2, 3, 4]
]
jax.tree.map(lambda x: x*2, list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]
jax.tree.map() 还允许对多个参数映射一个 N 元函数。例如
another_list_of_lists = list_of_lists
jax.tree.map(lambda x, y: x+y, list_of_lists, another_list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]
当使用多个参数与 jax.tree.map() 一起使用时,输入的结构必须完全匹配。也就是说,列表必须有相同数量的元素,字典必须有相同的键,等等。
jax.tree.map 与 ML 模型参数的示例#
本示例演示了 Pytree 操作在训练一个简单的 多层感知器(MLP)时可能很有用。
首先定义初始模型参数
import numpy as np
def init_mlp_params(layer_widths):
params = []
for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
params.append(
dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),
biases=np.ones(shape=(n_out,))
)
)
return params
params = init_mlp_params([1, 128, 128, 1])
使用 jax.tree.map() 检查初始参数的形状
jax.tree.map(lambda x: x.shape, params)
[{'biases': (128,), 'weights': (1, 128)},
{'biases': (128,), 'weights': (128, 128)},
{'biases': (1,), 'weights': (128, 1)}]
接下来,定义训练 MLP 模型的函数
# Define the forward pass.
def forward(params, x):
*hidden, last = params
for layer in hidden:
x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
return x @ last['weights'] + last['biases']
# Define the loss function.
def loss_fn(params, x, y):
return jnp.mean((forward(params, x) - y) ** 2)
# Set the learning rate.
LEARNING_RATE = 0.0001
# Using the stochastic gradient descent, define the parameter update function.
# Apply `@jax.jit` for JIT compilation (speed).
@jax.jit
def update(params, x, y):
# Calculate the gradients with `jax.grad`.
grads = jax.grad(loss_fn)(params, x, y)
# Note that `grads` is a pytree with the same structure as `params`.
# `jax.grad` is one of many JAX functions that has
# built-in support for pytrees.
# This is useful - you can apply the SGD update using JAX pytree utilities.
return jax.tree.map(
lambda p, g: p - LEARNING_RATE * g, params, grads
)
自定义 Pytree 节点#
本节解释了在 JAX 中如何通过使用 jax.tree_util.register_pytree_node() 和 jax.tree.map() 来扩展将被视为 Pytree 中内部节点(Pytree 节点)的 Python 类型集合。
为什么需要这个?在前面的示例中,Pytrees 被显示为列表、元组和字典,其余的都被视为 Pytree 叶子。这是因为如果您定义了自己的容器类,它将被视为 Pytree 叶子,除非您将其注册到 JAX。即使您的容器类内部包含 Pytrees,也是如此。例如
class Special(object):
def __init__(self, x, y):
self.x = x
self.y = y
jax.tree.leaves([
Special(0, 1),
Special(2, 4),
])
[<__main__.Special at 0x78d8317e7f80>, <__main__.Special at 0x78d8183061b0>]
因此,如果您尝试使用 jax.tree.map() 并期望叶子是容器内的元素,您将收到一个错误。
jax.tree.map(lambda x: x + 1,
[
Special(0, 1),
Special(2, 4)
])
TypeError: unsupported operand type(s) for +: 'Special' and 'int'
作为一种解决方案,JAX 允许通过全局类型注册表来扩展被视为内部 Pytree 节点的类型集合。此外,注册类型的 Pytree 值会被递归遍历。
首先,使用 jax.tree_util.register_pytree_node() 注册一个新类型。
from jax.tree_util import register_pytree_node
class RegisteredSpecial(Special):
def __repr__(self):
return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)
def special_flatten(v):
"""Specifies a flattening recipe.
Params:
v: The value of the registered type to flatten.
Returns:
A pair of an iterable with the children to be flattened recursively,
and some opaque auxiliary data to pass back to the unflattening recipe.
The auxiliary data is stored in the treedef for use during unflattening.
The auxiliary data could be used, for example, for dictionary keys.
"""
children = (v.x, v.y)
aux_data = None
return (children, aux_data)
def special_unflatten(aux_data, children):
"""Specifies an unflattening recipe.
Params:
aux_data: The opaque data that was specified during flattening of the
current tree definition.
children: The unflattened children
Returns:
A reconstructed object of the registered type, using the specified
children and auxiliary data.
"""
return RegisteredSpecial(*children)
# Global registration
register_pytree_node(
RegisteredSpecial,
special_flatten, # Instruct JAX what are the children nodes.
special_unflatten # Instruct JAX how to pack back into a `RegisteredSpecial`.
)
现在您可以遍历特殊的容器结构了。
jax.tree.map(lambda x: x + 1,
[
RegisteredSpecial(0, 1),
RegisteredSpecial(2, 4),
])
[RegisteredSpecial(x=1, y=2), RegisteredSpecial(x=3, y=5)]
现代 Python 配备了有用的工具来简化容器的定义。其中一些可以开箱即用地与 JAX 一起使用,但另一些需要更多注意。
例如,Python 的 NamedTuple 子类不需要注册即可被视为 Pytree 节点类型。
from typing import NamedTuple, Any
class MyOtherContainer(NamedTuple):
name: str
a: Any
b: Any
c: Any
# NamedTuple subclasses are handled as pytree nodes, so
# this will work out-of-the-box.
jax.tree.leaves([
MyOtherContainer('Alice', 1, 2, 3),
MyOtherContainer('Bob', 4, 5, 6)
])
['Alice', 1, 2, 3, 'Bob', 4, 5, 6]
请注意,name 字段现在显示为一个叶子,因为所有元组元素都是其子节点。这就是不需要硬编码注册类时发生的情况。
与 NamedTuple 子类不同,用 @dataclass 装饰的类不是自动 Pytrees。但是,它们可以使用 jax.tree_util.register_dataclass() 装饰器注册为 Pytrees。
from dataclasses import dataclass
import functools
@functools.partial(jax.tree_util.register_dataclass,
data_fields=['a', 'b', 'c'],
meta_fields=['name'])
@dataclass
class MyDataclassContainer(object):
name: str
a: Any
b: Any
c: Any
# MyDataclassContainer is now a pytree node.
jax.tree.leaves([
MyDataclassContainer('apple', 5.3, 1.2, jnp.zeros([4])),
MyDataclassContainer('banana', np.array([3, 4]), -1., 0.)
])
[5.3, 1.2, Array([0., 0., 0., 0.], dtype=float32), array([3, 4]), -1.0, 0.0]
请注意,name 字段没有显示为叶子。这是因为我们在 jax.tree_util.register_dataclass() 的 meta_fields 参数中包含了它,表示它应该被视为元数据/辅助数据,就像上面 RegisteredSpecial 中的 aux_data 一样。现在 MyDataclassContainer 的实例可以传递给 JIT 编译的函数,并且 name 将被视为静态(有关静态参数的更多信息,请参阅 将参数标记为静态)。
@jax.jit
def f(x: MyDataclassContainer | MyOtherContainer):
return x.a + x.b
# Works fine! `mdc.name` is static.
mdc = MyDataclassContainer('mdc', 1, 2, 3)
y = f(mdc)
这与 MyOtherContainer(NamedTuple 子类)形成对比。由于 name 字段是一个 Pytree 叶子,JIT 期望它能够被转换为 jax.Array,因此以下代码会引发错误。
moc = MyOtherContainer('moc', 1, 2, 3)
y = f(moc)
TypeError: Error interpreting argument to <function f at 0x78d818351580> as an abstract array. The problematic value is of type <class 'str'> and was passed to the function at path x.name.
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.
Pytrees 与 JAX 变换#
许多 JAX 函数,如 jax.lax.scan(),都操作数组的 Pytrees。此外,所有 JAX 函数变换都可以应用于接受 Pytrees 数组作为输入并产生 Pytrees 数组作为输出的函数。
一些 JAX 函数变换接受可选参数,这些参数指定如何处理某些输入或输出值(例如,jax.vmap() 的 in_axes 和 out_axes 参数)。这些参数也可以是 Pytrees,并且它们的结构必须与相应参数的 Pytree 结构相对应。特别是,为了能够将这些参数 Pytrees 中的叶子与参数 Pytrees 中的值“匹配”起来,参数 Pytrees 通常被约束为参数 Pytrees 的树前缀。
例如,如果您将以下输入传递给 jax.vmap()(请注意,函数的输入参数被视为一个元组)
vmap(f, in_axes=(a1, {"k1": a2, "k2": a3}))
那么您可以使用以下 in_axes Pytree 来指定只有 k2 参数被映射(axis=0),而其他参数不被映射(axis=None)。
vmap(f, in_axes=(None, {"k1": None, "k2": 0}))
可选参数 Pytree 的结构必须与主输入 Pytree 的结构匹配。但是,可选参数可以选择性地指定为“前缀” Pytree,这意味着单个叶子值可以应用于整个子 Pytree。
例如,如果您有与上面相同的 jax.vmap() 输入,但希望只映射字典参数,您可以使用
vmap(f, in_axes=(None, 0)) # equivalent to (None, {"k1": 0, "k2": 0})
或者,如果您希望映射每个参数,您可以编写一个单个叶子值,该值应用于整个参数元组 Pytree。
vmap(f, in_axes=0) # equivalent to (0, {"k1": 0, "k2": 0})
这恰好是 jax.vmap() 的默认 in_axes 值。
同样的逻辑适用于其他可选参数,这些参数引用了变换函数的特定输入或输出值,例如 jax.vmap() 中的 out_axes。
显式键路径#
在 Pytree 中,每个叶子都有一个键路径。叶子的键路径是一个键的 list,其中列表的长度等于叶子在 Pytree 中的深度。每个键是一个 可哈希对象,它代表了对相应 Pytree 节点类型的索引。键的类型取决于 Pytree 节点类型;例如,dict 的键类型与 tuple 的键类型不同。
对于内置 Pytree 节点类型,任何 Pytree 节点实例的键集都是唯一的。对于包含具有此属性的节点的 Pytree,每个叶子的键路径都是唯一的。
JAX 提供了以下 jax.tree_util.* 方法来处理键路径。
jax.tree_util.tree_flatten_with_path():功能类似于jax.tree.flatten(),但返回键路径。jax.tree_util.tree_map_with_path():功能类似于jax.tree.map(),但该函数也接受键路径作为参数。jax.tree_util.keystr():给定一个通用键路径,返回一个易于阅读的字符串表达式。
例如,一个用例是打印与某个叶子值相关的调试信息。
import collections
ATuple = collections.namedtuple("ATuple", ('name'))
tree = [1, {'k1': 2, 'k2': (3, 4)}, ATuple('foo')]
flattened, _ = jax.tree_util.tree_flatten_with_path(tree)
for key_path, value in flattened:
print(f'Value of tree{jax.tree_util.keystr(key_path)}: {value}')
Value of tree[0]: 1
Value of tree[1]['k1']: 2
Value of tree[1]['k2'][0]: 3
Value of tree[1]['k2'][1]: 4
Value of tree[2].name: foo
为了表示键路径,JAX 为内置 Pytree 节点类型提供了一些默认的键类型,即
SequenceKey(idx: int):用于列表和元组。DictKey(key: Hashable):用于字典。GetAttrKey(name: str):用于namedtuples,以及最好是自定义 Pytree 节点(更多内容请参阅下一节)。
您可以自由地为自定义节点定义自己的键类型。只要它们的 __str__() 方法也被重写为易于阅读的表达式,它们就可以与 jax.tree_util.keystr() 一起使用。
for key_path, _ in flattened:
print(f'Key path of tree{jax.tree_util.keystr(key_path)}: {repr(key_path)}')
Key path of tree[0]: (SequenceKey(idx=0),)
Key path of tree[1]['k1']: (SequenceKey(idx=1), DictKey(key='k1'))
Key path of tree[1]['k2'][0]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=0))
Key path of tree[1]['k2'][1]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=1))
Key path of tree[2].name: (SequenceKey(idx=2), GetAttrKey(name='name'))
常见的 Pytree 陷阱#
本节涵盖了在使用 JAX Pytrees 时遇到的一些最常见的问题(“陷阱”)。
将 Pytree 节点误认为叶子节点#
一个常见的需要注意的陷阱是意外地引入了树节点而不是叶子。
a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))]
# Try to make another pytree with ones instead of zeros.
shapes = jax.tree.map(lambda x: x.shape, a_tree)
jax.tree.map(jnp.ones, shapes)
[(Array([1., 1.], dtype=float32), Array([1., 1., 1.], dtype=float32)),
(Array([1., 1., 1.], dtype=float32), Array([1., 1., 1., 1.], dtype=float32))]
这里发生的是,数组的 shape 是一个元组,它是一个 Pytree 节点,其元素是叶子。因此,在 map 中,它不是对例如 (2, 3) 调用 jnp.ones,而是对 2 和 3 调用。
解决方案将取决于具体情况,但有两种广泛适用的选项:
重写代码以避免中间的
jax.tree.map()。将元组转换为 NumPy 数组(
np.array)或 JAX NumPy 数组(jnp.array),使整个序列成为一个叶子。
jax.tree_util 对 None 的处理#
jax.tree_util 函数将 None 视为 Pytree 节点的缺失,而不是叶子。
jax.tree.leaves([None, None, None])
[]
要将 None 视为叶子,您可以使用 is_leaf 参数。
jax.tree.leaves([None, None, None], is_leaf=lambda x: x is None)
[None, None, None]
自定义 Pytrees 和使用意外值进行初始化#
用户定义的 Pytree 对象遇到的另一个常见陷阱是 JAX 变换有时会使用意外值对其进行初始化,因此任何在初始化时进行的输入验证都可能失败。例如
class MyTree:
def __init__(self, a):
self.a = jnp.asarray(a)
register_pytree_node(MyTree, lambda tree: ((tree.a,), None),
lambda _, args: MyTree(*args))
tree = MyTree(jnp.arange(5.0))
jax.vmap(lambda x: x)(tree) # Error because object() is passed to `MyTree`.
<__main__.MyTree at 0x78d818307b00>
jax.jacobian(lambda x: x)(tree) # Error because MyTree(...) is passed to `MyTree`.
ValueError: None is not a valid value for jnp.array
在第一种情况
jax.vmap(...)(tree)中,JAX 的内部使用object()值的数组来推断 Pytree 的结构。在第二种情况
jax.jacobian(...)(tree)中,一个将 Pytree 映射到 Pytree 的函数的雅可比矩阵被定义为一个 Pytree 的 Pytree。
潜在解决方案 1
自定义 Pytree 类的
__init__和__new__方法通常应避免进行任何数组转换或其他输入验证,或者预料并处理这些特殊情况。例如。
class MyTree:
def __init__(self, a):
if not (type(a) is object or a is None or isinstance(a, MyTree)):
a = jnp.asarray(a)
self.a = a
潜在解决方案 2
将自定义的
tree_unflatten函数结构化,使其避免调用__init__。如果您选择此路径,请确保您的tree_unflatten函数在代码更新时保持与__init__同步。示例。
def tree_unflatten(aux_data, children):
del aux_data # Unused in this class.
obj = object.__new__(MyTree)
obj.a = a
return obj
常见的 Pytree 模式#
本节涵盖了 JAX Pytrees 的一些最常见模式。
使用 jax.tree.map 和 jax.tree.transpose 转置 Pytrees#
要转置 Pytree(将 Pytree 列表转换为 Pytree 列表),JAX 有两个函数:jax.tree.map()(更基础)和 jax.tree.transpose()(更灵活、复杂和冗长)。
选项 1: 使用 jax.tree.map()。这里有一个例子。
def tree_transpose(list_of_trees):
"""
Converts a list of trees of identical structure into a single tree of lists.
"""
return jax.tree.map(lambda *xs: list(xs), *list_of_trees)
# Convert a dataset from row-major to column-major.
episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)]
tree_transpose(episode_steps)
{'obs': [3, 4], 't': [1, 2]}
选项 2: 对于更复杂的转置,使用 jax.tree.transpose(),它更冗长,但允许您指定内部和外部 Pytree 的结构以获得更大的灵活性。例如。
jax.tree.transpose(
outer_treedef = jax.tree.structure([0 for e in episode_steps]),
inner_treedef = jax.tree.structure(episode_steps[0]),
pytree_to_transpose = episode_steps
)
{'obs': [3, 4], 't': [1, 2]}