jax.export.register_namedtuple_serialization#

jax.export.register_namedtuple_serialization(nodetype, *, serialized_name)[source]#

注册一个 namedtuple 以便进行序列化和反序列化。

JAX 对 collections.namedtuple 具有原生的 PyTree 支持,不需要调用 jax.tree_util.register_pytree_node()。但是,如果您想序列化具有 namedtuple 类型输入或输出的函数,则必须为该类型注册序列化。

参数:
  • nodetype (type[T]) – 我们希望序列化的 PyTree 节点类型。尝试为 nodetype 注册多个序列化是错误的。在反序列化时,此类型必须具有在序列化期间存在的相同键集。

  • serialized_name (str) – 一个将出现在序列化中并用于在反序列化期间查找注册的字符串。尝试为 serialized_name 注册多个序列化是错误的。

返回:

与作为 nodetype 传递的类型相同,以便此函数可以用作类装饰器。

返回类型:

type[T]