jax.experimental.custom_dce.custom_dce#

class jax.experimental.custom_dce.custom_dce(fun, *, static_argnums=())[源代码]#

自定义 JAX 可转换函数的 DCE 行为。

JAX 使用死代码消除 (DCE) 从 JAX 程序中移除未使用的计算。当程序完全由已知的 JAX 操作指定时,这通常可以透明地工作,但像调用 pallas_call()ffi_call() 这样的不透明内核可能会引起问题。

在 JAX 中,当使用 jax.jit() 暂存输出函数时,会执行 DCE,因此在 eager 模式下运行 JAX 时不会应用它。 同样,custom_dce 装饰器要求被装饰的函数和自定义 DCE 规则都与 jit() 兼容。

此装饰器允许用户通过定义自定义 DCE 规则来自定义函数的 DCE 行为。 对于 custom_dce 包装的函数 f(*args),DCE 规则的签名是 dce_rule(used_outs, *args),其中 used_outs 是一个 Pytree,其结构与 f 的输出相同,并且每个叶子都是一个 bool,指示应计算哪些输出。 剩余的参数 *argsf 的原始参数。 规则 dce_rule 应返回一个 Pytree,其结构与 f 的原始输出相同,但任何未使用的输出都可以替换为 None

例如

>>> @jax.experimental.custom_dce.custom_dce
... def f(x, y):
...   return jnp.sin(x) * y, x * jnp.sin(y)
...
>>> @f.def_dce
... def f_dce_rule(used_outs, x, y):
...   return (
...       jnp.sin(x) * y if used_outs[0] else None,
...       x * jnp.sin(y) if used_outs[1] else None,
...   )

在此示例中,used_outs 是一个包含两个 bool 值的 tuple,指示需要哪些输出。 DCE 规则仅计算所需的输出,并将未使用的输出替换为 None

如果为 custom_dce 提供了 static_argnums 参数,则在跟踪函数时,指示的参数将被视为静态参数,并且在调用 DCE 规则时,它们将被移动到前面。 例如,如果 fun 接受 2 个参数 fun(x, y),并且 static_argnums(1,),则 DCE 规则将以 dce_rule(y, used_outs, x) 的形式调用。

参数:
__init__(fun, *, static_argnums=())[源代码]#
参数:

方法

__init__(fun, *[, static_argnums])

def_dce(dce_rule)

为此函数定义自定义 DCE 规则。

属性

fun

static_argnums

dce_rule