jax.experimental.custom_dce.custom_dce#
- class jax.experimental.custom_dce.custom_dce(fun, *, static_argnums=())[源代码]#
自定义 JAX 可转换函数的 DCE 行为。
JAX 使用死代码消除 (DCE) 来移除 JAX 程序中未使用的计算。当程序完全由已知的 JAX 操作指定时,这通常是透明工作的,但像调用
pallas_call()或ffi_call()这样的不透明内核可能会导致问题。在 JAX 中,DCE 在使用
jax.jit()对函数进行阶段化时执行,因此在以 eager 模式运行 JAX 时不会应用它。同样,custom_dce装饰器要求被装饰的函数和自定义 DCE 规则都与jit()兼容。此装饰器允许用户通过定义自定义 DCE 规则来定制函数的 DCE 行为。对于一个被
custom_dce包裹的函数f(*args),DCE 规则的签名是dce_rule(used_outs, *args),其中used_outs是一个与f的输出具有相同结构的 Pytree,每个叶子是一个bool,指示应计算哪些输出。剩余的参数*args是f的原始参数。规则dce_rule应返回一个与f的原始输出具有相同结构的 Pytree,但任何未使用的输出都可以替换为None。例如
>>> @jax.experimental.custom_dce.custom_dce ... def f(x, y): ... return jnp.sin(x) * y, x * jnp.sin(y) ... >>> @f.def_dce ... def f_dce_rule(used_outs, x, y): ... return ( ... jnp.sin(x) * y if used_outs[0] else None, ... x * jnp.sin(y) if used_outs[1] else None, ... )
在此示例中,
used_outs是一个包含两个bool值的tuple,指示了需要哪些输出。DCE 规则只计算所需的输出,并将未使用的输出替换为None。如果为
custom_dce提供了static_argnums参数,则在追踪函数时,这些指定的参数将被视为静态参数,并且在调用 DCE 规则时会被移到前面。例如,如果fun接受两个参数fun(x, y),并且static_argnums是(1,),那么 DCE 规则将被调用为dce_rule(y, used_outs, x)。方法
__init__(fun, *[, static_argnums])def_dce(dce_rule)为此函数定义自定义 DCE 规则。
属性
funstatic_argnumsdce_rule