形状多态#

当 JAX 在 JIT 模式下使用时,函数将被跟踪、降低为 StableHLO,并针对每种输入类型和形状的组合进行编译。在导出函数并在另一个系统上反序列化它之后,我们不再有可用的 Python 源代码,因此我们无法重新跟踪和重新降低它。形状多态是 JAX 导出的一个特性,允许一些导出的函数用于整个系列的输入形状。这些函数在导出期间被跟踪和降低一次,并且 Exported 对象包含能够在许多具体输入形状上编译和执行函数所需的信息。我们通过在导出时指定包含维度变量(符号形状)的形状来实现这一点,如下例所示

>>> import jax
>>> from jax import export
>>> from jax import numpy as jnp
>>> def f(x):  # f: f32[a, b]
...   return jnp.concatenate([x, x], axis=1)

>>> # We construct symbolic dimension variables.
>>> a, b = export.symbolic_shape("a, b")

>>> # We can use the symbolic dimensions to construct shapes.
>>> x_shape = (a, b)
>>> x_shape
(a, b)

>>> # Then we export with symbolic shapes:
>>> exp: export.Exported = export.export(jax.jit(f))(
...     jax.ShapeDtypeStruct(x_shape, jnp.int32))
>>> exp.in_avals
(ShapedArray(int32[a,b]),)
>>> exp.out_avals
(ShapedArray(int32[a,2*b]),)

>>> # We can later call with concrete shapes (with a=3 and b=4), without re-tracing `f`.
>>> res = exp.call(np.ones((3, 4), dtype=np.int32))
>>> res.shape
(3, 8)

请注意,此类函数仍然会根据需要在每次调用的具体输入形状上重新编译。只有跟踪和降低被保存。

上面的示例中使用 jax.export.symbolic_shape() 将符号形状的字符串表示形式解析为维度表达式对象(类型为 _DimExpr),这些对象可以代替整数常量来构造形状。维度表达式对象重载了大多数整数运算符,因此在大多数情况下您可以像使用整数常量一样使用它们。有关更多详细信息,请参阅 使用维度变量进行计算

此外,我们提供了 jax.export.symbolic_args_specs(),可用于基于多态形状规范构造 jax.ShapeDtypeStruct 对象的 pytree

>>> def f1(x, y): # x: f32[a, 1], y : f32[a, 4]
...  return x + y

>>> # Assuming you have some actual args with concrete shapes
>>> x = np.ones((3, 1), dtype=np.int32)
>>> y = np.ones((3, 4), dtype=np.int32)
>>> args_specs = export.symbolic_args_specs((x, y), "a, ...")
>>> exp = export.export(jax.jit(f1))(* args_specs)
>>> exp.in_avals
(ShapedArray(int32[a,1]), ShapedArray(int32[a,4]))

请注意,多态形状规范 "a, ..." 如何包含占位符 ...,该占位符将从参数 (x, y) 的具体形状中填充。占位符 ... 代表 0 个或多个维度,而占位符 _ 代表一个维度。jax.export.symbolic_args_specs() 支持参数的 pytree,这些参数用于填充 dtypes 和任何占位符。该函数将构造一个参数规范 pytree (jax.ShapeDtypeStruct),该 pytree 与传递给它的参数结构相匹配。在规范应应用于多个参数的情况下,多态形状规范可以是 pytree 前缀,如上面的示例所示。请参阅 可选参数如何与 pytree 匹配

一些形状规范的示例

  • ("(b, _, _)", None) 可用于具有两个参数的函数,第一个参数是 3D 数组,其批处理前导维度应为符号。第一个参数的其他维度和第二个参数的形状基于实际参数进行专门化。请注意,如果第一个参数是 3D 数组的 pytree,所有数组都具有相同的前导维度,但可能具有不同的尾部维度,则相同的规范也适用。第二个参数的 None 值表示该参数不是符号的。等效地,可以使用 ...

  • ("(batch, ...)", "(batch,)") 指定两个参数具有匹配的前导维度,第一个参数的秩至少为 1,第二个参数的秩为 1。

