常见问题 (FAQ)#

我们在此收集常见问题的答案。欢迎贡献!

jit 改变了我的函数行为#

如果你的 Python 函数在使用 jax.jit() 后行为发生变化,可能是你的函数使用了全局状态,或者有副作用。在以下代码中,impure_func 使用了全局变量 y 并因 print 产生了副作用

y = 0

# @jit   # Different behavior with jit
def impure_func(x):
  print("Inside:", y)
  return x + y

for y in range(3):
  print("Result:", impure_func(y))

没有 jit 时,输出为

Inside: 0
Result: 0
Inside: 1
Result: 2
Inside: 2
Result: 4

而使用 jit 时,输出为

Inside: 0
Result: 0
Result: 1
Result: 2

对于 jax.jit(),函数会使用 Python 解释器执行一次,此时会发生 Inside 的打印,并观察到 y 的第一个值。然后,函数会被编译和缓存,并使用不同的 x 值多次执行,但 y 的值保持为第一次观察到的值。

延伸阅读

jit 改变了输出的精确数值#

有时用户会惊讶于 jit() 包装函数会改变函数输出的事实。例如

>>> from jax import jit
>>> import jax.numpy as jnp
>>> def f(x):
...   return jnp.log(jnp.sqrt(x))
>>> x = jnp.pi
>>> print(f(x))
0.572365
>>> print(jit(f)(x))
0.5723649

输出的细微差别来自 XLA 编译器的优化:在编译过程中,XLA 有时会重新排列或省略某些操作,以提高整体计算效率。

在这种情况下,XLA 利用对数的性质,将 log(sqrt(x)) 替换为 0.5 * log(x),这是一个数学上相同的表达式,但计算效率更高。输出的差异源于浮点算术只是对实数数学的近似,因此用不同方式计算同一表达式可能会产生细微不同的结果。

在其他情况下,XLA 的优化可能会导致更剧烈的差异。考虑以下示例

>>> def f(x):
...   return jnp.log(jnp.exp(x))
>>> x = 100.0
>>> print(f(x))
inf
>>> print(jit(f)(x))
100.0

在非 JIT 编译的逐操作模式下,结果是 inf,因为 jnp.exp(x) 溢出并返回 inf。然而,在 JIT 下,XLA 识别出 logexp 的逆运算,并从编译函数中移除这些操作,直接返回输入。在这种情况下,JIT 编译产生了更准确的实数结果浮点近似。

不幸的是,XLA 代数简化的完整列表没有很好的文档,但如果你熟悉 C++ 并对 XLA 编译器进行的优化类型感到好奇,你可以在源代码中查看它们:algebraic_simplifier.cc

jit 装饰的函数编译非常慢#

如果你的 jit 装饰函数在第一次调用时需要几十秒(甚至更长!)才能运行,但再次调用时执行很快,那么 JAX 正在花费很长时间跟踪或编译你的代码。

这通常表明调用你的函数在 JAX 的内部表示中生成了大量的代码,通常是因为它大量使用了 Python 控制流,例如 for 循环。对于少量循环迭代,Python 还可以,但如果你需要很多循环迭代,你应该重写你的代码以利用 JAX 的结构化控制流原语(例如 lax.scan())或避免用 jit 包装循环(你仍然可以在循环内部使用 jit 装饰的函数)。

如果你不确定这是否是问题,可以尝试在你的函数上运行 jax.make_jaxpr()。如果输出有数百或数千行,则编译会很慢。

有时,如何重写代码以避免 Python 循环并不明显,因为你的代码使用了许多形状不同的数组。在这种情况下,推荐的解决方案是使用 jax.numpy.where() 等函数,在具有固定形状的填充数组上进行计算。

如果你的函数因其他原因编译缓慢,请在 GitHub 上提出问题。

如何在方法中使用 jit#

大多数 jax.jit() 的例子都涉及装饰独立的 Python 函数,但在类中装饰方法会带来一些复杂性。例如,考虑以下简单的类,我们已在此方法上使用标准的 jit() 注解

>>> import jax.numpy as jnp
>>> from jax import jit

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   @jit  # <---- How to do this correctly?
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y

