使用pytree#

JAX内置支持看起来像数组字典(dicts)、列表的列表的字典或其他嵌套结构的对象——在JAX中,这些对象被称为pytree。本节将解释如何使用它们,提供有用的代码示例,并指出常见的“陷阱”和模式。

什么是pytree?#

Pytree是一种由容器类Python对象构建的类似容器的结构——“叶子”pytree和/或更多pytree。Pytree可以包括列表、元组和字典。叶子是任何非pytree的对象,例如数组,但单个叶子也是一个pytree。

在机器学习(ML)中,pytree可以包含

  • 模型参数

  • 数据集条目

  • 强化学习智能体观测值

在使用数据集时,您经常会遇到pytree(例如列表的列表的字典)。

下面是一个简单pytree的示例。在JAX中,您可以使用jax.tree.leaves(),从树中提取扁平化的叶子,如下所示

import jax
import jax.numpy as jnp

example_trees = [
    [1, 'a', object()],
    (1, (2, 3), ()),
    [1, {'k1': 2, 'k2': (3, 4)}, 5],
    {'a': 2, 'b': (2, 3)},
    jnp.array([1, 2, 3]),
]

# Print how many leaves the pytrees have.
for pytree in example_trees:
  # This `jax.tree.leaves()` method extracts the flattened leaves from the pytrees.
  leaves = jax.tree.leaves(pytree)
  print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")
[1, 'a', <object object at 0x75acb64a4070>]   has 3 leaves: [1, 'a', <object object at 0x75acb64a4070>]
(1, (2, 3), ())                               has 3 leaves: [1, 2, 3]
[1, {'k1': 2, 'k2': (3, 4)}, 5]               has 5 leaves: [1, 2, 3, 4, 5]
{'a': 2, 'b': (2, 3)}                         has 3 leaves: [2, 2, 3]
Array([1, 2, 3], dtype=int32)                 has 1 leaves: [Array([1, 2, 3], dtype=int32)]

任何由容器类Python对象构建的树状结构都可以在JAX中被视为pytree。如果类在pytree注册表中,则它们被认为是容器类的,默认情况下包括列表、元组和字典。任何其类型不在pytree容器注册表中的对象都将被视为树中的叶节点。

pytree注册表可以通过注册类并提供指定如何展平树的函数来扩展,以包含用户定义的容器类;请参阅下面的自定义pytree节点

常用pytree函数#

JAX提供了许多用于操作pytree的实用工具。这些工具可以在jax.tree_util子包中找到;为了方便,其中许多在jax.tree模块中也有别名。

常用函数:jax.tree.map#

最常用的pytree函数是jax.tree.map()。它与Python原生的map类似,但能透明地操作整个pytree。

这里有一个示例

list_of_lists = [
    [1, 2, 3],
    [1, 2],
    [1, 2, 3, 4]
]

jax.tree.map(lambda x: x*2, list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

jax.tree.map()还允许将一个N元函数映射到多个参数上。例如

another_list_of_lists = list_of_lists
jax.tree.map(lambda x, y: x+y, list_of_lists, another_list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

jax.tree.map()使用多个参数时,输入结构必须完全匹配。即列表必须具有相同数量的元素,字典必须具有相同的键等。

使用jax.tree.map处理ML模型参数的示例#

本示例演示了在训练简单的多层感知器(MLP)时,pytree操作如何发挥作用。

首先定义初始模型参数

import numpy as np

def init_mlp_params(layer_widths):
  params = []
  for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
    params.append(
        dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),
             biases=np.ones(shape=(n_out,))
            )
    )
  return params

params = init_mlp_params([1, 128, 128, 1])

使用jax.tree.map()检查初始参数的形状

jax.tree.map(lambda x: x.shape, params)
[{'biases': (128,), 'weights': (1, 128)},
 {'biases': (128,), 'weights': (128, 128)},
 {'biases': (1,), 'weights': (128, 1)}]

接下来,定义用于训练MLP模型的函数

# Define the forward pass.
def forward(params, x):
  *hidden, last = params
  for layer in hidden:
    x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
  return x @ last['weights'] + last['biases']

# Define the loss function.
def loss_fn(params, x, y):
  return jnp.mean((forward(params, x) - y) ** 2)

# Set the learning rate.
LEARNING_RATE = 0.0001

# Using the stochastic gradient descent, define the parameter update function.
# Apply `@jax.jit` for JIT compilation (speed).
@jax.jit
def update(params, x, y):
  # Calculate the gradients with `jax.grad`.
  grads = jax.grad(loss_fn)(params, x, y)
  # Note that `grads` is a pytree with the same structure as `params`.
  # `jax.grad` is one of many JAX functions that has
  # built-in support for pytrees.
  # This is useful - you can apply the SGD update using JAX pytree utilities.
  return jax.tree.map(
      lambda p, g: p - LEARNING_RATE * g, params, grads
  )