形状多态的正确性#

我们希望相信,对于任何适用的具体形状,导出的程序产生的结果与原始 JAX 程序编译和执行的结果相同。更准确地说

对于任何 JAX 函数 f 和任何包含符号形状的参数规范 arg_spec,以及任何形状与 arg_spec 匹配的具体参数 arg

  • 如果 JAX 原生执行在具体参数上成功:res = f(arg)

  • 并且如果导出使用符号形状成功:exp = export.export(f)(arg_spec)

  • 那么编译和运行导出将成功并获得相同的结果:res == exp.call(arg)

至关重要的是要理解,f(arg) 可以自由地重新调用 JAX 跟踪机制,实际上,对于每个不同的具体 arg 形状,它都会这样做,而 exp.call(arg) 的执行不能再使用 JAX 跟踪(此执行可能发生在 f 的源代码不可用的环境中)。

确保这种形式的正确性是困难的,在最困难的情况下,导出会失败。本章的其余部分描述了如何处理这些失败。

使用维度变量进行计算#

JAX 跟踪所有中间结果的形状。当这些形状依赖于维度变量时,JAX 将它们计算为涉及维度变量的符号维度表达式。维度变量代表大于或等于 1 的整数值。符号表达式可以表示对维度表达式和整数应用算术运算符(加、减、乘、整除、模,包括 NumPy 变体 np.sumnp.prod 等)的结果intnp.int 或任何可通过 operator.index 转换的对象)。然后,这些符号维度可以在 JAX 原语和 API 的形状参数中使用,例如,在 jnp.reshapejnp.arange、切片索引等中。

例如,在以下代码中,为了展平一个 2D 数组,计算 x.shape[0] * x.shape[1] 将符号维度 4 * b 计算为新形状

>>> f = lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1],))
>>> arg_spec = jax.ShapeDtypeStruct(export.symbolic_shape("b, 4"), jnp.int32)
>>> exp = export.export(jax.jit(f))(arg_spec)
>>> exp.out_avals
(ShapedArray(int32[4*b]),)

可以使用 jnp.array(x.shape[0]) 甚至 jnp.array(x.shape) 将维度表达式显式转换为 JAX 数组。这些操作的结果可以用作常规 JAX 数组,但不能再用作形状中的维度,例如,在 reshape

>>> exp = export.export(jax.jit(lambda x: jnp.array(x.shape[0]) + x))(
...     jax.ShapeDtypeStruct(export.symbolic_shape("b"), np.int32))
>>> exp.call(jnp.arange(3, dtype=np.int32))
Array([3, 4, 5], dtype=int32)

>>> exp = export.export(jax.jit(lambda x: x.reshape(jnp.array(x.shape[0]) + 2)))(
...     jax.ShapeDtypeStruct(export.symbolic_shape("b"), np.int32))  
Traceback (most recent call last):
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>].

当符号维度用于与非整数(例如,floatnp.floatnp.ndarray 或 JAX 数组)进行算术运算时,它会自动使用 jnp.array 转换为 JAX 数组。例如,在下面的函数中,x.shape[0] 的所有出现都隐式转换为 jnp.array(x.shape[0]),因为它们涉及与非整数标量或 JAX 数组的操作

>>> exp = export.export(jax.jit(
...     lambda x: (5. + x.shape[0],
...                x.shape[0] - np.arange(5, dtype=jnp.int32),
...                x + x.shape[0] + jnp.sin(x.shape[0]))))(
...     jax.ShapeDtypeStruct(export.symbolic_shape("b"), jnp.int32))
>>> exp.out_avals
(ShapedArray(float32[], weak_type=True),
 ShapedArray(int32[5]),
 ShapedArray(float32[b], weak_type=True))

>>> exp.call(jnp.ones((3,), jnp.int32))
 (Array(8., dtype=float32, weak_type=True),
  Array([ 3, 2, 1, 0, -1], dtype=int32),
  Array([4.14112, 4.14112, 4.14112], dtype=float32, weak_type=True))

另一个典型的例子是计算平均值(观察 x.shape[0] 如何自动转换为 JAX 数组)

