错误#

此页面列出了一些您在使用 JAX 时可能遇到的错误,以及如何修复这些错误的代表性示例。

class jax.errors.ConcretizationTypeError(tracer, context='')#

当 JAX Tracer 对象在需要具体值的上下文中使用时,会发生此错误(有关 Tracer 是什么的更多信息,请参阅不同类型的 JAX 值)。在某些情况下,可以通过将有问题的标记值标记为静态来轻松修复;在其他情况下,它可能表明您的程序正在执行 JAX 的 JIT 编译模型不直接支持的操作。

示例

在需要静态值的地方使用了跟踪值

此错误的一个常见原因是,在需要静态值的地方使用了跟踪值。例如

>>> from functools import partial
>>> from jax import jit
>>> import jax.numpy as jnp
>>> @jit
... def func(x, axis):
...   return x.min(axis)
>>> func(jnp.arange(4), 0)  
Traceback (most recent call last):
    ...
ConcretizationTypeError: Abstract tracer value encountered where concrete
value is expected: axis argument to jnp.min().

通常可以通过将有问题的参数标记为静态来修复此问题

>>> @partial(jit, static_argnums=1)
... def func(x, axis):
...   return x.min(axis)

>>> func(jnp.arange(4), 0)
Array(0, dtype=int32)
形状依赖于跟踪值

当 JIT 编译计算中的形状依赖于跟踪量中的值时,也可能出现此类错误。例如

>>> @jit
... def func(x):
...     return jnp.where(x < 0)

>>> func(jnp.arange(4))  
Traceback (most recent call last):
    ...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:
The error arose in jnp.nonzero.

这是一个与 JAX 的 JIT 编译模型不兼容的操作示例,该模型要求数组大小在编译时已知。在这里,返回数组的大小取决于 x 的内容,并且此类代码无法进行 JIT 编译。

在许多情况下,可以通过修改函数中使用的逻辑来解决此问题;例如,这是一个具有类似问题的代码

>>> @jit
... def func(x):
...     indices = jnp.where(x > 1)
...     return x[indices].sum()

>>> func(jnp.arange(4))  
Traceback (most recent call last):
    ...
ConcretizationTypeError: Abstract tracer value encountered where concrete
value is expected: The error arose in jnp.nonzero.

这是您如何以避免创建动态大小的索引数组的方式表达相同操作的方法

>>> @jit
... def func(x):
...   return jnp.where(x > 1, x, 0).sum()

>>> func(jnp.arange(4))
Array(5, dtype=int32)

要了解更多关于 tracer 与常规值以及具体值与抽象值之间的细微差别,您可能需要阅读不同类型的 JAX 值

参数:
  • tracer (core.Tracer)

  • context (str)

class jax.errors.KeyReuseError(message)#

当以不安全的方式重用 PRNG 键时,会发生此错误。仅当 jax_debug_key_reuse 设置为 True 时,才会检查键重用。

这是一个会导致此类错误的简单代码示例

>>> with jax.debug_key_reuse(True):  
...   key = jax.random.key(0)
...   value = jax.random.uniform(key)
...   new_value = jax.random.uniform(key)
...
---------------------------------------------------------------------------
KeyReuseError                             Traceback (most recent call last)
...
KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0

这种键重用是有问题的,因为 JAX PRNG 是无状态的,并且键必须手动拆分;有关此的更多信息,请参阅伪随机数教程

参数:

message (str)

jax.errors.JaxRuntimeError#

别名 XlaRuntimeError

class jax.errors.NonConcreteBooleanIndexError(tracer)#

当程序尝试在跟踪索引操作中使用非具体的布尔索引时,会发生此错误。在 JIT 编译下,JAX 数组必须具有静态形状(即在编译时已知的形状),因此必须谨慎使用布尔掩码。某些通过布尔掩码实现的逻辑在 jax.jit() 函数中根本不可能实现;在其他情况下,可以使用与 JIT 兼容的方式重新表达逻辑,通常使用 where() 的三参数版本。

以下是一些可能发生此错误的示例。

通过布尔掩码构造数组

当尝试在 JIT 上下文中通过布尔掩码创建数组时,最常出现这种情况。例如

>>> import jax
>>> import jax.numpy as jnp

>>> @jax.jit
... def positive_values(x):
...   return x[x > 0]

>>> positive_values(jnp.arange(-5, 5))  
Traceback (most recent call last):
    ...
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])

此函数尝试仅返回输入数组中的正值;除非 x 标记为静态,否则无法在编译时确定此返回数组的大小,因此此类操作无法在 JIT 编译下执行。

可重新表达的布尔逻辑

虽然不直接支持创建动态大小的数组,但在许多情况下,可以根据与 JIT 兼容的操作重新表达计算的逻辑。例如,这是另一个由于相同原因在 JIT 下失败的函数

