jax.tree_util.register_dataclass#
- jax.tree_util.register_dataclass(nodetype, data_fields=None, meta_fields=None, drop_fields=())[源代码]#
扩展了被视为 pytree 内部节点的类型集合。
这与
register_pytree_with_keys_class
的不同之处在于,C++ 注册表使用优化的 C++ 数据类内置函数,而不是参数函数。有关注册 pytree 的更多信息,请参阅 扩展 pytree。
- 参数:
nodetype (Typ) – 一个 Python 类型,将被视为内部 pytree 节点。假设它具有
dataclass
的语义:即,类属性表示对象的整个状态,并且可以作为关键字传递给类构造函数以创建对象的副本。所有定义的属性都应列在meta_fields
或data_fields
中。meta_fields (Sequence[str] | None | None) – 元数据字段名称:当此 pytree 传递给
jax.jit()
时,这些属性将被视为 {term}`静态` 。仅当nodetype
是数据类时,meta_fields
才是可选的,在这种情况下,可以使用dataclasses.field()
将各个字段标记为静态(请参见以下示例)。元数据字段必须是静态的、可哈希的、不可变的对象,因为这些对象用于生成 JIT 缓存键。特别是,元数据字段不能包含jax.Array
或numpy.ndarray
对象。data_fields (Sequence[str] | None | None) – 数据字段名称:当此 pytree 传递给
jax.jit()
时,这些属性将被视为非静态的。仅当nodetype
是数据类时,data_fields
才是可选的,在这种情况下,除非使用dataclasses.field()
标记为静态,否则字段被假定为数据字段(请参见以下示例)。数据字段必须是与 JAX 兼容的对象,例如数组(jax.Array
或numpy.ndarray
)、标量或叶子为数组或标量的 pytree。请注意,None
是有效的数据字段,因为 JAX 将其识别为空 pytree。drop_fields (Sequence[str])
- 返回:
输入类
nodetype
在添加到 JAX 的 pytree 注册表后保持不变,以便register_dataclass()
可以用作装饰器。- 返回类型:
Typ
示例
在 JAX v0.4.35 或更早版本中,您必须指定
data_fields
和meta_fields
才能使用此装饰器>>> import jax >>> from dataclasses import dataclass >>> from functools import partial ... >>> @partial(jax.tree_util.register_dataclass, ... data_fields=['x', 'y'], ... meta_fields=['op']) ... @dataclass ... class MyStruct: ... x: jax.Array ... y: jax.Array ... op: str ... >>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add') >>> m MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
从 JAX v0.4.36 开始,
data_fields
和meta_fields
参数对于dataclass()
输入是可选的,字段默认为data_fields
,除非使用dataclasses.field()
中的static 元数据标记为静态。>>> import jax >>> from dataclasses import dataclass, field ... >>> @jax.tree_util.register_dataclass ... @dataclass ... class MyStruct: ... x: jax.Array # defaults to non-static data field ... y: jax.Array # defaults to non-static data field ... op: str = field(metadata=dict(static=True)) # marked as static meta field. ... >>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add') >>> m MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
一旦注册了此类,就可以将其与
jax.tree
和jax.tree_util
中的函数一起使用。>>> leaves, treedef = jax.tree.flatten(m) >>> leaves [Array([1., 1., 1.], dtype=float32), Array([0, 1, 2], dtype=int32)] >>> treedef PyTreeDef(CustomNode(MyStruct[('add',)], [*, *])) >>> jax.tree.unflatten(treedef, leaves) MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
特别是,此注册允许
m
无缝地通过包装在jax.jit()
和其他 JAX 转换中的代码传递,其中data_fields
被视为动态参数,而meta_fields
被视为静态参数>>> @jax.jit ... def compiled_func(m): ... if m.op == 'add': ... return m.x + m.y ... else: ... raise ValueError(f"{m.op=}") ... >>> compiled_func(m) Array([1., 2., 3.], dtype=float32)