然而,这种方法在你尝试调用此方法时会导致错误

>>> c = CustomClass(2, True)
>>> c.calc(3)  
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
  File "<stdin>", line 1, in <module
TypeError: Argument '<CustomClass object at 0x7f7dd4125890>' of type <class 'CustomClass'> is not a valid JAX type.

问题在于函数的第一个参数是 self,其类型为 CustomClass,而 JAX 不知道如何处理这种类型。在这种情况下,我们有三种基本策略,我们将在下面讨论它们。

策略 1:JIT 编译的辅助函数#

最直接的方法是创建一个独立于类的辅助函数,该函数可以以正常方式进行 JIT 装饰。例如

>>> from functools import partial

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   def calc(self, y):
...     return _calc(self.mul, self.x, y)

>>> @partial(jit, static_argnums=0)
... def _calc(mul, x, y):
...   if mul:
...     return x * y
...   return y

结果将如预期般工作

>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6

这种方法的好处是它简单、明确,并且避免了需要让 JAX 学习如何处理 CustomClass 类型对象的问题。但是,你可能希望将所有方法逻辑放在同一个地方。

策略 2:将 self 标记为静态#

另一种常见模式是使用 static_argnumsself 参数标记为静态。但这必须小心进行,以避免意外结果。你可能很想简单地这样做

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   # WARNING: this example is broken, as we'll see below. Don't copy & paste!
...   @partial(jit, static_argnums=0)
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y

如果你调用该方法,它将不再引发错误

>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6

然而,有一个陷阱:如果你在第一次方法调用后修改了对象,后续的方法调用可能会返回不正确的结果

>>> c.mul = False
>>> print(c.calc(3))  # Should print 3
6

这是为什么呢?当你将一个对象标记为静态时,它将有效地用作 JIT 内部编译缓存中的字典键,这意味着它的哈希值(即 hash(obj))相等性(即 obj1 == obj2)和对象身份(即 obj1 is obj2)将被假定为具有一致的行为。自定义对象的默认 __hash__ 是其对象 ID,因此 JAX 无法知道被修改的对象应该触发重新编译。

你可以通过为你的对象定义适当的 __hash____eq__ 方法来部分解决这个问题;例如

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   @partial(jit, static_argnums=0)
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y
...
...   def __hash__(self):
...     return hash((self.x, self.mul))
...
...   def __eq__(self, other):
...     return (isinstance(other, CustomClass) and
...             (self.x, self.mul) == (other.x, other.mul))

(有关覆盖 __hash__ 时要求的更多讨论,请参阅 object.__hash__() 文档)。

只要你从不更改你的对象,这应该与 JIT 和其他转换正确配合。(例如,可变 Python 容器(例如 dictlist)不定义 __hash__,而它们不可变对应物(例如 tuple)定义 __hash__,就是因为将可变对象用作哈希键会导致一些微妙的问题)。

如果你的类依赖于原地修改(例如在其方法中设置 self.attr = ...),那么你的对象并非真正“静态”,将其标记为静态可能会导致问题。幸运的是,对于这种情况还有另一个选择。

策略 3:将 CustomClass 设为 PyTree#

正确 JIT 编译类方法最灵活的方法是将类型注册为自定义 PyTree 对象;请参阅扩展 pytree。这允许你精确指定类中哪些组件应被视为静态,哪些应被视为动态。下面是它可能的样子

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   @jit
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y
...
...   def _tree_flatten(self):
...     children = (self.x,)  # arrays / dynamic values
...     aux_data = {'mul': self.mul}  # static values
...     return (children, aux_data)
...
...   @classmethod
...   def _tree_unflatten(cls, aux_data, children):
...     return cls(*children, **aux_data)

>>> from jax import tree_util
>>> tree_util.register_pytree_node(CustomClass,
...                                CustomClass._tree_flatten,
...                                CustomClass._tree_unflatten)

这无疑更复杂,但它解决了上面使用的更简单方法相关的所有问题

>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6

>>> c.mul = False  # mutation is detected
>>> print(c.calc(3))
3

>>> c = CustomClass(jnp.array(2), True)  # non-hashable x is supported
>>> print(c.calc(3))
6

