Pytree#

什么是 Pytree?#

在 JAX 中,我们使用术语 pytree 来指代由容器型 Python 对象构建的树状结构。如果类在 Pytree 注册表中,则被认为是容器型的,默认情况下包括列表(lists)、元组(tuples)和字典(dicts)。也就是说

  1. 任何类型 在 Pytree 容器注册表中的对象都被视为 叶子 Pytree;

  2. 任何类型在 Pytree 容器注册表中且包含 Pytree 的对象,都被视为 Pytree。

对于 Pytree 容器注册表中的每个条目,都会注册一个容器型类型以及一对函数,这些函数指定如何将容器类型的实例转换为 (children, metadata) 对,以及如何将这样的对转换回容器类型的实例。利用这些函数,JAX 可以将任何注册容器对象的树规范化为元组。

Pytree 示例

[1, "a", object()]  # 3 leaves

(1, (2, 3), ())  # 3 leaves

[1, {"k1": 2, "k2": (3, 4)}, 5]  # 5 leaves

JAX 可以扩展为将其他容器类型视为 Pytree;参见下方的 扩展 Pytree

Pytree 和 JAX 函数#

许多 JAX 函数,例如 jax.lax.scan(),对数组 Pytree 进行操作。JAX 函数变换可以应用于接受数组 Pytree 作为输入并生成数组 Pytree 作为输出的函数。

将可选参数应用于 Pytree#

一些 JAX 函数变换接受可选参数,这些参数指定如何处理某些输入或输出值(例如 vmap()in_axesout_axes 参数)。这些参数也可以是 Pytree,它们的结构必须与相应参数的 Pytree 结构相对应。特别地,为了能够将这些参数 Pytree 中的叶子与参数 Pytree 中的值“匹配”起来,参数 Pytree 通常被限制为参数 Pytree 的树前缀。

例如,如果我们向 vmap() 传递以下输入(请注意,函数的输入参数被视为一个元组)

(a1, {"k1": a2, "k2": a3})

