jax.experimental.custom_dce.custom_dce#

class jax.experimental.custom_dce.custom_dce(fun, *, static_argnums=())[source]#

自定义 JAX 可转换函数的 DCE 行为。

JAX 使用死代码消除 (DCE) 从 JAX 程序中删除未使用的计算。当程序完全由已知的 JAX 操作指定时,这通常可以透明地工作,但像调用 pallas_call()ffi_call() 这样的不透明内核可能会导致问题。

在 JAX 中,DCE 在使用 jax.jit() 暂存函数时执行,因此在 Eager 模式下运行 JAX 时不会应用它。 同样,custom_dce 装饰器要求被装饰的函数和自定义 DCE 规则都与 jit() 兼容。

此装饰器允许用户通过定义自定义 DCE 规则来自定义函数的 DCE 行为。对于 custom_dce 包装的函数 f(*args),DCE 规则的签名是 dce_rule(used_outs, *args),其中 used_outs 是一个 Pytree,其结构与 f 的输出相同,并且每个叶子都是一个 bool,指示应该计算哪些输出。 剩余参数 *argsf 的原始参数。 规则 dce_rule 应该返回一个 Pytree,其结构与 f 的原始输出相同,但任何未使用的输出都可以替换为 None

例如

>>> @jax.experimental.custom_dce.custom_dce
... def f(x, y):
...   return jnp.sin(x) * y, x * jnp.sin(y)
...
>>> @f.def_dce
... def f_dce_rule(used_outs, x, y):
...   return (
...       jnp.sin(x) * y if used_outs[0] else None,
...       x * jnp.sin(y) if used_outs[1] else None,
...   )

在此示例中,used_outs 是一个带有两个 bool 值的 tuple,指示需要哪些输出。 DCE 规则仅计算所需的输出,并将未使用的输出替换为 None

如果 static_argnums 参数提供给 custom_dce,则在跟踪函数时,指示的参数将被视为静态,并且在调用 DCE 规则时,它们将被移动到前面。 例如,如果 fun 接受 2 个参数 fun(x, y),并且 static_argnums(1,),那么 DCE 规则将被调用为 dce_rule(y, used_outs, x)

参数:
__init__(fun, *, static_argnums=())[source]#
参数:

方法

__init__(fun, *[, static_argnums])

def_dce(dce_rule)

为此函数定义自定义 DCE 规则。

属性

fun

static_argnums

dce_rule