只要你的 tree_flattentree_unflatten 函数正确处理类中所有相关属性,你就应该能够直接将此类型的对象用作 JIT 编译函数的参数,而无需任何特殊注解。

控制设备上的数据和计算放置#

我们首先了解 JAX 中数据和计算放置的原则。

在 JAX 中,计算遵循数据放置。JAX 数组有两个放置属性:1) 数据所在的设备;2) 是否已提交到设备(数据有时被称为对设备具有粘性)。

默认情况下,JAX 数组以未提交状态放置在默认设备上(jax.devices()[0]),默认情况下是第一个 GPU 或 TPU。如果没有 GPU 或 TPU,jax.devices()[0] 是 CPU。默认设备可以通过 jax.default_device() 上下文管理器临时覆盖,或者通过设置环境变量 JAX_PLATFORMS 或 absl 标志 --jax_platforms 为“cpu”、“gpu”或“tpu”来为整个进程设置(JAX_PLATFORMS 也可以是平台列表,它决定了平台的优先级顺序)。

>>> from jax import numpy as jnp
>>> print(jnp.ones(3).devices())  
{CudaDevice(id=0)}

涉及未提交数据的计算在默认设备上执行,结果也以未提交状态位于默认设备上。

数据也可以使用带有 device 参数的 jax.device_put() 显式放置在设备上,在这种情况下,数据会提交到设备

>>> import jax
>>> from jax import device_put
>>> arr = device_put(1, jax.devices()[2])  
>>> print(arr.devices())  
{CudaDevice(id=2)}

涉及某些已提交输入的计算将在已提交设备上进行,并且结果也将提交到同一设备。在提交到多个设备的参数上调用操作将引发错误。

你也可以使用不带 device 参数的 jax.device_put()。如果数据已在某个设备上(已提交或未提交),则保持不变。如果数据不在任何设备上——也就是说,它是一个普通的 Python 或 NumPy 值——它将以未提交状态放置在默认设备上。

JIT 编译函数 behave like any other primitive operations—they will follow the data and will show errors if invoked on data committed on more than one device.