自定义pytree节点#

本节解释了在JAX中,您如何通过使用jax.tree_util.register_pytree_node()jax.tree.map()来扩展在pytree中被视为内部节点的Python类型集(pytree节点)。

为什么您需要这样做?在之前的示例中,pytree显示为列表、元组和字典,而其他所有内容都是pytree叶子。这是因为如果您定义自己的容器类,除非您将其注册到JAX,否则它将被视为pytree叶子。即使您的容器类内部包含树,情况也是如此。例如

class Special(object):
  def __init__(self, x, y):
    self.x = x
    self.y = y

jax.tree.leaves([
    Special(0, 1),
    Special(2, 4),
])
[<__main__.Special at 0x75acb420d5b0>, <__main__.Special at 0x75acb420d610>]

因此,如果您尝试使用jax.tree.map()并期望叶子是容器内部的元素,您将得到一个错误

jax.tree.map(lambda x: x + 1,
  [
    Special(0, 1),
    Special(2, 4)
  ])
TypeError: unsupported operand type(s) for +: 'Special' and 'int'

作为解决方案,JAX允许通过全局类型注册表扩展被视为内部pytree节点的类型集。此外,注册类型的值会被递归遍历。

首先,使用jax.tree_util.register_pytree_node()注册一个新类型

from jax.tree_util import register_pytree_node

class RegisteredSpecial(Special):
  def __repr__(self):
    return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)

def special_flatten(v):
  """Specifies a flattening recipe.

  Params:
    v: The value of the registered type to flatten.
  Returns:
    A pair of an iterable with the children to be flattened recursively,
    and some opaque auxiliary data to pass back to the unflattening recipe.
    The auxiliary data is stored in the treedef for use during unflattening.
    The auxiliary data could be used, for example, for dictionary keys.
  """
  children = (v.x, v.y)
  aux_data = None
  return (children, aux_data)

def special_unflatten(aux_data, children):
  """Specifies an unflattening recipe.

  Params:
    aux_data: The opaque data that was specified during flattening of the
      current tree definition.
    children: The unflattened children

  Returns:
    A reconstructed object of the registered type, using the specified
    children and auxiliary data.
  """
  return RegisteredSpecial(*children)

# Global registration
register_pytree_node(
    RegisteredSpecial,
    special_flatten,    # Instruct JAX what are the children nodes.
    special_unflatten   # Instruct JAX how to pack back into a `RegisteredSpecial`.
)

现在您可以遍历特殊的容器结构

jax.tree.map(lambda x: x + 1,
  [
   RegisteredSpecial(0, 1),
   RegisteredSpecial(2, 4),
  ])
[RegisteredSpecial(x=1, y=2), RegisteredSpecial(x=3, y=5)]

现代Python配备了有用的工具,使容器的定义更加容易。有些可以直接与JAX配合使用,而另一些则需要更多关注。

例如,Python的NamedTuple子类不需要注册即可被视为pytree节点类型

from typing import NamedTuple, Any

class MyOtherContainer(NamedTuple):
  name: str
  a: Any
  b: Any
  c: Any

# NamedTuple subclasses are handled as pytree nodes, so
# this will work out-of-the-box.
jax.tree.leaves([
    MyOtherContainer('Alice', 1, 2, 3),
    MyOtherContainer('Bob', 4, 5, 6)
])
['Alice', 1, 2, 3, 'Bob', 4, 5, 6]

请注意,name字段现在显示为叶子,因为所有元组元素都是子项。这就是您不必以困难方式注册类时发生的情况。

NamedTuple子类不同,使用@dataclass装饰的类不会自动成为pytree。但是,可以使用jax.tree_util.register_dataclass()装饰器将其注册为pytree

from dataclasses import dataclass
import functools

@functools.partial(jax.tree_util.register_dataclass,
                   data_fields=['a', 'b', 'c'],
                   meta_fields=['name'])
@dataclass
class MyDataclassContainer(object):
  name: str
  a: Any
  b: Any
  c: Any

# MyDataclassContainer is now a pytree node.
jax.tree.leaves([
  MyDataclassContainer('apple', 5.3, 1.2, jnp.zeros([4])),
  MyDataclassContainer('banana', np.array([3, 4]), -1., 0.)
])
[5.3, 1.2, Array([0., 0., 0., 0.], dtype=float32), array([3, 4]), -1.0, 0.0]

