jax.experimental.custom_dce 模块#

API#

custom_dce(fun, *[, static_argnums])

自定义 JAX 可转换函数的 DCE 行为。

custom_dce.def_dce(dce_rule)

为此函数定义自定义 DCE 规则。