Tracing#

jax.jit 和其他 JAX 变换的工作原理是跟踪一个函数,以确定其对特定形状和类型的输入的影响。为了了解跟踪过程,让我们在 JIT 编译的函数中放入一些 print() 语句,然后调用该函数。

from jax import jit
import jax.numpy as jnp
import numpy as np

@jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  result = jnp.dot(x + 1, y + 1)
  print(f"  result = {result}")
  return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)
Running f():
  x = JitTracer<float32[3,4]>
  y = JitTracer<float32[4]>
  result = JitTracer<float32[3]>
Array([4.389271 , 4.1128426, 5.4673476], dtype=float32)

请注意,打印语句会执行,但它们打印的不是传递给函数的实际数据,而是代表这些数据的跟踪器对象(类似于 Traced<ShapedArray(float32[])>)。

这些跟踪器对象是 jax.jit 用来提取函数指定的计算序列的。基本的跟踪器是代表数组的形状数据类型的占位符,但对值不敏感。这个记录的计算序列随后可以在 XLA 中高效地应用于具有相同形状和数据类型的新输入,而无需重新执行 Python 代码。

当我们在匹配的输入上再次调用编译后的函数时,不需要重新编译,也没有任何内容被打印出来,因为结果是在编译后的 XLA 中计算的,而不是在 Python 中。

x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
f(x2, y2)
Array([1.2962357, 8.36254  , 8.948654 ], dtype=float32)

提取的操作序列被编码在一个 JAX 表达式中,或者简称为 jaxpr。您可以使用 jax.make_jaxpr 变换来查看 jaxpr。

from jax import make_jaxpr

def f(x, y):
  return jnp.dot(x + 1, y + 1)

make_jaxpr(f)(x, y)
{ lambda ; a:f32[3,4] b:f32[4]. let
    c:f32[3,4] = add a 1.0:f32[]
    d:f32[4] = add b 1.0:f32[]
    e:f32[3] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] c d
  in (e,) }

注意由此产生的一个后果:因为 JIT 编译是在没有数组内容信息的情况下进行的,所以函数中的控制流语句不能依赖于跟踪值(请参阅 JIT 中的控制流和逻辑运算符)。例如,这会失败。

@jit
def f(x, neg):
  return -x if neg else x

f(1, True)
---------------------------------------------------------------------------
TracerBoolConversionError                 Traceback (most recent call last)
Cell In[4], line 5
      1 @jit
      2 def f(x, neg):
      3   return -x if neg else x
----> 5 f(1, True)

    [... skipping hidden 13 frame]

Cell In[4], line 3, in f(x, neg)
      1 @jit
      2 def f(x, neg):
----> 3   return -x if neg else x

    [... skipping hidden 1 frame]

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:1805, in concretization_function_error.<locals>.error(self, arg)
   1804 def error(self, arg):
-> 1805   raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /tmp/ipykernel_3399/2422663986.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError

如果您有一些不想被跟踪的变量,可以将它们标记为 JIT 编译的静态参数。

from functools import partial

@partial(jit, static_argnums=(1,))
def f(x, neg):
  return -x if neg else x

f(1, True)
Array(-1, dtype=int32, weak_type=True)

请注意,使用不同的静态参数调用 JIT 编译的函数会导致重新编译,因此函数仍然按预期工作。

f(1, False)
Array(1, dtype=int32, weak_type=True)

静态 vs 跟踪操作#

正如值可以是静态的或跟踪的,操作也可以是静态的或跟踪的。静态操作在 Python 中于编译时进行评估;跟踪操作在 XLA 中于运行时进行编译和评估。

静态值和跟踪值之间的这种区别,使得思考如何保持静态值不变变得很重要。考虑以下函数。

import jax.numpy as jnp
from jax import jit

@jit
def f(x):
  return x.reshape(jnp.array(x.shape).prod())

x = jnp.ones((2, 3))
f(x)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[7], line 9
      6   return x.reshape(jnp.array(x.shape).prod())
      8 x = jnp.ones((2, 3))
----> 9 f(x)

    [... skipping hidden 13 frame]

Cell In[7], line 6, in f(x)
      4 @jit
      5 def f(x):
----> 6   return x.reshape(jnp.array(x.shape).prod())

    [... skipping hidden 2 frame]

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:457, in _compute_newshape(arr, newshape)
    455 except:
    456   newshape = [newshape]
--> 457 newshape = core.canonicalize_shape(newshape)  # type: ignore[arg-type]
    458 neg1s = [i for i, d in enumerate(newshape) if type(d) is int and d == -1]
    459 if len(neg1s) > 1:

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:2017, in canonicalize_shape(shape, context)
   2015 except TypeError:
   2016   pass