>>> exp = export.export(jax.jit(
...     lambda x: jnp.sum(x, axis=0) / x.shape[0]))(
...     jax.ShapeDtypeStruct(export.symbolic_shape("b, c"), jnp.int32))
>>> exp.call(jnp.arange(12, dtype=jnp.int32).reshape((3, 4)))
Array([4., 5., 6., 7.], dtype=float32)

存在形状多态时的错误#

大多数 JAX 代码都假定 JAX 数组的形状是整数元组,但是对于形状多态,某些维度可能是符号表达式。这可能会导致许多错误。例如,我们可能会遇到常见的 JAX 形状检查错误

>>> v, = export.symbolic_shape("v,")
>>> export.export(jax.jit(lambda x, y: x + y))(
...     jax.ShapeDtypeStruct((v,), dtype=np.int32),
...     jax.ShapeDtypeStruct((4,), dtype=np.int32))
Traceback (most recent call last):
TypeError: add got incompatible shapes for broadcasting: (v,), (4,).

>>> export.export(jax.jit(lambda x: jnp.matmul(x, x)))(
...     jax.ShapeDtypeStruct((v, 4), dtype=np.int32))
Traceback (most recent call last):
TypeError: dot_general requires contracting dimensions to have the same shape, got (4,) and (v,).

我们可以通过指定参数具有形状 (v, v) 来修复上面的 matmul 示例。

部分支持符号维度的比较#

在 JAX 内部,存在许多涉及形状的相等和不等比较,例如,用于进行形状检查,甚至用于为某些原语选择实现。比较的支持方式如下

  • 相等性比较受到部分支持,但需要注意:如果两个符号维度在维度变量的所有估值下都表示相同的值,则相等性评估为 True,例如,对于 b + b == 2*b;否则,相等性评估为 False。有关此行为的重要后果的讨论,请参见下面的 注意事项

  • 不等性始终是相等性的否定。

  • 不等性比较受到部分支持,其方式与部分相等性类似。但是,在这种情况下,我们考虑到维度变量的范围是严格正整数。例如,b >= 1b >= 02 * a + b >= 3True,而 b >= 2a >= ba - b >= 0 则不确定,并导致异常。

在比较操作无法解析为布尔值的情况下,我们会引发 InconclusiveDimensionOperation。例如,

import jax
>>> export.export(jax.jit(lambda x: 0 if x.shape[0] + 1 >= x.shape[1] else 1))(
...     jax.ShapeDtypeStruct(export.symbolic_shape("a, b"), dtype=np.int32))  # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
jax._src.export.shape_poly.InconclusiveDimensionOperation: Symbolic dimension comparison 'a + 1' >= 'b' is inconclusive.
This error arises for comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a boolean value for all values of the symbolic dimensions involved.

如果您确实遇到了 InconclusiveDimensionOperation,则可以尝试以下几种策略

  • 如果您的代码使用内置的 maxmin,或者 np.maxnp.min,则可以将它们替换为 core.max_dimcore.min_dim,这具有将不等比较延迟到编译时的效果,此时形状将变得已知。

  • 尝试使用 core.max_dimcore.min_dim 重写条件语句,例如,使用 core.max_dim(d, 0) 而不是 d if d > 0 else 0

  • 尝试重写代码,使其较少依赖于维度应为整数的事实,并依赖于符号维度在大多数算术运算中都类似于整数的事实。例如,使用 d + 5 而不是 int(d) + 5

  • 指定符号约束,如下所述。

用户指定的符号约束#

默认情况下,JAX 假定所有维度变量的范围都大于或等于 1 的值,并且它尝试从中导出其他简单的不等式,例如

  • a + 2 >= 3,

  • a * 2 >= 1,

  • a + b + c >= 3,

  • a // 4 >= 0, a**2 >= 1, 等等。

如果您更改符号形状规范以添加维度大小的隐式约束,则可以避免某些不等比较失败。例如,

  • 您可以使用 2*b 作为维度来约束它为偶数且大于或等于 2。

  • 您可以使用 b + 15 作为维度来约束它至少为 16。例如,以下代码在没有 + 15 部分的情况下会失败,因为 JAX 将要验证切片大小是否最多与轴大小一样大。

