jax.export.symbolic_args_specs#

jax.export.symbolic_args_specs(args, shapes_specs, constraints=(), scope=None)[源]#

export 构造一个 jax.ShapeDtypeSpec 参数规范的 Pytree。

请参阅 jax.export.symbolic_shape() 的文档以及[形状多态性文档](https://jax.net.cn/en/latest/export/shape_poly.html)以获取详细信息。

参数:
返回:一个与 args 匹配的 jax.ShapeDTypeStruct Pytree,其中形状已

根据 shapes_specs 指定的符号维度进行替换。