jax.tree_util.register_dataclass#
- jax.tree_util.register_dataclass(nodetype, data_fields=None, meta_fields=None, drop_fields=())[source]#
扩展了在 pytrees 中被视为内部节点的类型集合。
这与
register_pytree_with_keys_class
的不同之处在于,C++ 注册表使用优化的 C++ dataclass 内置函数而不是参数函数。有关注册 pytrees 的更多信息,请参阅 扩展 pytrees。
- 参数:
nodetype (Typ) – 要视为内部 pytree 节点的 Python 类型。这被假定为具有
dataclass
的语义:即,类属性代表对象状态的整体,并且可以作为关键字传递给类构造函数以创建对象的副本。所有定义的属性都应在meta_fields
或data_fields
中列出。meta_fields (Sequence[str] | None | None) – 元数据字段名称:当此 pytree 传递给
jax.jit()
时,这些属性将被视为 {term}`static`。meta_fields
仅当nodetype
是 dataclass 时才是可选的,在这种情况下,可以通过dataclasses.field()
将各个字段标记为静态(请参阅下面的示例)。元数据字段必须是静态的、可哈希的、不可变的对象,因为这些对象用于生成 JIT 缓存键。特别是,元数据字段不能包含jax.Array
或numpy.ndarray
对象。data_fields (Sequence[str] | None | None) – 数据字段名称:当此 pytree 传递给
jax.jit()
时,这些属性将被视为非静态。data_fields
仅当nodetype
是 dataclass 时才是可选的,在这种情况下,除非通过dataclasses.field()
标记,否则字段被假定为数据字段(请参阅下面的示例)。数据字段必须是 JAX 兼容的对象,例如数组(jax.Array
或numpy.ndarray
)、标量或叶子是数组或标量的 pytrees。请注意,None
是有效的数据字段,因为 JAX 将其识别为空 pytree。drop_fields (Sequence[str])
- 返回:
输入类
nodetype
在添加到 JAX 的 pytree 注册表后保持不变返回,以便register_dataclass()
可以用作装饰器。- 返回类型:
Typ
示例
在 JAX v0.4.35 或更早版本中,您必须指定
data_fields
和meta_fields
才能使用此装饰器>>> import jax >>> from dataclasses import dataclass >>> from functools import partial ... >>> @partial(jax.tree_util.register_dataclass, ... data_fields=['x', 'y'], ... meta_fields=['op']) ... @dataclass ... class MyStruct: ... x: jax.Array ... y: jax.Array ... op: str ... >>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add') >>> m MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
从 JAX v0.4.36 开始,对于
dataclass()
输入,data_fields
和meta_fields
参数是可选的,字段默认为data_fields
,除非使用 static 元数据在dataclasses.field()
中标记为静态。>>> import jax >>> from dataclasses import dataclass, field ... >>> @jax.tree_util.register_dataclass ... @dataclass ... class MyStruct: ... x: jax.Array # defaults to non-static data field ... y: jax.Array # defaults to non-static data field ... op: str = field(metadata=dict(static=True)) # marked as static meta field. ... >>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add') >>> m MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
一旦此类被注册,它就可以与
jax.tree
和jax.tree_util
中的函数一起使用>>> leaves, treedef = jax.tree.flatten(m) >>> leaves [Array([1., 1., 1.], dtype=float32), Array([0, 1, 2], dtype=int32)] >>> treedef PyTreeDef(CustomNode(MyStruct[('add',)], [*, *])) >>> jax.tree.unflatten(treedef, leaves) MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
特别是,此注册允许
m
无缝地通过包装在jax.jit()
和其他 JAX 转换中的代码传递,其中data_fields
被视为动态参数,而meta_fields
被视为静态参数>>> @jax.jit ... def compiled_func(m): ... if m.op == 'add': ... return m.x + m.y ... else: ... raise ValueError(f"{m.op=}") ... >>> compiled_func(m) Array([1., 2., 3.], dtype=float32)