请注意,name字段没有显示为叶子。这是因为我们将其包含在jax.tree_util.register_dataclass()meta_fields参数中,表明它应被视为元数据/辅助数据,就像上面RegisteredSpecial中的aux_data一样。现在,MyDataclassContainer的实例可以传递给JIT编译的函数,并且name将被视为静态(有关静态参数的更多信息,请参阅将参数标记为静态

@jax.jit
def f(x: MyDataclassContainer | MyOtherContainer):
  return x.a + x.b

# Works fine! `mdc.name` is static.
mdc = MyDataclassContainer('mdc', 1, 2, 3)
y = f(mdc)

MyOtherContainerNamedTuple子类)相比,由于name字段是一个pytree叶子,JIT期望它能够转换为jax.Array,因此以下代码会引发错误

moc = MyOtherContainer('moc', 1, 2, 3)
y = f(moc)
TypeError: Error interpreting argument to <function f at 0x75acb4217c40> as an abstract array. The problematic value is of type <class 'str'> and was passed to the function at path x.name.
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.

Pytree与JAX变换#

许多JAX函数,如jax.lax.scan(),都在数组的pytree上操作。此外,所有JAX函数变换都可以应用于接受数组pytree作为输入并生成数组pytree作为输出的函数。

一些JAX函数变换接受可选参数,这些参数指定如何处理某些输入或输出值(例如jax.vmap()in_axesout_axes参数)。这些参数也可以是pytree,并且它们的结构必须与相应参数的pytree结构相对应。特别是,为了能够将这些参数pytree中的叶子与参数pytree中的值“匹配”起来,参数pytree通常被限制为参数pytree的树前缀。

例如,如果您将以下输入传递给jax.vmap()(请注意,函数的输入参数被视为一个元组)

vmap(f, in_axes=(a1, {"k1": a2, "k2": a3}))

那么您可以使用以下in_axes pytree来指定只有k2参数被映射(axis=0),其余的则不被映射(axis=None

vmap(f, in_axes=(None, {"k1": None, "k2": 0}))

可选参数pytree的结构必须与主输入pytree的结构匹配。但是,可选参数也可以选择指定为“前缀”pytree,这意味着单个叶值可以应用于整个子pytree。

例如,如果您有与上面相同的jax.vmap()输入,但希望只映射字典参数,您可以使用

vmap(f, in_axes=(None, 0))  # equivalent to (None, {"k1": 0, "k2": 0})

或者,如果您希望所有参数都被映射,您可以编写一个单个叶值,该值应用于整个参数元组pytree

vmap(f, in_axes=0)  # equivalent to (0, {"k1": 0, "k2": 0})

这恰好是jax.vmap()的默认in_axes值。

同样的逻辑适用于指代转换函数特定输入或输出值的其他可选参数,例如jax.vmap()中的out_axes

显式键路径#

在pytree中,每个叶子都有一个键路径。叶子的键路径是一个list,其中列表的长度等于叶子在pytree中的深度。每个都是一个可哈希对象,表示对应pytree节点类型的索引。键的类型取决于pytree节点类型;例如,dict的键类型与tuple的键类型不同。

对于内置的pytree节点类型,任何pytree节点实例的键集都是唯一的。对于包含具有此属性的节点的pytree,每个叶子的键路径都是唯一的。

JAX提供了以下jax.tree_util.*方法来处理键路径

例如,一个用例是打印与特定叶值相关的调试信息

import collections

ATuple = collections.namedtuple("ATuple", ('name'))

tree = [1, {'k1': 2, 'k2': (3, 4)}, ATuple('foo')]
flattened, _ = jax.tree_util.tree_flatten_with_path(tree)

for key_path, value in flattened:
  print(f'Value of tree{jax.tree_util.keystr(key_path)}: {value}')
Value of tree[0]: 1
Value of tree[1]['k1']: 2
Value of tree[1]['k2'][0]: 3
Value of tree[1]['k2'][1]: 4
Value of tree[2].name: foo

为了表达键路径,JAX为内置pytree节点类型提供了几种默认键类型,即

  • SequenceKey(idx: int):用于列表和元组。

  • DictKey(key: Hashable):用于字典。

  • GetAttrKey(name: str):用于namedtuple以及最好是自定义pytree节点(下一节有更多内容)

您可以自由为自定义节点定义自己的键类型。只要它们的__str__()方法也被易于阅读的表达式覆盖,它们就可以与jax.tree_util.keystr()一起使用。

for key_path, _ in flattened:
  print(f'Key path of tree{jax.tree_util.keystr(key_path)}: {repr(key_path)}')
Key path of tree[0]: (SequenceKey(idx=0),)
Key path of tree[1]['k1']: (SequenceKey(idx=1), DictKey(key='k1'))
Key path of tree[1]['k2'][0]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=0))
Key path of tree[1]['k2'][1]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=1))
Key path of tree[2].name: (SequenceKey(idx=2), GetAttrKey(name='name'))

