jax.experimental.custom_dce.custom_dce.def_dce#

custom_dce.def_dce(dce_rule)[源代码]#

为此函数定义自定义 DCE 规则。

参数:

dce_rule (Callable[[...], Any]) – 一个函数,该函数接受 (a) 使用 static_argnums 指定的任何静态参数,(b) 一个表示应计算哪些输出的 bool 值 Pytree (used_outs),以及 (c) 原始函数的其余 (非静态) 参数。该规则应返回一个 Pytree,其结构与原始函数的输出相同,但任何未使用的输出 (如 used_outs 所指示) 都可以替换为 None

返回类型:

Callable[[…], Any]