🔪 JAX - 锋芒毕露 🔪#
当在意大利乡村漫步时,人们会毫不犹豫地告诉你 JAX 具有 “una anima di pura programmazione funzionale”。
JAX 是一种用于表达和组合数值程序转换的语言。JAX 也能够为 CPU 或加速器 (GPU/TPU) 编译数值程序。JAX 在许多数值和科学程序中表现出色,但前提是它们必须以我们下面描述的某些约束条件编写。
import numpy as np
from jax import jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp
🔪 纯函数#
JAX 转换和编译旨在仅适用于功能纯粹的 Python 函数:所有输入数据都通过函数参数传递,所有结果都通过函数结果输出。如果使用相同的输入调用,纯函数将始终返回相同的结果。
以下是一些非功能纯粹的函数的示例,JAX 对它们的行为与 Python 解释器不同。请注意,这些行为并非由 JAX 系统保证;使用 JAX 的正确方法是仅在功能纯粹的 Python 函数上使用它。
def impure_print_side_effect(x):
print("Executing function") # This is a side-effect
return x
# The side-effects appear during the first run
print ("First call: ", jit(impure_print_side_effect)(4.))
# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print ("Second call: ", jit(impure_print_side_effect)(5.))
# JAX re-runs the Python function when the type or shape of the argument changes
print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.])))
Executing function
First call: 4.0
Second call: 5.0
Executing function
Third call, different type: [5.]
g = 0.
def impure_uses_globals(x):
return x + g
# JAX captures the value of the global during the first run
print ("First call: ", jit(impure_uses_globals)(4.))
g = 10. # Update the global
# Subsequent runs may silently use the cached value of the globals
print ("Second call: ", jit(impure_uses_globals)(5.))
# JAX re-runs the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))
First call: 4.0
Second call: 5.0
Third call, different type: [14.]
g = 0.
def impure_saves_global(x):
global g
g = x
return x
# JAX runs once the transformed function with special Traced values for arguments
print ("First call: ", jit(impure_saves_global)(4.))
print ("Saved global: ", g) # Saved global has an internal JAX value
First call: 4.0
Saved global: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>
即使 Python 函数实际上在内部使用了有状态对象,它也可以是功能纯粹的,只要它不读取或写入外部状态
def pure_uses_internal_state(x):
state = dict(even=0, odd=0)
for i in range(10):
state['even' if i % 2 == 0 else 'odd'] += x
return state['even'] + state['odd']
print(jit(pure_uses_internal_state)(5.))
50.0
不建议在任何您想要 jit
的 JAX 函数或任何控制流原语中使用迭代器。原因是迭代器是一个 python 对象,它引入了状态来检索下一个元素。因此,它与 JAX 的函数式编程模型不兼容。在下面的代码中,有一些尝试将迭代器与 JAX 一起使用的不正确示例。它们中的大多数返回错误,但有些给出意外的结果。
import jax.numpy as jnp
from jax import make_jaxpr
# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0
# lax.scan
def func11(arr, extra):
ones = jnp.ones(arr.shape)
def body(carry, aelems):
ae1, ae2 = aelems
return (carry + ae1 * ae2 + extra, carry)
return lax.scan(body, 0., (arr, ones))
make_jaxpr(func11)(jnp.arange(16), 5.)
# make_jaxpr(func11)(iter(range(16)), 5.) # throws error
# lax.cond
array_operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
iter_operand = iter(range(10))
# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error
45
0
🔪 原位更新#
在 Numpy 中,您习惯于这样做
numpy_array = np.zeros((3,3), dtype=np.float32)
print("original array:")
print(numpy_array)
# In place, mutating update
numpy_array[1, :] = 1.0
print("updated array:")
print(numpy_array)
original array:
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
updated array:
[[0. 0. 0.]
[1. 1. 1.]
[0. 0. 0.]]
但是,如果我们尝试在 jax.Array
上进行原位索引更新,我们会收到一个错误!(☉_☉)
%xmode Minimal
Exception reporting mode: Minimal
jax_array = jnp.zeros((3,3), dtype=jnp.float32)
# In place update of JAX's array will yield an error!
jax_array[1, :] = 1.0
TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://jax.net.cn/en/latest/_autosummary/jax.numpy.ndarray.at.html
如果我们尝试进行 __iadd__
风格的原位更新,我们会得到与 NumPy 不同的行为!(☉_☉) (☉_☉)
jax_array = jnp.array([10, 20])
jax_array_new = jax_array
jax_array_new += 10
print(jax_array_new) # `jax_array_new` is rebound to a new value [20, 30], but...
print(jax_array) # the original value is unodified as [10, 20] !
numpy_array = np.array([10, 20])
numpy_array_new = numpy_array
numpy_array_new += 10
print(numpy_array_new) # `numpy_array_new is numpy_array`, and it was updated
print(numpy_array) # in-place, so both are [20, 30] !
[20 30]
[10 20]
[20 30]
[20 30]
这是因为 NumPy 定义 __iadd__
来执行原位突变。相比之下,jax.Array
没有定义 __iadd__
,因此 Python 将 jax_array_new += 10
视为 jax_array_new = jax_array_new + 10
的语法糖,重新绑定变量而不突变任何数组。
允许原位突变变量会使程序分析和转换变得困难。JAX 要求程序是纯函数。
相反,JAX 提供了一种使用 .at
属性在 JAX 数组上的函数式数组更新。
️⚠️ 在 jit
代码和 lax.while_loop
或 lax.fori_loop
内部,切片的大小不能是参数值的函数,而只能是参数形状的函数 – 切片起始索引没有这样的限制。有关此限制的更多信息,请参阅下面的控制流部分。
数组更新:x.at[idx].set(y)
#
例如,上面的更新可以写成
jax_array = jnp.zeros((3,3), dtype=jnp.float32)
updated_array = jax_array.at[1, :].set(1.0)
print("updated array:\n", updated_array)
updated array:
[[0. 0. 0.]
[1. 1. 1.]
[0. 0. 0.]]
与 NumPy 版本不同,JAX 的数组更新函数是异地操作的。也就是说,更新后的数组作为新数组返回,原始数组不会被更新修改。
print("original array unchanged:\n", jax_array)
original array unchanged:
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
但是,在 jit 编译的代码内部,如果 x.at[idx].set(y)
的输入值 x
没有被重用,编译器将优化数组更新以使其原位发生。
使用其他操作的数组更新#
索引数组更新不仅限于简单地覆盖值。例如,我们可以执行索引加法,如下所示
print("original array:")
jax_array = jnp.ones((5, 6))
print(jax_array)
new_jax_array = jax_array.at[::2, 3:].add(7.)
print("new array post-addition:")
print(new_jax_array)
original array:
[[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]]
new array post-addition:
[[1. 1. 1. 8. 8. 8.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 8. 8. 8.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 8. 8. 8.]]
有关索引数组更新的更多详细信息,请参阅 .at
属性的文档。
🔪 越界索引#
在 Numpy 中,您习惯于在索引数组超出其边界时抛出错误,例如这样
np.arange(10)[11]
IndexError: index 11 is out of bounds for axis 0 with size 10
但是,从加速器上运行的代码引发错误可能很困难甚至不可能。因此,JAX 必须为越界索引选择一些非错误行为(类似于无效的浮点算术如何导致 NaN
)。当索引操作是数组索引更新(例如 index_add
或类似 scatter
的原语)时,将跳过越界索引处的更新;当操作是数组索引检索(例如 NumPy 索引或类似 gather
的原语)时,索引将被钳制到数组的边界,因为必须返回某些内容。例如,将从此索引操作返回数组的最后一个值
jnp.arange(10)[11]
Array(9, dtype=int32)
如果您想要更细粒度地控制越界索引的行为,可以使用 ndarray.at
的可选参数;例如
jnp.arange(10.0).at[11].get()
Array(9., dtype=float32)
jnp.arange(10.0).at[11].get(mode='fill', fill_value=jnp.nan)
Array(nan, dtype=float32)
请注意,由于索引检索的这种行为,像 jnp.nanargmin
和 jnp.nanargmax
这样的函数对于由 NaN 组成的切片返回 -1,而 Numpy 会抛出错误。
另请注意,由于上面描述的两种行为不是彼此的逆运算,因此反向模式自动微分(将索引更新转换为索引检索,反之亦然)将不会保留越界索引的语义。因此,将 JAX 中的越界索引视为 未定义行为 的情况可能是个好主意。
🔪 非数组输入:NumPy 与 JAX#
NumPy 通常很乐意接受 Python 列表或元组作为其 API 函数的输入
np.sum([1, 2, 3])
np.int64(6)
JAX 偏离了这一点,通常返回有用的错误
jnp.sum([1, 2, 3])
TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.
这是一个故意的设计选择,因为将列表或元组传递给跟踪函数可能会导致难以察觉的静默性能下降。
例如,考虑以下允许列表输入的 jnp.sum
的宽松版本
def permissive_sum(x):
return jnp.sum(jnp.array(x))
x = list(range(10))
permissive_sum(x)
Array(45, dtype=int32)
输出是我们所期望的,但这隐藏了底层潜在的性能问题。在 JAX 的跟踪和 JIT 编译模型中,Python 列表或元组中的每个元素都被视为单独的 JAX 变量,并单独处理并推送到设备。这可以在上面 permissive_sum
函数的 jaxpr 中看到
make_jaxpr(permissive_sum)(x)
{ lambda ; a:i32[] b:i32[] c:i32[] d:i32[] e:i32[] f:i32[] g:i32[] h:i32[] i:i32[]
j:i32[]. let
k:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
l:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
m:i32[] = convert_element_type[new_dtype=int32 weak_type=False] c
n:i32[] = convert_element_type[new_dtype=int32 weak_type=False] d
o:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
p:i32[] = convert_element_type[new_dtype=int32 weak_type=False] f
q:i32[] = convert_element_type[new_dtype=int32 weak_type=False] g
r:i32[] = convert_element_type[new_dtype=int32 weak_type=False] h
s:i32[] = convert_element_type[new_dtype=int32 weak_type=False] i
t:i32[] = convert_element_type[new_dtype=int32 weak_type=False] j
u:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] k
v:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] l
w:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] m
x:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] n
y:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] o
z:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] p
ba:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] q
bb:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] r
bc:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] s
bd:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] t
be:i32[10] = concatenate[dimension=0] u v w x y z ba bb bc bd
bf:i32[] = reduce_sum[axes=(0,)] be
in (bf,) }
列表的每个条目都作为单独的输入处理,导致跟踪和编译开销随着列表大小线性增长。为了防止出现此类意外情况,JAX 避免了将列表和元组隐式转换为数组。
如果您想将元组或列表传递给 JAX 函数,您可以先显式将其转换为数组
jnp.sum(jnp.array(x))
Array(45, dtype=int32)
🔪 随机数#
JAX 的伪随机数生成在重要方面与 Numpy 的不同。有关快速入门指南,请参阅 伪随机数。有关更多详细信息,请参阅 伪随机数 教程。
🔪 控制流#
已移至 使用 JIT 的控制流和逻辑运算符。
🔪 动态形状#
在 jax.jit
、jax.vmap
、jax.grad
等转换中使用的 JAX 代码要求所有输出数组和中间数组都具有静态形状:即,形状不能依赖于其他数组中的值。
例如,如果您要实现自己的 jnp.nansum
版本,您可能会从类似这样的代码开始
def nansum(x):
mask = ~jnp.isnan(x) # boolean mask selecting non-nan values
x_without_nans = x[mask]
return x_without_nans.sum()
在 JIT 和其他转换之外,这可以按预期工作
x = jnp.array([1, 2, jnp.nan, 3, 4])
print(nansum(x))
10.0
如果您尝试将 jax.jit
或其他转换应用于此函数,它将出错
jax.jit(nansum)(x)
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[5])
See https://jax.net.cn/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError
问题在于 x_without_nans
的大小取决于 x
中的值,这又是说它的大小是动态的另一种方式。通常在 JAX 中,可以通过其他方式解决对动态大小数组的需求。例如,这里可以使用 jnp.where
的三参数形式将 NaN 值替换为零,从而在避免动态形状的同时计算出相同的结果
@jax.jit
def nansum_2(x):
mask = ~jnp.isnan(x) # boolean mask selecting non-nan values
return jnp.where(mask, x, 0).sum()
print(nansum_2(x))
10.0
在其他出现动态形状数组的情况下,可以使用类似的技巧。
🔪 NaNs#
调试 NaNs#
如果您想跟踪 NaN 在您的函数或梯度中出现的位置,您可以打开 NaN 检查器,方法是
设置
JAX_DEBUG_NANS=True
环境变量;在您的主文件顶部附近添加
jax.config.update("jax_debug_nans", True)
;将
jax.config.parse_flags_with_absl()
添加到您的主文件,然后使用命令行标志(如--jax_debug_nans=True
)设置选项;
这将导致计算在产生 NaN 时立即出错。打开此选项会在 XLA 生成的每个浮点类型值中添加 NaN 检查。这意味着值被拉回到主机并作为 ndarray 进行检查,以用于 @jit
下的每个原始操作。对于 @jit
下的代码,将检查每个 @jit
函数的输出,如果存在 NaN,它将以去优化操作模式重新运行该函数,从而有效地一次删除一个级别的 @jit
。
可能会出现棘手的情况,例如 NaN 仅在 @jit
下发生,但在去优化模式下不会产生。在这种情况下,您会看到打印输出警告消息,但您的代码将继续执行。
如果 NaN 是在梯度评估的反向传递中产生的,当堆栈跟踪中向上几帧引发异常时,您将位于 backward_pass 函数中,这本质上是一个简单的 jaxpr 解释器,它反向遍历原始操作序列。在下面的示例中,我们使用命令行 env JAX_DEBUG_NANS=True ipython
启动了 ipython repl,然后运行了以下代码
In [1]: import jax.numpy as jnp
In [2]: jnp.divide(0., 0.)
---------------------------------------------------------------------------
FloatingPointError Traceback (most recent call last)
<ipython-input-2-f2e2c413b437> in <module>()
----> 1 jnp.divide(0., 0.)
.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
343 return floor_divide(x1, x2)
344 else:
--> 345 return true_divide(x1, x2)
346
347
.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
332 x1, x2 = _promote_shapes(x1, x2)
333 return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334 lax.convert_element_type(x2, result_dtype))
335
336
.../jax/jax/lax.pyc in div(x, y)
244 def div(x, y):
245 r"""Elementwise division: :math:`x \over y`."""
--> 246 return div_p.bind(x, y)
247
248 def rem(x, y):
... stack trace ...
.../jax/jax/interpreters/xla.pyc in handle_result(device_buffer)
103 py_val = device_buffer.to_py()
104 if np.any(np.isnan(py_val)):
--> 105 raise FloatingPointError("invalid value")
106 else:
107 return Array(device_buffer, *result_shape)
FloatingPointError: invalid value
捕获了生成的 NaN。通过运行 %debug
,我们可以获得事后调试器。这对于 @jit
下的函数也有效,如下例所示。
In [4]: from jax import jit
In [5]: @jit
...: def f(x, y):
...: a = x * y
...: b = (x + y) / (x - y)
...: c = a + 2
...: return a + b * c
...:
In [6]: x = jnp.array([2., 0.])
In [7]: y = jnp.array([3., 0.])
In [8]: f(x, y)
Invalid value encountered in the output of a jit function. Calling the de-optimized version.
---------------------------------------------------------------------------
FloatingPointError Traceback (most recent call last)
<ipython-input-8-811b7ddb3300> in <module>()
----> 1 f(x, y)
... stack trace ...
<ipython-input-5-619b39acbaac> in f(x, y)
2 def f(x, y):
3 a = x * y
----> 4 b = (x + y) / (x - y)
5 c = a + 2
6 return a + b * c
.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
343 return floor_divide(x1, x2)
344 else:
--> 345 return true_divide(x1, x2)
346
347
.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
332 x1, x2 = _promote_shapes(x1, x2)
333 return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334 lax.convert_element_type(x2, result_dtype))
335
336
.../jax/jax/lax.pyc in div(x, y)
244 def div(x, y):
245 r"""Elementwise division: :math:`x \over y`."""
--> 246 return div_p.bind(x, y)
247
248 def rem(x, y):
... stack trace ...
当此代码在 @jit
函数的输出中看到 NaN 时,它会调用去优化代码,因此我们仍然获得清晰的堆栈跟踪。我们可以使用 %debug
运行事后调试器,以检查所有值以找出错误。
⚠️ 如果您不进行调试,则不应打开 NaN 检查器,因为它可能会引入大量的设备-主机往返和性能回归!
⚠️ NaN 检查器不适用于 pmap
。要在 pmap
代码中调试 NaN,可以尝试将 pmap
替换为 vmap
。
🔪 双精度(64 位)#
目前,JAX 默认强制使用单精度数字,以缓解 Numpy API 积极将操作数提升为 double
的趋势。这对于许多机器学习应用程序来说是期望的行为,但它可能会让您感到惊讶!
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype
/tmp/ipykernel_1221/1258726447.py:1: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
dtype('float32')
要使用双精度数字,您需要在启动时设置 jax_enable_x64
配置变量。
有几种方法可以做到这一点
您可以通过设置环境变量
JAX_ENABLE_X64=True
来启用 64 位模式。您可以在启动时手动设置
jax_enable_x64
配置标志# again, this only works on startup! import jax jax.config.update("jax_enable_x64", True)
您可以使用
absl.app.run(main)
解析命令行标志import jax jax.config.config_with_absl()
如果您希望 JAX 为您运行 absl 解析,即您不想执行
absl.app.run(main)
,则可以改用import jax if __name__ == '__main__': # calls jax.config.config_with_absl() *and* runs absl parsing jax.config.parse_flags_with_absl()
请注意,#2-#4 适用于 JAX 的任何配置选项。
然后我们可以确认 x64
模式已启用,例如
import jax
import jax.numpy as jnp
from jax import random
jax.config.update("jax_enable_x64", True)
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype # --> dtype('float64')
注意事项#
⚠️ XLA 不支持所有后端上的 64 位卷积!
🔪 与 NumPy 的其他差异#
虽然 jax.numpy
尽一切努力复制 numpy API 的行为,但确实存在行为不同的极端情况。上面各节详细讨论了许多此类情况;这里我们列出其他几个已知 API 不同的地方。
对于二元运算,JAX 的类型提升规则与 NumPy 使用的规则略有不同。有关更多详细信息,请参阅 类型提升语义。
当执行不安全的类型转换(即目标 dtype 无法表示输入值的转换)时,JAX 的行为可能取决于后端,并且通常可能与 NumPy 的行为不同。Numpy 允许通过
casting
参数控制这些场景中的结果(请参阅np.ndarray.astype
);JAX 不提供任何此类配置,而是直接继承 XLA:ConvertElementType 的行为。以下是一个不安全转换的示例,NumPy 和 JAX 之间的结果不同
>>> np.arange(254.0, 258.0).astype('uint8') array([254, 255, 0, 1], dtype=uint8) >>> jnp.arange(254.0, 258.0).astype('uint8') Array([254, 255, 255, 255], dtype=uint8)
当从浮点类型转换为整数类型或反之亦然时,通常会出现这种不匹配。
当对 次正规 浮点数进行操作时,JAX 操作在某些后端使用刷新为零的语义。例如
>>> import jax.numpy as jnp >>> subnormal = jnp.float32(1E-45) >>> subnormal # subnormals are representable Array(1.e-45, dtype=float32) >>> subnormal + 0 # but are flushed to zero within operations Array(0., dtype=float32)
次正规值的详细操作语义通常会因后端而异。
完。#
如果这里没有涵盖让您痛哭流涕的内容,请告诉我们,我们将扩展这些入门提示!