错误#
此页面列出了您在使用 JAX 时可能遇到的一些错误,以及修复这些错误的代表性示例。
- class jax.errors.ConcretizationTypeError(tracer, context='')#
当 JAX Tracer 对象在需要具体值的上下文中被使用时,就会出现此错误(有关 Tracer 的更多信息,请参阅 不同种类的 JAX 值)。在某些情况下,可以通过将有问题的数值标记为静态来轻松修复;在其他情况下,这可能表明您的程序正在执行 JAX 的 JIT 编译模型不支持的操作。
示例
- 需要静态值时使用的跟踪值
此错误的一个常见原因是,在需要静态值的地方使用了跟踪值。例如:
>>> from functools import partial >>> from jax import jit >>> import jax.numpy as jnp >>> @jit ... def func(x, axis): ... return x.min(axis)
>>> func(jnp.arange(4), 0) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: axis argument to jnp.min().
通常可以通过将有问题的参数标记为静态来修复此问题。
>>> @partial(jit, static_argnums=1) ... def func(x, axis): ... return x.min(axis) >>> func(jnp.arange(4), 0) Array(0, dtype=int32)
- 形状取决于跟踪值
当您在 JIT 编译的计算中的形状依赖于跟踪数量的值时,也可能出现此类错误。例如:
>>> @jit ... def func(x): ... return jnp.where(x < 0) >>> func(jnp.arange(4)) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: The error arose in jnp.nonzero.
这是与 JAX 的 JIT 编译模型不兼容的操作示例,该模型要求在编译时知道数组的大小。此处返回数组的大小取决于 x 的内容,因此无法对此类代码进行 JIT 编译。
在许多情况下,可以通过修改函数中使用的逻辑来解决此问题;例如,这是一段有类似问题的代码:
>>> @jit ... def func(x): ... indices = jnp.where(x > 1) ... return x[indices].sum() >>> func(jnp.arange(4)) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: The error arose in jnp.nonzero.
以下是如何表达相同操作的方式,该方式避免了创建动态大小的索引数组:
>>> @jit ... def func(x): ... return jnp.where(x > 1, x, 0).sum() >>> func(jnp.arange(4)) Array(5, dtype=int32)
要了解更多关于跟踪值与常规值,以及具体值与抽象值之间细微差别的知识,您可能需要阅读 不同种类的 JAX 值。
- 参数:
tracer (core.Tracer)
context (str)
- class jax.errors.KeyReuseError(message)#
当 PRNG 密钥被不安全地重用时,会发生此错误。仅当 jax_debug_key_reuse 设置为 True 时,才会检查密钥重用。
下面是一个会导致此类错误的代码简单示例。
>>> with jax.debug_key_reuse(True): ... key = jax.random.key(0) ... value = jax.random.uniform(key) ... new_value = jax.random.uniform(key) ... --------------------------------------------------------------------------- KeyReuseError Traceback (most recent call last) ... KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0
这种密钥重用是有问题的,因为 JAX PRNG 是无状态的,并且必须手动拆分密钥;有关更多信息,请参阅 伪随机数教程。
- 参数:
message (str)
- class jax.errors.JaxRuntimeError#
JAX 运行时引发的运行时错误。虽然 JAX 运行时也可能引发其他异常,但运行时引发的大多数异常都是此类的实例。
- class jax.errors.NonConcreteBooleanIndexError(tracer)#
当程序尝试在跟踪的索引操作中使用非具体的布尔索引时,会发生此错误。在 JIT 编译下,JAX 数组必须具有静态形状(即在编译时已知的形状),因此布尔掩码必须谨慎使用。通过布尔掩码实现的一些逻辑在
jax.jit()函数中根本无法实现;在其他情况下,逻辑可以以 JIT 兼容的方式重新表达,通常使用三参数版本的where()。以下是此错误可能出现的一些示例。
- 通过布尔掩码构建数组
当尝试在 JIT 上下文中通过布尔掩码创建数组时,通常会出现此问题。例如:
>>> import jax >>> import jax.numpy as jnp >>> @jax.jit ... def positive_values(x): ... return x[x > 0] >>> positive_values(jnp.arange(-5, 5)) Traceback (most recent call last): ... NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])
此函数尝试仅返回输入数组中的正值;除非将 x 标记为静态,否则无法在编译时确定此返回数组的大小,因此在 JIT 编译下无法执行此类操作。
- 可重新表达的布尔逻辑
虽然直接不支持创建动态大小的数组,但在许多情况下,可以将计算逻辑重新表达为 JIT 兼容的操作。例如,另一个因相同原因在 JIT 下失败的函数如下:
>>> @jax.jit ... def sum_of_positive(x): ... return x[x > 0].sum() >>> sum_of_positive(jnp.arange(-5, 5)) Traceback (most recent call last): ... NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])
在这种情况下,有问题的数组只是一个中间值,我们可以改用三参数版本的
where()来表达相同的逻辑。>>> @jax.jit ... def sum_of_positive(x): ... return jnp.where(x > 0, x, 0).sum() >>> sum_of_positive(jnp.arange(-5, 5)) Array(10, dtype=int32)
用三参数
where()替换布尔掩码是解决此类问题的常见方法。- 对 JAX 数组进行布尔索引
此错误经常出现的另一个情况是使用布尔索引,例如使用
.at[...].set(...)。下面是一个简单的例子:>>> @jax.jit ... def manual_clip(x): ... return x.at[x < 0].set(0) >>> manual_clip(jnp.arange(-2, 2)) Traceback (most recent call last): ... NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[4])
此函数尝试将小于零的值设置为标量填充值。与上面一样,可以通过将逻辑重新表达为
where()来解决。>>> @jax.jit ... def manual_clip(x): ... return jnp.where(x < 0, 0, x) >>> manual_clip(jnp.arange(-2, 2)) Array([0, 0, 0, 1], dtype=int32)
- 参数:
tracer (core.Tracer)
- class jax.errors.TracerArrayConversionError(tracer)#
当程序尝试将 JAX Tracer 对象转换为标准 NumPy 数组时,会发生此错误(有关 Tracer 的更多信息,请参阅 不同种类的 JAX 值)。它通常发生在以下几种情况之一。
- 在 JAX 转换中使用非 JAX 函数
如果在 JAX 转换(
jit()、grad()、jax.vmap()等)内部尝试使用numpy或scipy等非 JAX 库,则可能发生此错误。例如:>>> from jax import jit >>> import numpy as np >>> @jit ... def func(x): ... return np.sin(x) >>> func(np.arange(4)) Traceback (most recent call last): ... TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int32[4]
在这种情况下,可以通过使用
jax.numpy.sin()代替numpy.sin()来解决此问题。>>> import jax.numpy as jnp >>> @jit ... def func(x): ... return jnp.sin(x) >>> func(jnp.arange(4)) Array([0. , 0.84147096, 0.9092974 , 0.14112 ], dtype=float32)
另请参阅 外部回调,了解有关从转换后的 JAX 代码调用宿主计算的选项。
- 使用 Tracer 索引 numpy 数组
如果此错误发生在涉及数组索引的行上,则可能是正在索引的数组
x是标准的 numpy.ndarray,而索引idx是跟踪的 JAX 数组。例如:>>> x = np.arange(10) >>> @jit ... def func(i): ... return x[i] >>> func(0) Traceback (most recent call last): ... TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int32[0]
根据具体情况,您可以通过将 numpy 数组转换为 JAX 数组来修复此问题:
>>> @jit ... def func(i): ... return jnp.asarray(x)[i] >>> func(0) Array(0, dtype=int32)
或者将索引声明为静态参数:
>>> from functools import partial >>> @partial(jit, static_argnums=(0,)) ... def func(i): ... return x[i] >>> func(0) Array(0, dtype=int32)
要了解更多关于跟踪值与常规值,以及具体值与抽象值之间细微差别的知识,您可能需要阅读 不同种类的 JAX 值。
- 参数:
tracer (core.Tracer)
- class jax.errors.TracerBoolConversionError(tracer)#
当 JAX 中的跟踪值在需要布尔值的上下文中被使用时,会发生此错误(有关 Tracer 的更多信息,请参阅 不同种类的 JAX 值)。
布尔转换可能是显式的(例如
bool(x))或隐式的,通过使用控制流(例如if x > 0或while x)、使用 Python 布尔运算符(例如z = x and y、z = x or y、z = not x)或使用它们的函数(例如z = max(x, y)、z = min(x, y)等)。在某些情况下,可以通过将跟踪值标记为静态来轻松解决此问题;在其他情况下,这可能表明您的程序正在执行 JAX 的 JIT 编译模型不支持的操作。
示例
- 在控制流中使用跟踪值
当跟踪值用于 Python 控制流时,经常会出现这种情况。例如:
>>> from jax import jit >>> import jax.numpy as jnp >>> @jit ... def func(x, y): ... return x if x.sum() < y.sum() else y >>> func(jnp.ones(4), jnp.zeros(4)) Traceback (most recent call last): ... TracerBoolConversionError: Attempted boolean conversion of JAX Tracer [...]
我们可以将输入
x和y都标记为静态,但这会破坏在此处使用jax.jit()的目的。另一个选择是将 if 语句重新表达为三项jax.numpy.where()。>>> @jit ... def func(x, y): ... return jnp.where(x.sum() < y.sum(), x, y) >>> func(jnp.ones(4), jnp.zeros(4)) Array([0., 0., 0., 0.], dtype=float32)
有关包括循环在内的更复杂的控制流,请参阅 控制流运算符。
- 对跟踪值的控制流
此错误的另一个常见原因是您无意中跟踪了布尔标志。例如:
>>> @jit ... def func(x, normalize=True): ... if normalize: ... return x / x.sum() ... return x >>> func(jnp.arange(5), True) Traceback (most recent call last): ... TracerBoolConversionError: Attempted boolean conversion of JAX Tracer ...
在这里,由于标志
normalize被跟踪,因此不能在 Python 控制流中使用它。在这种情况下,最好的解决方案可能是将此值标记为静态。>>> from functools import partial >>> @partial(jit, static_argnames=['normalize']) ... def func(x, normalize=True): ... if normalize: ... return x / x.sum() ... return x >>> func(jnp.arange(5), True) Array([0. , 0.1, 0.2, 0.3, 0.4], dtype=float32)
有关
static_argnums的更多信息,请参阅jax.jit()的文档。- 使用非 JAX 感知函数
此错误的另一个常见原因是,在 JAX 代码中使用非 JAX 感知函数。例如:
>>> @jit ... def func(x): ... return min(x, 0)
>>> func(2) Traceback (most recent call last): ... TracerBoolConversionError: Attempted boolean conversion of JAX Tracer ...
在这种情况下,错误发生是因为 Python 的内置
min函数与 JAX 转换不兼容。可以通过将其替换为jnp.minimum来修复。>>> @jit ... def func(x): ... return jnp.minimum(x, 0)
>>> print(func(2)) 0
要了解更多关于跟踪值与常规值,以及具体值与抽象值之间细微差别的知识,您可能需要阅读 不同种类的 JAX 值。
- 参数:
tracer (core.Tracer)
- class jax.errors.TracerIntegerConversionError(tracer)#
当 JAX Tracer 对象在需要 Python 整数的上下文中被使用时,可能会发生此错误(有关 Tracer 的更多信息,请参阅 不同种类的 JAX 值)。它通常发生在以下几种情况。
- 将 Tracer 传递代替整数
如果您尝试将跟踪值传递给需要静态整数参数的函数,可能会发生此错误;例如:
>>> from jax import jit >>> import numpy as np >>> @jit ... def func(x, axis): ... return np.split(x, 2, axis) >>> func(np.arange(4), 0) Traceback (most recent call last): ... TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[0]
发生这种情况时,解决方案通常是将有问题的参数标记为静态。
>>> from functools import partial >>> @partial(jit, static_argnums=1) ... def func(x, axis): ... return np.split(x, 2, axis) >>> func(np.arange(10), 0) [Array([0, 1, 2, 3, 4], dtype=int32), Array([5, 6, 7, 8, 9], dtype=int32)]
另一种方法是将转换应用于封装要保护的参数的闭包,无论是手动操作(如下所示)还是使用
functools.partial()。>>> jit(lambda arr: np.split(arr, 2, 0))(np.arange(4)) [Array([0, 1], dtype=int32), Array([2, 3], dtype=int32)]
请注意,每次调用都会创建一个新的闭包,这会破坏编译缓存机制,因此首选 static_argnums。
- 使用 Tracer 索引列表
如果您尝试使用跟踪数量来索引 Python 列表,可能会发生此错误。例如:
>>> import jax.numpy as jnp >>> from jax import jit >>> L = [1, 2, 3] >>> @jit ... def func(i): ... return L[i] >>> func(0) Traceback (most recent call last): ... TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[0]
根据具体情况,通常可以通过将列表转换为 JAX 数组来解决此问题:
>>> @jit ... def func(i): ... return jnp.array(L)[i] >>> func(0) Array(1, dtype=int32)
或者将索引声明为静态参数:
>>> from functools import partial >>> @partial(jit, static_argnums=0) ... def func(i): ... return L[i] >>> func(0) Array(1, dtype=int32, weak_type=True)
要了解更多关于跟踪值与常规值,以及具体值与抽象值之间细微差别的知识,您可能需要阅读 不同种类的 JAX 值。
- 参数:
tracer (core.Tracer)
- class jax.errors.UnexpectedTracerError(msg)#
当您使用了一个已从函数中“泄漏”出来的 JAX 值时,会发生此错误。什么是泄漏值?如果您对函数
f使用 JAX 转换,而该函数在f之外的某个作用域中存储了对中间值的引用,那么该值将被视为已泄漏。泄漏值是一种副作用。(有关避免副作用的更多信息,请参阅 纯函数)JAX 在您稍后将泄漏的值用于另一个操作时检测到泄漏,此时会引发
UnexpectedTracerError。要解决此问题,请避免副作用:如果一个函数计算了一个在外层作用域需要的值,请显式地从转换后的函数中返回该值。具体来说,
Tracer是 JAX 在转换期间(例如,在jit()、pmap()、vmap()等)表示函数中间值的内部表示。在转换之外遇到Tracer暗示着泄漏。- 泄漏值的生命周期
考虑以下转换后的函数泄漏值到外层作用域的示例:
>>> from jax import jit >>> import jax.numpy as jnp >>> outs = [] >>> @jit # 1 ... def side_effecting(x): ... y = x + 1 # 3 ... outs.append(y) # 4 >>> x = 1 >>> side_effecting(x) # 2 >>> outs[0] + 1 # 5 Traceback (most recent call last): ... UnexpectedTracerError: Encountered an unexpected tracer.
在此示例中,我们将一个跟踪值从内部转换作用域泄漏到外层作用域。当使用泄漏值时,我们会得到一个
UnexpectedTracerError,而不是在值泄漏时。此示例还演示了泄漏值的生命周期:
函数被转换(在本例中,通过
jit())。调用转换后的函数(启动函数的抽象跟踪,并将
x转换为Tracer)。创建中间值
y,该值稍后将被泄漏(跟踪函数的中间值也是一个Tracer)。值被泄漏(附加到外层作用域的列表中,通过侧通道逃逸函数)。
泄漏的值被使用,并引发 UnexpectedTracerError。
UnexpectedTracerError 消息试图通过包含每个阶段的信息来指向代码中的这些位置。分别:
转换后的函数(
side_effecting)的名称以及哪个转换启动了跟踪(jit())。泄漏的 Tracer 被创建处的重构堆栈跟踪,其中包括转换后的函数被调用的位置。(
When the Tracer was created, the final 5 stack frames were...)。从重构的堆栈跟踪中,创建泄漏 Tracer 的代码行。
泄漏位置未包含在错误消息中,因为它很难确定!JAX 只能告诉您泄漏值是什么样子(它有什么形状以及在哪里创建的),以及它越过了哪个边界(转换的名称和转换后的函数的名称)。
当前错误的堆栈跟踪指向值被使用的地方。
通过从转换后的函数中返回该值来修复此错误。
>>> from jax import jit >>> import jax.numpy as jnp >>> outs = [] >>> @jit ... def not_side_effecting(x): ... y = x+1 ... return y >>> x = 1 >>> y = not_side_effecting(x) >>> outs.append(y) >>> outs[0] + 1 # all good! no longer a leaked value. Array(3, dtype=int32, weak_type=True)
- 泄漏检查器
如上面第 2 点和第 3 点所述,JAX 显示了一个重构的堆栈跟踪,指向泄漏值被创建的位置。这是因为 JAX 仅在泄漏值被使用时才引发错误,而不是在值泄漏时。这不是引发此错误的最佳位置,因为您需要知道 Tracer 被泄漏的位置才能修复错误。
为了更容易地追踪到这个位置,您可以使用泄漏检查器。当启用泄漏检查器时,一旦
Tracer被泄漏,就会引发错误。(更准确地说,当从该Tracer被泄漏的转换后的函数返回时,它会引发错误)。要启用泄漏检查器,您可以使用
JAX_CHECK_TRACER_LEAKS环境变量或with jax.checking_leaks()上下文管理器。注意
请注意,此工具是实验性的,可能会报告误报。它通过禁用一些 JAX 缓存来工作,因此会对性能产生负面影响,并且只应在调试时使用。
示例用法
>>> from jax import jit >>> import jax.numpy as jnp >>> outs = [] >>> @jit ... def side_effecting(x): ... y = x+1 ... outs.append(y) >>> x = 1 >>> with jax.checking_leaks(): ... y = side_effecting(x) Traceback (most recent call last): ... Exception: Leaked Trace
- 参数:
msg (str)