jax.tree_util.register_dataclass#

jax.tree_util.register_dataclass(nodetype, data_fields=None, meta_fields=None, drop_fields=())[source]#

扩展了在 pytrees 中被视为内部节点的类型集合。

这与 register_pytree_with_keys_class 的不同之处在于,C++ 注册表使用优化的 C++ dataclass 内置函数而不是参数函数。

有关注册 pytrees 的更多信息,请参阅 扩展 pytrees

参数:
  • nodetype (Typ) – 要视为内部 pytree 节点的 Python 类型。这被假定为具有 dataclass 的语义:即,类属性代表对象状态的整体,并且可以作为关键字传递给类构造函数以创建对象的副本。所有定义的属性都应在 meta_fieldsdata_fields 中列出。

  • meta_fields (Sequence[str] | None | None) – 元数据字段名称:当此 pytree 传递给 jax.jit() 时,这些属性将被视为 {term}`static`。meta_fields 仅当 nodetype 是 dataclass 时才是可选的,在这种情况下,可以通过 dataclasses.field() 将各个字段标记为静态(请参阅下面的示例)。元数据字段必须是静态的、可哈希的、不可变的对象,因为这些对象用于生成 JIT 缓存键。特别是,元数据字段不能包含 jax.Arraynumpy.ndarray 对象。

  • data_fields (Sequence[str] | None | None) – 数据字段名称:当此 pytree 传递给 jax.jit() 时,这些属性将被视为非静态。data_fields 仅当 nodetype 是 dataclass 时才是可选的,在这种情况下,除非通过 dataclasses.field() 标记,否则字段被假定为数据字段(请参阅下面的示例)。数据字段必须是 JAX 兼容的对象,例如数组(jax.Arraynumpy.ndarray)、标量或叶子是数组或标量的 pytrees。请注意,None 是有效的数据字段,因为 JAX 将其识别为空 pytree。

  • drop_fields (Sequence[str])

返回:

输入类 nodetype 在添加到 JAX 的 pytree 注册表后保持不变返回,以便 register_dataclass() 可以用作装饰器。

返回类型:

Typ

示例

在 JAX v0.4.35 或更早版本中,您必须指定 data_fieldsmeta_fields 才能使用此装饰器

>>> import jax
>>> from dataclasses import dataclass
>>> from functools import partial
...
>>> @partial(jax.tree_util.register_dataclass,
...          data_fields=['x', 'y'],
...          meta_fields=['op'])
... @dataclass
... class MyStruct:
...   x: jax.Array
...   y: jax.Array
...   op: str
...
>>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add')
>>> m
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')

从 JAX v0.4.36 开始,对于 dataclass() 输入,data_fieldsmeta_fields 参数是可选的,字段默认为 data_fields,除非使用 static 元数据在 dataclasses.field() 中标记为静态。

>>> import jax
>>> from dataclasses import dataclass, field
...
>>> @jax.tree_util.register_dataclass
... @dataclass
... class MyStruct:
...   x: jax.Array  # defaults to non-static data field
...   y: jax.Array  # defaults to non-static data field
...   op: str = field(metadata=dict(static=True))  # marked as static meta field.
...
>>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add')
>>> m
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')

一旦此类被注册,它就可以与 jax.treejax.tree_util 中的函数一起使用

>>> leaves, treedef = jax.tree.flatten(m)
>>> leaves
[Array([1., 1., 1.], dtype=float32), Array([0, 1, 2], dtype=int32)]
>>> treedef
PyTreeDef(CustomNode(MyStruct[('add',)], [*, *]))
>>> jax.tree.unflatten(treedef, leaves)
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')

特别是,此注册允许 m 无缝地通过包装在 jax.jit() 和其他 JAX 转换中的代码传递,其中 data_fields 被视为动态参数,而 meta_fields 被视为静态参数

>>> @jax.jit
... def compiled_func(m):
...   if m.op == 'add':
...     return m.x + m.y
...   else:
...     raise ValueError(f"{m.op=}")
...
>>> compiled_func(m)
Array([1., 2., 3.], dtype=float32)