常见问题解答 (FAQ)#

我们在此收集常见问题的答案。欢迎投稿!

jit 改变了我的函数的行为#

如果您有一个 Python 函数在使用 jax.jit() 后行为发生改变,可能是因为您的函数使用了全局状态或产生了副作用。在下面的代码中,impure_func 使用了全局变量 y,并且由于 print 而产生副作用。

y = 0

# @jit   # Different behavior with jit
def impure_func(x):
  print("Inside:", y)
  return x + y

for y in range(3):
  print("Result:", impure_func(y))

没有 jit 时的输出是

Inside: 0
Result: 0
Inside: 1
Result: 2
Inside: 2
Result: 4

使用 jit 时的输出是

Inside: 0
Result: 0
Result: 1
Result: 2

对于 jax.jit(),函数会使用 Python 解释器执行一次,此时会发生 Inside 的打印,并观察到 y 的第一个值。然后,函数会被编译并缓存,并使用不同的 x 值执行多次,但 y 的第一个值保持不变。

附加阅读

jit 改变了输出的精确数值#

有时用户会对使用 jit() 包装函数会改变函数的输出感到惊讶。例如

>>> from jax import jit
>>> import jax.numpy as jnp
>>> def f(x):
...   return jnp.log(jnp.sqrt(x))
>>> x = jnp.pi
>>> print(f(x))
0.572365
>>> print(jit(f)(x))
0.5723649

输出的这种细微差异源于 XLA 编译器内部的优化:在编译过程中,XLA 有时会重新排列或消除某些操作,以使整个计算更有效率。

在这种情况下,XLA 利用对数函数的性质,将 log(sqrt(x)) 替换为 0.5 * log(x),这是一个数学上相同的表达式,并且比原始表达式计算起来更有效率。输出的差异来自于浮点数算术只是对实数运算的近似,因此计算相同表达式的不同方式可能会产生细微不同的结果。

有时,XLA 的优化可能导致更显著的差异。考虑以下示例

>>> def f(x):
...   return jnp.log(jnp.exp(x))
>>> x = 100.0
>>> print(f(x))
inf
>>> print(jit(f)(x))
100.0

在非 JIT 编译的逐个操作模式下,结果是 inf,因为 jnp.exp(x) 溢出并返回 inf。然而,在 JIT 下,XLA 识别出 logexp 的逆运算,并从编译的函数中移除这些操作,只返回输入。在这种情况下,JIT 编译产生了对实际结果更准确的浮点近似。

不幸的是,XLA 代数简化规则的完整列表并未得到充分记录,但如果您熟悉 C++ 并对 XLA 编译器进行的优化类型感到好奇,可以在源代码中查看:algebraic_simplifier.cc

带有 jit 装饰的函数编译速度非常慢#

如果您的 jit 装饰的函数第一次调用时花费数十秒(甚至更长时间)才能运行,但在再次调用时运行速度很快,这表明 JAX 在跟踪或编译您的代码时花费了很长时间。

这通常是调用您的函数会在 JAX 的内部表示中生成大量代码的迹象,通常是因为它大量使用了 Python 的控制流,例如 for 循环。对于少量的循环迭代,Python 是可以接受的,但如果您需要 *大量* 循环迭代,则应重写代码以利用 JAX 的 结构化控制流原语(如 lax.scan())或避免用 jit 包装循环(您仍然可以在循环 *内部* 使用 JIT 编译的函数)。

如果您不确定这是否是问题所在,可以尝试对您的函数运行 jax.make_jaxpr()。如果输出长达数百或数千行,则可以预期编译速度较慢。

有时,由于代码使用了许多不同形状的数组,因此不清楚如何重写代码以避免 Python 循环。在这种情况下,推荐的解决方案是使用 jax.numpy.where() 等函数在具有固定形状的填充数组上执行计算。

如果您的函数由于其他原因编译缓慢,请在 GitHub 上创建一个 issue。

如何在使用 jit 时使用方法?#

大多数 jax.jit() 的示例都涉及装饰独立的 Python 函数,但装饰类中的方法会带来一些复杂性。例如,考虑以下简单的类,我们在方法上使用了标准的 jit() 注解

>>> import jax.numpy as jnp
>>> from jax import jit

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   @jit  # <---- How to do this correctly?
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y

然而,这种方法在尝试调用此方法时会导致错误

