jax.export.register_pytree_node_serialization#

jax.export.register_pytree_node_serialization(nodetype, *, serialized_name, serialize_auxdata, deserialize_auxdata, from_children=None)[source]#

注册一个自定义的 PyTree 节点,用于序列化和反序列化。

在使用此函数之前,您必须先调用它,然后才能序列化和反序列化 PyTree 节点,这些节点是原生不支持的类型。我们为 Exportedin_treeout_tree 字段序列化 PyTree 节点,它们是已导出函数调用约定的组成部分。

此函数必须在调用 jax.tree_util.register_pytree_node() 之后调用(除了 collections.namedtuple,它不需要调用 register_pytree_node)。

参数:
  • nodetype (type[T]) – 我们想要序列化的 PyTree 节点的类型。为 nodetype 注册多个序列化是错误的。

  • serialized_name (str) – 将出现在序列化中并用于在反序列化期间查找注册的字符串。为 serialized_name 注册多个序列化是错误的。

  • serialize_auxdata (_SerializeAuxData) – 序列化 PyTree 的 auxdata(由 jax.tree_util.register_pytree_node()flatten_func 参数返回)。

  • deserialize_auxdata (_DeserializeAuxData) – 反序列化由 serialize_auxdata 序列化的 auxdata。

  • from_children (_BuildFromChildren | None) – 如果存在,这是一个函数,它接收 deserialize_auxdata 的结果以及一些子节点,并创建一个 nodetype 的实例。这类似于传递给 jax.tree_util.register_pytree_node()unflatten_func。如果不存在,我们将查找并使用 unflatten_func。这对于 collections.namedtuple 是必需的,因为它没有 register_pytree_node,但重写该函数可能很有用。请注意,from_children 的结果仅与 jax.tree_util.tree_structure() 一起用于构建一个正确的 PyTree 节点,它不用于构建已序列化函数的输出。

返回:

与作为 nodetype 传递的类型相同,因此此函数可以用作类装饰器。

返回类型:

type[T]