>>> _ = export.export(jax.jit(lambda x: x[0:16]))(
...    jax.ShapeDtypeStruct(export.symbolic_shape("b + 15"), dtype=np.int32))

此类隐式符号约束用于确定比较,并在编译时进行检查,如下面的 说明

您还可以指定显式符号约束

>>> # Introduce dimension variable with constraints.
>>> a, b = export.symbolic_shape("a, b",
...                              constraints=("a >= b", "b >= 16"))
>>> _ = export.export(jax.jit(lambda x: x[:x.shape[1], :16]))(
...    jax.ShapeDtypeStruct((a, b), dtype=np.int32))

约束与隐式约束一起形成合取。您可以指定 >=<=== 约束。目前,JAX 对使用符号约束进行推理的支持有限

  • 您可以从变量大于或等于或小于或等于常量的形式的约束中获得最多的好处。例如,从约束 a >= 16b >= 8 中,我们可以推断出 a + 2*b >= 32

  • 当约束涉及更复杂的表达式时,您获得的能力会受到限制。例如,从 a >= b + 8 我们可以推断出 a - b >= 8,但不能推断出 a >= 9。我们将来可能会在这一领域有所改进。

  • 等式约束被视为重写规则:每当遇到 == 左侧的符号表达式时,它都会被重写为右侧的表达式。例如,floordiv(a, b) == c 的工作原理是将所有出现的 floordiv(a, b) 替换为 c。等式约束的左侧顶层不得包含加法或减法。有效的左侧示例包括 a * b,或 4 * a,或 floordiv(a + c, b)

>>> # Introduce dimension variable with equality constraints.
>>> a, b, c, d = export.symbolic_shape("a, b, c, d",
...                                    constraints=("a * b == c + d",))
>>> 2 * b * a
2*d + 2*c

>>> a * b * b
b*d + b*c

符号约束还可以帮助解决 JAX 推理机制中的限制。例如,在下面的代码中,JAX 将尝试证明切片大小 x.shape[0] % 3(即符号表达式 mod(b, 3))小于或等于轴大小(即 b)。对于 b 的所有严格正值,这恰好为真,但 JAX 的符号比较规则无法证明这一点。因此,以下代码会引发错误

from jax import lax
>>> b, = export.symbolic_shape("b")
>>> f = lambda x: lax.slice_in_dim(x, 0, x.shape[0] % 3)
>>> export.export(jax.jit(f))(
...     jax.ShapeDtypeStruct((b,), dtype=np.int32))  # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
jax._src.export.shape_poly.InconclusiveDimensionOperation: Symbolic dimension comparison 'b' >= 'mod(b, 3)' is inconclusive.
This error arises for comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a boolean value for all values of the symbolic dimensions involved.

这里的一个选择是将代码限制为仅在 3 的倍数的轴大小上工作(通过在形状中将 b 替换为 3*b)。然后,JAX 将能够将模运算 mod(3*b, 3) 简化为 0。另一种选择是添加一个符号约束,其中包含 JAX 尝试证明的确切的不确定不等式

>>> b, = export.symbolic_shape("b",
...                            constraints=["b >= mod(b, 3)"])
>>> f = lambda x: lax.slice_in_dim(x, 0, x.shape[0] % 3)
>>> _ = export.export(jax.jit(f))(
...     jax.ShapeDtypeStruct((b,), dtype=np.int32))

与隐式约束一样,显式符号约束在编译时进行检查,使用的机制与下面解释的相同。

符号维度作用域#

符号约束存储在 αn jax.export.SymbolicScope 对象中,该对象为每次调用 jax.export.symbolic_shapes() 隐式创建。您必须小心不要混合使用使用不同作用域的符号表达式。例如,以下代码将失败,因为 a1a2 使用不同的作用域(由 jax.export.symbolic_shape() 的不同调用创建)

>>> a1, = export.symbolic_shape("a,")
>>> a2, = export.symbolic_shape("a,", constraints=("a >= 8",))