>>> c = CustomClass(2, True)
>>> c.calc(3)  
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
  File "<stdin>", line 1, in <module
TypeError: Argument '<CustomClass object at 0x7f7dd4125890>' of type <class 'CustomClass'> is not a valid JAX type.

问题在于函数的第一个参数是 self,其类型是 CustomClass,而 JAX 无法处理此类型。在这种情况下,我们有三种基本策略,下面将进行讨论。

策略 1:JIT 编译的辅助函数#

最直接的方法是创建一个位于类外部的辅助函数,该函数可以按常规方式进行 JIT 装饰。例如

>>> from functools import partial

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   def calc(self, y):
...     return _calc(self.mul, self.x, y)

>>> @partial(jit, static_argnums=0)
... def _calc(mul, x, y):
...   if mul:
...     return x * y
...   return y

结果将按预期工作

>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6

这种方法的优点是它简单、明确,并且避免了教 JAX 如何处理 CustomClass 类型对象的需要。但是,您可能希望将所有方法逻辑保留在同一位置。

策略 2:将 self 标记为静态#

另一个常见的模式是使用 static_argnumsself 参数标记为静态。但这必须谨慎进行,以避免意外结果。您可能会忍不住直接这样做

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   # WARNING: this example is broken, as we'll see below. Don't copy & paste!
...   @partial(jit, static_argnums=0)
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y

如果您调用该方法,它将不再引发错误

>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6

但是,有一个陷阱:如果您在第一次方法调用后修改了对象,后续的方法调用可能会返回不正确的结果

>>> c.mul = False
>>> print(c.calc(3))  # Should print 3
6

为什么会这样?当您将对象标记为静态时,它将在 JIT 的内部编译缓存中被用作字典键,这意味着假定其哈希值(即 hash(obj))、相等性(即 obj1 == obj2)和对象身份(即 obj1 is obj2)具有一致的行为。自定义对象的默认 __hash__ 是其对象 ID,因此 JAX 无法知道修改后的对象应触发重新编译。

您可以通过为您的对象定义适当的 __hash____eq__ 方法来部分解决这个问题;例如

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   @partial(jit, static_argnums=0)
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y
...
...   def __hash__(self):
...     return hash((self.x, self.mul))
...
...   def __eq__(self, other):
...     return (isinstance(other, CustomClass) and
...             (self.x, self.mul) == (other.x, other.mul))

(有关覆盖 __hash__ 时要求的信息,请参阅 object.__hash__() 文档)。

这应该可以与 JIT 和其他转换一起正常工作,*前提是您永远不要修改您的对象*。用作哈希键的对象的修改会导致各种微妙的问题,这就是为什么例如可变 Python 容器(例如 dictlist)不定义 __hash__,而它们的不可变对应项(例如 tuple)则定义。

如果您的类依赖于原地修改(例如在其方法中设置 self.attr = ...),那么您的对象实际上并不是“静态”的,将其标记为静态可能会导致问题。幸运的是,这种情况还有另一种选择。

策略 3:将 CustomClass 设为 PyTree#

正确 JIT 编译类方法的 सर्वात灵活的方法是注册该类型为自定义 PyTree 对象;请参阅 扩展 pytrees。这允许您精确指定类的哪些组件应被视为静态,哪些应被视为动态。下面是它的样子

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   @jit
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y
...
...   def _tree_flatten(self):
...     children = (self.x,)  # arrays / dynamic values
...     aux_data = {'mul': self.mul}  # static values
...     return (children, aux_data)
...
...   @classmethod
...   def _tree_unflatten(cls, aux_data, children):
...     return cls(*children, **aux_data)

>>> from jax import tree_util
>>> tree_util.register_pytree_node(CustomClass,
...                                CustomClass._tree_flatten,
...                                CustomClass._tree_unflatten)

这当然更复杂,但它解决了前面简单方法带来的所有问题

>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6

>>> c.mul = False  # mutation is detected
>>> print(c.calc(3))
3

>>> c = CustomClass(jnp.array(2), True)  # non-hashable x is supported
>>> print(c.calc(3))
6

只要您的 tree_flattentree_unflatten 函数能够正确处理类中的所有相关属性,您就可以直接将此类型的对象用作 JIT 编译函数的参数,而无需任何特殊注解。

控制数据和计算在设备上的放置#

首先,让我们了解 JAX 中数据和计算放置的原理。

