常用问题解答 (FAQ)#

我们在此处收集常用问题的答案。欢迎贡献!

jit 改变了我的函数的行为#

如果您有一个 Python 函数在使用 jax.jit() 后更改了行为,可能是您的函数使用了全局状态,或者有副作用。在以下代码中,impure_func 使用了全局 y 并且由于 print 而具有副作用

y = 0

# @jit   # Different behavior with jit
def impure_func(x):
  print("Inside:", y)
  return x + y

for y in range(3):
  print("Result:", impure_func(y))

不使用 jit 的输出是

Inside: 0
Result: 0
Inside: 1
Result: 2
Inside: 2
Result: 4

使用 jit 的输出是

Inside: 0
Result: 0
Result: 1
Result: 2

对于 jax.jit(),该函数使用 Python 解释器执行一次,此时会发生 Inside 打印,并观察到 y 的第一个值。然后,该函数被编译并缓存,并使用不同的 x 值执行多次,但使用相同的 y 的第一个值。

延伸阅读

jit 改变了输出的精确数值#

有时用户会对使用 jit() 包装函数会改变函数输出这一事实感到惊讶。例如

>>> from jax import jit
>>> import jax.numpy as jnp
>>> def f(x):
...   return jnp.log(jnp.sqrt(x))
>>> x = jnp.pi
>>> print(f(x))
0.572365
>>> print(jit(f)(x))
0.5723649

输出的这种细微差异来自于 XLA 编译器内部的优化:在编译期间,XLA 有时会重新排列或省略某些操作,以使整体计算更有效率。

在这种情况下,XLA 利用对数的性质将 log(sqrt(x)) 替换为 0.5 * log(x),这是一个数学上相同的表达式,可以比原始表达式更有效地计算。输出的差异来自于浮点运算只是对实数数学的近似,因此计算相同表达式的不同方法可能会有细微不同的结果。

其他时候,XLA 的优化可能会导致更显著的差异。考虑以下示例

>>> def f(x):
...   return jnp.log(jnp.exp(x))
>>> x = 100.0
>>> print(f(x))
inf
>>> print(jit(f)(x))
100.0

在非 JIT 编译的逐操作模式下,结果是 inf,因为 jnp.exp(x) 溢出并返回 inf。但是,在 JIT 下,XLA 识别到 logexp 的逆运算,并从编译后的函数中移除这些操作,只是返回输入。在这种情况下,JIT 编译产生更精确的实数结果的浮点近似值。

不幸的是,XLA 的代数简化的完整列表没有得到很好的文档记录,但是如果您熟悉 C++ 并且对 XLA 编译器进行的优化类型感到好奇,您可以在源代码中查看它们:algebraic_simplifier.cc

jit 装饰的函数编译速度非常慢#

如果您的 jit 装饰的函数在您第一次调用时需要数十秒(或更长时间)才能运行,但在再次调用时执行速度很快,则 JAX 需要很长时间来跟踪或编译您的代码。

这通常表明调用您的函数会在 JAX 的内部表示中生成大量代码,通常是因为它大量使用了 Python 控制流,例如 for 循环。对于少量的循环迭代,Python 还可以,但是如果您需要大量循环迭代,则应该重写代码以使用 JAX 的结构化控制流原语(例如 lax.scan()),或者避免用 jit 包装循环(您仍然可以在循环内部使用 jit 装饰的函数)。

如果您不确定这是否是问题所在,您可以尝试在您的函数上运行 jax.make_jaxpr()。如果输出有数百或数千行长,您可以预期编译速度会很慢。

有时,不清楚如何重写代码以避免 Python 循环,因为您的代码使用了许多具有不同形状的数组。在这种情况下,推荐的解决方案是使用诸如 jax.numpy.where() 之类的函数,在具有固定形状的填充数组上进行计算。

如果您的函数由于其他原因编译速度很慢,请在 GitHub 上打开一个 issue。

如何将 jit 与方法一起使用?#

大多数 jax.jit() 的示例都与装饰独立的 Python 函数有关,但是装饰类中的方法会引入一些复杂性。例如,考虑以下简单类,我们在一个方法上使用了标准的 jit() 注解

>>> import jax.numpy as jnp
>>> from jax import jit

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   @jit  # <---- How to do this correctly?
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y

