形状多态性#
当 JAX 在 JIT 模式下使用时,函数会针对每种输入类型和形状组合进行追踪、降级到 StableHLO,并编译。在导出函数并将其反序列化到另一个系统后,我们不再拥有 Python 源文件,因此无法重新追踪和重新降级。形状多态性是 JAX 导出的一项功能,允许某些导出的函数适用于一整族输入形状。这些函数在导出期间只追踪和降级一次,并且 Exported 对象包含编译和执行许多具体输入形状函数所需的信息。我们通过在导出时指定包含维度变量(符号形状)的形状来实现这一点,示例如下:
>>> import jax
>>> from jax import export
>>> from jax import numpy as jnp
>>> def f(x): # f: f32[a, b]
... return jnp.concatenate([x, x], axis=1)
>>> # We construct symbolic dimension variables.
>>> a, b = export.symbolic_shape("a, b")
>>> # We can use the symbolic dimensions to construct shapes.
>>> x_shape = (a, b)
>>> x_shape
(a, b)
>>> # Then we export with symbolic shapes:
>>> exp: export.Exported = export.export(jax.jit(f))(
... jax.ShapeDtypeStruct(x_shape, jnp.int32))
>>> exp.in_avals
(ShapedArray(int32[a,b]),)
>>> exp.out_avals
(ShapedArray(int32[a,2*b]),)
>>> # We can later call with concrete shapes (with a=3 and b=4), without re-tracing `f`.
>>> res = exp.call(np.ones((3, 4), dtype=np.int32))
>>> res.shape
(3, 8)
请注意,此类函数在每次被调用时,仍会根据具体的输入形状进行即时重新编译。只有追踪和降级过程被保存。
上述示例中使用了 jax.export.symbolic_shape() 将符号形状的字符串表示解析为维度表达式对象(类型为 _DimExpr),这些对象可用于代替整数常量来构造形状。维度表达式对象重载了大多数整数运算符,因此在大多数情况下,您可以像使用整数常量一样使用它们。有关更多详细信息,请参阅使用维度变量进行计算。
此外,我们还提供了 jax.export.symbolic_args_specs(),可用于根据多态形状规范构造 jax.ShapeDtypeStruct 对象的 pytree:
>>> def f1(x, y): # x: f32[a, 1], y : f32[a, 4]
... return x + y
>>> # Assuming you have some actual args with concrete shapes
>>> x = np.ones((3, 1), dtype=np.int32)
>>> y = np.ones((3, 4), dtype=np.int32)
>>> args_specs = export.symbolic_args_specs((x, y), "a, ...")
>>> exp = export.export(jax.jit(f1))(* args_specs)
>>> exp.in_avals
(ShapedArray(int32[a,1]), ShapedArray(int32[a,4]))
请注意多态形状规范 "a, ..." 如何包含占位符 ...,该占位符将由参数 (x, y) 的具体形状填充。占位符 ... 代表 0 个或更多维度,而占位符 _ 代表一个维度。jax.export.symbolic_args_specs() 支持参数的 pytree,它们用于填充 dtype 和任何占位符。该函数将构造一个与传递给它的参数结构匹配的参数规范 (jax.ShapeDtypeStruct) pytree。在一种规范适用于多个参数的情况下,多态形状规范可以是 pytree 前缀,如上述示例所示。请参阅 可选参数如何与参数匹配。
一些形状规范的例子:
("(b, _, _)", None)可用于具有两个参数的函数,其中第一个参数是具有批处理前导维度的 3D 数组,该维度应为符号。第一个参数的其他维度和第二个参数的形状根据实际参数进行特殊化。请注意,如果第一个参数是 3D 数组的 pytree,所有数组都具有相同的前导维度但可能具有不同的尾随维度,则相同的规范也将有效。第二个参数的值None表示该参数不是符号。同样,也可以使用...。("(batch, ...)", "(batch,)")指定两个参数具有匹配的前导维度,第一个参数的秩至少为 1,第二个参数的秩为 1。
形状多态性的正确性#
我们希望相信,导出的程序在编译并针对任何适用的具体形状执行时,会产生与原始 JAX 程序相同的结果。更准确地说:
对于任何 JAX 函数 f 和任何包含符号形状的参数规范 arg_spec,以及形状与 arg_spec 匹配的任何具体参数 arg:
如果 JAX 本机执行在具体参数上成功:
res = f(arg),并且如果使用符号形状导出成功:
exp = export.export(f)(arg_spec),则编译并运行导出的结果将成功并得到相同的结果:
res == exp.call(arg)
理解这一点至关重要:f(arg) 可以自由地重新调用 JAX 追踪机制,实际上对于每个不同的具体 arg 形状它都这样做,而 exp.call(arg) 的执行不能再使用 JAX 追踪(此执行可能发生在 f 的源代码不可用的环境中)。
确保这种形式的正确性很困难,在最困难的情况下导出会失败。本章的其余部分描述了如何处理这些失败。
使用维度变量进行计算#
JAX 会跟踪所有中间结果的形状。当这些形状依赖于维度变量时,JAX 会将它们计算为涉及维度变量的符号维度表达式。维度变量代表大于或等于 1 的整数值。符号表达式可以表示算术运算符(加、减、乘、整除、取模,包括 NumPy 变体 np.sum、np.prod 等)应用于维度表达式和整数(int、np.int 或任何可通过 operator.index 转换的对象)的结果。这些符号维度随后可以在 JAX 原语和 API 的形状参数中使用,例如,在 jnp.reshape、jnp.arange、切片索引等。
例如,在以下用于扁平化 2D 数组的代码中,计算 x.shape[0] * x.shape[1] 会将符号维度 4 * b 计算为新形状:
>>> f = lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1],))
>>> arg_spec = jax.ShapeDtypeStruct(export.symbolic_shape("b, 4"), jnp.int32)
>>> exp = export.export(jax.jit(f))(arg_spec)
>>> exp.out_avals
(ShapedArray(int32[4*b]),)
可以将维度表达式显式转换为 JAX 数组,使用 jnp.array(x.shape[0]) 甚至 jnp.array(x.shape)。这些操作的结果可以作为常规 JAX 数组使用,但不能再作为形状中的维度,例如在 reshape 中:
>>> exp = export.export(jax.jit(lambda x: jnp.array(x.shape[0]) + x))(
... jax.ShapeDtypeStruct(export.symbolic_shape("b"), np.int32))
>>> exp.call(jnp.arange(3, dtype=np.int32))
Array([3, 4, 5], dtype=int32)
>>> exp = export.export(jax.jit(lambda x: x.reshape(jnp.array(x.shape[0]) + 2)))(
... jax.ShapeDtypeStruct(export.symbolic_shape("b"), np.int32))
Traceback (most recent call last):
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>].
当符号维度与非整数(例如 float、np.float、np.ndarray 或 JAX 数组)进行算术运算时,它会自动使用 jnp.array 转换为 JAX 数组。例如,在下面的函数中,所有 x.shape[0] 的出现都会隐式转换为 jnp.array(x.shape[0]),因为它们涉及与非整数标量或 JAX 数组的操作:
>>> exp = export.export(jax.jit(
... lambda x: (5. + x.shape[0],
... x.shape[0] - np.arange(5, dtype=jnp.int32),
... x + x.shape[0] + jnp.sin(x.shape[0]))))(
... jax.ShapeDtypeStruct(export.symbolic_shape("b"), jnp.int32))
>>> exp.out_avals
(ShapedArray(float32[], weak_type=True),
ShapedArray(int32[5]),
ShapedArray(float32[b], weak_type=True))
>>> exp.call(jnp.ones((3,), jnp.int32))
(Array(8., dtype=float32, weak_type=True),
Array([ 3, 2, 1, 0, -1], dtype=int32),
Array([4.14112, 4.14112, 4.14112], dtype=float32, weak_type=True))
另一个典型例子是在计算平均值时(请注意 x.shape[0] 如何自动变为 JAX 数组):
>>> exp = export.export(jax.jit(
... lambda x: jnp.sum(x, axis=0) / x.shape[0]))(
... jax.ShapeDtypeStruct(export.symbolic_shape("b, c"), jnp.int32))
>>> exp.call(jnp.arange(12, dtype=jnp.int32).reshape((3, 4)))
Array([4., 5., 6., 7.], dtype=float32)
存在形状多态性时的错误#
大多数 JAX 代码假定 JAX 数组的形状是整数元组,但通过形状多态性,某些维度可能是符号表达式。这可能导致许多错误。例如,我们可能会遇到常见的 JAX 形状检查错误:
>>> v, = export.symbolic_shape("v,")
>>> export.export(jax.jit(lambda x, y: x + y))(
... jax.ShapeDtypeStruct((v,), dtype=np.int32),
... jax.ShapeDtypeStruct((4,), dtype=np.int32))
Traceback (most recent call last):
TypeError: add got incompatible shapes for broadcasting: (v,), (4,).
>>> export.export(jax.jit(lambda x: jnp.matmul(x, x)))(
... jax.ShapeDtypeStruct((v, 4), dtype=np.int32))
Traceback (most recent call last):
TypeError: dot_general requires contracting dimensions to have the same shape, got (4,) and (v,).
我们可以通过指定参数的形状为 (v, v) 来修复上述矩阵乘法示例。
符号维度的比较得到部分支持#
在 JAX 内部,存在许多涉及形状的等式和不等式比较,例如用于形状检查,甚至用于选择某些原语的实现。比较得到如下支持:
相等性支持存在一个注意事项:如果两个符号维度在维度变量的所有求值下表示相同的值,则相等性求值为
True,例如b + b == 2*b;否则,相等性求值为False。有关此行为的重要后果的讨论,请参阅下文。不等式总是相等性的否定。
不等式得到部分支持,类似于部分相等性。但是,在这种情况下,我们考虑维度变量的范围是严格正整数。例如,
b >= 1、b >= 0、2 * a + b >= 3为True,而b >= 2、a >= b、a - b >= 0则不确定并导致异常。
在比较操作无法解析为布尔值的情况下,我们会引发 InconclusiveDimensionOperation。例如:
import jax
>>> export.export(jax.jit(lambda x: 0 if x.shape[0] + 1 >= x.shape[1] else 1))(
... jax.ShapeDtypeStruct(export.symbolic_shape("a, b"), dtype=np.int32)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
jax._src.export.shape_poly.InconclusiveDimensionOperation: Symbolic dimension comparison 'a + 1' >= 'b' is inconclusive.
This error arises for comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a boolean value for all values of the symbolic dimensions involved.
如果您确实遇到 InconclusiveDimensionOperation,可以尝试以下几种策略:
如果您的代码使用内置的
max或min,或者np.max或np.min,那么您可以将其替换为core.max_dim和core.min_dim,这会延迟不等式比较到编译时,届时形状将已知。尝试使用
core.max_dim和core.min_dim重写条件语句,例如,代替d if d > 0 else 0,您可以编写core.max_dim(d, 0)。尝试重写代码,使其更少依赖于维度必须是整数的事实,并依赖于符号维度在大多数算术运算中具有整数的鸭子类型特性。例如,代替
int(d) + 5,编写d + 5。指定符号约束,如下所述。
用户指定的符号约束#
默认情况下,JAX 假设所有维度变量的取值范围都大于或等于 1,并尝试从中推导出其他简单的不等式,例如:
a + 2 >= 3,a * 2 >= 1,a + b + c >= 3,a // 4 >= 0、a**2 >= 1等等。
如果更改符号形状规范以添加维度大小的隐式约束,则可以避免一些不等式比较失败。例如:
您可以对维度使用
2*b以将其约束为偶数且大于或等于 2。您可以将
b + 15用于维度,将其约束为至少 16。例如,以下代码如果没有+ 15部分,则会失败,因为 JAX 会希望验证切片大小最多与轴大小相同。
>>> _ = export.export(jax.jit(lambda x: x[0:16]))(
... jax.ShapeDtypeStruct(export.symbolic_shape("b + 15"), dtype=np.int32))
这些隐式符号约束用于决定比较,并在编译时进行检查,如下文所述。
您还可以指定显式符号约束:
>>> # Introduce dimension variable with constraints.
>>> a, b = export.symbolic_shape("a, b",
... constraints=("a >= b", "b >= 16"))
>>> _ = export.export(jax.jit(lambda x: x[:x.shape[1], :16]))(
... jax.ShapeDtypeStruct((a, b), dtype=np.int32))
这些约束与隐式约束共同构成一个合取。您可以指定 >=、<= 和 == 约束。目前,JAX 对符号约束的推理支持有限:
当约束形式为变量大于或等于或小于或等于常量时,您将获得最大的收益。例如,从
a >= 16和b >= 8的约束中,我们可以推断出a + 2*b >= 32。当约束涉及更复杂的表达式时,您将获得有限的功能,例如,从
a >= b + 8中,我们可以推断出a - b >= 8,但不能推断出a >= 9。我们将来可能会在一定程度上改进这个领域。等式约束被视为重写规则:只要遇到
==左侧的符号表达式,它就会被重写为右侧的表达式。例如,floordiv(a, b) == c的工作原理是将所有出现的floordiv(a, b)替换为c。等式约束的左侧顶层不能包含加法或减法。有效左侧的示例包括a * b,或4 * a,或floordiv(a + c, b)。
>>> # Introduce dimension variable with equality constraints.
>>> a, b, c, d = export.symbolic_shape("a, b, c, d",
... constraints=("a * b == c + d",))
>>> 2 * b * a
2*d + 2*c
>>> a * b * b
b*d + b*c
符号约束也有助于解决 JAX 推理机制的局限性。例如,在下面的代码中,JAX 将尝试证明切片大小 x.shape[0] % 3(即符号表达式 mod(b, 3))小于或等于轴大小 b。对于所有严格正的 b 值,这都是成立的,但这并不是 JAX 的符号比较规则能够证明的。因此,以下代码会引发错误:
from jax import lax
>>> b, = export.symbolic_shape("b")
>>> f = lambda x: lax.slice_in_dim(x, 0, x.shape[0] % 3)
>>> export.export(jax.jit(f))(
... jax.ShapeDtypeStruct((b,), dtype=np.int32)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
jax._src.export.shape_poly.InconclusiveDimensionOperation: Symbolic dimension comparison 'b' >= 'mod(b, 3)' is inconclusive.
This error arises for comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a boolean value for all values of the symbolic dimensions involved.
这里一个选项是限制代码只适用于 3 的倍数的轴大小(通过在形状中用 3*b 替换 b)。然后,JAX 将能够将模运算 mod(3*b, 3) 简化为 0。另一个选项是添加一个符号约束,其值与 JAX 试图证明的模糊不等式完全相同:
>>> b, = export.symbolic_shape("b",
... constraints=["b >= mod(b, 3)"])
>>> f = lambda x: lax.slice_in_dim(x, 0, x.shape[0] % 3)
>>> _ = export.export(jax.jit(f))(
... jax.ShapeDtypeStruct((b,), dtype=np.int32))
与隐式约束一样,显式符号约束在编译时进行检查,使用与下文所述相同的机制。
符号维度范围#
符号约束存储在 αn jax.export.SymbolicScope 对象中,该对象在每次调用 jax.export.symbolic_shapes() 时隐式创建。您必须小心,不要混合使用不同范围的符号表达式。例如,以下代码将失败,因为 a1 和 a2 使用不同的范围(由不同的 jax.export.symbolic_shape() 调用创建):
>>> a1, = export.symbolic_shape("a,")
>>> a2, = export.symbolic_shape("a,", constraints=("a >= 8",))
>>> a1 + a2
Traceback (most recent call last):
ValueError: Invalid mixing of symbolic scopes for linear combination.
Expected scope 4776451856 created at <doctest shape_poly.md[31]>:1:6 (<module>)
and found for 'a' (unknown) scope 4776979920 created at <doctest shape_poly.md[32]>:1:6 (<module>) with constraints:
a >= 8
源自单次调用 jax.export.symbolic_shape() 的符号表达式共享一个作用域,并且可以在算术运算中混合使用。结果也将共享相同的范围。
您可以重用范围:
>>> a, = export.symbolic_shape("a,", constraints=("a >= 8",))
>>> b, = export.symbolic_shape("b,", scope=a.scope) # Reuse the scope of `a`
>>> a + b # Allowed
b + a
您也可以显式创建范围:
>>> my_scope = export.SymbolicScope()
>>> c, = export.symbolic_shape("c", scope=my_scope)
>>> d, = export.symbolic_shape("d", scope=my_scope)
>>> c + d # Allowed
d + c
JAX 追踪使用部分以形状为键的缓存,并且如果符号形状使用不同的范围,即使它们打印出来看起来相同,也会被认为是不同的。
等式比较的注意事项#
对于 b + 1 == b 或 b == 0(在这种情况下,可以确定维度在维度变量的所有值下都不同),以及 b == 1 和 a == b,等式比较都返回 False。这是不健全的,我们应该引发 core.InconclusiveDimensionOperation,因为在某些求值下结果应该是 True,而在其他求值下结果应该是 False。我们选择使等式完全化,从而允许不健全,因为否则在哈希维度表达式或包含它们的对象的哈希冲突时(形状、core.AbstractValue、core.Jaxpr)可能会出现虚假错误。除了哈希错误之外,部分等式语义会导致以下表达式的错误:b == a or b == b 或 b in [a, b],尽管如果我们改变比较顺序就可以避免错误。
形式为 if x.shape[0] != 1: raise NiceErrorMessage 的代码即使在处理相等性时也是健全的,但形式为 if x.shape[0] != 1: return 1 的代码则是不健全的。
维度变量必须可从输入形状中求解#
目前,在调用导出对象时传递维度变量值的唯一方法是间接通过数组参数的形状。例如,b 的值可以在调用点从类型为 f32[b] 的第一个参数的形状中推断出来。这对于大多数用例都有效,并且它反映了 JIT 函数的调用约定。
有时您可能希望导出一个由整数值参数化的函数,该整数值决定程序中的某些形状。例如,我们可能希望导出下面定义的函数 my_top_k,它由 k 的值参数化,该值决定结果的形状。以下尝试将导致错误,因为维度变量 k 无法从输入 x: i32[4, 10] 的形状中推导出来:
>>> def my_top_k(k, x): # x: i32[4, 10], k <= 10
... return lax.top_k(x, k)[0] # : i32[4, 3]
>>> x = np.arange(40, dtype=np.int32).reshape((4, 10))
>>> # Export with static `k=3`. Since `k` appears in shapes it must be in `static_argnums`.
>>> exp_static_k = export.export(jax.jit(my_top_k, static_argnums=0))(3, x)
>>> exp_static_k.in_avals[0]
ShapedArray(int32[4,10])
>>> exp_static_k.out_avals[0]
ShapedArray(int32[4,3])
>>> # When calling the exported function we pass only the non-static arguments
>>> exp_static_k.call(x)
Array([[ 9, 8, 7],
[19, 18, 17],
[29, 28, 27],
[39, 38, 37]], dtype=int32)
>>> # Now attempt to export with symbolic `k` so that we choose `k` after export.
>>> k, = export.symbolic_shape("k", constraints=["k <= 10"])
>>> export.export(jax.jit(my_top_k, static_argnums=0))(k, x)
Traceback (most recent call last):
UnexpectedDimVar: "Encountered dimension variable 'k' that is not appearing in the shapes of the function arguments
将来,除了通过输入形状隐式传递之外,我们可能会添加一种额外的机制来传递维度变量的值。与此同时,上述用例的解决方案是将函数参数 k 替换为形状为 (0, k) 的数组,这样 k 就可以从数组的输入形状中推导出来。第一个维度为 0,以确保整个数组为空,并且在调用导出函数时没有性能损失。
>>> def my_top_k_with_dimensions(dimensions, x): # dimensions: i32[0, k], x: i32[4, 10]
... return my_top_k(dimensions.shape[1], x)
>>> exp = export.export(jax.jit(my_top_k_with_dimensions))(
... jax.ShapeDtypeStruct((0, k), dtype=np.int32),
... x)
>>> exp.in_avals
(ShapedArray(int32[0,k]), ShapedArray(int32[4,10]))
>>> exp.out_avals[0]
ShapedArray(int32[4,k])
>>> # When we invoke `exp` we must construct and pass an array of shape (0, k)
>>> exp.call(np.zeros((0, 3), dtype=np.int32), x)
Array([[ 9, 8, 7],
[19, 18, 17],
[29, 28, 27],
[39, 38, 37]], dtype=int32)
另一种可能出现错误的情况是,某些维度变量确实出现在输入形状中,但以 JAX 目前无法解决的非线性表达式形式出现:
>>> a, = export.symbolic_shape("a")
>>> export.export(jax.jit(lambda x: x.shape[0]))(
... jax.ShapeDtypeStruct((a * a,), dtype=np.int32))
Traceback (most recent call last):
ValueError: Cannot solve for values of dimension variables {'a'}.
We can only solve linear uni-variate constraints.
Using the following polymorphic shapes specifications: args[0].shape = (a^2,).
Unprocessed specifications: 'a^2' for dimension size args[0].shape[0].
形状断言错误#
JAX 假定维度变量的取值范围为严格正整数,并且当代码针对具体输入形状进行编译时,会检查此假定。
例如,给定符号输入形状 (b, b, 2*d),当使用实际参数 arg 调用时,JAX 将生成代码以检查以下断言:
arg.shape[0] >= 1arg.shape[1] == arg.shape[0]arg.shape[2] % 2 == 0arg.shape[2] // 2 >= 1
例如,当我们对形状为 (3, 3, 5) 的参数调用导出时,我们会得到以下错误:
>>> def f(x): # x: f32[b, b, 2*d]
... return x
>>> exp = export.export(jax.jit(f))(
... jax.ShapeDtypeStruct(export.symbolic_shape("b, b, 2*d"), dtype=np.int32))
>>> exp.call(np.ones((3, 3, 5), dtype=np.int32))
Traceback (most recent call last):
ValueError: Input shapes do not match the polymorphic shapes specification.
Division had remainder 1 when computing the value of 'd'.
Using the following polymorphic shapes specifications:
args[0].shape = (b, b, 2*d).
Obtained dimension variables: 'b' = 3 from specification 'b' for dimension args[0].shape[0] (= 3), .
Please see https://jax.net.cn/en/latest/export/shape_poly.html#shape-assertion-errors for more details.
这些错误发生在编译之前的预处理步骤中。
调试#
首先,请参阅调试文档。此外,您可以调试形状细化,它会在编译时为具有维度变量或多平台支持的模块调用。
如果在形状细化过程中出现错误,您可以设置 JAX_DUMP_IR_TO 环境变量以查看形状细化之前的 HLO 模块转储(命名为 ..._before_refine_polymorphic_shapes.mlir)。此模块应该已经具有静态输入形状。
要启用形状细化所有阶段的日志记录,您可以在 OSS 中设置环境变量 TF_CPP_VMODULE=refine_polymorphic_shapes=3(在 Google 内部,您传递 --vmodule=refine_polymorphic_shapes=3)。
# Log from python
JAX_DUMP_IR_TO=/tmp/export.dumps/ TF_CPP_VMODULE=refine_polymorphic_shapes=3 python tests/shape_poly_test.py ShapePolyTest.test_simple_unary -v=3