jax.extend.core.Jaxpr#

class jax.extend.core.Jaxpr(constvars, invars, outvars, eqns, effects=frozenset({}), debug_info=None, is_high=False)[源代码]#
参数:
  • constvars (Sequence[Var])

  • invars (Sequence[Var])

  • outvars (Sequence[Atom])

  • eqns (Sequence[JaxprEqn])

  • effects (Effects)

  • debug_info (DebugInfo)

  • is_high (bool)

__init__(constvars, invars, outvars, eqns, effects=frozenset({}), debug_info=None, is_high=False)[源代码]#
参数:
  • constvars (Sequence[Var]) – 用于常量的变量列表。数组常量被替换为此类变量,而标量常量则保持内联。

  • invars (Sequence[Var]) – 输入变量列表。`constvars` 和 `invars` 一起构成了 Jaxpr 的输入。

  • outvars (Sequence[Atom]) – 输出原子的列表。

  • eqns (Sequence[JaxprEqn]) – 方程的列表。

  • effects (Effects) – 效应集。Jaxpr 上的效应是每个方程效应并集的超集。

  • debug_info (DebugInfo) – 调试信息。

  • is_high (bool)

方法

__init__(constvars, invars, outvars, eqns[, ...])

pretty_print(*[, source_info, print_shapes, ...])

replace(**kwargs)

属性

constvars

debug_info

effects

eqns

final_aval_qdds

in_aval_qdds

in_avals

invars

is_high

out_avals

outvars