🔪 JAX - 注意事项 🔪#

Open in Colab Open in Kaggle

漫步在意大利乡村时,人们会毫不犹豫地告诉你,**JAX** 拥有 “una anima di pura programmazione funzionale”

**JAX** 是一种用于**表达**和**组合**数值程序**转换**的语言。**JAX** 还能够为 CPU 或加速器(GPU/TPU)**编译**数值程序。JAX 非常适用于许多数值和科学程序,但**前提是它们必须符合我们下面描述的某些约束**。

import numpy as np
from jax import jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp

🔪 纯函数#

JAX 的转换和编译旨在仅适用于函数式纯净的 Python 函数:所有输入数据通过函数参数传入,所有结果通过函数结果输出。如果纯函数以相同的输入调用,它将始终返回相同的结果。

以下是一些非函数式纯净的函数示例,对于这些函数,JAX 的行为与 Python 解释器不同。请注意,JAX 系统不保证这些行为;使用 JAX 的正确方法是仅将其用于函数式纯净的 Python 函数。

def impure_print_side_effect(x):
  print("Executing function")  # This is a side-effect
  return x

# The side-effects appear during the first run
print ("First call: ", jit(impure_print_side_effect)(4.))

# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print ("Second call: ", jit(impure_print_side_effect)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.])))
Executing function
First call:  4.0
Second call:  5.0
Executing function
Third call, different type:  [5.]
g = 0.
def impure_uses_globals(x):
  return x + g

# JAX captures the value of the global during the first run
print ("First call: ", jit(impure_uses_globals)(4.))
g = 10.  # Update the global

# Subsequent runs may silently use the cached value of the globals
print ("Second call: ", jit(impure_uses_globals)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))
First call:  4.0
Second call:  5.0
Third call, different type:  [14.]
g = 0.
def impure_saves_global(x):
  global g
  g = x
  return x

# JAX runs once the transformed function with special Traced values for arguments
print ("First call: ", jit(impure_saves_global)(4.))
print ("Saved global: ", g)  # Saved global has an internal JAX value
First call:  4.0
Saved global:  JitTracer<~float32[]>

一个 Python 函数可以是函数式纯净的,即使它内部实际使用了有状态对象,只要它不读取或写入外部状态。

def pure_uses_internal_state(x):
  state = dict(even=0, odd=0)
  for i in range(10):
    state['even' if i % 2 == 0 else 'odd'] += x
  return state['even'] + state['odd']

print(jit(pure_uses_internal_state)(5.))
50.0

不建议在任何您希望 jit 的 JAX 函数或任何控制流原语中使用迭代器。原因是迭代器是一个 Python 对象,它引入了状态来检索下一个元素。因此,它与 JAX 的函数式编程模型不兼容。在下面的代码中,有一些错误地尝试将迭代器与 JAX 一起使用的示例。它们中的大多数都会返回错误,但有些会给出意想不到的结果。

import jax.numpy as jnp
from jax import make_jaxpr

# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0

# lax.scan
def func11(arr, extra):
    ones = jnp.ones(arr.shape)
    def body(carry, aelems):
        ae1, ae2 = aelems
        return (carry + ae1 * ae2 + extra, carry)
    return lax.scan(body, 0., (arr, ones))
make_jaxpr(func11)(jnp.arange(16), 5.)
# make_jaxpr(func11)(iter(range(16)), 5.) # throws error

# lax.cond
array_operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
iter_operand = iter(range(10))
# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error
45
0

🔪 原地更新#

在 Numpy 中,您习惯于这样做

numpy_array = np.zeros((3,3), dtype=np.float32)
print("original array:")
print(numpy_array)

# In place, mutating update
numpy_array[1, :] = 1.0
print("updated array:")
print(numpy_array)
original array:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
updated array:
[[0. 0. 0.]
 [1. 1. 1.]
 [0. 0. 0.]]

然而,如果我们尝试对 jax.Array 进行原地索引更新,我们会得到一个**错误**!(☉_☉)

%xmode Minimal
Exception reporting mode: Minimal
jax_array = jnp.zeros((3,3), dtype=jnp.float32)

# In place update of JAX's array will yield an error!
jax_array[1, :] = 1.0
TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://jax.net.cn/en/latest/_autosummary/jax.numpy.ndarray.at.html