我们可以使用以下 in_axes Pytree 来指定只有 k2 参数被映射(axis=0),其余参数不被映射(axis=None

(None, {"k1": None, "k2": 0})

可选参数 Pytree 的结构必须与主输入 Pytree 的结构匹配。然而,可选参数也可以指定为“前缀”Pytree,这意味着单个叶子值可以应用于整个子 Pytree。例如,如果我们有与上面相同的 vmap() 输入,但希望只映射字典参数,我们可以使用

(None, 0)  # equivalent to (None, {"k1": 0, "k2": 0})

或者,如果我们希望每个参数都被映射,我们可以简单地编写一个应用于整个参数元组 Pytree 的单个叶子值

0

这恰好是 vmap() 的默认 in_axes 值!

同样的逻辑也适用于引用变换函数特定输入或输出值的其他可选参数,例如 vmapout_axes

查看对象的 Pytree 定义#

为了调试目的,查看任意 object 的 Pytree 定义,您可以使用

from jax.tree_util import tree_structure
print(tree_structure(object))

开发者信息#

这主要是 JAX 内部文档,终端用户通常不需要理解这些内容即可使用 JAX,除非他们要向 JAX 注册新的用户定义容器类型。其中一些细节可能会发生变化。

Pytree 内部处理#

JAX 在 api.py 边界(以及控制流原语中)将 Pytree 展平为叶子列表。这使得下游 JAX 内部更简单:grad()jit()vmap() 等变换可以处理接受和返回各种不同 Python 容器的用户函数,而系统的所有其他部分可以对只接受(多个)数组参数并始终返回扁平数组列表的函数进行操作。

当 JAX 展平一个 Pytree 时,它会生成一个叶子列表和一个编码原始值结构的 treedef 对象。然后可以使用 treedef 在变换叶子后构造一个匹配的结构化值。Pytree 是树状的,而不是 DAG(有向无环图)或图状的,因为我们处理它们时假定引用透明性,并且它们不能包含引用循环。

下面是一个简单示例

from jax.tree_util import tree_flatten, tree_unflatten
import jax.numpy as jnp

# The structured value to be transformed
value_structured = [1., (2., 3.)]

# The leaves in value_flat correspond to the `*` markers in value_tree
value_flat, value_tree = tree_flatten(value_structured)
print(f"{value_flat=}\n{value_tree=}")

# Transform the flat value list using an element-wise numeric transformer
transformed_flat = list(map(lambda v: v * 2., value_flat))
print(f"{transformed_flat=}")

# Reconstruct the structured output, using the original
transformed_structured = tree_unflatten(value_tree, transformed_flat)
print(f"{transformed_structured=}")
value_flat=[1.0, 2.0, 3.0]
value_tree=PyTreeDef([*, (*, *)])
transformed_flat=[2.0, 4.0, 6.0]
transformed_structured=[2.0, (4.0, 6.0)]

默认情况下,Pytree 容器可以是 lists、tuples、dicts、namedtuple、None、OrderedDict。其他类型的值,包括数值和 ndarray 值,都被视为叶子。

from collections import namedtuple
Point = namedtuple('Point', ['x', 'y'])

example_containers = [
    (1., [2., 3.]),
    (1., {'b': 2., 'a': 3.}),
    1.,
    None,
    jnp.zeros(2),
    Point(1., 2.)
]
def show_example(structured):
  flat, tree = tree_flatten(structured)
  unflattened = tree_unflatten(tree, flat)
  print(f"{structured=}\n  {flat=}\n  {tree=}\n  {unflattened=}")

for structured in example_containers:
  show_example(structured)
structured=(1.0, [2.0, 3.0])
  flat=[1.0, 2.0, 3.0]
  tree=PyTreeDef((*, [*, *]))
  unflattened=(1.0, [2.0, 3.0])
structured=(1.0, {'b': 2.0, 'a': 3.0})
  flat=[1.0, 3.0, 2.0]
  tree=PyTreeDef((*, {'a': *, 'b': *}))
  unflattened=(1.0, {'a': 3.0, 'b': 2.0})
structured=1.0
  flat=[1.0]
  tree=PyTreeDef(*)
  unflattened=1.0
structured=None
  flat=[]
  tree=PyTreeDef(None)
  unflattened=None
structured=Array([0., 0.], dtype=float32)
  flat=[Array([0., 0.], dtype=float32)]
  tree=PyTreeDef(*)
  unflattened=Array([0., 0.], dtype=float32)
structured=Point(x=1.0, y=2.0)
  flat=[1.0, 2.0]
  tree=PyTreeDef(CustomNode(namedtuple[Point], [*, *]))
  unflattened=Point(x=1.0, y=2.0)

扩展 Pytree#

默认情况下,结构化值中任何未被识别为内部 Pytree 节点(即容器型)的部分都被视为叶子。

class Special(object):
  def __init__(self, x, y):
    self.x = x
    self.y = y

  def __repr__(self):
    return "Special(x={}, y={})".format(self.x, self.y)


show_example(Special(1., 2.))
structured=Special(x=1.0, y=2.0)
  flat=[Special(x=1.0, y=2.0)]
  tree=PyTreeDef(*)
  unflattened=Special(x=1.0, y=2.0)

被视为内部 Pytree 节点的 Python 类型集是可扩展的,通过一个全局类型注册表,注册类型的值会被递归遍历。要注册新类型,可以使用 register_pytree_node()

from jax.tree_util import register_pytree_node

class RegisteredSpecial(Special):
  def __repr__(self):
    return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)

def special_flatten(v):
  """Specifies a flattening recipe.

  Params:
    v: the value of registered type to flatten.
  Returns:
    a pair of an iterable with the children to be flattened recursively,
    and some opaque auxiliary data to pass back to the unflattening recipe.
    The auxiliary data is stored in the treedef for use during unflattening.
    The auxiliary data could be used, e.g., for dictionary keys.
  """
  children = (v.x, v.y)
  aux_data = None
  return (children, aux_data)

