使用 JIT 的控制流和逻辑运算符#
在急切执行(即 jit
之外)时,JAX 代码的控制流和逻辑运算符与 Numpy 代码的工作方式相同。但在 jit
中使用控制流和逻辑运算符则更为复杂。
简而言之,Python 控制流和逻辑运算符在 JIT 编译时进行评估,使得编译后的函数代表控制流图中的单一路径(逻辑运算符通过短路影响路径)。如果路径依赖于输入的值,则函数(默认情况下)无法进行 JIT 编译。路径可能依赖于输入的形状或数据类型,并且每次使用新形状或数据类型的输入调用函数时,函数都会被重新编译。
from jax import grad, jit
import jax.numpy as jnp
例如,这行得通
@jit
def f(x):
for i in range(3):
x = 2 * x
return x
print(f(3))
24
这也能行
@jit
def g(x):
y = 0.
for i in range(x.shape[0]):
y = y + x[i]
return y
print(g(jnp.array([1., 2., 3.])))
6.0
但这个不行,至少默认情况下不行
@jit
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
# This will fail!
f(2)
---------------------------------------------------------------------------
TracerBoolConversionError Traceback (most recent call last)
Cell In[4], line 9
6 return -4 * x
8 # This will fail!
----> 9 f(2)
[... skipping hidden 13 frame]
Cell In[4], line 3, in f(x)
1 @jit
2 def f(x):
----> 3 if x < 3:
4 return 3. * x ** 2
5 else:
[... skipping hidden 1 frame]
File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:1721, in concretization_function_error.<locals>.error(self, arg)
1720 def error(self, arg):
-> 1721 raise TracerBoolConversionError(arg)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /tmp/ipykernel_1553/3402096563.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError
这个也不行
@jit
def g(x):
return (x > 0) and (x < 3)
# This will fail!
g(2)
---------------------------------------------------------------------------
TracerBoolConversionError Traceback (most recent call last)
Cell In[5], line 6
3 return (x > 0) and (x < 3)
5 # This will fail!
----> 6 g(2)
[... skipping hidden 13 frame]
Cell In[5], line 3, in g(x)
1 @jit
2 def g(x):
----> 3 return (x > 0) and (x < 3)
[... skipping hidden 1 frame]
File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:1721, in concretization_function_error.<locals>.error(self, arg)
1720 def error(self, arg):
-> 1721 raise TracerBoolConversionError(arg)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function g at /tmp/ipykernel_1553/543860509.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError
这是为什么?!
当我们对函数进行 jit
编译时,我们通常希望编译一个适用于多种不同参数值的版本,以便我们可以缓存和重用编译后的代码。这样就不必在每次函数求值时都重新编译。
例如,如果我们在数组 jnp.array([1., 2., 3.], jnp.float32)
上评估一个 @jit
函数,我们可能希望编译一份代码,可以重用于在 jnp.array([4., 5., 6.], jnp.float32)
上评估该函数,以节省编译时间。
为了获得一份适用于多种不同参数值的 Python 代码视图,JAX 使用 ShapedArray
抽象作为输入进行追踪,其中每个抽象值代表具有固定形状和数据类型的所有数组值的集合。例如,如果我们使用抽象值 ShapedArray((3,), jnp.float32)
进行追踪,我们会得到一个函数视图,该视图可以重用于相应数组集合中的任何具体值。这意味着我们可以节省编译时间。
但这有一个权衡:如果我们在一个未绑定到特定具体值的 ShapedArray((), jnp.float32)
上追踪 Python 函数,当我们遇到 if x < 3
这样的代码行时,表达式 x < 3
会被评估为一个抽象的 ShapedArray((), jnp.bool_)
,它代表集合 {True, False}
。当 Python 尝试将其强制转换为具体的 True
或 False
时,就会出现错误:我们不知道该走哪个分支,也无法继续追踪!权衡在于,通过更高层次的抽象,我们能获得更通用的 Python 代码视图(从而节省重新编译),但要完成追踪,Python 代码需要更多约束。
好消息是你可以自己控制这种权衡。通过让 jit
在更精细的抽象值上进行追踪,你可以放宽可追踪性约束。例如,使用 jit
的 static_argnames
(或 static_argnums
)参数,我们可以指定在某些参数的具体值上进行追踪。下面是那个示例函数:
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
f = jit(f, static_argnames='x')
print(f(2.))
12.0
这是另一个例子,这次涉及一个循环
def f(x, n):
y = 0.
for i in range(n):
y = y + x[i]
return y
f = jit(f, static_argnames='n')
f(jnp.array([2., 3., 4.]), 2)
Array(5., dtype=float32)
实际上,循环是静态展开的。JAX 也可以在更高层次的抽象上进行追踪,例如 Unshaped
,但这目前不是任何转换的默认设置。
️⚠️ 形状依赖于参数**值**的函数
这些控制流问题也以更微妙的方式出现:我们希望 **jit** 的数值函数不能根据参数**值**来特殊化内部数组的形状(根据参数**形状**进行特殊化是允许的)。举一个简单的例子,我们来创建一个输出恰好依赖于输入变量 length
的函数。
def example_fun(length, val):
return jnp.ones((length,)) * val
# un-jit'd works fine
print(example_fun(5, 4))
[4. 4. 4. 4. 4.]
bad_example_jit = jit(example_fun)
# this will fail:
bad_example_jit(10, 4)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[9], line 3
1 bad_example_jit = jit(example_fun)
2 # this will fail:
----> 3 bad_example_jit(10, 4)
[... skipping hidden 13 frame]
Cell In[8], line 2, in example_fun(length, val)
1 def example_fun(length, val):
----> 2 return jnp.ones((length,)) * val
File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/numpy/array_creation.py:138, in ones(shape, dtype, device, out_sharding)
136 raise TypeError("expected sequence object with len >= 0 or a single integer")
137 if (m := _check_forgot_shape_tuple("ones", shape, dtype)): raise TypeError(m)
--> 138 shape = canonicalize_shape(shape)
139 dtypes.check_user_dtype_supported(dtype, "ones")
140 sharding = util.choose_device_or_out_sharding(
141 device, out_sharding, 'jnp.ones')
File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/numpy/array_creation.py:46, in canonicalize_shape(shape, context)
44 return core.canonicalize_shape((shape,), context)
45 else:
---> 46 return core.canonicalize_shape(shape, context)
File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:1924, in canonicalize_shape(shape, context)
1922 except TypeError:
1923 pass
-> 1924 raise _invalid_shape_error(shape, context)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (JitTracer<~int32[]>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function example_fun at /tmp/ipykernel_1553/1210496444.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument length.
# static_argnames tells JAX to recompile on changes at these argument positions:
good_example_jit = jit(example_fun, static_argnames='length')
# first compile
print(good_example_jit(10, 4))
# recompiles
print(good_example_jit(5, 4))
[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]
[4. 4. 4. 4. 4.]
static_argnames
会很方便,如果示例中的 length
很少改变,但如果它经常改变,那将是灾难性的!
最后,如果你的函数有全局副作用,JAX 的追踪器可能会导致奇怪的事情发生。一个常见的陷阱是尝试在 jit
编译的函数内部打印数组。
@jit
def f(x):
print(x)
y = 2 * x
print(y)
return y
f(2)
JitTracer<~int32[]>
JitTracer<~int32[]>
Array(4, dtype=int32, weak_type=True)
结构化控制流原语#
JAX 中有更多的控制流选项。假设你想避免重新编译,但仍希望使用可追踪且能避免展开大型循环的控制流。那么你可以使用以下 4 个结构化控制流原语:
lax.cond
可微分lax.while_loop
**前向模式可微分**lax.fori_loop
通常**前向模式可微分**;如果端点是静态的,则**前向和反向模式均可微分**。lax.scan
可微分
cond#
Python 等价物
def cond(pred, true_fun, false_fun, operand):
if pred:
return true_fun(operand)
else:
return false_fun(operand)
from jax import lax
operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, operand)
# --> array([1.], dtype=float32)
lax.cond(False, lambda x: x+1, lambda x: x-1, operand)
# --> array([-1.], dtype=float32)
Array([-1.], dtype=float32)
jax.lax
提供了另外两个允许基于动态谓词进行分支的函数:
lax.select
类似于lax.cond
的批处理版本,其中选择项表示为预先计算的数组,而非函数。lax.switch
类似于lax.cond
,但允许在任意数量的可调用选项之间切换。
此外,jax.numpy
提供了几个 NumPy 风格的接口来调用这些函数:
带有三个参数的
jnp.where
是lax.select
的 NumPy 风格封装。jnp.piecewise
是lax.switch
的 NumPy 风格封装,但它根据布尔条件列表而不是单个标量索引进行切换。jnp.select
的 API 类似于jnp.piecewise
,但其选择项是预先计算的数组,而非函数。它是通过多次调用lax.select
实现的。
while_loop#
Python 等价物
def while_loop(cond_fun, body_fun, init_val):
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
init_val = 0
cond_fun = lambda x: x < 10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)
# --> array(10, dtype=int32)
Array(10, dtype=int32, weak_type=True)
fori_loop#
Python 等价物
def fori_loop(start, stop, body_fun, init_val):
val = init_val
for i in range(start, stop):
val = body_fun(i, val)
return val
init_val = 0
start = 0
stop = 10
body_fun = lambda i,x: x+i
lax.fori_loop(start, stop, body_fun, init_val)
# --> array(45, dtype=int32)
Array(45, dtype=int32, weak_type=True)
概述#
\(\ast\) = 循环条件**不**依赖于参数**值**——展开循环
逻辑运算符#
jax.numpy
提供了 logical_and
、logical_or
和 logical_not
,它们对数组进行元素级操作,并且可以在 jit
下评估而无需重新编译。与 NumPy 中的对应项一样,这些二元运算符不会短路。位运算符(&
、|
、~
)也可以与 jit
一起使用。
例如,考虑一个检查输入是否为正偶数的函数。当输入为标量时,纯 Python 版本和 JAX 版本给出相同的结果。
def python_check_positive_even(x):
is_even = x % 2 == 0
# `and` short-circults, so when `is_even` is `False`, `x > 0` is not evaluated.
return is_even and (x > 0)
@jit
def jax_check_positive_even(x):
is_even = x % 2 == 0
# `logical_and` does not short circuit, so `x > 0` is always evaluated.
return jnp.logical_and(is_even, x > 0)
print(python_check_positive_even(24))
print(jax_check_positive_even(24))
True
True
当带有 logical_and
的 JAX 版本应用于数组时,它返回元素级的值。
x = jnp.array([-1, 2, 5])
print(jax_check_positive_even(x))
[False True False]
当 Python 逻辑运算符应用于包含多个元素的 JAX 数组时,即使没有 jit
,也会报错。这复刻了 NumPy 的行为。
print(python_check_positive_even(x))
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[17], line 1
----> 1 print(python_check_positive_even(x))
Cell In[15], line 4, in python_check_positive_even(x)
2 is_even = x % 2 == 0
3 # `and` short-circults, so when `is_even` is `False`, `x > 0` is not evaluated.
----> 4 return is_even and (x > 0)
File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/array.py:305, in ArrayImpl.__bool__(self)
304 def __bool__(self):
--> 305 core.check_bool_conversion(self)
306 return bool(self._value)
File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:821, in check_bool_conversion(arr)
818 raise ValueError("The truth value of an empty array is ambiguous. Use"
819 " `array.size > 0` to check that an array is not empty.")
820 if arr.size > 1:
--> 821 raise ValueError("The truth value of an array with more than one element"
822 " is ambiguous. Use a.any() or a.all()")
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
Python 控制流 + 自动微分#
请记住,上述关于控制流和逻辑运算符的约束仅在与 jit
结合使用时才相关。如果你只是想对你的 Python 函数应用 grad
,而不使用 jit
,那么你可以毫无问题地使用常规 Python 控制流结构,就像使用 Autograd(或 PyTorch 或 TF Eager)一样。
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
print(grad(f)(2.)) # ok!
print(grad(f)(4.)) # ok!
12.0
-4.0