(在 2021 年 3 月的 PR #6002 之前,数组常量创建存在一些惰性,因此 jax.device_put(jnp.zeros(...), jax.devices()[1]) 或类似操作实际上会在 jax.devices()[1] 上创建零数组,而不是在默认设备上创建后再移动。但为了简化实现,此优化已被移除。)

(截至 2020 年 4 月,jax.jit() 有一个 device 参数会影响设备放置。该参数是实验性的,可能会被移除或更改,不建议使用。)

有关一个实际示例,我们建议通读 multi_device_test.py 中的 test_computation_follows_data

基准测试 JAX 代码#

你刚刚将一个复杂的函数从 NumPy/SciPy 移植到 JAX。这真的加快了速度吗?

在测量使用 JAX 的代码速度时,请记住与 NumPy 的这些重要差异

  1. JAX 代码是即时 (JIT) 编译的。 大多数用 JAX 编写的代码都可以以支持 JIT 编译的方式编写,这可以使其运行快得多(请参阅是否 JIT)。为了从 JAX 获得最大性能,你应该在最外层函数调用上应用 jax.jit()

    请记住,你第一次运行 JAX 代码时会比较慢,因为它正在被编译。即使你自己的代码中没有使用 jit,情况也是如此,因为 JAX 的内置函数也是 JIT 编译的。

  2. JAX 具有异步调度功能。 这意味着你需要调用 .block_until_ready() 以确保计算实际发生(请参阅异步调度)。

  3. JAX 默认仅使用 32 位 dtype。 你可能希望在 NumPy 中明确使用 32 位 dtype,或在 JAX 中启用 64 位 dtype(请参阅双精度(64 位))以进行公平比较。

  4. CPU 和加速器之间的数据传输需要时间。 如果你只想测量函数评估所需的时间,你可能希望首先将数据传输到你想要运行它的设备上(请参阅控制设备上的数据和计算放置)。

以下是将所有这些技巧组合成一个微基准测试的示例,用于比较 JAX 和 NumPy,并利用 IPython 便捷的 %time 和 %timeit magic 命令

import numpy as np
import jax

def f(x):  # function we're benchmarking (works in both NumPy & JAX)
  return x.T @ (x - x.mean(axis=0))

x_np = np.ones((1000, 1000), dtype=np.float32)  # same as JAX default dtype
%timeit f(x_np)  # measure NumPy runtime

# measure JAX device transfer time
%time x_jax = jax.device_put(x_np).block_until_ready()

f_jit = jax.jit(f)
%time f_jit(x_jax).block_until_ready()  # measure JAX compilation time
%timeit f_jit(x_jax).block_until_ready()  # measure JAX runtime

Colab 中使用 GPU 运行时,我们看到

  • NumPy 在 CPU 上每次评估需要 16.2 毫秒

  • JAX 复制 NumPy 数组到 GPU 需要 1.26 毫秒

  • JAX 编译函数需要 193 毫秒

  • JAX 在 GPU 上每次评估需要 485 微秒

在这种情况下,我们看到一旦数据传输完毕并且函数编译完成,JAX 在 GPU 上重复评估的速度比 NumPy 快约 30 倍。

这是否是一个公平的比较?也许是。最终重要的性能是运行完整应用程序的性能,其中不可避免地会包含一定量的数据传输和编译。此外,我们特意选择了足够大的数组(1000x1000)和足够密集的计算(@ 运算符执行矩阵-矩阵乘法),以分摊 JAX/加速器与 NumPy/CPU 之间增加的开销。例如,如果我们将此示例切换为使用 10x10 输入,JAX/GPU 的运行速度比 NumPy/CPU 慢 10 倍(100 微秒 对 10 微秒)。

JAX 比 NumPy 快吗?#

用户经常试图通过此类基准测试回答的一个问题是 JAX 是否比 NumPy 快;由于两个包的差异,没有简单的答案。

广义上讲

  • NumPy 操作是即时、同步执行的,并且只在 CPU 上执行。

  • JAX 操作可以即时执行,也可以在编译后执行(如果在 jit() 内部);它们是异步调度的(请参阅异步调度);它们可以在 CPU、GPU 或 TPU 上执行,每种设备都有截然不同且不断变化的性能特征。

这些架构差异使得 NumPy 和 JAX 之间有意义的直接基准比较变得困难。

此外,这些差异导致了软件包之间工程重点的不同:例如,NumPy 在减少单个数组操作的每次调用分派开销方面投入了大量精力,因为在 NumPy 的计算模型中,这种开销是无法避免的。另一方面,JAX 有多种方法可以避免分派开销(例如 JIT 编译、异步分派、批处理转换等),因此减少每次调用开销的优先级较低。

考虑到所有这些,总而言之:如果你在 CPU 上对单个数组操作进行微基准测试,由于其较低的每次操作分派开销,通常可以预期 NumPy 会优于 JAX。如果你在 GPU 或 TPU 上运行代码,或者在 CPU 上对更复杂的 JIT 编译操作序列进行基准测试,通常可以预期 JAX 会优于 NumPy。

不同类型的 JAX 值#

在转换函数的过程中,JAX 会用特殊的跟踪器值替换一些函数参数。

如果你使用 print 语句,你可能会看到这一点

def func(x):
  print(x)
  return jnp.cos(x)

res = jax.jit(func)(0.)

上面的代码确实返回了正确的值 1.,但它也打印了 Traced<ShapedArray(float32[])> 作为 x 的值。通常,JAX 会以透明的方式在内部处理这些跟踪器值,例如,在用于实现 jax.numpy 函数的数字 JAX 原语中。这就是为什么 jnp.cos 在上面的例子中有效的原因。

更准确地说,对于 JAX 转换函数的参数会引入跟踪器值,但由特殊参数(例如 jax.jit()static_argnumsjax.pmap()static_broadcasted_argnums)标识的参数除外,它们仍然是常规值。通常,涉及至少一个跟踪器值的计算将产生一个跟踪器值。除了跟踪器值之外,还有常规 Python 值:在 JAX 转换之外计算的值,或由上述某些 JAX 转换的静态参数产生的值,或仅由其他常规 Python 值计算的值。这些是在没有 JAX 转换的情况下在任何地方使用的值。

跟踪器值带有一个抽象值,例如,ShapedArray,其中包含有关数组形状和 dtype 的信息。我们在此将此类跟踪器称为抽象跟踪器。某些跟踪器(例如为自动微分转换的参数引入的跟踪器)带有 ConcreteArray 抽象值,这些值实际上包含常规数组数据,并用于(例如)解决条件语句。我们在此将此类跟踪器称为具体跟踪器。从这些具体跟踪器(可能与常规值结合)计算出的跟踪器值会产生具体跟踪器。具体值是常规值或具体跟踪器。

通常,从跟踪器值计算出的值本身就是跟踪器值。只有极少数例外情况,当计算可以完全使用跟踪器携带的抽象值完成时,结果可以是常规值。例如,获取具有 ShapedArray 抽象值的跟踪器的形状。另一个例子是当将具体跟踪器值显式转换为常规类型时,例如 int(x)x.astype(float)。另一种情况是 bool(x),当具体性允许时,它会生成一个 Python 布尔值。这种情况特别突出,因为它经常出现在控制流中。

以下是转换如何引入抽象或具体跟踪器

  • jax.jit():为所有位置参数引入抽象跟踪器,除了 static_argnums 指定的参数,后者保持常规值。

  • jax.pmap():为所有位置参数引入抽象跟踪器,除了 static_broadcasted_argnums 指定的参数。

  • jax.vmap()jax.make_jaxpr()xla_computation():为所有位置参数引入抽象跟踪器

  • jax.jvp()jax.grad() 为所有位置参数引入具体跟踪器。一个例外是当这些转换位于外部转换之内且实际参数本身是抽象跟踪器时;在这种情况下,自动微分转换引入的跟踪器也是抽象跟踪器。

  • 所有高阶控制流原语(lax.cond()lax.while_loop()lax.fori_loop()lax.scan())在处理函数时会引入抽象跟踪器,无论当前是否有 JAX 转换。

当你有一些只能对常规 Python 值进行操作的代码时,所有这些都是相关的,例如基于数据的条件控制流的代码

def divide(x, y):
  return x / y if y >= 1. else 0.

如果我们要应用 jax.jit(),我们必须确保指定 static_argnums=1 以确保 y 保持为常规值。这是由于布尔表达式 y >= 1.,它需要具体值(常规或跟踪器)。如果我们显式编写 bool(y >= 1.),或 int(y),或 float(y),也会发生同样的情况。

有趣的是,jax.grad(divide)(3., 2.) 之所以有效,是因为 jax.grad() 使用具体跟踪器,并使用 y 的具体值解析条件。

缓冲区捐赠#

JAX 执行计算时,会为所有输入和输出使用设备上的缓冲区。如果你知道其中一个输入在计算后不再需要,并且它与其中一个输出的形状和元素类型匹配,则可以指定将相应的输入缓冲区捐赠以存储输出。这将减少执行所需的内存,减少量为捐赠缓冲区的大小。

如果你有以下模式,可以使用缓冲区捐赠

params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params, state)