如果我们尝试进行 __iadd__ 风格的原地更新,我们会得到与 **NumPy 不同的行为**!(☉_☉)(☉_☉)

jax_array = jnp.array([10, 20])
jax_array_new = jax_array
jax_array_new += 10
print(jax_array_new)  # `jax_array_new` is rebound to a new value [20, 30], but...
print(jax_array)      # the original value is unodified as [10, 20] !

numpy_array = np.array([10, 20])
numpy_array_new = numpy_array
numpy_array_new += 10
print(numpy_array_new)  # `numpy_array_new is numpy_array`, and it was updated
print(numpy_array)      # in-place, so both are [20, 30] !
[20 30]
[10 20]
[20 30]
[20 30]

这是因为 NumPy 将 __iadd__ 定义为执行原地修改。相比之下,jax.Array 没有定义 __iadd__,所以 Python 将 jax_array_new += 10 视为 jax_array_new = jax_array_new + 10 的语法糖,重新绑定变量而无需修改任何数组。

允许变量原地修改会使程序分析和转换变得困难。JAX 要求程序是纯函数。

相反,JAX 提供了一种函数式数组更新方式,使用 JAX 数组上的 .at 属性

️⚠️ 在 jit 编译的代码以及 lax.while_looplax.fori_loop 中,切片的**大小**不能是参数的函数,而只能是参数形状的函数——切片起始索引没有这样的限制。有关此限制的更多信息,请参阅下面的**控制流**部分。

数组更新:x.at[idx].set(y)#

例如,上面的更新可以写成

jax_array = jnp.zeros((3,3), dtype=jnp.float32)
updated_array = jax_array.at[1, :].set(1.0)
print("updated array:\n", updated_array)
updated array:
 [[0. 0. 0.]
 [1. 1. 1.]
 [0. 0. 0.]]

JAX 的数组更新函数与 NumPy 版本不同,它们是异地操作(out-of-place)。也就是说,更新后的数组作为新数组返回,原始数组不会被更新修改。

print("original array unchanged:\n", jax_array)
original array unchanged:
 [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]

然而,在 **jit** 编译的代码中,如果 x.at[idx].set(y) 的**输入值** x 未被重复使用,编译器将优化数组更新以实现原地操作。

使用其他操作进行数组更新#

索引数组更新不仅仅限于覆盖值。例如,我们可以按如下方式执行索引加法

print("original array:")
jax_array = jnp.ones((5, 6))
print(jax_array)

new_jax_array = jax_array.at[::2, 3:].add(7.)
print("new array post-addition:")
print(new_jax_array)
original array:
[[1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]]
new array post-addition:
[[1. 1. 1. 8. 8. 8.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 8. 8. 8.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 8. 8. 8.]]

有关索引数组更新的更多详细信息,请参阅 .at 属性的文档

🔪 越界索引#

在 Numpy 中,您习惯于在数组越界索引时抛出错误,如下所示

np.arange(10)[11]
IndexError: index 11 is out of bounds for axis 0 with size 10

然而,从在加速器上运行的代码中抛出错误可能很困难或不可能。因此,JAX 必须为越界索引选择一些非错误行为(类似于无效浮点运算导致 NaN 的方式)。当索引操作是数组索引更新(例如 index_add 或类似 scatter 的原语)时,越界索引处的更新将被跳过;当操作是数组索引检索(例如 NumPy 索引或类似 gather 的原语)时,索引将被限制在数组边界内,因为**必须**返回某些内容。例如,此索引操作将返回数组的最后一个值

jnp.arange(10)[11]
Array(9, dtype=int32)

如果您想对越界索引的行为进行更精细的控制,您可以使用 ndarray.at 的可选参数;例如

jnp.arange(10.0).at[11].get()
Array(9., dtype=float32)
jnp.arange(10.0).at[11].get(mode='fill', fill_value=jnp.nan)
Array(nan, dtype=float32)

请注意,由于索引检索的这种行为,诸如 jnp.nanargminjnp.nanargmax 等函数对于包含 NaNs 的切片返回 -1,而 Numpy 则会抛出错误。

