jax.export.symbolic_args_specs#
- jax.export.symbolic_args_specs(args, shapes_specs, constraints=(), scope=None)[source]#
为 export 构建 jax.ShapeDtypeSpec 参数规范的 pytree。
有关详细信息,请参阅
jax.export.symbolic_shape()
和 [形状多态性文档](https://jax.net.cn/en/latest/export/shape_poly.html)。- 参数:
args – 参数的 pytree。这些可以是 jax.Array 或 jax.ShapeDTypeSpec。它们用于学习参数的 pytree 结构、它们的 dtypes,并在 shapes_specs 包含占位符的地方填充实际形状。请注意,只有 shapes_specs 是占位符的形状维度才从 args 中使用。
shapes_specs – 应该是 None(所有参数都具有静态形状)、单个字符串(有关 shape_spec,请参阅
jax.export.symbolic_shape()
;应用于所有参数),或与 args 的前缀匹配的 pytree。请参阅[可选参数如何匹配到参数](https://jax.net.cn/en/latest/pytrees.html#applying-optional-parameters-to-pytrees)。constraints (Sequence[str]) – 与
jax.export.symbolic_shape()
相同。scope (SymbolicScope | None | None) – 与
jax.export.symbolic_shape()
相同。
- 返回:与 args 匹配的 jax.ShapeDTypeStruct 的 pytree,其形状
替换为由 shapes_specs 指定的符号维度。