常见pytree陷阱#

本节涵盖了使用JAX pytree时遇到的一些最常见问题(“陷阱”)。

将pytree节点误认为是叶子#

一个常见的需要注意的陷阱是意外地引入了树节点而不是叶子

a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))]

# Try to make another pytree with ones instead of zeros.
shapes = jax.tree.map(lambda x: x.shape, a_tree)
jax.tree.map(jnp.ones, shapes)
[(Array([1., 1.], dtype=float32), Array([1., 1., 1.], dtype=float32)),
 (Array([1., 1., 1.], dtype=float32), Array([1., 1., 1., 1.], dtype=float32))]

这里发生的情况是,数组的shape是一个元组,它是一个pytree节点,其元素是叶子。因此,在映射中,不是对例如(2, 3)调用jnp.ones,而是对23调用。

解决方案将取决于具体情况,但有两个广泛适用的选项

  • 重写代码以避免中间的jax.tree.map()

  • 将元组转换为NumPy数组(np.array)或JAX NumPy数组(jnp.array),这会使整个序列成为一个叶子。

jax.tree_utilNone的处理#

jax.tree_util函数将None视为pytree节点的缺失,而不是叶子

jax.tree.leaves([None, None, None])
[]

要将None视为叶子,可以使用is_leaf参数

jax.tree.leaves([None, None, None], is_leaf=lambda x: x is None)
[None, None, None]

自定义pytree及使用意外值初始化#

用户自定义pytree对象的另一个常见陷阱是,JAX变换有时会用意外值初始化它们,从而导致初始化时进行的任何输入验证都可能失败。例如

class MyTree:
  def __init__(self, a):
    self.a = jnp.asarray(a)

register_pytree_node(MyTree, lambda tree: ((tree.a,), None),
    lambda _, args: MyTree(*args))

tree = MyTree(jnp.arange(5.0))

jax.vmap(lambda x: x)(tree)      # Error because object() is passed to `MyTree`.
TypeError: Value '<object object at 0x75acf00eb9d0>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
jax.jacobian(lambda x: x)(tree)  # Error because MyTree(...) is passed to `MyTree`.
ValueError: None is not a valid value for jnp.array
  • 在第一个案例中,使用jax.vmap(...)(tree)时,JAX的内部使用object()值的数组来推断树的结构

  • 在第二个案例中,使用jax.jacobian(...)(tree)时,将树映射到树的函数的雅可比矩阵被定义为树的树。

潜在解决方案1

  • 自定义pytree类的__init____new__方法通常应避免进行任何数组转换或其他输入验证,否则应预料并处理这些特殊情况。例如

class MyTree:
  def __init__(self, a):
    if not (type(a) is object or a is None or isinstance(a, MyTree)):
      a = jnp.asarray(a)
    self.a = a

潜在解决方案2

  • 构建您的自定义tree_unflatten函数,使其避免调用__init__。如果您选择此路径,请确保您的tree_unflatten函数在代码更新时与__init__保持同步。示例

def tree_unflatten(aux_data, children):
  del aux_data  # Unused in this class.
  obj = object.__new__(MyTree)
  obj.a = a
  return obj

常见pytree模式#

本节涵盖了JAX pytree的一些最常见模式。

使用jax.tree.mapjax.tree.transpose转置pytree#

要转置pytree(将树列表转换为列表树),JAX有两个函数:jax.tree.map()(更基础)和jax.tree.transpose()(更灵活、复杂和冗长)。

选项1:使用jax.tree.map()。这里有一个示例

def tree_transpose(list_of_trees):
  """
  Converts a list of trees of identical structure into a single tree of lists.
  """
  return jax.tree.map(lambda *xs: list(xs), *list_of_trees)

# Convert a dataset from row-major to column-major.
episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)]
tree_transpose(episode_steps)
{'obs': [3, 4], 't': [1, 2]}

选项2:对于更复杂的转置,使用jax.tree.transpose(),它更冗长,但允许您指定内部和外部pytree的结构,从而提供更大的灵活性。例如

jax.tree.transpose(
  outer_treedef = jax.tree.structure([0 for e in episode_steps]),
  inner_treedef = jax.tree.structure(episode_steps[0]),
  pytree_to_transpose = episode_steps
)
{'obs': [3, 4], 't': [1, 2]}