>>> @jax.jit
... def sum_of_positive(x):
...   return x[x > 0].sum()

>>> sum_of_positive(jnp.arange(-5, 5))  
Traceback (most recent call last):
    ...
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])

但是,在这种情况下,有问题的数组只是一个中间值,我们可以根据 jax.numpy.where() 的与 JIT 兼容的三参数版本来表达相同的逻辑

>>> @jax.jit
... def sum_of_positive(x):
...   return jnp.where(x > 0, x, 0).sum()

>>> sum_of_positive(jnp.arange(-5, 5))
Array(10, dtype=int32)

用三参数 where() 替换布尔掩码是解决此类问题的常用方法。

布尔索引到 JAX 数组中

经常出现此错误的另一种情况是使用布尔索引时,例如使用 .at[...].set(...)。这是一个简单的例子

>>> @jax.jit
... def manual_clip(x):
...   return x.at[x < 0].set(0)

>>> manual_clip(jnp.arange(-2, 2))  
Traceback (most recent call last):
    ...
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[4])

此函数尝试将小于零的值设置为标量填充值。如上所述,可以通过根据 where() 重新表达逻辑来解决此问题

>>> @jax.jit
... def manual_clip(x):
...   return jnp.where(x < 0, 0, x)

>>> manual_clip(jnp.arange(-2, 2))
Array([0, 0, 0, 1], dtype=int32)
参数:

tracer (core.Tracer)

class jax.errors.TracerArrayConversionError(tracer)#

当程序尝试将 JAX Tracer 对象转换为标准 NumPy 数组时,会发生此错误(有关 Tracer 是什么的更多信息,请参阅不同类型的 JAX 值)。它通常在以下几种情况下发生。

在 JAX 转换中使用非 JAX 函数

如果您尝试在 JAX 转换(jit()grad()jax.vmap() 等)内部使用非 JAX 库(如 numpyscipy)时,可能会发生此错误。例如

>>> from jax import jit
>>> import numpy as np

>>> @jit
... def func(x):
...   return np.sin(x)

>>> func(np.arange(4))  
Traceback (most recent call last):
    ...
TracerArrayConversionError: The numpy.ndarray conversion method
__array__() was called on traced array with shape int32[4]

在这种情况下,您可以通过使用 jax.numpy.sin() 代替 numpy.sin() 来修复此问题

>>> import jax.numpy as jnp
>>> @jit
... def func(x):
...   return jnp.sin(x)

>>> func(jnp.arange(4))
Array([0.        , 0.84147096, 0.9092974 , 0.14112   ], dtype=float32)

另请参阅外部回调,了解从转换后的 JAX 代码回调到主机端计算的选项。

使用 tracer 索引 numpy 数组

如果此错误发生在涉及数组索引的行上,则可能是被索引的数组 x 是标准 numpy.ndarray,而索引 idx 是跟踪的 JAX 数组。例如

>>> x = np.arange(10)

>>> @jit
... def func(i):
...   return x[i]

>>> func(0)  
Traceback (most recent call last):
    ...
TracerArrayConversionError: The numpy.ndarray conversion method
__array__() was called on traced array with shape int32[0]

根据上下文,您可以通过将 numpy 数组转换为 JAX 数组来修复此问题

>>> @jit
... def func(i):
...   return jnp.asarray(x)[i]

>>> func(0)
Array(0, dtype=int32)

或通过将索引声明为静态参数

>>> from functools import partial
>>> @partial(jit, static_argnums=(0,))
... def func(i):
...   return x[i]

>>> func(0)
Array(0, dtype=int32)

要了解更多关于 tracer 与常规值以及具体值与抽象值之间的细微差别,您可能需要阅读不同类型的 JAX 值

参数:

tracer (core.Tracer)

class jax.errors.TracerBoolConversionError(tracer)#

当 JAX 中的跟踪值在需要布尔值的上下文中使用时,会发生此错误(有关 Tracer 是什么的更多信息,请参阅不同类型的 JAX 值)。

布尔转换可能是显式的(例如 bool(x))或隐式的,通过使用控制流(例如 if x > 0while x)、使用 Python 布尔运算符(例如 z = x and yz = x or yz = not x)或使用它们的函数(例如 z = max(x, y)z = min(x, y) 等)。

在某些情况下,可以通过将跟踪值标记为静态来轻松修复此问题;在其他情况下,它可能表明您的程序正在执行 JAX 的 JIT 编译模型不直接支持的操作。

示例

在控制流中使用的跟踪值

经常出现这种情况的一种情况是,在 Python 控制流中使用了跟踪值。例如

>>> from jax import jit
>>> import jax.numpy as jnp
>>> @jit
... def func(x, y):
...   return x if x.sum() < y.sum() else y

