jax.export.symbolic_args_specs#

jax.export.symbolic_args_specs(args, shapes_specs, constraints=(), scope=None)[源代码]#

export 构建 jax.ShapeDtypeSpec 参数规范的 pytree。

有关详细信息,请参阅 jax.export.symbolic_shape() 和 [形状多态性文档](https://jax.net.cn/en/latest/export/shape_poly.html)的文档。

参数:
返回值:与 args 匹配的 jax.ShapeDTypeStruct 的 pytree,其中形状

根据 shapes_specs 的指定,替换为符号维度。