jax.tree_util.register_dataclass#
- jax.tree_util.register_dataclass(nodetype, data_fields=None, meta_fields=None, drop_fields=())[source]#
扩展被视为 pytrees 内部节点的类型集。
这与
register_pytree_with_keys_class
不同,因为 C++ 注册表使用优化的 C++ dataclass 内置函数而非参数函数。有关注册 pytree 的更多信息,请参阅 扩展 pytree。
- 参数:
nodetype (Typ) – 一个 Python 类型,用作内部 pytree 节点。这被假定具有
dataclass
的语义:即,类属性代表对象的整个状态,并且可以作为关键字传递给类构造函数以创建对象的副本。所有定义的属性都应列在meta_fields
或data_fields
中。meta_fields (Sequence[str] | None) – 元数据字段名称:当此 pytree 传递给
jax.jit()
时,这些属性将被视为 {term}`静态`。只有当nodetype
是一个 dataclass 时,meta_fields
才是可选的,在这种情况下,可以通过dataclasses.field()
将单个字段标记为静态(参见下面的示例)。元数据字段必须是静态的、可哈希的、不可变的对象,因为这些对象用于生成 JIT 缓存键。特别是,元数据字段不能包含jax.Array
或numpy.ndarray
对象。data_fields (Sequence[str] | None) – 数据字段名称:当此 pytree 传递给
jax.jit()
时,这些属性将被视为非静态。只有当nodetype
是一个 dataclass 时,data_fields
才是可选的,在这种情况下,除非通过dataclasses.field()
标记(参见下面的示例),否则字段被假定为数据字段。数据字段必须是 JAX 兼容的对象,例如数组(jax.Array
或numpy.ndarray
)、标量或其叶子是数组或标量的 pytree。请注意,None
是一个有效的数据字段,因为 JAX 将其识别为空 pytree。drop_fields (Sequence[str])
- 返回:
输入类
nodetype
在添加到 JAX 的 pytree 注册表后保持不变,因此register_dataclass()
可以用作装饰器。- 返回类型:
类型
示例
在 JAX v0.4.35 或更早版本中,您必须指定
data_fields
和meta_fields
才能使用此装饰器。>>> import jax >>> from dataclasses import dataclass >>> from functools import partial ... >>> @partial(jax.tree_util.register_dataclass, ... data_fields=['x', 'y'], ... meta_fields=['op']) ... @dataclass ... class MyStruct: ... x: jax.Array ... y: jax.Array ... op: str ... >>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add') >>> m MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
从 JAX v0.4.36 开始,对于
dataclass()
输入,data_fields
和meta_fields
参数是可选的,字段默认是data_fields
,除非在dataclasses.field()
中使用 static 元数据标记为静态。>>> import jax >>> from dataclasses import dataclass, field ... >>> @jax.tree_util.register_dataclass ... @dataclass ... class MyStruct: ... x: jax.Array # defaults to non-static data field ... y: jax.Array # defaults to non-static data field ... op: str = field(metadata=dict(static=True)) # marked as static meta field. ... >>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add') >>> m MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
一旦此类别被注册,它就可以与
jax.tree
和jax.tree_util
中的函数一起使用。>>> leaves, treedef = jax.tree.flatten(m) >>> leaves [Array([1., 1., 1.], dtype=float32), Array([0, 1, 2], dtype=int32)] >>> treedef PyTreeDef(CustomNode(MyStruct[('add',)], [*, *])) >>> jax.tree.unflatten(treedef, leaves) MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
特别是,此注册允许
m
无缝地通过jax.jit()
和其他 JAX 转换包裹的代码,其中data_fields
被视为动态参数,而meta_fields
被视为静态参数。>>> @jax.jit ... def compiled_func(m): ... if m.op == 'add': ... return m.x + m.y ... else: ... raise ValueError(f"{m.op=}") ... >>> compiled_func(m) Array([1., 2., 3.], dtype=float32)