在 JAX 中,计算遵循数据放置。JAX 数组有两个放置属性:1)数据所在的设备;2)数据是否被 **承诺** 给设备(数据有时被称为“粘性”到设备)。

默认情况下,JAX 数组以未承诺的方式放置在默认设备(默认情况下是第一个 GPU 或 TPU)jax.devices()[0] 上。如果不存在 GPU 或 TPU,jax.devices()[0] 就是 CPU。默认设备可以临时使用 jax.default_device() 上下文管理器覆盖,或者通过将环境变量 JAX_PLATFORMS 或 absl 标志 --jax_platforms 设置为“cpu”、“gpu”或“tpu”(JAX_PLATFORMS 也可以是平台列表,按优先级顺序确定可用平台)来为整个进程设置。

>>> from jax import numpy as jnp
>>> print(jnp.ones(3).devices())  
{CudaDevice(id=0)}

涉及未承诺数据的计算将在默认设备上执行,结果也将以未承诺的方式放置在默认设备上。

数据也可以使用带有 device 参数的 jax.device_put() 显式放置在设备上,在这种情况下,数据将 **承诺** 给该设备

>>> import jax
>>> from jax import device_put
>>> arr = device_put(1, jax.devices()[2])  
>>> print(arr.devices())  
{CudaDevice(id=2)}

涉及某些已承诺输入的计算将在已承诺设备上执行,结果也将以已承诺方式放置在同一设备上。在承诺给多个设备的参数上调用操作将引发错误。

您也可以在不带 device 参数的情况下使用 jax.device_put()。如果数据已经在某个设备上(已承诺或未承诺),则保持不变。如果数据不在任何设备上——即,它是常规的 Python 或 NumPy 值——它将以未承诺的方式放置在默认设备上。

Jitted 函数的行为与其他基本操作一样——它们将跟随数据,并在调用数据已承诺给多个设备时引发错误。

