使用 JIT 的控制流和逻辑运算符#

在急切执行(即 jit 之外)时,JAX 代码的控制流和逻辑运算符与 Numpy 代码的工作方式相同。但在 jit 中使用控制流和逻辑运算符则更为复杂。

简而言之,Python 控制流和逻辑运算符在 JIT 编译时进行评估,使得编译后的函数代表控制流图中的单一路径(逻辑运算符通过短路影响路径)。如果路径依赖于输入的值,则函数(默认情况下)无法进行 JIT 编译。路径可能依赖于输入的形状或数据类型,并且每次使用新形状或数据类型的输入调用函数时,函数都会被重新编译。

from jax import grad, jit
import jax.numpy as jnp

例如,这行得通

@jit
def f(x):
  for i in range(3):
    x = 2 * x
  return x

print(f(3))
24

这也能行

@jit
def g(x):
  y = 0.
  for i in range(x.shape[0]):
    y = y + x[i]
  return y

print(g(jnp.array([1., 2., 3.])))
6.0

但这个不行,至少默认情况下不行

@jit
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

# This will fail!
f(2)
---------------------------------------------------------------------------
TracerBoolConversionError                 Traceback (most recent call last)
Cell In[4], line 9
      6     return -4 * x
      8 # This will fail!
----> 9 f(2)

    [... skipping hidden 13 frame]

Cell In[4], line 3, in f(x)
      1 @jit
      2 def f(x):
----> 3   if x < 3:
      4     return 3. * x ** 2
      5   else:

    [... skipping hidden 1 frame]

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:1721, in concretization_function_error.<locals>.error(self, arg)
   1720 def error(self, arg):
-> 1721   raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /tmp/ipykernel_1553/3402096563.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError

这个也不行

@jit
def g(x):
  return (x > 0) and (x < 3)

# This will fail!
g(2)
---------------------------------------------------------------------------
TracerBoolConversionError                 Traceback (most recent call last)
Cell In[5], line 6
      3   return (x > 0) and (x < 3)
      5 # This will fail!
----> 6 g(2)

    [... skipping hidden 13 frame]

Cell In[5], line 3, in g(x)
      1 @jit
      2 def g(x):
----> 3   return (x > 0) and (x < 3)

    [... skipping hidden 1 frame]

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:1721, in concretization_function_error.<locals>.error(self, arg)
   1720 def error(self, arg):
-> 1721   raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function g at /tmp/ipykernel_1553/543860509.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError

这是为什么?!

当我们对函数进行 jit 编译时,我们通常希望编译一个适用于多种不同参数值的版本,以便我们可以缓存和重用编译后的代码。这样就不必在每次函数求值时都重新编译。

例如,如果我们在数组 jnp.array([1., 2., 3.], jnp.float32) 上评估一个 @jit 函数,我们可能希望编译一份代码,可以重用于在 jnp.array([4., 5., 6.], jnp.float32) 上评估该函数,以节省编译时间。

为了获得一份适用于多种不同参数值的 Python 代码视图,JAX 使用 ShapedArray 抽象作为输入进行追踪,其中每个抽象值代表具有固定形状和数据类型的所有数组值的集合。例如,如果我们使用抽象值 ShapedArray((3,), jnp.float32) 进行追踪,我们会得到一个函数视图,该视图可以重用于相应数组集合中的任何具体值。这意味着我们可以节省编译时间。

但这有一个权衡:如果我们在一个未绑定到特定具体值的 ShapedArray((), jnp.float32) 上追踪 Python 函数,当我们遇到 if x < 3 这样的代码行时,表达式 x < 3 会被评估为一个抽象的 ShapedArray((), jnp.bool_),它代表集合 {True, False}。当 Python 尝试将其强制转换为具体的 TrueFalse 时,就会出现错误:我们不知道该走哪个分支,也无法继续追踪!权衡在于,通过更高层次的抽象,我们能获得更通用的 Python 代码视图(从而节省重新编译),但要完成追踪,Python 代码需要更多约束。

好消息是你可以自己控制这种权衡。通过让 jit 在更精细的抽象值上进行追踪,你可以放宽可追踪性约束。例如,使用 jitstatic_argnames(或 static_argnums)参数,我们可以指定在某些参数的具体值上进行追踪。下面是那个示例函数:

def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

f = jit(f, static_argnames='x')

print(f(2.))
12.0

