jax.disable_jit#

jax.disable_jit(disable=True)[source]#

一个上下文管理器,用于禁用 jit() 在其动态上下文中的行为。

出于调试目的,拥有一种能在动态上下文中全局禁用 jit() 的机制会很有用。请注意,这不仅会禁用用户显式使用的 jit(),还会移除 JAX 库中使用的任何隐式 JIT 编译:这包括传递给高级原语(如 scan()while_loop())的 bodycond 函数的隐式 JIT 计算,以及 jax.numpy 函数实现中使用的 JIT,以及 jit() 在 API 实现中使用的任何其他情况。但请注意,即使在 disable_jit 下,单个原语操作仍会像正常的即时逐操作执行一样由 XLA 编译。

对 JIT 编译函数参数具有数据依赖关系的值会被跟踪和抽象化。例如,一个抽象值可能是一个 ShapedArray 实例,它表示具有给定形状和数据类型的所有可能数组的集合,但并不表示具有特定值的具体数组。如果您在 JIT 编译函数中使用了无害的副作用操作(例如打印),您可能会注意到这些。

>>> import jax
>>>
>>> @jax.jit
... def f(x):
...   y = x * 2
...   print("Value of y is", y)
...   return y + 3
...
>>> print(f(jax.numpy.array([1, 2, 3])))
Value of y is JitTracer<int32[3]>
[5 7 9]

在这里,y 已被 jit() 抽象化为一个 ShapedArray,它表示一个具有固定形状和类型但值任意的数组。y 的值也被跟踪。如果我们在调试时想查看具体值,并同时避免跟踪器,我们可以使用 disable_jit() 上下文管理器

>>> import jax
>>>
>>> with jax.disable_jit():
...   print(f(jax.numpy.array([1, 2, 3])))
...
Value of y is [2 4 6]
[5 7 9]
参数:

disable (bool)