jax.export.register_pytree_node_serialization#

jax.export.register_pytree_node_serialization(nodetype, *, serialized_name, serialize_auxdata, deserialize_auxdata, from_children=None)[源代码]#

注册自定义 PyTree 节点以进行序列化和反序列化。

您必须先使用此函数,然后才能序列化和反序列化本机不支持的类型的 PyTree 节点。我们序列化 Exportedin_treeout_tree 字段的 PyTree 节点,这些字段是导出函数的调用约定的一部分。

此函数必须在调用 jax.tree_util.register_pytree_node 后调用(collections.namedtuple 除外,它不需要调用 register_pytree_node)。

参数:
  • nodetype (type[T]) – 我们要序列化的 PyTree 节点的类型。尝试为 nodetype 注册多个序列化是错误的。

  • serialized_name (str) – 一个字符串,它将出现在序列化中,并在反序列化期间用于查找注册信息。尝试为 serialized_name 注册多个序列化是错误的。

  • serialize_auxdata (_SerializeAuxData) – 序列化 PyTree 的辅助数据 (由 jax.tree_util.register_pytree_nodeflatten_func 参数返回)。

  • deserialize_auxdata (_DeserializeAuxData) – 反序列化由 serialize_auxdata 序列化的辅助数据。

  • from_children (_BuildFromChildren | None | None) – 如果存在,这是一个函数,它接受 deserialize_auxdata 的结果以及一些子节点,并创建一个 nodetype 的实例。这类似于传递给 jax.tree_util.register_pytree_nodeunflatten_func。如果不存在,我们将查找并使用 unflatten_func。这对于 collections.namedtuple 是必需的,因为它没有 register_pytree_node,但覆盖该函数可能很有用。请注意,from_children 的结果仅与 jax.tree_util.tree_structure 一起使用,以构建正确的 PyTree 节点,它不用于构建序列化函数的输出。

返回:

与作为 nodetype 传递的类型相同的类型,以便此函数可以用作类装饰器。

返回类型:

type[T]