JIT 下的控制流和逻辑运算符#

在即时执行(在 `jit` 之外)时,JAX 代码就像 Numpy 代码一样处理 Python 的控制流和逻辑运算符。但在 `jit` 中使用控制流和逻辑运算符会更复杂一些。

简而言之,Python 的控制流和逻辑运算符在 JIT 编译时进行求值,编译后的函数代表了通过控制流图的单一路径(逻辑运算符通过短路影响路径)。如果路径依赖于输入的计算值,该函数(默认情况下)将无法被 JIT 编译。路径可能依赖于输入的形状或数据类型,并且每当函数使用具有新形状或数据类型的输入进行调用时,该函数都会被重新编译。

from jax import grad, jit
import jax.numpy as jnp

例如,这样可以正常工作

@jit
def f(x):
  for i in range(3):
    x = 2 * x
  return x

print(f(3))
24

这样也可以

@jit
def g(x):
  y = 0.
  for i in range(x.shape[0]):
    y = y + x[i]
  return y

print(g(jnp.array([1., 2., 3.])))
6.0

但这不行,至少默认情况下不行

@jit
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

# This will fail!
f(2)
---------------------------------------------------------------------------
TracerBoolConversionError                 Traceback (most recent call last)
Cell In[4], line 9
      6     return -4 * x
      8 # This will fail!
----> 9 f(2)

    [... skipping hidden 13 frame]

Cell In[4], line 3, in f(x)
      1 @jit
      2 def f(x):
----> 3   if x < 3:
      4     return 3. * x ** 2
      5   else:

    [... skipping hidden 1 frame]

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:1805, in concretization_function_error.<locals>.error(self, arg)
   1804 def error(self, arg):
-> 1805   raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /tmp/ipykernel_1560/3402096563.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError

这个也不行

@jit
def g(x):
  return (x > 0) and (x < 3)

# This will fail!
g(2)
---------------------------------------------------------------------------
TracerBoolConversionError                 Traceback (most recent call last)
Cell In[5], line 6
      3   return (x > 0) and (x < 3)
      5 # This will fail!
----> 6 g(2)

    [... skipping hidden 13 frame]

Cell In[5], line 3, in g(x)
      1 @jit
      2 def g(x):
----> 3   return (x > 0) and (x < 3)

    [... skipping hidden 1 frame]

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:1805, in concretization_function_error.<locals>.error(self, arg)
   1804 def error(self, arg):
-> 1805   raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function g at /tmp/ipykernel_1560/543860509.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError

这是怎么回事!?

当我们对函数进行 `jit` 编译时,我们通常希望编译一个能处理许多不同参数值的函数版本,以便我们可以缓存并重用编译后的代码。这样,我们就不需要在每次函数求值时都重新编译。

例如,如果我们使用数组 `jnp.array([1., 2., 3.], jnp.float32)` 来求值一个 `@jit` 函数,我们可能希望编译的代码能够重用于求值 `jnp.array([4., 5., 6.], jnp.float32)`,以节省编译时间。

为了获得一个可以处理许多不同参数值的 Python 代码视图,JAX 使用 `ShapedArray` 抽象将输入进行追踪,其中每个抽象值代表具有固定形状和数据类型的数组值集合。例如,如果我们使用抽象值 `ShapedArray((3,), jnp.float32)` 进行追踪,我们就可以得到一个可以重用于相应数组集合中任何具体值的函数视图。这意味着我们可以节省编译时间。

但这里有一个权衡:如果我们使用一个非具体化值的抽象值 `ShapedArray((), jnp.float32)` 来追踪一个 Python 函数,当遇到像 `if x < 3` 这样的行时,表达式 `x < 3` 将会计算出一个抽象值 `ShapedArray((), jnp.bool_)`,它代表了 `{True, False}` 集合。当 Python 尝试将其强制转换为具体的 `True` 或 `False` 时,我们会收到一个错误:我们不知道应该选择哪个分支,也无法继续追踪!权衡是,随着抽象级别的提高,我们对 Python 代码的通用视图会增强(从而节省重新编译),但我们也需要对 Python 代码施加更多约束才能完成追踪。

好消息是,您可以自己控制这种权衡。通过让 `jit` 在更精炼的抽象值上进行追踪,您可以放宽可追踪性的约束。例如,使用 `jit` 的 `static_argnames`(或 `static_argnums`)参数,我们可以指定某些参数的追踪应基于具体值。这里再次给出那个示例函数

def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

f = jit(f, static_argnames='x')

