jax.experimental.custom_dce.custom_dce#

class jax.experimental.custom_dce.custom_dce(fun, *, static_argnums=())[源代码]#

自定义 JAX 可转换函数的 DCE 行为。

JAX 使用死代码消除 (DCE) 来移除 JAX 程序中未使用的计算。当程序完全由已知的 JAX 操作指定时,这通常是透明工作的,但像调用 pallas_call()ffi_call() 这样的不透明内核可能会导致问题。

在 JAX 中,DCE 在使用 jax.jit() 对函数进行阶段化时执行,因此在以 eager 模式运行 JAX 时不会应用它。同样,custom_dce 装饰器要求被装饰的函数和自定义 DCE 规则都与 jit() 兼容。

此装饰器允许用户通过定义自定义 DCE 规则来定制函数的 DCE 行为。对于一个被 custom_dce 包裹的函数 f(*args),DCE 规则的签名是 dce_rule(used_outs, *args),其中 used_outs 是一个与 f 的输出具有相同结构的 Pytree,每个叶子是一个 bool,指示应计算哪些输出。剩余的参数 *argsf 的原始参数。规则 dce_rule 应返回一个与 f 的原始输出具有相同结构的 Pytree,但任何未使用的输出都可以替换为 None

例如

>>> @jax.experimental.custom_dce.custom_dce
... def f(x, y):
...   return jnp.sin(x) * y, x * jnp.sin(y)
...
>>> @f.def_dce
... def f_dce_rule(used_outs, x, y):
...   return (
...       jnp.sin(x) * y if used_outs[0] else None,
...       x * jnp.sin(y) if used_outs[1] else None,
...   )

在此示例中,used_outs 是一个包含两个 bool 值的 tuple,指示了需要哪些输出。DCE 规则只计算所需的输出,并将未使用的输出替换为 None

如果为 custom_dce 提供了 static_argnums 参数,则在追踪函数时,这些指定的参数将被视为静态参数,并且在调用 DCE 规则时会被移到前面。例如,如果 fun 接受两个参数 fun(x, y),并且 static_argnums(1,),那么 DCE 规则将被调用为 dce_rule(y, used_outs, x)

参数:
__init__(fun, *, static_argnums=())[源代码]#
参数:

方法

__init__(fun, *[, static_argnums])

def_dce(dce_rule)

为此函数定义自定义 DCE 规则。

属性

fun

static_argnums

dce_rule