jax.lax.optimization_barrier#
- jax.lax.optimization_barrier(operand, /)[source]#
阻止编译器跨越 barrier 移动操作。
优化 barrier 有许多可能的用途
优化 barrier 确保在任何依赖于 barrier 输出的运算符之前评估所有输入。这可以用于强制执行特定的操作顺序。
优化 barrier 阻止公共子表达式消除。JAX 使用它来实现重物化。
优化 barrier 阻止编译器融合。也就是说,barrier 之前的操作可能不会与 barrier 之后的操作融合到同一个内核中,这是由编译器决定的。
JAX 没有为优化 barrier 定义导数或批处理规则。
优化 barrier 在编译函数之外不起作用。
- 参数:
operand – JAX 值的 pytree。
- 返回:
JAX 值的 pytree,具有与
operand
相同的结构和内容。
示例
防止对 sin 的两次调用之间进行公共子表达式消除
>>> def f(x): ... return jax.lax.optimization_barrier(jax.lax.sin(x)) + jax.lax.sin(x) >>> jax.jit(f)(0.) Array(0., dtype=float32, weak_type=True)