(在 2021 年 3 月的 PR #6002 之前,数组常量的创建存在一些延迟,因此 jax.device_put(jnp.zeros(...), jax.devices()[1]) 或类似的代码实际上会在 jax.devices()[1] 上创建零数组,而不是在默认设备上创建数组然后移动。但此优化已被移除,以简化实现。)

(截至 2020 年 4 月,jax.jit() 有一个 device 参数会影响设备放置。该参数是实验性的,可能会被移除或更改,不建议使用。)

对于一个完整的示例,我们建议阅读 multi_device_test.py 中的 test_computation_follows_data

基准测试 JAX 代码#

您刚刚将一个复杂的 NumPy/SciPy 函数移植到 JAX。这真的加快了速度吗?

在测量使用 JAX 的代码速度时,请牢记与 NumPy 的这些重要区别

  1. JAX 代码是即时 (JIT) 编译的。 JAX 中编写的大多数代码都可以以支持 JIT 编译的方式编写,这可以使其运行速度 *快得多*(参见 JIT 或不 JIT)。为了从 JAX 中获得最大性能,您应该在最外层的函数调用上应用 jax.jit()

    请注意,第一次运行 JAX 代码时,它会比较慢,因为它正在被编译。即使您自己的代码中没有使用 jit,这也是如此,因为 JAX 的内置函数也会被 JIT 编译。

  2. JAX 具有异步分派。 这意味着您需要调用 .block_until_ready() 来确保计算已实际发生(参见 异步分派)。

  3. JAX 默认只使用 32 位 dtype。 您可能希望在 NumPy 中显式使用 32 位 dtype,或者在 JAX 中启用 64 位 dtype(参见 双精度 (64 位))以进行公平比较。

  4. CPU 和加速器之间传输数据需要时间。 如果您只想测量评估函数需要多长时间,您可能希望首先将数据传输到您想要运行它的设备上(参见 控制数据和计算在设备上的放置)。

以下是如何将所有这些技巧结合起来进行 JAX 与 NumPy 比较的微基准测试示例,利用 IPython 方便的 %time 和 %timeit 魔术命令

import numpy as np
import jax

def f(x):  # function we're benchmarking (works in both NumPy & JAX)
  return x.T @ (x - x.mean(axis=0))

x_np = np.ones((1000, 1000), dtype=np.float32)  # same as JAX default dtype
%timeit f(x_np)  # measure NumPy runtime

# measure JAX device transfer time
%time x_jax = jax.device_put(x_np).block_until_ready()

f_jit = jax.jit(f)
%time f_jit(x_jax).block_until_ready()  # measure JAX compilation time
%timeit f_jit(x_jax).block_until_ready()  # measure JAX runtime

在 Colab 中使用 GPU 运行时,我们看到

  • NumPy 在 CPU 上每次评估需要 16.2 毫秒

  • JAX 复制 NumPy 数组到 GPU 需要 1.26 毫秒

  • JAX 编译函数需要 193 毫秒

  • JAX 在 GPU 上每次评估需要 485 微秒

在这种情况下,我们看到一旦数据传输完成且函数编译完成,JAX 在 GPU 上的重复评估速度大约快 30 倍。

这是公平的比较吗?也许吧。最终重要的性能是运行完整的应用程序,这不可避免地包含一定量的数据传输和编译。此外,我们小心地选择了足够大的数组(1000x1000)和足够密集的计算(@ 运算符执行矩阵-矩阵乘法),以摊销 JAX/加速器与 NumPy/CPU 之间增加的开销。例如,如果我们在此示例中使用 10x10 的输入,JAX/GPU 的运行速度将比 NumPy/CPU 慢 10 倍(100 微秒 vs 10 微秒)。

JAX 比 NumPy 快吗?#

用户经常试图用这些基准测试来回答的一个问题是 JAX 是否比 NumPy 快;由于这两个包的差异,没有简单的答案。

总的来说

  • NumPy 操作是立即执行、同步执行,并且仅在 CPU 上执行。

  • JAX 操作可能立即执行或在编译后执行(如果位于 jit() 内部);它们是异步分派的(参见 异步分派);并且它们可以在 CPU、GPU 或 TPU 上执行,每个设备都具有截然不同且不断发展的性能特征。

这些架构差异使得有意义的 NumPy 和 JAX 之间的直接基准测试比较变得困难。

此外,这些差异导致了两个包之间工程重点的不同:例如,NumPy 在降低单个数组操作的每次调用分派开销方面付出了巨大的努力,因为在 NumPy 的计算模型中,这种开销是无法避免的。另一方面,JAX 有多种方法可以避免分派开销(例如 JIT 编译、异步分派、批处理转换等),因此降低每次调用开销的优先级较低。

考虑到所有这些,总结如下:如果您在 CPU 上对单个数组操作进行微基准测试,您通常可以预期 NumPy 的性能优于 JAX,因为其每次操作的分派开销较低。如果您在 GPU 或 TPU 上运行代码,或者对 CPU 上的更复杂的 JIT 编译操作序列进行基准测试,您通常可以预期 JAX 的性能优于 NumPy。

缓冲区捐赠#

当 JAX 执行计算时,它使用设备上的缓冲区来处理所有输入和输出。如果您知道其中一个输入在计算后不再需要,并且其形状和元素类型与其中一个输出匹配,您可以指定您希望将相应的输入缓冲区捐赠用于保存输出。这将通过捐赠缓冲区的大小来减少执行所需的内存。

如果您有类似以下模式,您可以使用缓冲区捐赠

params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params, state)

您可以将此视为一种对不可变 JAX 数组进行内存高效的函数式更新的方法。在计算的边界内,XLA 可以为您进行此优化,但在 jit/pmap 边界,您需要向 XLA 保证您在调用捐赠函数后不会使用被捐赠的输入缓冲区。

您可以通过使用 donate_argnums 参数来实现这一点,该参数用于 jax.jit()jax.pjit()jax.pmap() 函数。此参数是位置参数列表的索引(从 0 开始)的序列

def add(x, y):
  return x + y

x = jax.device_put(np.ones((2, 3)))
y = jax.device_put(np.ones((2, 3)))
# Execute `add` with donation of the buffer for `y`. The result has
# the same shape and type as `y`, so it will share its buffer.
z = jax.jit(add, donate_argnums=(1,))(x, y)

请注意,目前当使用关键字参数调用函数时,这不起作用!以下代码将不会捐赠任何缓冲区

params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params=params, state=state)

如果一个被捐赠缓冲区的参数是 pytree,那么它的所有组件的缓冲区都会被捐赠

def add_ones(xs: List[Array]):
  return [x + 1 for x in xs]