你可以将其视为对不可变 JAX 数组进行内存高效函数式更新的一种方式。在计算边界内,XLA 可以为你进行此优化,但在 jit/pmap 边界处,你需要向 XLA 保证在调用捐赠函数后不再使用已捐赠的输入缓冲区。

你可以通过使用 jax.jit()jax.pjit()jax.pmap() 函数的 donate_argnums 参数来实现这一点。该参数是位置参数列表中的索引序列(从 0 开始)

def add(x, y):
  return x + y

x = jax.device_put(np.ones((2, 3)))
y = jax.device_put(np.ones((2, 3)))
# Execute `add` with donation of the buffer for `y`. The result has
# the same shape and type as `y`, so it will share its buffer.
z = jax.jit(add, donate_argnums=(1,))(x, y)

请注意,这目前在使用关键字参数调用函数时无效!以下代码将不会捐赠任何缓冲区

params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params=params, state=state)

如果其缓冲区被捐赠的参数是 pytree,则其所有组件的缓冲区都将被捐赠

def add_ones(xs: List[Array]):
  return [x + 1 for x in xs]

xs = [jax.device_put(np.ones((2, 3))), jax.device_put(np.ones((3, 4)))]
# Execute `add_ones` with donation of all the buffers for `xs`.
# The outputs have the same shape and type as the elements of `xs`,
# so they will share those buffers.
z = jax.jit(add_ones, donate_argnums=0)(xs)