但是,当您尝试调用此方法时,此方法将导致错误

>>> c = CustomClass(2, True)
>>> c.calc(3)  
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
  File "<stdin>", line 1, in <module
TypeError: Argument '<CustomClass object at 0x7f7dd4125890>' of type <class 'CustomClass'> is not a valid JAX type.

问题在于该函数的第一个参数是 self,其类型为 CustomClass,而 JAX 不知道如何处理此类型。在这种情况下,我们可以使用三种基本策略,我们将在下面讨论它们。

策略 1:JIT 编译的辅助函数#

最直接的方法是创建一个类外部的辅助函数,该函数可以像往常一样进行 JIT 装饰。例如

>>> from functools import partial

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   def calc(self, y):
...     return _calc(self.mul, self.x, y)

>>> @partial(jit, static_argnums=0)
... def _calc(mul, x, y):
...   if mul:
...     return x * y
...   return y

结果将按预期工作

>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6

这种方法的好处是它简单、明确,并且避免了教 JAX 如何处理 CustomClass 类型的对象的需求。但是,您可能希望将所有方法逻辑都放在同一个位置。

策略 2:将 self 标记为静态#

另一种常见的模式是使用 static_argnumsself 参数标记为静态。但是,必须小心地执行此操作以避免意外结果。您可能会想简单地这样做

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   # WARNING: this example is broken, as we'll see below. Don't copy & paste!
...   @partial(jit, static_argnums=0)
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y

如果您调用该方法,它将不再引发错误

>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6

但是,这里有一个陷阱:如果您在第一次方法调用后改变对象,则后续方法调用可能会返回不正确的结果

>>> c.mul = False
>>> print(c.calc(3))  # Should print 3
6

这是为什么呢?当您将对象标记为静态时,它实际上将用作 JIT 内部编译缓存中的字典键,这意味着它的哈希值(即 hash(obj))相等性(即 obj1 == obj2)和对象标识(即 obj1 is obj2)将被假定为具有一致的行为。自定义对象的默认 __hash__ 是其对象 ID,因此 JAX 无法知道已改变的对象是否应触发重新编译。

您可以通过为对象定义适当的 __hash____eq__ 方法来部分解决此问题;例如

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   @partial(jit, static_argnums=0)
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y
...
...   def __hash__(self):
...     return hash((self.x, self.mul))
...
...   def __eq__(self, other):
...     return (isinstance(other, CustomClass) and
...             (self.x, self.mul) == (other.x, other.mul))

(有关覆盖 __hash__ 时的要求的更多讨论,请参见 object.__hash__() 文档)。

只要您永远不改变您的对象,这应该与 JIT 和其他转换一起正常工作。用作哈希键的对象的突变会导致一些微妙的问题,这就是为什么例如可变 Python 容器(例如 dictlist)不定义 __hash__,而它们的不可变对应物(例如 tuple)则定义了 __hash__

如果您的类依赖于就地突变(例如在其方法中设置 self.attr = ...),那么您的对象实际上不是“静态”的,将其标记为静态可能会导致问题。幸运的是,对于这种情况还有另一种选择。

策略 3:使 CustomClass 成为 PyTree#

正确 JIT 编译类方法的最灵活方法是将该类型注册为自定义 PyTree 对象;请参阅 扩展 pytrees。这使您可以精确指定类的哪些组件应被视为静态,哪些应被视为动态。这是它的样子

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   @jit
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y
...
...   def _tree_flatten(self):
...     children = (self.x,)  # arrays / dynamic values
...     aux_data = {'mul': self.mul}  # static values
...     return (children, aux_data)
...
...   @classmethod
...   def _tree_unflatten(cls, aux_data, children):
...     return cls(*children, **aux_data)

>>> from jax import tree_util
>>> tree_util.register_pytree_node(CustomClass,
...                                CustomClass._tree_flatten,
...                                CustomClass._tree_unflatten)

这当然更复杂,但是它解决了与上面使用的更简单方法相关的所有问题

>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6

>>> c.mul = False  # mutation is detected
>>> print(c.calc(3))
3

>>> c = CustomClass(jnp.array(2), True)  # non-hashable x is supported
>>> print(c.calc(3))
6

只要您的 tree_flattentree_unflatten 函数正确处理类中的所有相关属性,您就应该能够直接将此类型的对象用作 JIT 编译函数的参数,而无需任何特殊注解。