xs = [jax.device_put(np.ones((2, 3))), jax.device_put(np.ones((3, 4)))]
# Execute `add_ones` with donation of all the buffers for `xs`.
# The outputs have the same shape and type as the elements of `xs`,
# so they will share those buffers.
z = jax.jit(add_ones, donate_argnums=0)(xs)

不允许捐赠在计算中后续使用的缓冲区,JAX 会因为 y 的缓冲区在被捐赠后失效而给出错误

# Donate the buffer for `y`
z = jax.jit(add, donate_argnums=(1,))(x, y)
w = y + 1  # Reuses `y` whose buffer was donated above
# >> RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer

如果您捐赠的缓冲区未被使用,您将收到一个警告,例如,因为捐赠的缓冲区多于可以用于输出的缓冲区

# Execute `add` with donation of the buffers for both `x` and `y`.
# One of those buffers will be used for the result, but the other will
# not be used.
z = jax.jit(add, donate_argnums=(0, 1))(x, y)
# >> UserWarning: Some donated buffers were not usable: f32[2,3]{1,0}

如果不存在形状与捐赠匹配的输出,捐赠也可能未被使用

y = jax.device_put(np.ones((1, 3)))  # `y` has different shape than the output
# Execute `add` with donation of the buffer for `y`.
z = jax.jit(add, donate_argnums=(1,))(x, y)
# >> UserWarning: Some donated buffers were not usable: f32[1,3]{1,0}

使用 where 时,梯度包含 NaN#

如果您使用 where 定义了一个函数以避免未定义值,如果您不小心,您可能会得到一个用于反向微分的 NaN

def my_log(x):
  return jnp.where(x > 0., jnp.log(x), 0.)

my_log(0.) ==> 0.  # Ok
jax.grad(my_log)(0.)  ==> NaN

简短的解释是,在 grad 计算过程中,与未定义的 jnp.log(x) 对应的伴随是 NaN,它会被累加到 jnp.where 的伴随中。编写此类函数的正确方法是确保在部分定义的函数 *内部* 有一个 jnp.where,以确保伴随始终是有限的

def safe_for_grad_log(x):
  return jnp.log(jnp.where(x > 0., x, 1.))

safe_for_grad_log(0.) ==> 0.  # Ok
jax.grad(safe_for_grad_log)(0.)  ==> 0.  # Ok

可能需要除了原始的 jnp.where 之外,还有一个内部的 jnp.where,例如

def my_log_or_y(x, y):
  """Return log(x) if x > 0 or y"""
  return jnp.where(x > 0., jnp.log(jnp.where(x > 0., x, 1.)), y)

附加阅读

为什么基于排序顺序的函数的梯度为零?#

如果您定义了一个函数,该函数使用依赖于输入相对顺序的操作(例如 maxgreaterargsort 等)来处理输入,那么您可能会惊讶地发现梯度处处为零。这是一个例子,我们定义 f(x) 为一个阶跃函数,当 x 为负时返回 0,当 x 为正时返回 1

import jax
import numpy as np
import jax.numpy as jnp

def f(x):
  return (x > 0).astype(float)

df = jax.vmap(jax.grad(f))

x = jnp.array([-1.0, -0.5, 0.0, 0.5, 1.0])

print(f"f(x)  = {f(x)}")
# f(x)  = [0. 0. 0. 1. 1.]

print(f"df(x) = {df(x)}")
# df(x) = [0. 0. 0. 0. 0.]

梯度处处为零的事实乍一看可能会令人困惑:毕竟,输出确实会响应输入而变化,那么梯度怎么会为零呢?然而,在这种情况下,零是正确的答案。

为什么会这样?请记住,微分测量的是给定 x 的无穷小变化时 f 的变化。对于 x=1.0f 返回 1.0。如果我们微调 x 使其稍微增大或减小,这不会改变输出,因此根据定义,grad(f)(1.0) 应该为零。同样的逻辑适用于所有大于零的 f 值:无穷小地扰动输入不会改变输出,因此梯度为零。同样,对于所有小于零的 x 值,输出为零。扰动 x 不会改变此输出,因此梯度为零。这使我们面临 x=0 的棘手情况。当然,如果您向上扰动 x,它会改变输出,但这存在问题:x 的无穷小变化会导致函数值发生有限变化,这意味着梯度未定义。幸运的是,在这种情况下,我们还有另一种测量梯度的方法:我们向下扰动函数,此时输出不变,因此梯度为零。JAX 和其他自动微分系统倾向于以这种方式处理不连续性:如果正梯度和负梯度不一致,但一个已定义而另一个未定义,则使用已定义的梯度。根据此函数定义,数学上和数值上,此函数的梯度处处为零。