这是另一个例子,这次涉及一个循环

def f(x, n):
  y = 0.
  for i in range(n):
    y = y + x[i]
  return y

f = jit(f, static_argnames='n')

f(jnp.array([2., 3., 4.]), 2)
Array(5., dtype=float32)

实际上,循环是静态展开的。JAX 也可以在更高层次的抽象上进行追踪,例如 Unshaped,但这目前不是任何转换的默认设置。

️⚠️ 形状依赖于参数**值**的函数

这些控制流问题也以更微妙的方式出现:我们希望 **jit** 的数值函数不能根据参数**值**来特殊化内部数组的形状(根据参数**形状**进行特殊化是允许的)。举一个简单的例子,我们来创建一个输出恰好依赖于输入变量 length 的函数。

def example_fun(length, val):
  return jnp.ones((length,)) * val
# un-jit'd works fine
print(example_fun(5, 4))
[4. 4. 4. 4. 4.]
bad_example_jit = jit(example_fun)
# this will fail:
bad_example_jit(10, 4)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[9], line 3
      1 bad_example_jit = jit(example_fun)
      2 # this will fail:
----> 3 bad_example_jit(10, 4)

    [... skipping hidden 13 frame]

Cell In[8], line 2, in example_fun(length, val)
      1 def example_fun(length, val):
----> 2   return jnp.ones((length,)) * val

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/numpy/array_creation.py:138, in ones(shape, dtype, device, out_sharding)
    136   raise TypeError("expected sequence object with len >= 0 or a single integer")
    137 if (m := _check_forgot_shape_tuple("ones", shape, dtype)): raise TypeError(m)
--> 138 shape = canonicalize_shape(shape)
    139 dtypes.check_user_dtype_supported(dtype, "ones")
    140 sharding = util.choose_device_or_out_sharding(
    141     device, out_sharding, 'jnp.ones')

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/numpy/array_creation.py:46, in canonicalize_shape(shape, context)
     44   return core.canonicalize_shape((shape,), context)
     45 else:
---> 46   return core.canonicalize_shape(shape, context)

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:1924, in canonicalize_shape(shape, context)
   1922 except TypeError:
   1923   pass
-> 1924 raise _invalid_shape_error(shape, context)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got (JitTracer<~int32[]>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function example_fun at /tmp/ipykernel_1553/1210496444.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument length.
# static_argnames tells JAX to recompile on changes at these argument positions:
good_example_jit = jit(example_fun, static_argnames='length')
# first compile
print(good_example_jit(10, 4))
# recompiles
print(good_example_jit(5, 4))
[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]
[4. 4. 4. 4. 4.]

static_argnames 会很方便,如果示例中的 length 很少改变,但如果它经常改变,那将是灾难性的!

最后,如果你的函数有全局副作用,JAX 的追踪器可能会导致奇怪的事情发生。一个常见的陷阱是尝试在 jit 编译的函数内部打印数组。

@jit
def f(x):
  print(x)
  y = 2 * x
  print(y)
  return y
f(2)
JitTracer<~int32[]>
JitTracer<~int32[]>
Array(4, dtype=int32, weak_type=True)

结构化控制流原语#

JAX 中有更多的控制流选项。假设你想避免重新编译,但仍希望使用可追踪且能避免展开大型循环的控制流。那么你可以使用以下 4 个结构化控制流原语:

  • lax.cond 可微分

  • lax.while_loop **前向模式可微分**

  • lax.fori_loop 通常**前向模式可微分**;如果端点是静态的,则**前向和反向模式均可微分**。

  • lax.scan 可微分

cond#

Python 等价物

def cond(pred, true_fun, false_fun, operand):
  if pred:
    return true_fun(operand)
  else:
    return false_fun(operand)
from jax import lax

operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, operand)
# --> array([1.], dtype=float32)
lax.cond(False, lambda x: x+1, lambda x: x-1, operand)
# --> array([-1.], dtype=float32)
Array([-1.], dtype=float32)

jax.lax 提供了另外两个允许基于动态谓词进行分支的函数:

  • lax.select 类似于 lax.cond 的批处理版本,其中选择项表示为预先计算的数组,而非函数。

  • lax.switch 类似于 lax.cond,但允许在任意数量的可调用选项之间切换。