控制设备上的数据和计算放置#

让我们首先看一下 JAX 中数据和计算放置的原则。

在 JAX 中,计算跟随数据放置。JAX 数组具有两个放置属性:1) 数据驻留的设备;2) 是否提交到设备(数据有时被称为粘性到设备)。

默认情况下,JAX 数组以未提交状态放置在默认设备 (jax.devices()[0]) 上,默认情况下这是第一个 GPU 或 TPU。如果不存在 GPU 或 TPU,则 jax.devices()[0] 是 CPU。可以使用 jax.default_device() 上下文管理器临时覆盖默认设备,或者通过将环境变量 JAX_PLATFORMS 或 absl 标志 --jax_platforms 设置为 “cpu”、“gpu” 或 “tpu” 来为整个进程设置默认设备(JAX_PLATFORMS 也可以是平台列表,它决定了哪些平台按优先级顺序可用)。

>>> from jax import numpy as jnp
>>> print(jnp.ones(3).devices())  
{CudaDevice(id=0)}

涉及未提交数据的计算在默认设备上执行,并且结果在默认设备上未提交。

也可以使用带有 device 参数的 jax.device_put() 将数据显式放置在设备上,在这种情况下,数据将提交到该设备

>>> import jax
>>> from jax import device_put
>>> arr = device_put(1, jax.devices()[2])  
>>> print(arr.devices())  
{CudaDevice(id=2)}

涉及某些已提交输入的计算将在已提交设备上发生,并且结果将提交到同一设备。在提交到多个设备的参数上调用操作将引发错误。

您还可以使用不带 device 参数的 jax.device_put()。如果数据已在设备上(已提交或未提交),则保持原样。如果数据不在任何设备上——也就是说,它是常规 Python 或 NumPy 值——则将其以未提交状态放置在默认设备上。

Jitted 函数的行为类似于任何其他原始操作——它们将跟随数据,并且如果在提交到多个设备的数据上调用,将显示错误。