>>> a1 + a2  
Traceback (most recent call last):
ValueError: Invalid mixing of symbolic scopes for linear combination.
Expected  scope 4776451856 created at <doctest shape_poly.md[31]>:1:6 (<module>)
and found for 'a' (unknown) scope 4776979920 created at <doctest shape_poly.md[32]>:1:6 (<module>) with constraints:
  a >= 8

源自对 jax.export.symbolic_shape() 的单次调用的符号表达式共享一个作用域,并且可以混合在算术运算中。结果也将共享相同的作用域。

您可以重用作用域

>>> a, = export.symbolic_shape("a,", constraints=("a >= 8",))
>>> b, = export.symbolic_shape("b,", scope=a.scope)  # Reuse the scope of `a`

>>> a + b  # Allowed
b + a

您还可以显式创建作用域

>>> my_scope = export.SymbolicScope()
>>> c, = export.symbolic_shape("c", scope=my_scope)
>>> d, = export.symbolic_shape("d", scope=my_scope)
>>> c + d  # Allowed
d + c

JAX 跟踪使用部分由形状键控的缓存,如果符号形状使用不同的作用域,即使它们打印出来相同,也会被认为是不同的。

等式比较的注意事项#

对于 b + 1 == bb == 0(在这种情况下,可以确定对于维度变量的所有值,维度都是不同的),以及对于 b == 1a == b,等式比较返回 False。这是不健全的,我们应该引发 core.InconclusiveDimensionOperation,因为在某些估值下,结果应该是 True,而在其他估值下,结果应该是 False。我们选择使等式完全成立,从而允许不健全,因为否则,当哈希维度表达式或包含它们的对象(形状、core.AbstractValuecore.Jaxpr)时,我们可能会在哈希冲突的情况下得到虚假错误。除了哈希错误之外,等式的部分语义还会导致以下表达式出现错误 b == a or b == bb in [a, b],即使如果我们更改比较顺序,也可以避免错误。

即使使用这种等式处理方式,if x.shape[0] != 1: raise NiceErrorMessage 形式的代码也是健全的,但 if x.shape[0] != 1: return 1 形式的代码是不健全的。

维度变量必须可以从输入形状中求解#

目前,当调用导出的对象时,传递维度变量值的唯一方法是间接地通过数组参数的形状。例如,b 的值可以在调用站点从类型为 f32[b] 的第一个参数的形状推断出来。这在大多数用例中都有效,并且反映了 JIT 函数的调用约定。

有时,您可能希望导出一个由整数值参数化的函数,该整数值确定程序中的某些形状。例如,我们可能希望导出下面定义的函数 my_top_k,该函数由 k 的值参数化,该值确定结果的形状。以下尝试将导致错误,因为维度变量 k 无法从输入 x: i32[4, 10] 的形状派生出来

>>> def my_top_k(k, x):  # x: i32[4, 10], k <= 10
...   return lax.top_k(x, k)[0]  # : i32[4, 3]
>>> x = np.arange(40, dtype=np.int32).reshape((4, 10))

>>> # Export with static `k=3`. Since `k` appears in shapes it must be in `static_argnums`.
>>> exp_static_k = export.export(jax.jit(my_top_k, static_argnums=0))(3, x)
>>> exp_static_k.in_avals[0]
ShapedArray(int32[4,10])

>>> exp_static_k.out_avals[0]
ShapedArray(int32[4,3])

>>> # When calling the exported function we pass only the non-static arguments
>>> exp_static_k.call(x)
Array([[ 9,  8,  7],
       [19, 18, 17],
       [29, 28, 27],
       [39, 38, 37]], dtype=int32)

>>> # Now attempt to export with symbolic `k` so that we choose `k` after export.
>>> k, = export.symbolic_shape("k", constraints=["k <= 10"])
>>> export.export(jax.jit(my_top_k, static_argnums=0))(k, x)  
Traceback (most recent call last):
UnexpectedDimVar: "Encountered dimension variable 'k' that is not appearing in the shapes of the function arguments