另请注意,由于上述两种行为并非互逆,反向模式自动微分(将索引更新转换为索引检索,反之亦然)将无法保留越界索引的语义。因此,将 JAX 中的越界索引视为 未定义行为 的一种情况可能是个好主意。

🔪 非数组输入:NumPy 与 JAX#

NumPy 通常乐于接受 Python 列表或元组作为其 API 函数的输入

np.sum([1, 2, 3])
np.int64(6)

JAX 与此不同,通常会返回一个有用的错误

jnp.sum([1, 2, 3])
TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.

这是一个有意的设计选择,因为将列表或元组传递给跟踪函数可能会导致难以察觉的静默性能下降。

例如,考虑以下允许列表输入的 jnp.sum 宽松版本

def permissive_sum(x):
  return jnp.sum(jnp.array(x))

x = list(range(10))
permissive_sum(x)
Array(45, dtype=int32)

输出符合我们的预期,但这在底层隐藏了潜在的性能问题。在 JAX 的跟踪和 JIT 编译模型中,Python 列表或元组中的每个元素都被视为一个独立的 JAX 变量,并单独处理并推送到设备。这可以在上面 permissive_sum 函数的 jaxpr 中看到

make_jaxpr(permissive_sum)(x)
{ lambda ; a:i32[] b:i32[] c:i32[] d:i32[] e:i32[] f:i32[] g:i32[] h:i32[] i:i32[]
    j:i32[]. let
    k:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
    l:i32[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] k
    m:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    n:i32[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] m
    o:i32[] = convert_element_type[new_dtype=int32 weak_type=False] c
    p:i32[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] o
    q:i32[] = convert_element_type[new_dtype=int32 weak_type=False] d
    r:i32[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] q
    s:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
    t:i32[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] s
    u:i32[] = convert_element_type[new_dtype=int32 weak_type=False] f
    v:i32[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] u
    w:i32[] = convert_element_type[new_dtype=int32 weak_type=False] g
    x:i32[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] w
    y:i32[] = convert_element_type[new_dtype=int32 weak_type=False] h
    z:i32[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] y
    ba:i32[] = convert_element_type[new_dtype=int32 weak_type=False] i
    bb:i32[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] ba
    bc:i32[] = convert_element_type[new_dtype=int32 weak_type=False] j
    bd:i32[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] bc
    be:i32[10] = concatenate[dimension=0] l n p r t v x z bb bd
    bf:i32[] = reduce_sum[axes=(0,)] be
  in (bf,) }

列表的每个条目都被作为一个独立的输入处理,导致跟踪和编译开销随列表大小线性增长。为了防止此类意外,JAX 避免将列表和元组隐式转换为数组。

如果您想将元组或列表传递给 JAX 函数,您可以通过先将其显式转换为数组来完成此操作

jnp.sum(jnp.array(x))
Array(45, dtype=int32)

🔪 随机数#

JAX 的伪随机数生成与 Numpy 的在重要方面有所不同。有关快速入门,请参阅 伪随机数。有关更多详细信息,请参阅 伪随机数 教程。

🔪 控制流#

已移至 JIT 下的控制流和逻辑运算符

🔪 动态形状#

jax.jitjax.vmapjax.grad 等转换中使用的 JAX 代码要求所有输出数组和中间数组都具有静态形状:也就是说,形状不能依赖于其他数组中的值。

例如,如果您要实现自己的 jnp.nansum 版本,您可能会从这样的代码开始

def nansum(x):
  mask = ~jnp.isnan(x)  # boolean mask selecting non-nan values
  x_without_nans = x[mask]
  return x_without_nans.sum()

在 JIT 和其他转换之外,这按预期工作

x = jnp.array([1, 2, jnp.nan, 3, 4])
print(nansum(x))
10.0

如果您尝试将 jax.jit 或其他转换应用于此函数,它将出错

jax.jit(nansum)(x)
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got bool[5]

See https://jax.net.cn/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError

问题在于 x_without_nans 的大小取决于 x 中的值,换句话说,其大小是动态的。在 JAX 中,通常可以通过其他方式解决对动态大小数组的需求。例如,这里可以使用 jnp.where 的三参数形式将 NaN 值替换为零,从而在避免动态形状的同时计算出相同的结果