不允许捐赠随后在计算中使用的缓冲区,JAX 将报错,因为 y 的缓冲区在捐赠后已失效

# Donate the buffer for `y`
z = jax.jit(add, donate_argnums=(1,))(x, y)
w = y + 1  # Reuses `y` whose buffer was donated above
# >> RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer

如果捐赠的缓冲区未使用,例如因为捐赠的缓冲区多于输出可用的缓冲区,你将收到警告

# Execute `add` with donation of the buffers for both `x` and `y`.
# One of those buffers will be used for the result, but the other will
# not be used.
z = jax.jit(add, donate_argnums=(0, 1))(x, y)
# >> UserWarning: Some donated buffers were not usable: f32[2,3]{1,0}

如果没有任何输出的形状与捐赠匹配,捐赠也可能未使用

y = jax.device_put(np.ones((1, 3)))  # `y` has different shape than the output
# Execute `add` with donation of the buffer for `y`.
z = jax.jit(add, donate_argnums=(1,))(x, y)
# >> UserWarning: Some donated buffers were not usable: f32[1,3]{1,0}

使用 where 时梯度包含 NaN#

如果你使用 where 定义一个函数以避免未定义的值,如果你不小心,可能会在反向微分时得到 NaN

def my_log(x):
  return jnp.where(x > 0., jnp.log(x), 0.)

my_log(0.) ==> 0.  # Ok
jax.grad(my_log)(0.)  ==> NaN

简短的解释是,在 grad 计算期间,对应于未定义的 jnp.log(x) 的伴随量是一个 NaN,并且它被累加到 jnp.where 的伴随量中。编写此类函数的正确方法是确保在部分定义的函数内部有一个 jnp.where,以确保伴随量始终是有限的

def safe_for_grad_log(x):
  return jnp.log(jnp.where(x > 0., x, 1.))

safe_for_grad_log(0.) ==> 0.  # Ok
jax.grad(safe_for_grad_log)(0.)  ==> 0.  # Ok

除了原始的 jnp.where 之外,可能还需要内部的 jnp.where,例如

def my_log_or_y(x, y):
  """Return log(x) if x > 0 or y"""
  return jnp.where(x > 0., jnp.log(jnp.where(x > 0., x, 1.)), y)

延伸阅读

为什么基于排序的函数的梯度为零?#

如果你定义一个函数,该函数使用依赖于输入相对顺序的操作(例如 maxgreaterargsort 等)处理输入,那么你可能会惊讶地发现梯度处处为零。这里有一个例子,我们定义 f(x) 为一个阶跃函数,当 x 为负时返回 0,当 x 为正时返回 1

import jax
import numpy as np
import jax.numpy as jnp

def f(x):
  return (x > 0).astype(float)

df = jax.vmap(jax.grad(f))

x = jnp.array([-1.0, -0.5, 0.0, 0.5, 1.0])

print(f"f(x)  = {f(x)}")
# f(x)  = [0. 0. 0. 1. 1.]

print(f"df(x) = {df(x)}")
# df(x) = [0. 0. 0. 0. 0.]

梯度处处为零的事实乍一看可能令人困惑:毕竟,输出确实会随输入而变化,那么梯度怎么会为零呢?然而,在这种情况下,零是正确的结果。