(在 2021 年 3 月的 PR #6002 之前,数组常量的创建存在一些延迟,因此 jax.device_put(jnp.zeros(...), jax.devices()[1]) 或类似操作实际上会在 jax.devices()[1] 上创建零数组,而不是在默认设备上创建数组然后移动它。但是,此优化已被删除,以便简化实现。)

(截至 2020 年 4 月,jax.jit() 具有影响设备放置的 device 参数。该参数是实验性的,可能会被删除或更改,不建议使用。)

对于一个完整的示例,我们建议通读 test_computation_follows_data in multi_device_test.py

基准测试 JAX 代码#

您刚刚将一个棘手的函数从 NumPy/SciPy 移植到 JAX。这真的加快了速度吗?

在测量使用 JAX 的代码的速度时,请记住与 NumPy 的这些重要区别

  1. JAX 代码是即时 (JIT) 编译的。 大多数用 JAX 编写的代码都可以以支持 JIT 编译的方式编写,这可以使其运行快得多(请参阅 To JIT or not to JIT)。要从 JAX 获得最佳性能,您应该在最外层的函数调用上应用 jax.jit()

    请记住,首次运行 JAX 代码时,速度会较慢,因为它正在被编译。即使您在自己的代码中没有使用 jit,也是如此,因为 JAX 的内置函数也经过 JIT 编译。

  2. JAX 具有异步调度。 这意味着您需要调用 .block_until_ready() 以确保计算实际发生(请参阅 异步调度)。

  3. JAX 默认仅使用 32 位 dtype。 您可能需要在 NumPy 中显式使用 32 位 dtype,或者在 JAX 中启用 64 位 dtype(有关公平比较,请参阅 双精度(64 位))。

  4. 在 CPU 和加速器之间传输数据需要时间。 如果您只想测量评估函数所需的时间,您可能需要首先将数据传输到您要运行它的设备上(请参阅 控制设备上的数据和计算放置)。

这是一个如何将所有这些技巧组合成一个微基准测试的示例,用于比较 JAX 与 NumPy,使用了 IPython 便利的 %time 和 %timeit magic 命令

import numpy as np
import jax.numpy as jnp
import jax

def f(x):  # function we're benchmarking (works in both NumPy & JAX)
  return x.T @ (x - x.mean(axis=0))

x_np = np.ones((1000, 1000), dtype=np.float32)  # same as JAX default dtype
%timeit f(x_np)  # measure NumPy runtime

%time x_jax = jax.device_put(x_np)  # measure JAX device transfer time
f_jit = jax.jit(f)
%time f_jit(x_jax).block_until_ready()  # measure JAX compilation time
%timeit f_jit(x_jax).block_until_ready()  # measure JAX runtime

当在 Colab 中使用 GPU 运行时,我们看到

  • NumPy 在 CPU 上每次评估耗时 16.2 毫秒

  • JAX 复制 NumPy 数组到 GPU 耗时 1.26 毫秒

  • JAX 编译函数耗时 193 毫秒

  • JAX 在 GPU 上每次评估耗时 485 微秒

在这种情况下,我们看到一旦数据传输完成且函数编译完成,GPU 上的 JAX 对于重复评估来说速度大约快 30 倍。

这是一个公平的比较吗?也许是。最终重要的性能是运行完整的应用程序,这不可避免地包括一定数量的数据传输和编译。此外,我们还小心地选择了足够大的数组(1000x1000)和足够密集的计算(@ 运算符执行矩阵-矩阵乘法),以摊销 JAX/加速器与 NumPy/CPU 相比增加的开销。例如,如果我们将此示例切换为使用 10x10 输入,则 JAX/GPU 的运行速度比 NumPy/CPU 慢 10 倍(100 微秒 vs 10 微秒)。

JAX 比 NumPy 快吗?#

用户经常试图通过此类基准测试来回答的一个问题是 JAX 是否比 NumPy 快;由于这两个软件包的差异,没有简单的答案。

概括来说

  • NumPy 操作是急切的、同步的,并且仅在 CPU 上执行。

  • JAX 操作可以急切执行,也可以在编译后执行(如果在 jit() 内部);它们是异步调度的(请参阅 异步调度);并且它们可以在 CPU、GPU 或 TPU 上执行,每个都具有截然不同且不断发展的性能特征。

这些架构差异使得 NumPy 和 JAX 之间进行有意义的直接基准比较变得困难。

此外,这些差异导致了软件包之间不同的工程重点:例如,NumPy 在降低单个数组操作的每次调用调度开销方面投入了大量精力,因为在 NumPy 的计算模型中,这种开销是无法避免的。另一方面,JAX 有几种方法可以避免调度开销(例如,JIT 编译、异步调度、批处理转换等),因此降低每次调用开销的重要性较低。

考虑到所有这些,总结一下:如果您正在 CPU 上对单个数组操作进行微基准测试,您通常可以预期 NumPy 的性能优于 JAX,因为它每次操作的调度开销较低。如果您在 GPU 或 TPU 上运行代码,或者正在对 CPU 上更复杂的 JIT 编译操作序列进行基准测试,您通常可以预期 JAX 的性能优于 NumPy。

不同类型的 JAX 值#

在转换函数的过程中,JAX 会将一些函数参数替换为特殊的 tracer 值。

如果您使用 print 语句,您可以看到这一点

def func(x):
  print(x)
  return jnp.cos(x)

res = jax.jit(func)(0.)

上面的代码确实返回了正确的值 1.,但它也为 x 的值打印了 Traced<ShapedArray(float32[])>。通常,JAX 以透明的方式在内部处理这些 tracer 值,例如,在用于实现 jax.numpy 函数的数值 JAX 原语中。这就是为什么 jnp.cos 在上面的示例中有效。

更准确地说,对于 JAX 转换函数的参数,会引入 tracer 值,但由特殊参数标识的参数除外,例如 jax.jit()static_argnumsjax.pmap()static_broadcasted_argnums。通常,至少涉及一个 tracer 值的计算将产生一个 tracer 值。除了 tracer 值之外,还有 常规 Python 值:在 JAX 转换之外计算的值,或者来自上述某些 JAX 转换的静态参数,或者仅从其他常规 Python 值计算的值。这些是在没有 JAX 转换的情况下在任何地方使用的值。

tracer 值携带一个 抽象 值,例如,带有关于数组的形状和 dtype 信息的 ShapedArray。我们在这里将此类 tracer 称为 抽象 tracer。一些 tracer,例如,为自动微分转换的参数引入的 tracer,携带 ConcreteArray 抽象值,该值实际上包括常规数组数据,并用于例如解析条件语句。我们在这里将此类 tracer 称为 具体 tracer。从这些具体 tracer 计算出的 tracer 值,可能与常规值组合,会产生具体 tracer。具体值 要么是常规值,要么是具体 tracer。

大多数情况下,从 tracer 值计算出的值本身就是 tracer 值。只有极少数例外,当计算可以完全使用 tracer 携带的抽象值完成时,在这种情况下,结果可以是常规值。例如,获取具有 ShapedArray 抽象值的 tracer 的形状。另一个示例是将具体 tracer 值显式转换为常规类型,例如,int(x)x.astype(float)。另一种情况是 bool(x),当具体性使其成为可能时,它会生成一个 Python 布尔值。这种情况尤其突出,因为它经常出现在控制流中。

以下是转换如何引入抽象或具体 tracer

  • jax.jit():为除 static_argnums 表示的参数之外的所有位置参数引入 抽象 tracer,这些参数仍然是常规值。

  • jax.pmap():为除 static_broadcasted_argnums 表示的参数之外的所有位置参数引入 抽象 tracer

  • jax.vmap()jax.make_jaxpr()xla_computation():为所有位置参数引入 抽象 tracer

  • jax.jvp()jax.grad() 为所有位置参数引入 具体 tracer。一个例外是当这些转换位于外部转换内部,并且实际参数本身是抽象 tracer 时;在这种情况下,自动微分转换引入的 tracer 也是抽象 tracer。

  • 所有高阶控制流原语(lax.cond()lax.while_loop()lax.fori_loop()lax.scan())在处理函数时引入 抽象 tracer,无论是否正在进行 JAX 转换。

当您的代码只能在常规 Python 值上运行时,所有这些都相关,例如具有基于数据的条件控制流的代码

def divide(x, y):
  return x / y if y >= 1. else 0.

如果我们想应用 jax.jit(),我们必须确保指定 static_argnums=1 以确保 y 保持为常规值。这是由于布尔表达式 y >= 1.,它需要具体值(常规值或 tracer)。如果我们显式地写 bool(y >= 1.)int(y)float(y),也会发生同样的情况。

有趣的是,jax.grad(divide)(3., 2.) 可以工作,因为 jax.grad() 使用具体 tracer,并使用 y 的具体值解析条件语句。

缓冲区捐赠#

当 JAX 执行计算时,它在设备上为所有输入和输出使用缓冲区。如果您知道计算后不再需要某个输入,并且它与某个输出的形状和元素类型匹配,您可以指定您希望捐赠相应的输入缓冲区以保存输出。这将减少执行所需的内存,减少量为捐赠缓冲区的大小。

如果您有类似以下模式的内容,您可以使用缓冲区捐赠

params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params, state)