问题源于我们的函数在 x = 0 处存在不连续性。我们这里的 f 本质上是一个 海维赛德阶跃函数,我们可以使用 Sigmoid 函数 作为平滑的替代。当 x 远离零时,sigmoid 近似于海维赛德函数,但在 x = 0 处的间断被一个平滑、可微分的曲线取代。使用 jax.nn.sigmoid() 的结果是,我们得到了一个类似的计算,并具有良好的微分梯度

def g(x):
  return jax.nn.sigmoid(x)

dg = jax.vmap(jax.grad(g))

x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])

with np.printoptions(suppress=True, precision=2):
  print(f"g(x)  = {g(x)}")
  # g(x)  = [0.   0.27 0.5  0.73 1.  ]

  print(f"dg(x) = {dg(x)}")
  # dg(x) = [0.   0.2  0.25 0.2  0.  ]

jax.nn 子模块还具有其他常用排序函数的光滑版本,例如 jax.nn.softmax() 可以替代 jax.numpy.argmax() 的用法,jax.nn.soft_sign() 可以替代 jax.numpy.sign() 的用法,jax.nn.softplus()jax.nn.squareplus() 可以替代 jax.nn.relu() 的用法,等等。

如何将 JAX Tracer 转换为 NumPy 数组?#

在运行时检查转换后的 JAX 函数时,您会发现数组值被 jax.core.Tracer 对象替换了

@jax.jit
def f(x):
  print(type(x))
  return x

f(jnp.arange(5))

这将打印以下内容

<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>

一个常见的问题是如何将这样的 tracer 转换回普通的 NumPy 数组。简而言之,**无法将 Tracer 转换为 NumPy 数组**,因为 tracer 是表示具有给定形状和 dtype 的 *所有可能* 值的抽象表示,而 numpy 数组是该抽象类的具体成员。有关 tracer 在 JAX 转换的上下文中如何工作的更多讨论,请参阅 JIT 机制

将 Tracer 转换回数组的问题通常出现在另一个目标相关的上下文中,例如访问计算中的中间值。例如

  • 如果您想在运行时打印一个跟踪值用于调试目的,可以考虑使用 jax.debug.print()

  • 如果您想在转换后的 JAX 函数中使用非 JAX 代码,可以考虑使用 jax.pure_callback(),其示例可在 纯回调示例 中找到。

  • 如果您希望在运行时输入或输出数组缓冲区(例如,从文件加载数据,或将数组内容记录到磁盘),可以考虑使用 jax.experimental.io_callback(),其示例可在 IO 回调示例 中找到。

有关运行时回调及其用法的更多信息,请参阅 JAX 中的外部回调

为什么某些 CUDA 库加载/初始化失败?#

在解析动态库时,JAX 使用标准的 动态链接器搜索模式。JAX 设置 RPATH 指向 pip 安装的 NVIDIA CUDA 包的 JAX 相关位置,如果已安装则优先使用。如果 ld.so 在其常规搜索路径中找不到您的 CUDA 运行时库,那么您必须在 LD_LIBRARY_PATH 中显式包含这些库的路径。确保您的 CUDA 文件可被发现的最简单方法是安装 nvidia-*-cu12 pip 包,这些包包含在标准的 jax[cuda_12] 安装选项中。

有时,即使您已确保运行时库可被发现,但在加载或初始化它们时仍可能存在一些问题。这类问题的常见原因是运行时 CUDA 库初始化内存不足。这有时会发生,因为 JAX 会为更快的执行预先分配过大的当前可用设备内存块,偶尔导致为运行时 CUDA 库初始化留下的内存不足。

当运行多个 JAX 实例,JAX 与执行自身预分配的 TensorFlow 协同运行时,或在 GPU 被其他进程大量利用的系统上运行 JAX 时,这种情况尤其可能发生。如有疑问,请尝试通过减少预分配来重新运行程序,方法是降低 XLA_PYTHON_CLIENT_MEM_FRACTION(默认值为 .75),或将 XLA_PYTHON_CLIENT_PREALLOCATE=false 设置为 true。更多详细信息,请参阅 JAX GPU 内存分配 页面。