这是为什么呢?请记住,微分测量的是 x 发生微小变化时 f 的变化。对于 x=1.0f 返回 1.0。如果我们对 x 进行轻微的增大或减小扰动,这不会改变输出,因此根据定义,grad(f)(1.0) 应该为零。同样的逻辑适用于所有大于零的 f 值:对输入进行微小扰动不会改变输出,因此梯度为零。类似地,对于所有小于零的 x 值,输出为零。扰动 x 不会改变此输出,因此梯度为零。这就剩下 x=0 的棘手情况。当然,如果你向上扰动 x,它会改变输出,但这有问题:x 的微小变化会产生函数值的有限变化,这意味着梯度是未定义的。幸运的是,在这种情况下我们还有另一种方法来测量梯度:我们向下扰动函数,在这种情况下输出不会改变,因此梯度为零。JAX 和其他自动微分系统倾向于以这种方式处理不连续性:如果正梯度和负梯度不一致,但其中一个已定义而另一个未定义,我们使用已定义的那个。在这种梯度定义下,该函数的梯度在数学上和数值上都处处为零。

问题源于我们的函数在 x = 0 处存在不连续性。这里的 f 本质上是一个Heaviside 阶跃函数,我们可以使用Sigmoid 函数作为平滑替代。当 x 远离零时,sigmoid 近似等于 Heaviside 函数,但用平滑、可微分的曲线取代了 x = 0 处的不连续性。由于使用 jax.nn.sigmoid(),我们得到了一个具有良好定义梯度的类似计算

def g(x):
  return jax.nn.sigmoid(x)

dg = jax.vmap(jax.grad(g))

x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])

with np.printoptions(suppress=True, precision=2):
  print(f"g(x)  = {g(x)}")
  # g(x)  = [0.   0.27 0.5  0.73 1.  ]

  print(f"dg(x) = {dg(x)}")
  # dg(x) = [0.   0.2  0.25 0.2  0.  ]

jax.nn 子模块还包含其他常见基于排序函数的平滑版本,例如 jax.nn.softmax() 可以替代 jax.numpy.argmax() 的用法,jax.nn.soft_sign() 可以替代 jax.numpy.sign() 的用法,jax.nn.softplus()jax.nn.squareplus() 可以替代 jax.nn.relu() 的用法等。

如何将 JAX Tracer 转换为 NumPy 数组?#

在运行时检查转换后的 JAX 函数时,你会发现数组值被替换为 jax.core.Tracer 对象

@jax.jit
def f(x):
  print(type(x))
  return x

f(jnp.arange(5))

这会打印以下内容

<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>

一个常见的问题是如何将这种跟踪器转换回普通的 NumPy 数组。简而言之,无法将跟踪器转换为 NumPy 数组,因为跟踪器是具有给定形状和 dtype 的每个可能值的抽象表示,而 NumPy 数组是该抽象类的一个具体成员。有关跟踪器在 JAX 转换上下文中如何工作的更多讨论,请参阅JIT 机制

将跟踪器转换回数组的问题通常出现在另一个目标的情况下,即在运行时访问计算中的中间值。例如

有关运行时回调及其使用示例的更多信息,请参阅JAX 中的外部回调

为什么某些 CUDA 库加载/初始化失败?#

在解析动态库时,JAX 使用通常的动态链接器搜索模式。JAX 将 RPATH 设置为指向 pip 安装的 NVIDIA CUDA 包的 JAX 相对位置,如果已安装则优先使用。如果 ld.so 无法在其常用搜索路径中找到你的 CUDA 运行时库,则必须在 LD_LIBRARY_PATH 中明确包含这些库的路径。确保 CUDA 文件可发现的最简单方法是安装 nvidia-*-cu12 pip 包,这些包包含在标准的 jax[cuda_12] 安装选项中。

偶尔,即使你已确保运行时库可被发现,在加载或初始化它们时仍可能出现一些问题。此类问题的常见原因是 CUDA 库在运行时初始化时内存不足。这有时是因为 JAX 会预分配过大的当前可用设备内存块以实现更快的执行,偶尔会导致用于运行时 CUDA 库初始化的内存不足。

当运行多个 JAX 实例、JAX 与 TensorFlow(其执行自己的预分配)并行运行,或者 JAX 在 GPU 被其他进程大量利用的系统上运行时,这种情况尤其可能发生。如有疑问,请尝试通过减少 XLA_PYTHON_CLIENT_MEM_FRACTION(从默认的 .75)或将 XLA_PYTHON_CLIENT_PREALLOCATE=false 来减少预分配。有关更多详细信息,请参阅 JAX GPU 内存分配页面。