jax.disable_jit#
- jax.disable_jit(disable=True)[source]#
一个上下文管理器,用于禁用
jit()
在其动态上下文中的行为。出于调试目的,拥有一种能在动态上下文中全局禁用
jit()
的机制会很有用。请注意,这不仅会禁用用户显式使用的jit()
,还会移除 JAX 库中使用的任何隐式 JIT 编译:这包括传递给高级原语(如scan()
和while_loop()
)的 body 和 cond 函数的隐式 JIT 计算,以及jax.numpy
函数实现中使用的 JIT,以及jit()
在 API 实现中使用的任何其他情况。但请注意,即使在 disable_jit 下,单个原语操作仍会像正常的即时逐操作执行一样由 XLA 编译。对 JIT 编译函数参数具有数据依赖关系的值会被跟踪和抽象化。例如,一个抽象值可能是一个
ShapedArray
实例,它表示具有给定形状和数据类型的所有可能数组的集合,但并不表示具有特定值的具体数组。如果您在 JIT 编译函数中使用了无害的副作用操作(例如打印),您可能会注意到这些。>>> import jax >>> >>> @jax.jit ... def f(x): ... y = x * 2 ... print("Value of y is", y) ... return y + 3 ... >>> print(f(jax.numpy.array([1, 2, 3]))) Value of y is JitTracer<int32[3]> [5 7 9]
在这里,
y
已被jit()
抽象化为一个ShapedArray
,它表示一个具有固定形状和类型但值任意的数组。y
的值也被跟踪。如果我们在调试时想查看具体值,并同时避免跟踪器,我们可以使用disable_jit()
上下文管理器>>> import jax >>> >>> with jax.disable_jit(): ... print(f(jax.numpy.array([1, 2, 3]))) ... Value of y is [2 4 6] [5 7 9]
- 参数:
disable (bool)