>>> func(jnp.ones(4), jnp.zeros(4))  
Traceback (most recent call last):
    ...
TracerBoolConversionError: Attempted boolean conversion of JAX Tracer [...]

我们可以将输入 xy 都标记为静态,但这会破坏此处使用 jax.jit() 的目的。另一种选择是以三项 jax.numpy.where() 的形式重新表达 if 语句

>>> @jit
... def func(x, y):
...   return jnp.where(x.sum() < y.sum(), x, y)

>>> func(jnp.ones(4), jnp.zeros(4))
Array([0., 0., 0., 0.], dtype=float32)

对于包括循环在内的更复杂的控制流,请参阅控制流运算符

对跟踪值进行控制流

此错误的另一个常见原因是,如果您不小心跟踪了布尔标志。例如

>>> @jit
... def func(x, normalize=True):
...   if normalize:
...     return x / x.sum()
...   return x

>>> func(jnp.arange(5), True)  
Traceback (most recent call last):
    ...
TracerBoolConversionError: Attempted boolean conversion of JAX Tracer ...

在这里,由于标志 normalize 被跟踪,因此无法在 Python 控制流中使用它。在这种情况下,最好的解决方案可能是将此值标记为静态

>>> from functools import partial
>>> @partial(jit, static_argnames=['normalize'])
... def func(x, normalize=True):
...   if normalize:
...     return x / x.sum()
...   return x

>>> func(jnp.arange(5), True)
Array([0. , 0.1, 0.2, 0.3, 0.4], dtype=float32)

有关 static_argnums 的更多信息,请参阅 jax.jit() 的文档。

使用非 JAX 感知函数

此错误的另一个常见原因是,在 JAX 代码中使用非 JAX 感知函数。例如

>>> @jit
... def func(x):
...   return min(x, 0)
>>> func(2)  
Traceback (most recent call last):
    ...
TracerBoolConversionError: Attempted boolean conversion of JAX Tracer ...

在这种情况下,发生错误是因为 Python 的内置 min 函数与 JAX 转换不兼容。可以通过将其替换为 jnp.minumum 来修复此问题

>>> @jit
... def func(x):
...   return jnp.minimum(x, 0)
>>> print(func(2))
0

要了解更多关于 tracer 与常规值以及具体值与抽象值之间的细微差别,您可能需要阅读不同类型的 JAX 值

参数:

tracer (core.Tracer)

class jax.errors.TracerIntegerConversionError(tracer)#

当 JAX Tracer 对象在需要 Python 整数的上下文中使用时,可能会发生此错误(有关 Tracer 是什么的更多信息,请参阅不同类型的 JAX 值)。它通常在以下几种情况下发生。

传递 tracer 代替整数

如果您尝试将跟踪值传递给需要静态整数参数的函数,则可能会发生此错误;例如

>>> from jax import jit
>>> import numpy as np

>>> @jit
... def func(x, axis):
...   return np.split(x, 2, axis)

>>> func(np.arange(4), 0)  
Traceback (most recent call last):
    ...
TracerIntegerConversionError: The __index__() method was called on
traced array with shape int32[0]

发生这种情况时,解决方案通常是将有问题的参数标记为静态

>>> from functools import partial
>>> @partial(jit, static_argnums=1)
... def func(x, axis):
...   return np.split(x, 2, axis)

>>> func(np.arange(10), 0)
[Array([0, 1, 2, 3, 4], dtype=int32),
 Array([5, 6, 7, 8, 9], dtype=int32)]

另一种方法是将转换应用于封装要保护的参数的闭包,可以手动执行,如下所示,也可以使用 functools.partial()

>>> jit(lambda arr: np.split(arr, 2, 0))(np.arange(4))
[Array([0, 1], dtype=int32), Array([2, 3], dtype=int32)]

请注意,每次调用都会创建一个新的闭包,这会破坏编译缓存机制,这就是为什么首选 static_argnums 的原因。

使用 Tracer 索引列表

如果您尝试使用跟踪量索引 Python 列表,则可能会发生此错误。例如

>>> import jax.numpy as jnp
>>> from jax import jit

>>> L = [1, 2, 3]

>>> @jit
... def func(i):
...   return L[i]

>>> func(0)  
Traceback (most recent call last):
    ...
TracerIntegerConversionError: The __index__() method was called on
traced array with shape int32[0]

根据上下文,您通常可以通过将列表转换为 JAX 数组来修复此问题

>>> @jit
... def func(i):
...   return jnp.array(L)[i]

>>> func(0)
Array(1, dtype=int32)

或通过将索引声明为静态参数

>>> from functools import partial
>>> @partial(jit, static_argnums=0)
... def func(i):
...   return L[i]

>>> func(0)
Array(1, dtype=int32, weak_type=True)

