jax.export.symbolic_args_specs#
- jax.export.symbolic_args_specs(args, shapes_specs, constraints=(), scope=None)[源]#
为 export 构造一个 jax.ShapeDtypeSpec 参数规范的 Pytree。
请参阅
jax.export.symbolic_shape()
的文档以及[形状多态性文档](https://jax.net.cn/en/latest/export/shape_poly.html)以获取详细信息。- 参数:
args – 参数的 Pytree。这些可以是 jax.Array 或 jax.ShapeDTypeSpec。它们用于学习参数的 Pytree 结构、其数据类型,并在 shapes_specs 包含占位符时填入实际形状。请注意,只有 shapes_specs 作为占位符的形状维度才从 args 中使用。
shapes_specs – 应为 None(所有参数都具有静态形状)、单个字符串(请参阅
jax.export.symbolic_shape()
的 shape_spec;适用于所有参数),或与 args 前缀匹配的 Pytree。请参阅[可选参数如何与参数匹配](https://jax.net.cn/en/latest/pytrees.html#applying-optional-parameters-to-pytrees)。constraints (序列[str]) – 同
jax.export.symbolic_shape()
。scope (SymbolicScope | None) – 同
jax.export.symbolic_shape()
。
- 返回:一个与 args 匹配的 jax.ShapeDTypeStruct Pytree,其中形状已
根据 shapes_specs 指定的符号维度进行替换。