print(f(2.))
12.0

再举一个例子,这次涉及一个循环

def f(x, n):
  y = 0.
  for i in range(n):
    y = y + x[i]
  return y

f = jit(f, static_argnames='n')

f(jnp.array([2., 3., 4.]), 2)
Array(5., dtype=float32)

实际上,这个循环被静态展开了。JAX 也可以在 *更高* 的抽象级别进行追踪,例如 `Unshaped`,但这目前并不是任何转换的默认选项。

⚠️ 参数值依赖形状的函数

这些控制流问题也会以更微妙的方式出现:我们想要 `jit` 的数值函数无法根据参数 *值* 来专门化内部数组的形状(根据参数 *形状* 来专门化是可以的)。作为最简单的例子,我们创建一个输出恰好依赖于输入变量 `length` 的函数。

def example_fun(length, val):
  return jnp.ones((length,)) * val
# un-jit'd works fine
print(example_fun(5, 4))
[4. 4. 4. 4. 4.]
bad_example_jit = jit(example_fun)
# this will fail:
bad_example_jit(10, 4)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[9], line 3
      1 bad_example_jit = jit(example_fun)
      2 # this will fail:
----> 3 bad_example_jit(10, 4)

    [... skipping hidden 13 frame]

Cell In[8], line 2, in example_fun(length, val)
      1 def example_fun(length, val):
----> 2   return jnp.ones((length,)) * val

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/numpy/array_creation.py:139, in ones(shape, dtype, device, out_sharding)
    137   raise TypeError("expected sequence object with len >= 0 or a single integer")
    138 if (m := _check_forgot_shape_tuple("ones", shape, dtype)): raise TypeError(m)
--> 139 shape = canonicalize_shape(shape)
    140 dtype = dtypes.check_and_canonicalize_user_dtype(
    141     float if dtype is None else dtype, "ones")
    142 sharding = util.choose_device_or_out_sharding(
    143     device, out_sharding, 'jnp.ones')

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/numpy/array_creation.py:46, in canonicalize_shape(shape, context)
     44   return core.canonicalize_shape((shape,), context)
     45 else:
---> 46   return core.canonicalize_shape(shape, context)

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:2017, in canonicalize_shape(shape, context)
   2015 except TypeError:
   2016   pass
-> 2017 raise _invalid_shape_error(shape, context)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got (JitTracer<~int32[]>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function example_fun at /tmp/ipykernel_1560/1210496444.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument length.
# static_argnames tells JAX to recompile on changes at these argument positions:
good_example_jit = jit(example_fun, static_argnames='length')
# first compile
print(good_example_jit(10, 4))
# recompiles
print(good_example_jit(5, 4))
[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]
[4. 4. 4. 4. 4.]

如果我们的示例中的 `length` 很少改变,那么 `static_argnames` 会很有用,但如果它经常改变,那将是灾难性的!

最后,如果您的函数有全局的副作用,JAX 的追踪器可能会导致奇怪的行为。一个常见的陷阱是在 `jit` 编译的函数内部尝试打印数组。

@jit
def f(x):
  print(x)
  y = 2 * x
  print(y)
  return y
f(2)
JitTracer<~int32[]>
JitTracer<~int32[]>
Array(4, dtype=int32, weak_type=True)

结构化控制流原语#

JAX 提供了更多控制流的选项。假设您想避免重新编译,但仍然想使用可追踪的控制流,并且避免展开大型循环。那么您可以使用这 4 个结构化控制流原语。

  • `lax.cond` *可微分*

  • `lax.while_loop` **前向模式可微分**

  • `lax.fori_loop` 通常是**前向模式可微分**;如果端点是静态的,则是**前向和后向模式都可微分**。

  • `lax.scan` *可微分*

cond#

python 等价实现

def cond(pred, true_fun, false_fun, operand):
  if pred:
    return true_fun(operand)
  else:
    return false_fun(operand)
from jax import lax

operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, operand)
# --> array([1.], dtype=float32)
lax.cond(False, lambda x: x+1, lambda x: x-1, operand)
# --> array([-1.], dtype=float32)
Array([-1.], dtype=float32)

`jax.lax` 提供了另外两个允许基于动态谓词进行分支的函数。

  • `lax.select` 类似于 `lax.cond` 的批量版本,其选择是预先计算的数组而不是函数。

  • `lax.switch` 类似于 `lax.cond`,但允许在任意数量的可调用选项之间切换。

