jax.ensure_compile_time_eval#
- jax.ensure_compile_time_eval()[source]#
用于确保在跟踪/编译时进行评估(或报错)的上下文管理器。
某些 JAX API,例如
jax.jit()
和jax.lax.scan()
,涉及暂存(staging),即将数值表达式(如jax.numpy
函数应用)的评估延迟,以便它们的计算可以单独进行(例如在优化编译之后),而不是在评估相应的 Python 表达式时即时执行这些计算。但这种延迟可能不合需要。例如,数值可能需要用于评估 Python 控制流,因此其评估不能被延迟。另一个例子是,为了性能原因,确保编译时评估(或“常量折叠”)可能是有益的。此上下文管理器确保 JAX 计算被即时评估。如果无法即时评估,则会引发
ConcretizationTypeError
。这是一个刻意设计的例子
import jax import jax.numpy as jnp @jax.jit def f(x): with jax.ensure_compile_time_eval(): y = jnp.sin(3.0) z = jnp.sin(y) z_positive = z > 0 if z_positive: # z_positive is usable in Python control flow return jnp.sin(x) else: return jnp.cos(x)
这是一个来自 jax-ml/jax#3974 的真实世界示例
import jax import jax.numpy as jnp from jax import random @jax.jit def jax_fn(x): with jax.ensure_compile_time_eval(): y = random.randint(random.key(0), (1000,1000), 0, 100) y2 = y @ y x2 = jnp.sum(y2) * x return x2
通常可以通过简单地将常量表达式“提升”(hoisting)出相应的暂存 API 来实现类似的行为
y = random.randint(random.key(0), (1000,1000), 0, 100) @jax.jit def jax_fn(x): y2 = y @ y x2 = jnp.sum(y2)*x return x2
但在某些情况下,使用此上下文管理器可能更方便。