jax.disable_jit#

jax.disable_jit(disable=True)[source]#

上下文管理器,用于在其动态上下文中禁用 jit() 行为。

为了进行调试,拥有一个机制来禁用动态上下文中所有位置的 jit() 非常有用。 请注意,这不仅会禁用用户显式使用 jit(),还会删除 JAX 库使用的任何隐式 JIT 编译: 这包括传递给高级原语(如 scan()while_loop())的 bodycond 函数的隐式 JIT 计算、jax.numpy 函数实现中使用的 JIT 以及 API 实现中任何其他使用 jit() 的情况。 但请注意,即使在 disable_jit 下,各个原始操作仍将由 XLA 编译,就像在正常的 eager op-by-op 执行中一样。

数据依赖于 jitted 函数参数的值会被跟踪和抽象。 例如,抽象值可以是 ShapedArray 实例,表示具有给定形状和 dtype 的所有可能数组的集合,但不表示具有特定值的具体数组。 如果你在 jitted 函数中使用良性副作用操作(例如 print),你可能会注意到这些。

>>> import jax
>>>
>>> @jax.jit
... def f(x):
...   y = x * 2
...   print("Value of y is", y)
...   return y + 3
...
>>> print(f(jax.numpy.array([1, 2, 3])))  
Value of y is Traced<ShapedArray(int32[3])>with<DynamicJaxprTrace...>
[5 7 9]

在这里,y 已被 jit() 抽象为 ShapedArray,它表示具有固定形状和类型但具有任意值的数组。 y 的值也被跟踪。 如果我们想在调试时看到一个具体的值,并且也避免 tracer,我们可以使用 disable_jit() 上下文管理器。

>>> import jax
>>>
>>> with jax.disable_jit():
...   print(f(jax.numpy.array([1, 2, 3])))
...
Value of y is [2 4 6]
[5 7 9]
参数:

disable (bool)