将来,除了通过输入形状隐式传递维度变量的值之外,我们可能会添加额外的机制来传递维度变量的值。同时,上述用例的解决方法是将函数参数 k 替换为形状为 (0, k) 的数组,以便可以从数组的输入形状派生出 k。第一个维度为 0,以确保整个数组为空,并且在调用导出的函数时不会产生性能损失。

>>> def my_top_k_with_dimensions(dimensions, x):  # dimensions: i32[0, k], x: i32[4, 10]
...   return my_top_k(dimensions.shape[1], x)
>>> exp = export.export(jax.jit(my_top_k_with_dimensions))(
...     jax.ShapeDtypeStruct((0, k), dtype=np.int32),
...     x)
>>> exp.in_avals
(ShapedArray(int32[0,k]), ShapedArray(int32[4,10]))

>>> exp.out_avals[0]
ShapedArray(int32[4,k])

>>> # When we invoke `exp` we must construct and pass an array of shape (0, k)
>>> exp.call(np.zeros((0, 3), dtype=np.int32), x)
Array([[ 9,  8,  7],
       [19, 18, 17],
       [29, 28, 27],
       [39, 38, 37]], dtype=int32)

当某些维度变量确实出现在输入形状中,但在 JAX 当前无法求解的非线性表达式中时,您可能会遇到另一种错误情况

>>> a, = export.symbolic_shape("a")
>>> export.export(jax.jit(lambda x: x.shape[0]))(
...    jax.ShapeDtypeStruct((a * a,), dtype=np.int32))  
Traceback (most recent call last):
ValueError: Cannot solve for values of dimension variables {'a'}.
We can only solve linear uni-variate constraints.
Using the following polymorphic shapes specifications: args[0].shape = (a^2,).
Unprocessed specifications: 'a^2' for dimension size args[0].shape[0].

形状断言错误#

JAX 假设维度变量的范围为严格正整数,并且在为具体输入形状编译代码时会检查此假设。

例如,给定符号输入形状 (b, b, 2*d),当使用实际参数 arg 调用时,JAX 将生成代码来检查以下断言

  • arg.shape[0] >= 1

  • arg.shape[1] == arg.shape[0]

  • arg.shape[2] % 2 == 0

  • arg.shape[2] // 2 >= 1

例如,这是当我们在形状为 (3, 3, 5) 的参数上调用导出时得到的错误

>>> def f(x):  # x: f32[b, b, 2*d]
...   return x
>>> exp = export.export(jax.jit(f))(
...     jax.ShapeDtypeStruct(export.symbolic_shape("b, b, 2*d"), dtype=np.int32))   
>>> exp.call(np.ones((3, 3, 5), dtype=np.int32))  
Traceback (most recent call last):
ValueError: Input shapes do not match the polymorphic shapes specification.
Division had remainder 1 when computing the value of 'd'.
Using the following polymorphic shapes specifications:
  args[0].shape = (b, b, 2*d).
Obtained dimension variables: 'b' = 3 from specification 'b' for dimension args[0].shape[0] (= 3), .
Please see https://jax.net.cn/en/latest/export/shape_poly.html#shape-assertion-errors for more details.

这些错误发生在编译之前的预处理步骤中。

调试#

首先,请参阅 调试 文档。此外,您可以调试形状细化,它在编译时为具有维度变量或多平台支持的模块调用。

如果在形状细化期间出现错误,您可以设置 JAX_DUMP_IR_TO 环境变量以查看形状细化之前 HLO 模块的转储(名为 ..._before_refine_polymorphic_shapes.mlir)。此模块应已具有静态输入形状。

要启用所有形状细化阶段的日志记录,您可以在 OSS 中设置环境变量 TF_CPP_VMODULE=refine_polymorphic_shapes=3(在 Google 内部,您传递 --vmodule=refine_polymorphic_shapes=3

# Log from python
JAX_DUMP_IR_TO=/tmp/export.dumps/ TF_CPP_VMODULE=refine_polymorphic_shapes=3 python tests/shape_poly_test.py ShapePolyTest.test_simple_unary -v=3