您可以将其视为在不可变的 JAX 数组上执行内存高效的函数式更新的一种方式。在计算 XLA 的边界内,XLA 可以为您进行此优化,但在 jit/pmap 边界处,您需要向 XLA 保证在调用捐赠函数后您不会使用捐赠的输入缓冲区。

您可以通过使用函数 jax.jit()jax.pjit()jax.pmap()donate_argnums 参数来实现这一点。此参数是位置参数列表的索引序列(从 0 开始)

def add(x, y):
  return x + y

x = jax.device_put(np.ones((2, 3)))
y = jax.device_put(np.ones((2, 3)))
# Execute `add` with donation of the buffer for `y`. The result has
# the same shape and type as `y`, so it will share its buffer.
z = jax.jit(add, donate_argnums=(1,))(x, y)

请注意,当使用关键字参数调用函数时,这目前不起作用!以下代码不会捐赠任何缓冲区

params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params=params, state=state)

如果缓冲区被捐赠的参数是一个 pytree,则其所有组件的缓冲区都会被捐赠

def add_ones(xs: List[Array]):
  return [x + 1 for x in xs]

xs = [jax.device_put(np.ones((2, 3))), jax.device_put(np.ones((3, 4)))]
# Execute `add_ones` with donation of all the buffers for `xs`.
# The outputs have the same shape and type as the elements of `xs`,
# so they will share those buffers.
z = jax.jit(add_ones, donate_argnums=0)(xs)