def special_unflatten(aux_data, children):
  """Specifies an unflattening recipe.

  Params:
    aux_data: the opaque data that was specified during flattening of the
      current treedef.
    children: the unflattened children

  Returns:
    a re-constructed object of the registered type, using the specified
    children and auxiliary data.
  """
  return RegisteredSpecial(*children)

# Global registration
register_pytree_node(
    RegisteredSpecial,
    special_flatten,    # tell JAX what are the children nodes
    special_unflatten   # tell JAX how to pack back into a RegisteredSpecial
)

show_example(RegisteredSpecial(1., 2.))
structured=RegisteredSpecial(x=1.0, y=2.0)
  flat=[1.0, 2.0]
  tree=PyTreeDef(CustomNode(RegisteredSpecial[None], [*, *]))
  unflattened=RegisteredSpecial(x=1.0, y=2.0)

另外,您可以在类上定义适当的 tree_flattentree_unflatten 方法,并用 register_pytree_node_class() 进行装饰。

from jax.tree_util import register_pytree_node_class

@register_pytree_node_class
class RegisteredSpecial2(Special):
  def __repr__(self):
    return "RegisteredSpecial2(x={}, y={})".format(self.x, self.y)

  def tree_flatten(self):
    children = (self.x, self.y)
    aux_data = None
    return (children, aux_data)

  @classmethod
  def tree_unflatten(cls, aux_data, children):
    return cls(*children)

show_example(RegisteredSpecial2(1., 2.))
structured=RegisteredSpecial2(x=1.0, y=2.0)
  flat=[1.0, 2.0]
  tree=PyTreeDef(CustomNode(RegisteredSpecial2[None], [*, *]))
  unflattened=RegisteredSpecial2(x=1.0, y=2.0)

在定义 unflattening 函数时,通常 children 应包含数据结构的所有动态元素(数组、动态标量和 Pytree),而 aux_data 应包含所有将并入 treedef 结构中的静态元素。JAX 有时需要比较 treedef 的相等性,或者计算其哈希值以用于 JIT 缓存,因此必须注意确保展平方案中指定的辅助数据支持有意义的哈希和相等性比较。

所有用于操作 Pytree 的函数都位于 jax.tree_util 中。

自定义 PyTree 和初始化#

用户定义的 PyTree 对象一个常见的陷阱是,JAX 变换有时会用意想不到的值初始化它们,从而可能导致初始化时进行的任何输入验证失败。例如

class MyTree:
  def __init__(self, a):
    self.a = jnp.asarray(a)

register_pytree_node(MyTree, lambda tree: ((tree.a,), None),
    lambda _, args: MyTree(*args))

tree = MyTree(jnp.arange(5.0))

jax.vmap(lambda x: x)(tree)      # Error because object() is passed to MyTree.
jax.jacobian(lambda x: x)(tree)  # Error because MyTree(...) is passed to MyTree

在第一种情况下,JAX 的内部机制使用 object() 值的数组来推断树的结构;在第二种情况下,将树映射到树的函数的雅可比(jacobian)被定义为树的树。

因此,自定义 PyTree 类的 __init____new__ 方法通常应避免执行任何数组转换或其他输入验证,否则需要预见并处理这些特殊情况。例如

class MyTree:
  def __init__(self, a):
    if not (type(a) is object or a is None or isinstance(a, MyTree)):
      a = jnp.asarray(a)
    self.a = a

另一种可能性是构建您的 tree_unflatten 函数,使其避免调用 __init__;例如

def tree_unflatten(aux_data, children):
  del aux_data  # unused in this class
  obj = object.__new__(MyTree)
  obj.a = a
  return obj

如果您选择这条路径,请确保在代码更新时,您的 tree_unflatten 函数与 __init__ 保持同步。