jax.experimental.custom_dce.custom_dce#
- class jax.experimental.custom_dce.custom_dce(fun, *, static_argnums=())[源代码]#
自定义 JAX 可转换函数的 DCE 行为。
JAX 使用死代码消除 (DCE) 从 JAX 程序中移除未使用的计算。当程序完全由已知的 JAX 操作指定时,这通常可以透明地工作,但像调用
pallas_call()
或ffi_call()
这样的不透明内核可能会引起问题。在 JAX 中,当使用
jax.jit()
暂存输出函数时,会执行 DCE,因此在 eager 模式下运行 JAX 时不会应用它。 同样,custom_dce
装饰器要求被装饰的函数和自定义 DCE 规则都与jit()
兼容。此装饰器允许用户通过定义自定义 DCE 规则来自定义函数的 DCE 行为。 对于
custom_dce
包装的函数f(*args)
,DCE 规则的签名是dce_rule(used_outs, *args)
,其中used_outs
是一个 Pytree,其结构与f
的输出相同,并且每个叶子都是一个bool
,指示应计算哪些输出。 剩余的参数*args
是f
的原始参数。 规则dce_rule
应返回一个 Pytree,其结构与f
的原始输出相同,但任何未使用的输出都可以替换为None
。例如
>>> @jax.experimental.custom_dce.custom_dce ... def f(x, y): ... return jnp.sin(x) * y, x * jnp.sin(y) ... >>> @f.def_dce ... def f_dce_rule(used_outs, x, y): ... return ( ... jnp.sin(x) * y if used_outs[0] else None, ... x * jnp.sin(y) if used_outs[1] else None, ... )
在此示例中,
used_outs
是一个包含两个bool
值的tuple
,指示需要哪些输出。 DCE 规则仅计算所需的输出,并将未使用的输出替换为None
。如果为
custom_dce
提供了static_argnums
参数,则在跟踪函数时,指示的参数将被视为静态参数,并且在调用 DCE 规则时,它们将被移动到前面。 例如,如果fun
接受 2 个参数fun(x, y)
,并且static_argnums
是(1,)
,则 DCE 规则将以dce_rule(y, used_outs, x)
的形式调用。方法
__init__
(fun, *[, static_argnums])def_dce
(dce_rule)为此函数定义自定义 DCE 规则。
属性
fun
static_argnums
dce_rule