不允许捐赠随后在计算中使用的缓冲区,JAX 会给出错误,因为 y 的缓冲区在捐赠后已失效

# Donate the buffer for `y`
z = jax.jit(add, donate_argnums=(1,))(x, y)
w = y + 1  # Reuses `y` whose buffer was donated above
# >> RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer

如果捐赠的缓冲区未使用,您将收到警告,例如,因为捐赠的缓冲区多于可用于输出的缓冲区

# Execute `add` with donation of the buffers for both `x` and `y`.
# One of those buffers will be used for the result, but the other will
# not be used.
z = jax.jit(add, donate_argnums=(0, 1))(x, y)
# >> UserWarning: Some donated buffers were not usable: f32[2,3]{1,0}

如果没有输出的形状与捐赠匹配,则捐赠也可能未使用

y = jax.device_put(np.ones((1, 3)))  # `y` has different shape than the output
# Execute `add` with donation of the buffer for `y`.
z = jax.jit(add, donate_argnums=(1,))(x, y)
# >> UserWarning: Some donated buffers were not usable: f32[1,3]{1,0}

使用 where 时,梯度包含 NaN#

如果您使用 where 定义一个函数以避免未定义的值,如果您不小心,您可能会在反向微分中获得 NaN

def my_log(x):
  return jnp.where(x > 0., jnp.log(x), 0.)

my_log(0.) ==> 0.  # Ok
jax.grad(my_log)(0.)  ==> NaN

一个简短的解释是,在 grad 计算期间,对应于未定义的 jnp.log(x) 的伴随值是 NaN,并且它会累积到 jnp.where 的伴随值中。编写此类函数的正确方法是确保在部分定义的函数内部存在 jnp.where,以确保伴随值始终是有限的

def safe_for_grad_log(x):
  return jnp.log(jnp.where(x > 0., x, 1.))

safe_for_grad_log(0.) ==> 0.  # Ok
jax.grad(safe_for_grad_log)(0.)  ==> 0.  # Ok

除了原始的 jnp.where 之外,可能还需要内部的 jnp.where,例如

def my_log_or_y(x, y):
  """Return log(x) if x > 0 or y"""
  return jnp.where(x > 0., jnp.log(jnp.where(x > 0., x, 1.)), y)

延伸阅读

为什么基于排序顺序的函数的梯度为零?#

如果您定义一个函数,该函数使用依赖于输入相对顺序的操作(例如 maxgreaterargsort 等)处理输入,那么您可能会惊讶地发现梯度处处为零。这是一个示例,我们定义 f(x) 为阶跃函数,当 x 为负数时返回 0,当 x 为正数时返回 1

import jax
import numpy as np
import jax.numpy as jnp

def f(x):
  return (x > 0).astype(float)

df = jax.vmap(jax.grad(f))

x = jnp.array([-1.0, -0.5, 0.0, 0.5, 1.0])

print(f"f(x)  = {f(x)}")
# f(x)  = [0. 0. 0. 1. 1.]

print(f"df(x) = {df(x)}")
# df(x) = [0. 0. 0. 0. 0.]

梯度处处为零的事实起初可能会令人困惑:毕竟,输出确实会响应输入而变化,那么梯度怎么可能为零呢?但是,在这种情况下,零被证明是正确的结果。

这是为什么?请记住,微分测量的是给定 x 的无穷小变化时 f 的变化。对于 x=1.0f 返回 1.0。如果我们扰动 x 使其略大或略小,这不会改变输出,因此根据定义,grad(f)(1.0) 应该为零。对于所有大于零的 f 值,此逻辑同样成立:无穷小地扰动输入不会改变输出,因此梯度为零。类似地,对于所有小于零的 x 值,输出为零。扰动 x 不会改变此输出,因此梯度为零。这给我们留下了棘手的 x=0 情况。当然,如果您向上扰动 x,它会改变输出,但这有问题:x 的无穷小变化会导致函数值的有限变化,这意味着梯度未定义。幸运的是,我们还有另一种方法来衡量这种情况下的梯度:我们向下扰动函数,在这种情况下,输出不会改变,因此梯度为零。JAX 和其他自动微分系统倾向于以这种方式处理不连续性:如果正梯度和负梯度不一致,但一个是定义的,另一个是未定义的,我们使用定义的那个。根据梯度的这种定义,从数学和数值上来说,此函数的梯度处处为零。

