常见问题解答 (FAQ)#
我们在此收集常见问题的答案。欢迎投稿!
jit
改变了我的函数的行为#
如果您有一个 Python 函数在使用 jax.jit()
后行为发生改变,可能是因为您的函数使用了全局状态或产生了副作用。在下面的代码中,impure_func
使用了全局变量 y
,并且由于 print
而产生副作用。
y = 0
# @jit # Different behavior with jit
def impure_func(x):
print("Inside:", y)
return x + y
for y in range(3):
print("Result:", impure_func(y))
没有 jit
时的输出是
Inside: 0
Result: 0
Inside: 1
Result: 2
Inside: 2
Result: 4
使用 jit
时的输出是
Inside: 0
Result: 0
Result: 1
Result: 2
对于 jax.jit()
,函数会使用 Python 解释器执行一次,此时会发生 Inside
的打印,并观察到 y
的第一个值。然后,函数会被编译并缓存,并使用不同的 x
值执行多次,但 y
的第一个值保持不变。
附加阅读
jit
改变了输出的精确数值#
有时用户会对使用 jit()
包装函数会改变函数的输出感到惊讶。例如
>>> from jax import jit
>>> import jax.numpy as jnp
>>> def f(x):
... return jnp.log(jnp.sqrt(x))
>>> x = jnp.pi
>>> print(f(x))
0.572365
>>> print(jit(f)(x))
0.5723649
输出的这种细微差异源于 XLA 编译器内部的优化:在编译过程中,XLA 有时会重新排列或消除某些操作,以使整个计算更有效率。
在这种情况下,XLA 利用对数函数的性质,将 log(sqrt(x))
替换为 0.5 * log(x)
,这是一个数学上相同的表达式,并且比原始表达式计算起来更有效率。输出的差异来自于浮点数算术只是对实数运算的近似,因此计算相同表达式的不同方式可能会产生细微不同的结果。
有时,XLA 的优化可能导致更显著的差异。考虑以下示例
>>> def f(x):
... return jnp.log(jnp.exp(x))
>>> x = 100.0
>>> print(f(x))
inf
>>> print(jit(f)(x))
100.0
在非 JIT 编译的逐个操作模式下,结果是 inf
,因为 jnp.exp(x)
溢出并返回 inf
。然而,在 JIT 下,XLA 识别出 log
是 exp
的逆运算,并从编译的函数中移除这些操作,只返回输入。在这种情况下,JIT 编译产生了对实际结果更准确的浮点近似。
不幸的是,XLA 代数简化规则的完整列表并未得到充分记录,但如果您熟悉 C++ 并对 XLA 编译器进行的优化类型感到好奇,可以在源代码中查看:algebraic_simplifier.cc。
带有 jit
装饰的函数编译速度非常慢#
如果您的 jit
装饰的函数第一次调用时花费数十秒(甚至更长时间)才能运行,但在再次调用时运行速度很快,这表明 JAX 在跟踪或编译您的代码时花费了很长时间。
这通常是调用您的函数会在 JAX 的内部表示中生成大量代码的迹象,通常是因为它大量使用了 Python 的控制流,例如 for
循环。对于少量的循环迭代,Python 是可以接受的,但如果您需要 *大量* 循环迭代,则应重写代码以利用 JAX 的 结构化控制流原语(如 lax.scan()
)或避免用 jit
包装循环(您仍然可以在循环 *内部* 使用 JIT 编译的函数)。
如果您不确定这是否是问题所在,可以尝试对您的函数运行 jax.make_jaxpr()
。如果输出长达数百或数千行,则可以预期编译速度较慢。
有时,由于代码使用了许多不同形状的数组,因此不清楚如何重写代码以避免 Python 循环。在这种情况下,推荐的解决方案是使用 jax.numpy.where()
等函数在具有固定形状的填充数组上执行计算。
如果您的函数由于其他原因编译缓慢,请在 GitHub 上创建一个 issue。
如何在使用 jit
时使用方法?#
大多数 jax.jit()
的示例都涉及装饰独立的 Python 函数,但装饰类中的方法会带来一些复杂性。例如,考虑以下简单的类,我们在方法上使用了标准的 jit()
注解
>>> import jax.numpy as jnp
>>> from jax import jit
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... @jit # <---- How to do this correctly?
... def calc(self, y):
... if self.mul:
... return self.x * y
... return y
然而,这种方法在尝试调用此方法时会导致错误
>>> c = CustomClass(2, True)
>>> c.calc(3)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File "<stdin>", line 1, in <module
TypeError: Argument '<CustomClass object at 0x7f7dd4125890>' of type <class 'CustomClass'> is not a valid JAX type.
问题在于函数的第一个参数是 self
,其类型是 CustomClass
,而 JAX 无法处理此类型。在这种情况下,我们有三种基本策略,下面将进行讨论。
策略 1:JIT 编译的辅助函数#
最直接的方法是创建一个位于类外部的辅助函数,该函数可以按常规方式进行 JIT 装饰。例如
>>> from functools import partial
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... def calc(self, y):
... return _calc(self.mul, self.x, y)
>>> @partial(jit, static_argnums=0)
... def _calc(mul, x, y):
... if mul:
... return x * y
... return y
结果将按预期工作
>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6
这种方法的优点是它简单、明确,并且避免了教 JAX 如何处理 CustomClass
类型对象的需要。但是,您可能希望将所有方法逻辑保留在同一位置。
策略 2:将 self
标记为静态#
另一个常见的模式是使用 static_argnums
将 self
参数标记为静态。但这必须谨慎进行,以避免意外结果。您可能会忍不住直接这样做
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... # WARNING: this example is broken, as we'll see below. Don't copy & paste!
... @partial(jit, static_argnums=0)
... def calc(self, y):
... if self.mul:
... return self.x * y
... return y
如果您调用该方法,它将不再引发错误
>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6
但是,有一个陷阱:如果您在第一次方法调用后修改了对象,后续的方法调用可能会返回不正确的结果
>>> c.mul = False
>>> print(c.calc(3)) # Should print 3
6
为什么会这样?当您将对象标记为静态时,它将在 JIT 的内部编译缓存中被用作字典键,这意味着假定其哈希值(即 hash(obj)
)、相等性(即 obj1 == obj2
)和对象身份(即 obj1 is obj2
)具有一致的行为。自定义对象的默认 __hash__
是其对象 ID,因此 JAX 无法知道修改后的对象应触发重新编译。
您可以通过为您的对象定义适当的 __hash__
和 __eq__
方法来部分解决这个问题;例如
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... @partial(jit, static_argnums=0)
... def calc(self, y):
... if self.mul:
... return self.x * y
... return y
...
... def __hash__(self):
... return hash((self.x, self.mul))
...
... def __eq__(self, other):
... return (isinstance(other, CustomClass) and
... (self.x, self.mul) == (other.x, other.mul))
(有关覆盖 __hash__
时要求的信息,请参阅 object.__hash__()
文档)。
这应该可以与 JIT 和其他转换一起正常工作,*前提是您永远不要修改您的对象*。用作哈希键的对象的修改会导致各种微妙的问题,这就是为什么例如可变 Python 容器(例如 dict
、list
)不定义 __hash__
,而它们的不可变对应项(例如 tuple
)则定义。
如果您的类依赖于原地修改(例如在其方法中设置 self.attr = ...
),那么您的对象实际上并不是“静态”的,将其标记为静态可能会导致问题。幸运的是,这种情况还有另一种选择。
策略 3:将 CustomClass
设为 PyTree#
正确 JIT 编译类方法的 सर्वात灵活的方法是注册该类型为自定义 PyTree 对象;请参阅 扩展 pytrees。这允许您精确指定类的哪些组件应被视为静态,哪些应被视为动态。下面是它的样子
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... @jit
... def calc(self, y):
... if self.mul:
... return self.x * y
... return y
...
... def _tree_flatten(self):
... children = (self.x,) # arrays / dynamic values
... aux_data = {'mul': self.mul} # static values
... return (children, aux_data)
...
... @classmethod
... def _tree_unflatten(cls, aux_data, children):
... return cls(*children, **aux_data)
>>> from jax import tree_util
>>> tree_util.register_pytree_node(CustomClass,
... CustomClass._tree_flatten,
... CustomClass._tree_unflatten)
这当然更复杂,但它解决了前面简单方法带来的所有问题
>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6
>>> c.mul = False # mutation is detected
>>> print(c.calc(3))
3
>>> c = CustomClass(jnp.array(2), True) # non-hashable x is supported
>>> print(c.calc(3))
6
只要您的 tree_flatten
和 tree_unflatten
函数能够正确处理类中的所有相关属性,您就可以直接将此类型的对象用作 JIT 编译函数的参数,而无需任何特殊注解。
控制数据和计算在设备上的放置#
首先,让我们了解 JAX 中数据和计算放置的原理。
在 JAX 中,计算遵循数据放置。JAX 数组有两个放置属性:1)数据所在的设备;2)数据是否被 **承诺** 给设备(数据有时被称为“粘性”到设备)。
默认情况下,JAX 数组以未承诺的方式放置在默认设备(默认情况下是第一个 GPU 或 TPU)jax.devices()[0]
上。如果不存在 GPU 或 TPU,jax.devices()[0]
就是 CPU。默认设备可以临时使用 jax.default_device()
上下文管理器覆盖,或者通过将环境变量 JAX_PLATFORMS
或 absl 标志 --jax_platforms
设置为“cpu”、“gpu”或“tpu”(JAX_PLATFORMS
也可以是平台列表,按优先级顺序确定可用平台)来为整个进程设置。
>>> from jax import numpy as jnp
>>> print(jnp.ones(3).devices())
{CudaDevice(id=0)}
涉及未承诺数据的计算将在默认设备上执行,结果也将以未承诺的方式放置在默认设备上。
数据也可以使用带有 device
参数的 jax.device_put()
显式放置在设备上,在这种情况下,数据将 **承诺** 给该设备
>>> import jax
>>> from jax import device_put
>>> arr = device_put(1, jax.devices()[2])
>>> print(arr.devices())
{CudaDevice(id=2)}
涉及某些已承诺输入的计算将在已承诺设备上执行,结果也将以已承诺方式放置在同一设备上。在承诺给多个设备的参数上调用操作将引发错误。
您也可以在不带 device
参数的情况下使用 jax.device_put()
。如果数据已经在某个设备上(已承诺或未承诺),则保持不变。如果数据不在任何设备上——即,它是常规的 Python 或 NumPy 值——它将以未承诺的方式放置在默认设备上。
Jitted 函数的行为与其他基本操作一样——它们将跟随数据,并在调用数据已承诺给多个设备时引发错误。
(在 2021 年 3 月的 PR #6002 之前,数组常量的创建存在一些延迟,因此 jax.device_put(jnp.zeros(...), jax.devices()[1])
或类似的代码实际上会在 jax.devices()[1]
上创建零数组,而不是在默认设备上创建数组然后移动。但此优化已被移除,以简化实现。)
(截至 2020 年 4 月,jax.jit()
有一个 device 参数会影响设备放置。该参数是实验性的,可能会被移除或更改,不建议使用。)
对于一个完整的示例,我们建议阅读 multi_device_test.py 中的 test_computation_follows_data
。
基准测试 JAX 代码#
您刚刚将一个复杂的 NumPy/SciPy 函数移植到 JAX。这真的加快了速度吗?
在测量使用 JAX 的代码速度时,请牢记与 NumPy 的这些重要区别
JAX 代码是即时 (JIT) 编译的。 JAX 中编写的大多数代码都可以以支持 JIT 编译的方式编写,这可以使其运行速度 *快得多*(参见 JIT 或不 JIT)。为了从 JAX 中获得最大性能,您应该在最外层的函数调用上应用
jax.jit()
。请注意,第一次运行 JAX 代码时,它会比较慢,因为它正在被编译。即使您自己的代码中没有使用
jit
,这也是如此,因为 JAX 的内置函数也会被 JIT 编译。JAX 具有异步分派。 这意味着您需要调用
.block_until_ready()
来确保计算已实际发生(参见 异步分派)。JAX 默认只使用 32 位 dtype。 您可能希望在 NumPy 中显式使用 32 位 dtype,或者在 JAX 中启用 64 位 dtype(参见 双精度 (64 位))以进行公平比较。
CPU 和加速器之间传输数据需要时间。 如果您只想测量评估函数需要多长时间,您可能希望首先将数据传输到您想要运行它的设备上(参见 控制数据和计算在设备上的放置)。
以下是如何将所有这些技巧结合起来进行 JAX 与 NumPy 比较的微基准测试示例,利用 IPython 方便的 %time 和 %timeit 魔术命令
import numpy as np
import jax
def f(x): # function we're benchmarking (works in both NumPy & JAX)
return x.T @ (x - x.mean(axis=0))
x_np = np.ones((1000, 1000), dtype=np.float32) # same as JAX default dtype
%timeit f(x_np) # measure NumPy runtime
# measure JAX device transfer time
%time x_jax = jax.device_put(x_np).block_until_ready()
f_jit = jax.jit(f)
%time f_jit(x_jax).block_until_ready() # measure JAX compilation time
%timeit f_jit(x_jax).block_until_ready() # measure JAX runtime
在 Colab 中使用 GPU 运行时,我们看到
NumPy 在 CPU 上每次评估需要 16.2 毫秒
JAX 复制 NumPy 数组到 GPU 需要 1.26 毫秒
JAX 编译函数需要 193 毫秒
JAX 在 GPU 上每次评估需要 485 微秒
在这种情况下,我们看到一旦数据传输完成且函数编译完成,JAX 在 GPU 上的重复评估速度大约快 30 倍。
这是公平的比较吗?也许吧。最终重要的性能是运行完整的应用程序,这不可避免地包含一定量的数据传输和编译。此外,我们小心地选择了足够大的数组(1000x1000)和足够密集的计算(@ 运算符执行矩阵-矩阵乘法),以摊销 JAX/加速器与 NumPy/CPU 之间增加的开销。例如,如果我们在此示例中使用 10x10 的输入,JAX/GPU 的运行速度将比 NumPy/CPU 慢 10 倍(100 微秒 vs 10 微秒)。
JAX 比 NumPy 快吗?#
用户经常试图用这些基准测试来回答的一个问题是 JAX 是否比 NumPy 快;由于这两个包的差异,没有简单的答案。
总的来说
NumPy 操作是立即执行、同步执行,并且仅在 CPU 上执行。
JAX 操作可能立即执行或在编译后执行(如果位于
jit()
内部);它们是异步分派的(参见 异步分派);并且它们可以在 CPU、GPU 或 TPU 上执行,每个设备都具有截然不同且不断发展的性能特征。
这些架构差异使得有意义的 NumPy 和 JAX 之间的直接基准测试比较变得困难。
此外,这些差异导致了两个包之间工程重点的不同:例如,NumPy 在降低单个数组操作的每次调用分派开销方面付出了巨大的努力,因为在 NumPy 的计算模型中,这种开销是无法避免的。另一方面,JAX 有多种方法可以避免分派开销(例如 JIT 编译、异步分派、批处理转换等),因此降低每次调用开销的优先级较低。
考虑到所有这些,总结如下:如果您在 CPU 上对单个数组操作进行微基准测试,您通常可以预期 NumPy 的性能优于 JAX,因为其每次操作的分派开销较低。如果您在 GPU 或 TPU 上运行代码,或者对 CPU 上的更复杂的 JIT 编译操作序列进行基准测试,您通常可以预期 JAX 的性能优于 NumPy。
缓冲区捐赠#
当 JAX 执行计算时,它使用设备上的缓冲区来处理所有输入和输出。如果您知道其中一个输入在计算后不再需要,并且其形状和元素类型与其中一个输出匹配,您可以指定您希望将相应的输入缓冲区捐赠用于保存输出。这将通过捐赠缓冲区的大小来减少执行所需的内存。
如果您有类似以下模式,您可以使用缓冲区捐赠
params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params, state)
您可以将此视为一种对不可变 JAX 数组进行内存高效的函数式更新的方法。在计算的边界内,XLA 可以为您进行此优化,但在 jit/pmap 边界,您需要向 XLA 保证您在调用捐赠函数后不会使用被捐赠的输入缓冲区。
您可以通过使用 donate_argnums 参数来实现这一点,该参数用于 jax.jit()
、jax.pjit()
和 jax.pmap()
函数。此参数是位置参数列表的索引(从 0 开始)的序列
def add(x, y):
return x + y
x = jax.device_put(np.ones((2, 3)))
y = jax.device_put(np.ones((2, 3)))
# Execute `add` with donation of the buffer for `y`. The result has
# the same shape and type as `y`, so it will share its buffer.
z = jax.jit(add, donate_argnums=(1,))(x, y)
请注意,目前当使用关键字参数调用函数时,这不起作用!以下代码将不会捐赠任何缓冲区
params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params=params, state=state)
如果一个被捐赠缓冲区的参数是 pytree,那么它的所有组件的缓冲区都会被捐赠
def add_ones(xs: List[Array]):
return [x + 1 for x in xs]
xs = [jax.device_put(np.ones((2, 3))), jax.device_put(np.ones((3, 4)))]
# Execute `add_ones` with donation of all the buffers for `xs`.
# The outputs have the same shape and type as the elements of `xs`,
# so they will share those buffers.
z = jax.jit(add_ones, donate_argnums=0)(xs)
不允许捐赠在计算中后续使用的缓冲区,JAX 会因为 y 的缓冲区在被捐赠后失效而给出错误
# Donate the buffer for `y`
z = jax.jit(add, donate_argnums=(1,))(x, y)
w = y + 1 # Reuses `y` whose buffer was donated above
# >> RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer
如果您捐赠的缓冲区未被使用,您将收到一个警告,例如,因为捐赠的缓冲区多于可以用于输出的缓冲区
# Execute `add` with donation of the buffers for both `x` and `y`.
# One of those buffers will be used for the result, but the other will
# not be used.
z = jax.jit(add, donate_argnums=(0, 1))(x, y)
# >> UserWarning: Some donated buffers were not usable: f32[2,3]{1,0}
如果不存在形状与捐赠匹配的输出,捐赠也可能未被使用
y = jax.device_put(np.ones((1, 3))) # `y` has different shape than the output
# Execute `add` with donation of the buffer for `y`.
z = jax.jit(add, donate_argnums=(1,))(x, y)
# >> UserWarning: Some donated buffers were not usable: f32[1,3]{1,0}
使用 where
时,梯度包含 NaN#
如果您使用 where
定义了一个函数以避免未定义值,如果您不小心,您可能会得到一个用于反向微分的 NaN
def my_log(x):
return jnp.where(x > 0., jnp.log(x), 0.)
my_log(0.) ==> 0. # Ok
jax.grad(my_log)(0.) ==> NaN
简短的解释是,在 grad
计算过程中,与未定义的 jnp.log(x)
对应的伴随是 NaN
,它会被累加到 jnp.where
的伴随中。编写此类函数的正确方法是确保在部分定义的函数 *内部* 有一个 jnp.where
,以确保伴随始终是有限的
def safe_for_grad_log(x):
return jnp.log(jnp.where(x > 0., x, 1.))
safe_for_grad_log(0.) ==> 0. # Ok
jax.grad(safe_for_grad_log)(0.) ==> 0. # Ok
可能需要除了原始的 jnp.where
之外,还有一个内部的 jnp.where
,例如
def my_log_or_y(x, y):
"""Return log(x) if x > 0 or y"""
return jnp.where(x > 0., jnp.log(jnp.where(x > 0., x, 1.)), y)
附加阅读
为什么基于排序顺序的函数的梯度为零?#
如果您定义了一个函数,该函数使用依赖于输入相对顺序的操作(例如 max
、greater
、argsort
等)来处理输入,那么您可能会惊讶地发现梯度处处为零。这是一个例子,我们定义 f(x) 为一个阶跃函数,当 x 为负时返回 0,当 x 为正时返回 1
import jax
import numpy as np
import jax.numpy as jnp
def f(x):
return (x > 0).astype(float)
df = jax.vmap(jax.grad(f))
x = jnp.array([-1.0, -0.5, 0.0, 0.5, 1.0])
print(f"f(x) = {f(x)}")
# f(x) = [0. 0. 0. 1. 1.]
print(f"df(x) = {df(x)}")
# df(x) = [0. 0. 0. 0. 0.]
梯度处处为零的事实乍一看可能会令人困惑:毕竟,输出确实会响应输入而变化,那么梯度怎么会为零呢?然而,在这种情况下,零是正确的答案。
为什么会这样?请记住,微分测量的是给定 x
的无穷小变化时 f
的变化。对于 x=1.0
,f
返回 1.0
。如果我们微调 x
使其稍微增大或减小,这不会改变输出,因此根据定义,grad(f)(1.0)
应该为零。同样的逻辑适用于所有大于零的 f
值:无穷小地扰动输入不会改变输出,因此梯度为零。同样,对于所有小于零的 x
值,输出为零。扰动 x
不会改变此输出,因此梯度为零。这使我们面临 x=0
的棘手情况。当然,如果您向上扰动 x
,它会改变输出,但这存在问题:x
的无穷小变化会导致函数值发生有限变化,这意味着梯度未定义。幸运的是,在这种情况下,我们还有另一种测量梯度的方法:我们向下扰动函数,此时输出不变,因此梯度为零。JAX 和其他自动微分系统倾向于以这种方式处理不连续性:如果正梯度和负梯度不一致,但一个已定义而另一个未定义,则使用已定义的梯度。根据此函数定义,数学上和数值上,此函数的梯度处处为零。
问题源于我们的函数在 x = 0
处存在不连续性。我们这里的 f
本质上是一个 海维赛德阶跃函数,我们可以使用 Sigmoid 函数 作为平滑的替代。当 x 远离零时,sigmoid 近似于海维赛德函数,但在 x = 0
处的间断被一个平滑、可微分的曲线取代。使用 jax.nn.sigmoid()
的结果是,我们得到了一个类似的计算,并具有良好的微分梯度
def g(x):
return jax.nn.sigmoid(x)
dg = jax.vmap(jax.grad(g))
x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])
with np.printoptions(suppress=True, precision=2):
print(f"g(x) = {g(x)}")
# g(x) = [0. 0.27 0.5 0.73 1. ]
print(f"dg(x) = {dg(x)}")
# dg(x) = [0. 0.2 0.25 0.2 0. ]
jax.nn
子模块还具有其他常用排序函数的光滑版本,例如 jax.nn.softmax()
可以替代 jax.numpy.argmax()
的用法,jax.nn.soft_sign()
可以替代 jax.numpy.sign()
的用法,jax.nn.softplus()
或 jax.nn.squareplus()
可以替代 jax.nn.relu()
的用法,等等。
如何将 JAX Tracer 转换为 NumPy 数组?#
在运行时检查转换后的 JAX 函数时,您会发现数组值被 jax.core.Tracer 对象替换了
@jax.jit
def f(x):
print(type(x))
return x
f(jnp.arange(5))
这将打印以下内容
<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>
一个常见的问题是如何将这样的 tracer 转换回普通的 NumPy 数组。简而言之,**无法将 Tracer 转换为 NumPy 数组**,因为 tracer 是表示具有给定形状和 dtype 的 *所有可能* 值的抽象表示,而 numpy 数组是该抽象类的具体成员。有关 tracer 在 JAX 转换的上下文中如何工作的更多讨论,请参阅 JIT 机制。
将 Tracer 转换回数组的问题通常出现在另一个目标相关的上下文中,例如访问计算中的中间值。例如
如果您想在运行时打印一个跟踪值用于调试目的,可以考虑使用
jax.debug.print()
。如果您想在转换后的 JAX 函数中使用非 JAX 代码,可以考虑使用
jax.pure_callback()
,其示例可在 纯回调示例 中找到。如果您希望在运行时输入或输出数组缓冲区(例如,从文件加载数据,或将数组内容记录到磁盘),可以考虑使用
jax.experimental.io_callback()
,其示例可在 IO 回调示例 中找到。
有关运行时回调及其用法的更多信息,请参阅 JAX 中的外部回调。
为什么某些 CUDA 库加载/初始化失败?#
在解析动态库时,JAX 使用标准的 动态链接器搜索模式。JAX 设置 RPATH
指向 pip 安装的 NVIDIA CUDA 包的 JAX 相关位置,如果已安装则优先使用。如果 ld.so
在其常规搜索路径中找不到您的 CUDA 运行时库,那么您必须在 LD_LIBRARY_PATH
中显式包含这些库的路径。确保您的 CUDA 文件可被发现的最简单方法是安装 nvidia-*-cu12
pip 包,这些包包含在标准的 jax[cuda_12]
安装选项中。
有时,即使您已确保运行时库可被发现,但在加载或初始化它们时仍可能存在一些问题。这类问题的常见原因是运行时 CUDA 库初始化内存不足。这有时会发生,因为 JAX 会为更快的执行预先分配过大的当前可用设备内存块,偶尔导致为运行时 CUDA 库初始化留下的内存不足。
当运行多个 JAX 实例,JAX 与执行自身预分配的 TensorFlow 协同运行时,或在 GPU 被其他进程大量利用的系统上运行 JAX 时,这种情况尤其可能发生。如有疑问,请尝试通过减少预分配来重新运行程序,方法是降低 XLA_PYTHON_CLIENT_MEM_FRACTION
(默认值为 .75
),或将 XLA_PYTHON_CLIENT_PREALLOCATE=false
设置为 true。更多详细信息,请参阅 JAX GPU 内存分配 页面。