@jax.jit
def nansum_2(x):
  mask = ~jnp.isnan(x)  # boolean mask selecting non-nan values
  return jnp.where(mask, x, 0).sum()

print(nansum_2(x))
10.0

在出现动态形状数组的其他情况下,也可以采用类似的技巧。

🔪 NaNs#

调试 NaNs#

如果您想跟踪 NaNs 在函数或梯度中出现的位置,您可以通过以下方式打开 NaN 检查器

  • 设置 JAX_DEBUG_NANS=True 环境变量;

  • 在您的主文件顶部附近添加 jax.config.update("jax_debug_nans", True)

  • 在您的主文件中添加 jax.config.parse_flags_with_absl(),然后使用命令行标志(例如 --jax_debug_nans=True)设置选项;

这会导致计算在产生 NaN 时立即报错。启用此选项会将 NaN 检查添加到 XLA 生成的每个浮点类型值中。这意味着对于不在 @jit 下的每个原始操作,值都会被拉回到主机并作为 ndarray 进行检查。对于 @jit 下的代码,会检查每个 @jit 函数的输出,如果存在 NaN,它将以去优化(op-by-op)模式重新运行该函数,从而每次有效移除一层 @jit

可能会出现一些棘手的情况,例如仅在 @jit 下发生但在去优化模式下不产生的 NaN。在这种情况下,您会看到一条警告消息,但您的代码将继续执行。

如果 NaN 在梯度评估的反向传播中产生,当堆栈跟踪中上层几帧抛出异常时,您将处于 backward_pass 函数中,它本质上是一个简单的 jaxpr 解释器,以反向顺序遍历原始操作序列。在下面的示例中,我们使用命令行 env JAX_DEBUG_NANS=True ipython 启动了一个 ipython repl,然后运行了以下代码

In [1]: import jax.numpy as jnp

In [2]: jnp.divide(0., 0.)
---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
<ipython-input-2-f2e2c413b437> in <module>()
----> 1 jnp.divide(0., 0.)

.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
    343     return floor_divide(x1, x2)
    344   else:
--> 345     return true_divide(x1, x2)
    346
    347

.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
    332   x1, x2 = _promote_shapes(x1, x2)
    333   return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334                  lax.convert_element_type(x2, result_dtype))
    335
    336

.../jax/jax/lax.pyc in div(x, y)
    244 def div(x, y):
    245   r"""Elementwise division: :math:`x \over y`."""
--> 246   return div_p.bind(x, y)
    247
    248 def rem(x, y):

... stack trace ...

.../jax/jax/interpreters/xla.pyc in handle_result(device_buffer)
    103         py_val = device_buffer.to_py()
    104         if np.any(np.isnan(py_val)):
--> 105           raise FloatingPointError("invalid value")
    106         else:
    107           return Array(device_buffer, *result_shape)

FloatingPointError: invalid value

生成的 NaN 被捕获了。通过运行 %debug,我们可以获得事后调试器。这对于 @jit 下的函数也有效,如下面的示例所示。

In [4]: from jax import jit

In [5]: @jit
   ...: def f(x, y):
   ...:     a = x * y
   ...:     b = (x + y) / (x - y)
   ...:     c = a + 2
   ...:     return a + b * c
   ...:

In [6]: x = jnp.array([2., 0.])

In [7]: y = jnp.array([3., 0.])

In [8]: f(x, y)
Invalid value encountered in the output of a jit function. Calling the de-optimized version.
---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
<ipython-input-8-811b7ddb3300> in <module>()
----> 1 f(x, y)

 ... stack trace ...

<ipython-input-5-619b39acbaac> in f(x, y)
      2 def f(x, y):
      3     a = x * y
----> 4     b = (x + y) / (x - y)
      5     c = a + 2
      6     return a + b * c

.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
    343     return floor_divide(x1, x2)
    344   else:
--> 345     return true_divide(x1, x2)
    346
    347

.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
    332   x1, x2 = _promote_shapes(x1, x2)
    333   return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334                  lax.convert_element_type(x2, result_dtype))
    335
    336

.../jax/jax/lax.pyc in div(x, y)
    244 def div(x, y):
    245   r"""Elementwise division: :math:`x \over y`."""
--> 246   return div_p.bind(x, y)
    247
    248 def rem(x, y):

 ... stack trace ...