问题源于我们的函数在 x = 0 处具有不连续性。我们的 f 本质上是一个 Heaviside 阶跃函数,我们可以使用 Sigmoid 函数 作为平滑的替代品。当 x 远离零时,sigmoid 函数近似等于 heaviside 函数,但用平滑的、可微分的曲线替换了 x = 0 处的不连续性。由于使用了 jax.nn.sigmoid(),我们得到了类似的计算,但梯度定义良好

def g(x):
  return jax.nn.sigmoid(x)

dg = jax.vmap(jax.grad(g))

x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])

with np.printoptions(suppress=True, precision=2):
  print(f"g(x)  = {g(x)}")
  # g(x)  = [0.   0.27 0.5  0.73 1.  ]

  print(f"dg(x) = {dg(x)}")
  # dg(x) = [0.   0.2  0.25 0.2  0.  ]

jax.nn 子模块还具有其他常见基于排序的函数的平滑版本,例如 jax.nn.softmax() 可以替换 jax.numpy.argmax() 的用法,jax.nn.soft_sign() 可以替换 jax.numpy.sign() 的用法,jax.nn.softplus()jax.nn.squareplus() 可以替换 jax.nn.relu() 等的用法。

如何将 JAX Tracer 转换为 NumPy 数组?#

在运行时检查转换后的 JAX 函数时,您会发现数组值被 Tracer 对象替换

@jax.jit
def f(x):
  print(type(x))
  return x

f(jnp.arange(5))

这将打印以下内容

<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>

一个常见的问题是如何将这样的 tracer 转换回普通的 NumPy 数组。简而言之,不可能将 Tracer 转换为 NumPy 数组,因为 tracer 是具有给定形状和 dtype 的每个可能值的抽象表示,而 numpy 数组是该抽象类的具体成员。有关 tracer 如何在 JAX 转换的上下文中工作的更多讨论,请参阅 JIT 机制

将 Tracer 转换回数组的问题通常出现在另一个目标的上下文中,该目标与在运行时访问计算中的中间值有关。例如

  • 如果您希望在运行时打印跟踪值以进行调试,您可以考虑使用 jax.debug.print()

  • 如果您希望在转换后的 JAX 函数中调用非 JAX 代码,您可以考虑使用 jax.pure_callback(),其示例可在 纯回调示例 中找到。

  • 如果您希望在运行时输入或输出数组缓冲区(例如,从文件加载数据,或将数组内容记录到磁盘),您可以考虑使用 jax.experimental.io_callback(),其示例可在 IO 回调示例 中找到。

有关运行时回调及其用法的更多信息,请参阅 JAX 中的外部回调

为什么某些 CUDA 库加载/初始化失败?#

在解析动态库时,JAX 使用通常的 动态链接器搜索模式。JAX 设置 RPATH 以指向 pip 安装的 NVIDIA CUDA 包的 JAX 相关位置,如果已安装,则优先使用它们。如果 ld.so 无法在其通常的搜索路径中找到您的 CUDA 运行时库,那么您必须在 LD_LIBRARY_PATH 中显式包含这些库的路径。确保您的 CUDA 文件可被发现的最简单方法是简单地安装 nvidia-*-cu12 pip 包,这些包包含在标准的 jax[cuda_12] 安装选项中。

偶尔,即使您已确保运行时库是可发现的,加载或初始化它们仍然可能存在一些问题。此类问题的常见原因是 CUDA 库在运行时初始化时内存不足。有时会发生这种情况,因为 JAX 会为更快的执行预先分配过大的当前可用设备内存块,偶尔会导致运行时 CUDA 库初始化可用的内存不足。

当运行多个 JAX 实例、与执行自身预分配的 TensorFlow 并行运行 JAX,或者在 GPU 被其他进程大量使用的系统上运行 JAX 时,这种情况尤其可能发生。如有疑问,请尝试通过从默认值 .75 降低 XLA_PYTHON_CLIENT_MEM_FRACTION 或设置 XLA_PYTHON_CLIENT_PREALLOCATE=false 来减少预分配,再次运行程序。有关更多详细信息,请参阅关于 JAX GPU 内存分配 的页面。