jax.export.register_namedtuple_serialization#

jax.export.register_namedtuple_serialization(nodetype, *, serialized_name)[源代码]#

注册用于序列化和反序列化的 namedtuple。

JAX 对 collections.namedtuple 具有原生 PyTree 支持,不需要调用 jax.tree_util.register_pytree_node。但是,如果要序列化具有 namedtuple 类型输入或输出的函数,则必须注册该类型以进行序列化。

参数:
  • nodetype (type[T]) – 我们要序列化其 PyTree 节点的类型。尝试为 nodetype 注册多个序列化是错误的。在反序列化时,此类型必须具有与序列化期间存在的相同键集。

  • serialized_name (str) – 一个字符串,它将出现在序列化中,并将用于在反序列化期间查找注册。尝试为 serialized_name 注册多个序列化是错误的。

返回:

与作为 nodetype 传递的类型相同,因此此函数可以用作类装饰器。

返回类型:

type[T]