当这段代码在 @jit 函数的输出中看到 NaN 时,它会调用去优化代码,因此我们仍然可以获得清晰的堆栈跟踪。我们可以使用 %debug 运行事后调试器,检查所有值以找出错误。

⚠️ 如果您不是在调试,则不应开启 NaN 检查器,因为它会引入大量的设备-主机往返开销和性能下降!

⚠️ NaN 检查器不适用于 pmap。要调试 pmap 代码中的 NaN,一种方法是尝试用 vmap 替换 pmap

🔪 双精度(64位)#

目前,JAX 默认强制使用单精度数字,以减轻 Numpy API 将操作数积极提升为 double 的倾向。这对于许多机器学习应用程序是期望的行为,但它可能会让您感到惊讶!

x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype
/tmp/ipykernel_1924/1258726447.py:1: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'>  is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
  x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
dtype('float32')

要使用双精度数字,您需要在**启动时**设置 jax_enable_x64 配置变量。

有几种方法可以做到这一点

  1. 您可以通过设置环境变量 JAX_ENABLE_X64=True 来启用 64 位模式。

  2. 您可以在启动时手动设置 jax_enable_x64 配置标志

    # again, this only works on startup!
    import jax
    jax.config.update("jax_enable_x64", True)
    
  3. 您可以使用 absl.app.run(main) 解析命令行标志

    import jax
    jax.config.config_with_absl()
    
  4. 如果您希望 JAX 为您运行 absl 解析,即您不想执行 absl.app.run(main),您可以改为使用

    import jax
    if __name__ == '__main__':
      # calls jax.config.config_with_absl() *and* runs absl parsing
      jax.config.parse_flags_with_absl()
    

请注意,#2-#4 适用于 JAX 的任何配置选项。

然后我们可以确认 x64 模式已启用,例如

import jax
import jax.numpy as jnp
from jax import random

jax.config.update("jax_enable_x64", True)
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype # --> dtype('float64')

注意事项#

⚠️ XLA 并非在所有后端都支持 64 位卷积!

🔪 与 NumPy 的其他差异#

虽然 jax.numpy 尽力复制 NumPy API 的行为,但在某些特殊情况下,行为确实存在差异。许多此类情况已在上面章节中详细讨论;这里我们列出其他几个已知的 API 差异之处。

  • 对于二元运算,JAX 的类型提升规则与 NumPy 使用的规则有所不同。有关更多详细信息,请参阅 类型提升语义

  • 当执行不安全的类型转换时(即目标 dtype 无法表示输入值的转换),JAX 的行为可能依赖于后端,并且通常可能与 NumPy 的行为有所不同。Numpy 允许通过 casting 参数控制这些情况下的结果(参见 np.ndarray.astype);JAX 不提供任何此类配置,而是直接继承 XLA:ConvertElementType 的行为。

    以下是一个 NumPy 和 JAX 之间结果不同的不安全类型转换示例

    >>> np.arange(254.0, 258.0).astype('uint8')
    array([254, 255,   0,   1], dtype=uint8)
    
    >>> jnp.arange(254.0, 258.0).astype('uint8')
    Array([254, 255, 255, 255], dtype=uint8)
    
    

    这种不匹配通常发生在将极端浮点值转换为整数类型或反之时。

  • 当对 非规范 浮点数进行操作时,JAX 操作在某些后端使用冲零语义(flush-to-zero semantics)。例如

    >>> import jax.numpy as jnp
    >>> subnormal = jnp.float32(1E-45)
    >>> subnormal  # subnormals are representable
    Array(1.e-45, dtype=float32)
    >>> subnormal + 0  # but are flushed to zero within operations
    Array(0., dtype=float32)
    
    

    非规范值的详细操作语义通常会因后端而异。

🔪 教程中涵盖的注意事项#

  • JIT 下的控制流和逻辑运算符 讨论了如何处理 jit 对 Python 控制流和逻辑运算符的使用施加的约束。

  • 有状态计算 提供了一些关于如何在 JAX 程序中正确处理状态的建议,鉴于 JAX 转换只能应用于纯函数。

结束。#

如果这里没有涵盖的内容让您痛哭流涕、咬牙切齿,请告知我们,我们将扩展这些入门性的建议