使用 JIT 的控制流和逻辑运算符#

在急切执行(jit 之外)时,JAX 代码像 Numpy 代码一样使用 Python 控制流和逻辑运算符。将控制流和逻辑运算符与 jit 一起使用会更加复杂。

简而言之,Python 控制流和逻辑运算符在 JIT 编译时进行评估,这样编译后的函数表示通过控制流图的单一路径(逻辑运算符通过短路影响路径)。如果路径取决于输入的值,则该函数(默认情况下)无法进行 JIT 编译。路径可能取决于输入的形状或 dtype,并且每次在具有新形状或 dtype 的输入上调用该函数时,都会重新编译该函数。

from jax import grad, jit
import jax.numpy as jnp

例如,这行得通

@jit
def f(x):
  for i in range(3):
    x = 2 * x
  return x

print(f(3))
24

这个也行得通

@jit
def g(x):
  y = 0.
  for i in range(x.shape[0]):
    y = y + x[i]
  return y

print(g(jnp.array([1., 2., 3.])))
6.0

但这个不行,至少默认情况下不行

@jit
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

# This will fail!
f(2)
---------------------------------------------------------------------------
TracerBoolConversionError                 Traceback (most recent call last)
Cell In[4], line 9
      6     return -4 * x
      8 # This will fail!
----> 9 f(2)

    [... skipping hidden 13 frame]

Cell In[4], line 3, in f(x)
      1 @jit
      2 def f(x):
----> 3   if x < 3:
      4     return 3. * x ** 2
      5   else:

    [... skipping hidden 1 frame]

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py:1476, in concretization_function_error.<locals>.error(self, arg)
   1475 def error(self, arg):
-> 1476   raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /tmp/ipykernel_821/3402096563.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

这个也不行

@jit
def g(x):
  return (x > 0) and (x < 3)

# This will fail!
g(2)
---------------------------------------------------------------------------
TracerBoolConversionError                 Traceback (most recent call last)
Cell In[5], line 6
      3   return (x > 0) and (x < 3)
      5 # This will fail!
----> 6 g(2)

    [... skipping hidden 13 frame]

Cell In[5], line 3, in g(x)
      1 @jit
      2 def g(x):
----> 3   return (x > 0) and (x < 3)

    [... skipping hidden 1 frame]

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py:1476, in concretization_function_error.<locals>.error(self, arg)
   1475 def error(self, arg):
-> 1476   raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function g at /tmp/ipykernel_821/543860509.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

这是怎么回事!?

当我们使用 jit 编译一个函数时,我们通常希望编译一个适用于许多不同参数值的函数版本,以便我们可以缓存和重用编译后的代码。这样,我们就不必在每次函数评估时都重新编译。

例如,如果我们对数组 jnp.array([1., 2., 3.], jnp.float32) 评估一个 @jit 函数,我们可能希望编译可以重用的代码,以便在 jnp.array([4., 5., 6.], jnp.float32) 上评估该函数,从而节省编译时间。

为了获得适用于许多不同参数值的 Python 代码视图,JAX 使用 ShapedArray 抽象作为输入对其进行追踪,其中每个抽象值表示具有固定形状和 dtype 的所有数组值的集合。例如,如果我们使用抽象值 ShapedArray((3,), jnp.float32) 进行追踪,我们会得到一个可以重用于相应数组集合中任何具体值的函数视图。这意味着我们可以节省编译时间。

但是这里有一个权衡:如果我们使用未提交给特定具体值的 ShapedArray((), jnp.float32) 追踪一个 Python 函数,当遇到类似 if x < 3 的行时,表达式 x < 3 将评估为一个抽象的 ShapedArray((), jnp.bool_),它表示集合 {True, False}。当 Python 尝试将其强制转换为具体的 TrueFalse 时,我们会得到一个错误:我们不知道该走哪个分支,并且无法继续追踪!权衡之处在于,使用更高层次的抽象,我们可以获得更通用的 Python 代码视图(从而节省重新编译的时间),但我们需要对 Python 代码施加更多约束才能完成追踪。

好消息是,您可以自己控制这种权衡。通过让 jit 在更精细的抽象值上进行追踪,您可以放宽可追踪性约束。例如,使用 jitstatic_argnames (或 static_argnums) 参数,我们可以指定在某些参数的具体值上进行追踪。这是另一个示例函数

