jax.extend.core 模块# ClosedJaxpr(jaxpr, consts) Jaxpr(constvars, invars, outvars, eqns[, ...]) JaxprEqn(invars, outvars, primitive, params, ...) Literal(val, aval) Primitive(name) Token(buf) Var(aval[, initial_qdd, final_qdd]) array_types set() -> 新建空集对象 set(iterable) -> 新建集对象 jaxpr_as_fun primitives