要了解更多关于 tracer 与常规值以及具体值与抽象值之间的细微差别,您可能需要阅读不同类型的 JAX 值

参数:

tracer (core.Tracer)

class jax.errors.UnexpectedTracerError(msg)#

当您使用已从函数中泄漏出来的 JAX 值时,会发生此错误。泄漏值是什么意思?如果您对函数 f 使用 JAX 转换,该函数在 f 之外的某个作用域中存储对中间值的引用,则该值被视为已泄漏。泄漏值是一种副作用。(阅读有关避免副作用的更多信息,请参阅纯函数

当您稍后在另一个操作中使用泄漏的值时,JAX 会检测到泄漏,此时它会引发 UnexpectedTracerError。要修复此问题,请避免副作用:如果函数计算了外部作用域中需要的值,请从转换后的函数中显式返回该值。

具体来说,Tracer 是 JAX 在转换期间函数中间值的内部表示,例如在 jit()pmap()vmap() 等中。在转换之外遇到 Tracer 意味着泄漏。

泄漏值的生命周期

考虑以下转换函数的示例,该函数将值泄漏到外部作用域

>>> from jax import jit
>>> import jax.numpy as jnp

>>> outs = []
>>> @jit                   # 1
... def side_effecting(x):
...   y = x + 1            # 3
...   outs.append(y)       # 4

>>> x = 1
>>> side_effecting(x)      # 2
>>> outs[0] + 1            # 5  
Traceback (most recent call last):
    ...
UnexpectedTracerError: Encountered an unexpected tracer.

在此示例中,我们将 Traced 值从内部转换作用域泄漏到外部作用域。当使用泄漏的值时,而不是在泄漏值时,我们会得到 UnexpectedTracerError

此示例还演示了泄漏值的生命周期

  1. 函数被转换(在本例中,通过 jit()

  2. 调用转换后的函数(启动函数的抽象跟踪并将 x 转换为 Tracer

  3. 创建中间值 y,稍后将泄漏该值(跟踪函数的中间值也是 Tracer

  4. 值泄漏(附加到外部作用域中的列表,通过侧通道逃逸函数)

  5. 使用泄漏的值,并引发 UnexpectedTracerError。

UnexpectedTracerError 消息尝试通过包含有关每个阶段的信息来指向代码中的这些位置。分别

  1. 转换函数的名称 (side_effecting) 以及哪个转换启动了跟踪 jit()

  2. 泄漏的 Tracer 创建位置的重建堆栈跟踪,其中包括调用转换函数的位置。( Tracer 被创建时,最后的 5 个堆栈帧是...)。

  3. 从重建的堆栈跟踪中,创建泄漏的 Tracer 的代码行。

  4. 泄漏位置未包含在错误消息中,因为它很难确定!JAX 只能告诉您泄漏的值是什么样子(它具有什么形状以及在何处创建)以及它泄漏到哪个边界(转换的名称和转换后的函数的名称)。

  5. 当前错误的堆栈跟踪会指向值被使用的位置。

通过从转换后的函数返回值,可以修复此错误。

>>> from jax import jit
>>> import jax.numpy as jnp

>>> outs = []
>>> @jit
... def not_side_effecting(x):
...   y = x+1
...   return y

>>> x = 1
>>> y = not_side_effecting(x)
>>> outs.append(y)
>>> outs[0] + 1  # all good! no longer a leaked value.
Array(3, dtype=int32, weak_type=True)
泄漏检查器

正如上面第 2 点和第 3 点所讨论的,JAX 显示的重建堆栈跟踪会指向泄漏值的创建位置。 这是因为 JAX 只在泄漏值被使用时才抛出错误,而不是在值被泄漏时。 这不是抛出此错误的最佳位置,因为你需要知道 Tracer 泄漏的位置才能修复错误。

为了更容易追踪到这个位置,你可以使用泄漏检查器。 当泄漏检查器启用后,一旦 Tracer 泄漏,就会立即抛出错误。(更准确地说,它会在 Tracer 泄漏的转换函数返回时抛出错误)

要启用泄漏检查器,你可以使用 JAX_CHECK_TRACER_LEAKS 环境变量或 with jax.checking_leaks() 上下文管理器。

注意

请注意,此工具是实验性的,可能会报告误报。 它的工作原理是禁用一些 JAX 缓存,因此会对性能产生负面影响,应仅在调试时使用。

示例用法

>>> from jax import jit
>>> import jax.numpy as jnp

>>> outs = []
>>> @jit
... def side_effecting(x):
...   y = x+1
...   outs.append(y)

>>> x = 1
>>> with jax.checking_leaks():
...   y = side_effecting(x)  
Traceback (most recent call last):
    ...
Exception: Leaked Trace
参数:

msg (字符串)