def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

f = jit(f, static_argnames='x')

print(f(2.))
12.0

这是另一个示例,这次涉及一个循环

def f(x, n):
  y = 0.
  for i in range(n):
    y = y + x[i]
  return y

f = jit(f, static_argnames='n')

f(jnp.array([2., 3., 4.]), 2)
Array(5., dtype=float32)

实际上,循环会被静态展开。JAX 还可以在更高的抽象级别上进行追踪,例如 Unshaped,但这目前不是任何转换的默认设置

⚠️ 具有参数值依赖形状的函数

这些控制流问题也以更微妙的方式出现:我们想要 jit 的数值函数不能根据参数来专门化内部数组的形状(根据参数形状进行专门化是可以的)。作为一个简单的例子,让我们创建一个输出恰好取决于输入变量 length 的函数。

def example_fun(length, val):
  return jnp.ones((length,)) * val
# un-jit'd works fine
print(example_fun(5, 4))
[4. 4. 4. 4. 4.]
bad_example_jit = jit(example_fun)
# this will fail:
bad_example_jit(10, 4)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[9], line 3
      1 bad_example_jit = jit(example_fun)
      2 # this will fail:
----> 3 bad_example_jit(10, 4)

    [... skipping hidden 13 frame]

Cell In[8], line 2, in example_fun(length, val)
      1 def example_fun(length, val):
----> 2   return jnp.ones((length,)) * val

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:6182, in ones(shape, dtype, device)
   6180   raise TypeError("expected sequence object with len >= 0 or a single integer")
   6181 if (m := _check_forgot_shape_tuple("ones", shape, dtype)): raise TypeError(m)
-> 6182 shape = canonicalize_shape(shape)
   6183 dtypes.check_user_dtype_supported(dtype, "ones")
   6184 return lax.full(shape, 1, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device))

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:101, in canonicalize_shape(shape, context)
     99   return core.canonicalize_shape((shape,), context)
    100 else:
--> 101   return core.canonicalize_shape(shape, context)

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py:1621, in canonicalize_shape(shape, context)
   1619 except TypeError:
   1620   pass
-> 1621 raise _invalid_shape_error(shape, context)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function example_fun at /tmp/ipykernel_821/1210496444.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument length.
# static_argnames tells JAX to recompile on changes at these argument positions:
good_example_jit = jit(example_fun, static_argnames='length')
# first compile
print(good_example_jit(10, 4))
# recompiles
print(good_example_jit(5, 4))
[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]
[4. 4. 4. 4. 4.]

如果示例中的 length 很少更改,那么 static_argnames 会很方便,但如果它经常更改,那将是灾难性的!

最后,如果您的函数具有全局副作用,则 JAX 的追踪器可能会导致奇怪的事情发生。一个常见的陷阱是尝试在 jit 函数内部打印数组

@jit
def f(x):
  print(x)
  y = 2 * x
  print(y)
  return y
f(2)
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace>
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace>
Array(4, dtype=int32, weak_type=True)

结构化控制流原语#

JAX 中有更多控制流选项。假设您想避免重新编译,但仍然想使用可追踪的控制流,并避免展开大型循环。那么您可以使用这 4 个结构化控制流原语

  • lax.cond 可微分

  • lax.while_loop 前向模式可微分

  • lax.fori_loop 一般为前向模式可微分;如果端点是静态的,则为前向和反向模式可微分

  • lax.scan 可微分

cond#

python 等效项

def cond(pred, true_fun, false_fun, operand):
  if pred:
    return true_fun(operand)
  else:
    return false_fun(operand)
from jax import lax

operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, operand)
# --> array([1.], dtype=float32)
lax.cond(False, lambda x: x+1, lambda x: x-1, operand)
# --> array([-1.], dtype=float32)
Array([-1.], dtype=float32)

jax.lax 提供了另外两个允许对动态谓词进行分支的函数

  • lax.select 类似于 lax.cond 的批量版本,其选择表示为预先计算的数组,而不是函数。

  • lax.switch 类似于 lax.cond,但允许在任意数量的可调用选项之间切换。

