jax.tree_util.register_dataclass#
- jax.tree_util.register_dataclass(nodetype, data_fields=None, meta_fields=None, drop_fields=())[source]#
扩展被视为 pytrees 内部节点的类型集。
这与
register_pytree_with_keys_class不同之处在于,C++ 注册表使用优化的 C++ dataclass 内置类型,而不是参数函数。有关注册 pytree 的更多信息,请参阅 扩展 pytrees。
- 参数:
nodetype (Typ) – 要视为内部 pytree 节点的 Python 类型。假定其具有
dataclass的语义:即,类属性代表对象的全部状态,并且可以作为关键字传递给类构造函数来创建对象的副本。所有定义的属性都应列在meta_fields或data_fields中。meta_fields (Sequence[str] | None) – 元数据字段名:这些是当此 pytree 传递给
jax.jit()时将被视为 静态 的属性。仅当nodetype是 dataclass 时,meta_fields才是可选的,在这种情况下,单个字段可以通过dataclasses.field()标记为静态(如下面的示例所示)。元数据字段必须是静态的、可哈希的、不可变的对象,因为这些对象用于生成 JIT 缓存键。特别是,元数据字段不能包含jax.Array或numpy.ndarray对象。data_fields (Sequence[str] | None) – 数据字段名:这些是当此 pytree 传递给
jax.jit()时将被视为非静态的属性。仅当nodetype是 dataclass 时,data_fields才是可选的,在这种情况下,除非通过dataclasses.field()标记为静态,否则假定字段为数据字段(如下面的示例所示)。数据字段必须是 JAX 兼容的对象,例如数组(jax.Array或numpy.ndarray)、标量或叶子节点为数组或标量的 pytree。请注意,None是有效的数据字段,因为 JAX 将其识别为空 pytree。drop_fields (Sequence[str])
- 返回:
在将输入类
nodetype添加到 JAX 的 pytree 注册表后,它将保持不变返回,因此register_dataclass()可以用作装饰器。- 返回类型:
类型
示例
在 JAX v0.4.35 或更早版本中,您必须指定
data_fields和meta_fields才能使用此装饰器。>>> import jax >>> from dataclasses import dataclass >>> from functools import partial ... >>> @partial(jax.tree_util.register_dataclass, ... data_fields=['x', 'y'], ... meta_fields=['op']) ... @dataclass ... class MyStruct: ... x: jax.Array ... y: jax.Array ... op: str ... >>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add') >>> m MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
从 JAX v0.4.36 开始,对于
dataclass()输入,data_fields和meta_fields参数是可选的,字段默认被视为data_fields,除非使用dataclasses.field()中的 static 元数据进行标记。>>> import jax >>> from dataclasses import dataclass, field ... >>> @jax.tree_util.register_dataclass ... @dataclass ... class MyStruct: ... x: jax.Array # defaults to non-static data field ... y: jax.Array # defaults to non-static data field ... op: str = field(metadata=dict(static=True)) # marked as static meta field. ... >>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add') >>> m MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
一旦注册了此类,它就可以与
jax.tree和jax.tree_util中的函数一起使用。>>> leaves, treedef = jax.tree.flatten(m) >>> leaves [Array([1., 1., 1.], dtype=float32), Array([0, 1, 2], dtype=int32)] >>> treedef PyTreeDef(CustomNode(MyStruct[('add',)], [*, *])) >>> jax.tree.unflatten(treedef, leaves) MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
特别是,此注册允许
m无缝地通过用jax.jit()和其他 JAX 转换包装的代码,其中data_fields被视为动态参数,而meta_fields被视为静态参数。>>> @jax.jit ... def compiled_func(m): ... if m.op == 'add': ... return m.x + m.y ... else: ... raise ValueError(f"{m.op=}") ... >>> compiled_func(m) Array([1., 2., 3.], dtype=float32)