jax.tree_util.register_dataclass#

jax.tree_util.register_dataclass(nodetype, data_fields=None, meta_fields=None, drop_fields=())[源代码]#

扩展了被视为 pytree 内部节点的类型集合。

这与 register_pytree_with_keys_class 的不同之处在于,C++ 注册表使用优化的 C++ 数据类内置函数,而不是参数函数。

有关注册 pytree 的更多信息,请参阅 扩展 pytree

参数:
  • nodetype (Typ) – 一个 Python 类型,将被视为内部 pytree 节点。假设它具有 dataclass 的语义:即,类属性表示对象的整个状态,并且可以作为关键字传递给类构造函数以创建对象的副本。所有定义的属性都应列在 meta_fieldsdata_fields 中。

  • meta_fields (Sequence[str] | None | None) – 元数据字段名称:当此 pytree 传递给 jax.jit() 时,这些属性将被视为 {term}`静态` 。仅当 nodetype 是数据类时,meta_fields 才是可选的,在这种情况下,可以使用 dataclasses.field() 将各个字段标记为静态(请参见以下示例)。元数据字段必须是静态的、可哈希的、不可变的对象,因为这些对象用于生成 JIT 缓存键。特别是,元数据字段不能包含 jax.Arraynumpy.ndarray 对象。

  • data_fields (Sequence[str] | None | None) – 数据字段名称:当此 pytree 传递给 jax.jit() 时,这些属性将被视为非静态的。仅当 nodetype 是数据类时,data_fields 才是可选的,在这种情况下,除非使用 dataclasses.field() 标记为静态,否则字段被假定为数据字段(请参见以下示例)。数据字段必须是与 JAX 兼容的对象,例如数组(jax.Arraynumpy.ndarray)、标量或叶子为数组或标量的 pytree。请注意,None 是有效的数据字段,因为 JAX 将其识别为空 pytree。

  • drop_fields (Sequence[str])

返回:

输入类 nodetype 在添加到 JAX 的 pytree 注册表后保持不变,以便 register_dataclass() 可以用作装饰器。

返回类型:

Typ

示例

在 JAX v0.4.35 或更早版本中,您必须指定 data_fieldsmeta_fields 才能使用此装饰器

>>> import jax
>>> from dataclasses import dataclass
>>> from functools import partial
...
>>> @partial(jax.tree_util.register_dataclass,
...          data_fields=['x', 'y'],
...          meta_fields=['op'])
... @dataclass
... class MyStruct:
...   x: jax.Array
...   y: jax.Array
...   op: str
...
>>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add')
>>> m
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')

从 JAX v0.4.36 开始,data_fieldsmeta_fields 参数对于 dataclass() 输入是可选的,字段默认为 data_fields,除非使用 dataclasses.field() 中的static 元数据标记为静态。

>>> import jax
>>> from dataclasses import dataclass, field
...
>>> @jax.tree_util.register_dataclass
... @dataclass
... class MyStruct:
...   x: jax.Array  # defaults to non-static data field
...   y: jax.Array  # defaults to non-static data field
...   op: str = field(metadata=dict(static=True))  # marked as static meta field.
...
>>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add')
>>> m
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')

一旦注册了此类,就可以将其与 jax.treejax.tree_util 中的函数一起使用。

>>> leaves, treedef = jax.tree.flatten(m)
>>> leaves
[Array([1., 1., 1.], dtype=float32), Array([0, 1, 2], dtype=int32)]
>>> treedef
PyTreeDef(CustomNode(MyStruct[('add',)], [*, *]))
>>> jax.tree.unflatten(treedef, leaves)
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')

特别是,此注册允许 m 无缝地通过包装在 jax.jit() 和其他 JAX 转换中的代码传递,其中 data_fields 被视为动态参数,而 meta_fields 被视为静态参数

>>> @jax.jit
... def compiled_func(m):
...   if m.op == 'add':
...     return m.x + m.y
...   else:
...     raise ValueError(f"{m.op=}")
...
>>> compiled_func(m)
Array([1., 2., 3.], dtype=float32)