jax.export.register_pytree_node_serialization#

jax.export.register_pytree_node_serialization(nodetype, *, serialized_name, serialize_auxdata, deserialize_auxdata, from_children=None)[源代码]#

注册用于序列化和反序列化的自定义 PyTree 节点。

在序列化和反序列化本机不支持的类型的 PyTree 节点之前,您必须使用此函数。 我们序列化 Exportedin_treeout_tree 字段的 PyTree 节点,它们是导出函数的调用约定的一部分。

必须在调用 jax.tree_util.register_pytree_node 之后调用此函数(collections.namedtuple 除外,它不需要调用 register_pytree_node)。

参数:
  • nodetype (type[T]) – 我们想要序列化的 PyTree 节点的类型。 尝试为 nodetype 注册多个序列化是一个错误。

  • serialized_name (str) – 将出现在序列化中的字符串,并将用于在反序列化期间查找注册。 尝试为 serialized_name 注册多个序列化是一个错误。

  • serialize_auxdata (_SerializeAuxData) – 序列化 PyTree auxdata(由 jax.tree_util.register_pytree_nodeflatten_func 参数返回)。

  • deserialize_auxdata (_DeserializeAuxData) – 反序列化由 serialize_auxdata 序列化的 auxdata。

  • from_children (_BuildFromChildren | None) – 如果存在,这是一个函数,它接受 deserialize_auxdata 的结果以及一些子项,并创建 nodetype 的实例。 这类似于传递给 jax.tree_util.register_pytree_nodeunflatten_func。 如果不存在,我们将查找并使用 unflatten_func。 这对于 collections.namedtuple 是必需的,它没有 register_pytree_node,但它可以用于覆盖该函数。 请注意,from_children 的结果仅与 jax.tree_util.tree_structure 一起使用以构造正确的 PyTree 节点,它不用于构造序列化函数的输出。

返回:

作为 nodetype 传递的相同类型,以便此函数可以用作类装饰器。

返回类型:

type[T]