此外,`jax.numpy` 提供了这些函数的几种 numpy 风格的接口。

  • `jnp.where`(带三个参数)是 `lax.select` 的 numpy 风格封装。

  • `jnp.piecewise` 是 `lax.switch` 的 numpy 风格封装,但它根据一组布尔条件进行切换,而不是单个标量索引。

  • `jnp.select` 具有与 `jnp.piecewise` 类似的 API,但选项是预先计算的数组而不是函数。它通过多次调用 `lax.select` 来实现。

while_loop#

python 等价实现

def while_loop(cond_fun, body_fun, init_val):
  val = init_val
  while cond_fun(val):
    val = body_fun(val)
  return val
init_val = 0
cond_fun = lambda x: x < 10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)
# --> array(10, dtype=int32)
Array(10, dtype=int32, weak_type=True)

fori_loop#

python 等价实现

def fori_loop(start, stop, body_fun, init_val):
  val = init_val
  for i in range(start, stop):
    val = body_fun(i, val)
  return val
init_val = 0
start = 0
stop = 10
body_fun = lambda i,x: x+i
lax.fori_loop(start, stop, body_fun, init_val)
# --> array(45, dtype=int32)
Array(45, dtype=int32, weak_type=True)

概述#

\[\begin{split} \begin{array} {r|rr} \hline \ \textrm{构造} & \textrm{jit} & \textrm{grad} \\ \hline \ \textrm{if} & ❌ & ✔ \\ \textrm{for} & ✔* & ✔\\ \textrm{while} & ✔* & ✔\\ \textrm{lax.cond} & ✔ & ✔\\ \textrm{lax.while_loop} & ✔ & \textrm{fwd}\\ \textrm{lax.fori_loop} & ✔ & \textrm{fwd}\\ \textrm{lax.scan} & ✔ & ✔\\ \hline \end{array} \end{split}\]

\(\ast\) = 循环条件与参数**值**无关 - 循环展开

逻辑运算符#

`jax.numpy` 提供了 `logical_and`、`logical_or` 和 `logical_not`,它们可以对数组进行元素级操作,并在 `jit` 下进行求值而无需重新编译。与其 Numpy 对等项类似,二元运算符不会短路。按位运算符(`&`、`|`、`~`)也可以与 `jit` 一起使用。

例如,考虑一个检查其输入是否为正偶整数的函数。纯 Python 版本和 JAX 版本在输入是标量时给出相同的答案。

def python_check_positive_even(x):
  is_even = x % 2 == 0
  # `and` short-circults, so when `is_even` is `False`, `x > 0` is not evaluated.
  return is_even and (x > 0)

@jit
def jax_check_positive_even(x):
  is_even = x % 2 == 0
  # `logical_and` does not short circuit, so `x > 0` is always evaluated.
  return jnp.logical_and(is_even, x > 0)

print(python_check_positive_even(24))
print(jax_check_positive_even(24))
True
True

当 `logical_and` 的 JAX 版本应用于数组时,它会返回逐元素的(elementwise)值。

x = jnp.array([-1, 2, 5])
print(jax_check_positive_even(x))
[False  True False]

Python 的逻辑运算符在应用于包含多个元素的 JAX 数组时会出错,即使没有 `jit`。这复制了 NumPy 的行为。

print(python_check_positive_even(x))
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[17], line 1
----> 1 print(python_check_positive_even(x))

Cell In[15], line 4, in python_check_positive_even(x)
      2 is_even = x % 2 == 0
      3 # `and` short-circults, so when `is_even` is `False`, `x > 0` is not evaluated.
----> 4 return is_even and (x > 0)

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/array.py:297, in ArrayImpl.__bool__(self)
    296 def __bool__(self):
--> 297   core.check_bool_conversion(self)
    298   return bool(self._value)

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:882, in check_bool_conversion(arr)
    879   raise ValueError("The truth value of an empty array is ambiguous. Use"
    880                    " `array.size > 0` to check that an array is not empty.")
    881 if arr.size > 1:
--> 882   raise ValueError("The truth value of an array with more than one element"
    883                    " is ambiguous. Use a.any() or a.all()")

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Python 控制流 + 自动微分#

请记住,上面关于控制流和逻辑运算符的限制仅在 `jit` 上相关。如果您只想将 `grad` 应用于您的 Python 函数,而不需要 `jit`,那么您可以毫无问题地使用常规的 Python 控制流构造,就像使用Autograd(或 Pytorch 或 TF Eager)一样。

def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

print(grad(f)(2.))  # ok!
print(grad(f)(4.))  # ok!
12.0
-4.0