jax.export.register_pytree_node_serialization#
- jax.export.register_pytree_node_serialization(nodetype, *, serialized_name, serialize_auxdata, deserialize_auxdata, from_children=None)[source]#
注册一个自定义的 PyTree 节点,用于序列化和反序列化。
在使用此函数之前,您必须先调用它,然后才能序列化和反序列化 PyTree 节点,这些节点是原生不支持的类型。我们为
Exported的in_tree和out_tree字段序列化 PyTree 节点,它们是已导出函数调用约定的组成部分。此函数必须在调用
jax.tree_util.register_pytree_node()之后调用(除了collections.namedtuple,它不需要调用register_pytree_node)。- 参数:
nodetype (type[T]) – 我们想要序列化的 PyTree 节点的类型。为
nodetype注册多个序列化是错误的。serialized_name (str) – 将出现在序列化中并用于在反序列化期间查找注册的字符串。为
serialized_name注册多个序列化是错误的。serialize_auxdata (_SerializeAuxData) – 序列化 PyTree 的 auxdata(由
jax.tree_util.register_pytree_node()的flatten_func参数返回)。deserialize_auxdata (_DeserializeAuxData) – 反序列化由
serialize_auxdata序列化的 auxdata。from_children (_BuildFromChildren | None) – 如果存在,这是一个函数,它接收
deserialize_auxdata的结果以及一些子节点,并创建一个nodetype的实例。这类似于传递给jax.tree_util.register_pytree_node()的unflatten_func。如果不存在,我们将查找并使用unflatten_func。这对于collections.namedtuple是必需的,因为它没有register_pytree_node,但重写该函数可能很有用。请注意,from_children的结果仅与jax.tree_util.tree_structure()一起用于构建一个正确的 PyTree 节点,它不用于构建已序列化函数的输出。
- 返回:
与作为
nodetype传递的类型相同,因此此函数可以用作类装饰器。- 返回类型:
type[T]