此外,jax.numpy 提供了这些函数的几个 numpy 风格的接口

  • jnp.where 带有三个参数是 lax.select 的 numpy 风格包装器。

  • jnp.piecewiselax.switch 的 numpy 风格包装器,但基于布尔条件列表而不是单个标量索引进行切换。

  • jnp.select 的 API 与 jnp.piecewise 类似,但选择是作为预先计算的数组而不是函数给出的。它是通过多次调用 lax.select 来实现的。

while_loop#

python 等效项

def while_loop(cond_fun, body_fun, init_val):
  val = init_val
  while cond_fun(val):
    val = body_fun(val)
  return val
init_val = 0
cond_fun = lambda x: x < 10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)
# --> array(10, dtype=int32)
Array(10, dtype=int32, weak_type=True)

fori_loop#

python 等效项

def fori_loop(start, stop, body_fun, init_val):
  val = init_val
  for i in range(start, stop):
    val = body_fun(i, val)
  return val
init_val = 0
start = 0
stop = 10
body_fun = lambda i,x: x+i
lax.fori_loop(start, stop, body_fun, init_val)
# --> array(45, dtype=int32)
Array(45, dtype=int32, weak_type=True)

总结#

\[\begin{split} \begin{array} {r|rr} \hline \ \textrm{构造} & \textrm{jit} & \textrm{grad} \\ \hline \ \textrm{if} & ❌ & ✔ \\ \textrm{for} & ✔* & ✔\\ \textrm{while} & ✔* & ✔\\ \textrm{lax.cond} & ✔ & ✔\\ \textrm{lax.while_loop} & ✔ & \textrm{fwd}\\ \textrm{lax.fori_loop} & ✔ & \textrm{fwd}\\ \textrm{lax.scan} & ✔ & ✔\\ \hline \end{array} \end{split}\]

\(\ast\) = 与参数无关的循环条件 - 展开循环

逻辑运算符#

jax.numpy 提供了 logical_andlogical_orlogical_not,它们在数组上逐元素运算,并且可以在 jit 下进行评估,而无需重新编译。与它们的 Numpy 对等项一样,二元运算符不会短路。按位运算符(&|~)也可以与 jit 一起使用。

例如,考虑一个检查其输入是否为正偶整数的函数。当输入为标量时,纯 Python 版本和 JAX 版本给出相同的答案。

def python_check_positive_even(x):
  is_even = x % 2 == 0
  # `and` short-circults, so when `is_even` is `False`, `x > 0` is not evaluated.
  return is_even and (x > 0)

@jit
def jax_check_positive_even(x):
  is_even = x % 2 == 0
  # `logical_and` does not short circuit, so `x > 0` is always evaluated.
  return jnp.logical_and(is_even, x > 0)

print(python_check_positive_even(24))
print(jax_check_positive_even(24))
True
True

当将带有 logical_and 的 JAX 版本应用于数组时,它将返回逐元素的值。

x = jnp.array([-1, 2, 5])
print(jax_check_positive_even(x))
[False  True False]

即使没有 jit,当应用于多个元素的 JAX 数组时,Python 逻辑运算符也会出错。这复制了 NumPy 的行为。

print(python_check_positive_even(x))
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[17], line 1
----> 1 print(python_check_positive_even(x))

Cell In[15], line 4, in python_check_positive_even(x)
      2 is_even = x % 2 == 0
      3 # `and` short-circults, so when `is_even` is `False`, `x > 0` is not evaluated.
----> 4 return is_even and (x > 0)

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/array.py:292, in ArrayImpl.__bool__(self)
    291 def __bool__(self):
--> 292   core.check_bool_conversion(self)
    293   return bool(self._value)

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py:655, in check_bool_conversion(arr)
    652   raise ValueError("The truth value of an empty array is ambiguous. Use"
    653                    " `array.size > 0` to check that an array is not empty.")
    654 if arr.size > 1:
--> 655   raise ValueError("The truth value of an array with more than one element"
    656                    " is ambiguous. Use a.any() or a.all()")

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Python 控制流 + 自动微分#

请记住,以上对控制流和逻辑运算符的约束仅在 jit 的情况下才相关。如果您只想将 grad 应用于您的 python 函数,而没有 jit,则可以毫无问题地使用常规的 Python 控制流结构,就像您在使用 Autograd (或 Pytorch 或 TF Eager) 一样。

def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

print(grad(f)(2.))  # ok!
print(grad(f)(4.))  # ok!
12.0
-4.0