-> 2017 raise _invalid_shape_error(shape, context)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got [JitTracer<int32[]>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function f at /tmp/ipykernel_3399/1983583872.py:4 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i32[] = reduce_prod[axes=(0,)] b
    from line /tmp/ipykernel_3399/1983583872.py:6:19 (f)

这会产生一个错误,指出找到了一个跟踪器而不是整数类型的具体值的一维序列。让我们在函数中添加一些打印语句来理解为什么会发生这种情况。

@jit
def f(x):
  print(f"x = {x}")
  print(f"x.shape = {x.shape}")
  print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}")
  # comment this out to avoid the error:
  # return x.reshape(jnp.array(x.shape).prod())

f(x)
x = JitTracer<float32[2,3]>
x.shape = (2, 3)
jnp.array(x.shape).prod() = JitTracer<int32[]>

请注意,虽然 x 被跟踪了,但 x.shape 是一个静态值。然而,当我们在此静态值上使用 jnp.arrayjnp.prod 时,它变成了一个跟踪值,此时它不能用于像 reshape() 这样的函数,该函数需要一个静态输入(回想一下:数组形状必须是静态的)。

一个有用的模式是使用 numpy 来执行应该静态的(即在编译时完成)操作,并使用 jax.numpy 来执行应该被跟踪的(即在运行时编译和执行)操作。对于这个函数,它可能看起来像这样。

from jax import jit
import jax.numpy as jnp
import numpy as np

@jit
def f(x):
  return x.reshape((np.prod(x.shape),))

f(x)
Array([1., 1., 1., 1., 1., 1.], dtype=float32)

因此,JAX 程序中的一个标准约定是 import numpy as npimport jax.numpy as jnp,以便同时提供这两个接口,以便更精细地控制操作是静态执行(使用 numpy,一次在编译时)还是跟踪执行(使用 jax.numpy,在运行时优化)。

理解哪些值和操作将是静态的,哪些将被跟踪,是有效使用 jax.jit 的关键。

不同种类的 JAX 值#

跟踪器值携带一个抽象值,例如,带有数组形状和数据类型信息的 ShapedArray。这里我们将这类跟踪器称为抽象跟踪器。有些跟踪器,例如为自动微分变换的参数引入的跟踪器,携带包含实际常规数组数据的 ConcreteArray 抽象值,并用于解析条件等。这里我们将这类跟踪器称为具体跟踪器。由这些具体跟踪器计算出的跟踪器值,也许与常规值组合,会产生具体跟踪器。具体值是常规值或具体跟踪器。

通常,涉及至少一个跟踪器值的计算会产生一个跟踪器值。很少有例外,当一个计算可以完全使用跟踪器携带的抽象值完成时,在这种情况下,结果可以是一个常规 Python 值。例如,获取带有 ShapedArray 抽象值的跟踪器的形状。另一个例子是显式地将具体跟踪器值转换为常规类型,例如 int(x)x.astype(float)。另一种情况是 bool(x),当具体化成为可能时,它会产生一个 Python bool。这种情况尤其突出,因为它在控制流中经常出现。

以下是变换引入抽象或具体跟踪器的方式。

  • jax.jit():为所有位置参数引入抽象跟踪器,除了由 static_argnums 指定的参数,它们保持为常规值。

  • jax.pmap():为所有位置参数引入抽象跟踪器,除了由 static_broadcasted_argnums 指定的参数。

  • jax.vmap()jax.make_jaxpr()xla_computation():为所有位置参数引入抽象跟踪器

  • jax.jvp()jax.grad():为所有位置参数引入具体跟踪器。一个例外是当这些变换位于外部变换内,并且实际参数本身是抽象跟踪器时;在这种情况下,由自动微分变换引入的跟踪器也是抽象跟踪器。

  • 所有高阶控制流原语(lax.cond()lax.while_loop()lax.fori_loop()lax.scan())在处理函数式参数时会引入抽象跟踪器,无论是否有 JAX 变换在进行中。

所有这些对于那些只能在常规 Python 值上操作的代码都很重要,例如基于数据进行条件控制流的代码。

def divide(x, y):
  return x / y if y >= 1. else 0.

如果我们想应用 jax.jit(),我们必须确保指定 static_argnums=1 以确保 y 保持为常规值。这是因为布尔表达式 y >= 1. 需要具体值(常规值或跟踪器)。如果我们显式地写 bool(y >= 1.),或者 int(y),或者 float(y),也会发生同样的情况。

有趣的是,jax.grad(divide)(3., 2.) 可以工作,因为 jax.grad() 使用具体跟踪器,并使用 y 的具体值来解析条件。