此外,jax.numpy 提供了几个 NumPy 风格的接口来调用这些函数:

  • 带有三个参数的 jnp.wherelax.select 的 NumPy 风格封装。

  • jnp.piecewiselax.switch 的 NumPy 风格封装,但它根据布尔条件列表而不是单个标量索引进行切换。

  • jnp.select 的 API 类似于 jnp.piecewise,但其选择项是预先计算的数组,而非函数。它是通过多次调用 lax.select 实现的。

while_loop#

Python 等价物

def while_loop(cond_fun, body_fun, init_val):
  val = init_val
  while cond_fun(val):
    val = body_fun(val)
  return val
init_val = 0
cond_fun = lambda x: x < 10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)
# --> array(10, dtype=int32)
Array(10, dtype=int32, weak_type=True)

fori_loop#

Python 等价物

def fori_loop(start, stop, body_fun, init_val):
  val = init_val
  for i in range(start, stop):
    val = body_fun(i, val)
  return val
init_val = 0
start = 0
stop = 10
body_fun = lambda i,x: x+i
lax.fori_loop(start, stop, body_fun, init_val)
# --> array(45, dtype=int32)
Array(45, dtype=int32, weak_type=True)

概述#

\[\begin{split} \begin{array} {r|rr} \hline \ \textrm{构造} & \textrm{jit} & \textrm{grad} \\ \hline \ \textrm{if} & ❌ & ✔ \\ \textrm{for} & ✔* & ✔\\ \textrm{while} & ✔* & ✔\\ \textrm{lax.cond} & ✔ & ✔\\ \textrm{lax.while_loop} & ✔ & \textrm{前向}\\ \textrm{lax.fori_loop} & ✔ & \textrm{前向}\\ \textrm{lax.scan} & ✔ & ✔\\ \hline \end{array} \end{split}\]

\(\ast\) = 循环条件**不**依赖于参数**值**——展开循环

逻辑运算符#

jax.numpy 提供了 logical_andlogical_orlogical_not,它们对数组进行元素级操作,并且可以在 jit 下评估而无需重新编译。与 NumPy 中的对应项一样,这些二元运算符不会短路。位运算符(&|~)也可以与 jit 一起使用。

例如,考虑一个检查输入是否为正偶数的函数。当输入为标量时,纯 Python 版本和 JAX 版本给出相同的结果。

def python_check_positive_even(x):
  is_even = x % 2 == 0
  # `and` short-circults, so when `is_even` is `False`, `x > 0` is not evaluated.
  return is_even and (x > 0)

@jit
def jax_check_positive_even(x):
  is_even = x % 2 == 0
  # `logical_and` does not short circuit, so `x > 0` is always evaluated.
  return jnp.logical_and(is_even, x > 0)

print(python_check_positive_even(24))
print(jax_check_positive_even(24))
True
True

当带有 logical_and 的 JAX 版本应用于数组时,它返回元素级的值。

x = jnp.array([-1, 2, 5])
print(jax_check_positive_even(x))
[False  True False]

当 Python 逻辑运算符应用于包含多个元素的 JAX 数组时,即使没有 jit,也会报错。这复刻了 NumPy 的行为。

print(python_check_positive_even(x))
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[17], line 1
----> 1 print(python_check_positive_even(x))

Cell In[15], line 4, in python_check_positive_even(x)
      2 is_even = x % 2 == 0
      3 # `and` short-circults, so when `is_even` is `False`, `x > 0` is not evaluated.
----> 4 return is_even and (x > 0)

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/array.py:305, in ArrayImpl.__bool__(self)
    304 def __bool__(self):
--> 305   core.check_bool_conversion(self)
    306   return bool(self._value)

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:821, in check_bool_conversion(arr)
    818   raise ValueError("The truth value of an empty array is ambiguous. Use"
    819                    " `array.size > 0` to check that an array is not empty.")
    820 if arr.size > 1:
--> 821   raise ValueError("The truth value of an array with more than one element"
    822                    " is ambiguous. Use a.any() or a.all()")

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Python 控制流 + 自动微分#

请记住,上述关于控制流和逻辑运算符的约束仅在与 jit 结合使用时才相关。如果你只是想对你的 Python 函数应用 grad,而不使用 jit,那么你可以毫无问题地使用常规 Python 控制流结构,就像使用 Autograd(或 PyTorch 或 TF Eager)一样。

def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

print(grad(f)(2.))  # ok!
print(grad(f)(4.))  # ok!
12.0
-4.0