JAX 类型注解路线图#
作者:jakevdp
日期:2022 年 8 月
背景#
Python 3.0 引入了可选的函数注解 (PEP 3107),后来在 Python 3.5 发布前后被编纂用于静态类型检查 (PEP 484)。在某种程度上,类型注解和静态类型检查已成为许多 Python 开发工作流程不可或缺的一部分,为此,我们在整个 JAX API 中的许多地方添加了注解。JAX 中类型注解的当前状态有点零散,而添加更多注解的努力受到了更基本的设计问题的阻碍。本文档试图总结这些问题,并为 JAX 中类型注解的目标和非目标生成路线图。
为什么我们需要这样的路线图?更好/更全面的类型注解是用户(包括内部和外部用户)经常提出的要求。此外,我们经常收到外部用户的拉取请求(例如,PR #9917,PR #10322)寻求改进 JAX 的类型注解:对于审查代码的 JAX 团队成员来说,并不总是清楚这些贡献是否有益,特别是当它们引入复杂的协议来解决 JAX 在 Python 中使用时固有的挑战时。本文档详细介绍了 JAX 在包内进行类型注解的目标和建议。
为什么需要类型注解?#
Python 项目可能希望为其代码库添加注解的原因有很多;我们在本文档中将它们总结为级别 1、级别 2 和级别 3。
级别 1:作为文档的注解#
最初在 PEP 3107 中引入时,类型注解的部分动机是能够将其用作函数参数类型和返回类型的简洁内联文档。长期以来,JAX 一直以这种方式使用注解;一个例子是创建别名为 Any
的类型名称的常见模式。一个例子可以在 lax/slicing.py
中找到 [来源]
Array = Any
Shape = core.Shape
def slice(operand: Array, start_indices: Sequence[int],
limit_indices: Sequence[int],
strides: Optional[Sequence[int]] = None) -> Array:
...
对于静态类型检查的目的,这种使用 Array = Any
作为数组类型注解的方式,对参数值没有任何约束(Any
等同于根本没有注解),但它确实可以作为开发者有用的代码内文档形式。
为了生成文档,别名的名称会丢失(HTML 文档 中 jax.lax.slice
报告操作数为 Any
类型),因此文档的好处不会超出源代码(尽管我们可以启用一些 sphinx-autodoc
选项来改进这一点:请参阅 autodoc_type_aliases)。
这种级别类型注解的一个好处是,用 Any
注解一个值永远不会出错,因此它将以文档的形式为开发者和用户提供具体的好处,而不会增加满足任何特定静态类型检查器更严格需求的复杂性。
级别 2:用于智能自动完成的注解#
许多现代 IDE 利用类型注解作为 智能代码完成 系统的输入。一个例子是 VSCode 的 Pylance 扩展,它使用 Microsoft 的 pyright 静态类型检查器作为 VSCode IntelliSense 完成的信息来源。
这种类型检查的使用需要比上面使用的简单别名更进一步;例如,知道 slice
函数返回名为 Array
的 Any
别名,并不会为代码完成引擎添加任何有用的信息。但是,如果我们用 DeviceArray
返回类型注解该函数,自动完成功能将知道如何填充结果的命名空间,从而能够在开发过程中建议更相关的自动完成。
JAX 已经开始在一些地方添加这种级别的类型注解;一个例子是 jax.random
包中的 jnp.ndarray
返回类型 [来源]
def shuffle(key: KeyArray, x: Array, axis: int = 0) -> jnp.ndarray:
...
在这种情况下,jnp.ndarray
是一个抽象基类,它前向声明了 JAX 数组的属性和方法(参见来源),因此 VSCode 中的 Pylance 可以在此函数的结果上提供全套自动完成功能。以下是显示结果的屏幕截图
自动完成字段中列出的是抽象 ndarray
类声明的所有方法和属性。我们将在下面进一步讨论为什么有必要创建这个抽象类,而不是直接用 DeviceArray
进行注解。
级别 3:用于静态类型检查的注解#
如今,当人们考虑 Python 代码中类型注解的目的时,静态类型检查通常是人们首先想到的。虽然 Python 不对类型进行任何运行时检查,但存在几种成熟的静态类型检查工具,可以将其作为 CI 测试套件的一部分来完成。对于 JAX 来说,最重要的工具如下
python/mypy 或多或少是开放 Python 世界的标准。JAX 目前在 Github Actions CI 检查中对源文件的子集运行 mypy。
google/pytype 是 Google 的静态类型检查器,Google 内部依赖 JAX 的项目经常使用它。
microsoft/pyright 很重要,因为它是在 VSCode 中用于之前提到的 Pylance 完成的静态类型检查器。
完全静态类型检查是所有类型注解应用中最严格的,因为它会在您的类型注解不完全正确时显示错误。一方面,这很好,因为您的静态类型分析可能会捕获错误的类型注解(例如,jnp.ndarray
抽象类中缺少 DeviceArray
方法的情况)。
另一方面,这种严格性可能会使类型检查过程在经常依赖鸭子类型而不是严格类型安全 API 的包中变得非常脆弱。您目前会在整个 JAX 代码库中发现数百个地方都散布着诸如 #type: ignore
(用于 mypy)或 #pytype: disable
(用于 pytype)之类的代码注释。这些通常代表出现类型问题的情况;它们可能是 JAX 类型注解中的不准确之处,或者是静态类型检查器正确跟踪代码中控制流的能力不准确。有时,它们是由于 pytype 或 mypy 行为中的真实且细微的错误造成的。在极少数情况下,它们可能是由于 JAX 使用的 Python 模式难以甚至不可能用 Python 的静态类型注解语法来表达。
JAX 类型注解的挑战#
JAX 目前的类型注解混合了不同的风格,旨在达到上面讨论的所有三个级别的类型注解。部分原因在于,JAX 的源代码对 Python 的类型注解系统提出了一些独特的挑战。我们在这里概述它们。
挑战 1:pytype、mypy 和开发者摩擦#
JAX 当前面临的一个挑战是,包开发必须满足两个不同的静态类型检查系统 pytype
(内部 CI 和 Google 内部项目使用)和 mypy
(外部 CI 和外部依赖项使用)的约束。尽管这两个类型检查器在行为上具有广泛的重叠,但每个检查器都呈现出自己独特的极端情况,JAX 代码库中大量的 #type: ignore
和 #pytype: disable
语句就证明了这一点。
这会在开发中造成摩擦:内部贡献者可能会迭代直到测试通过,但最终发现他们的 pytype 批准的代码在导出时违反了 mypy。对于外部贡献者来说,情况通常相反:最近的一个例子是 #9596,在未能通过 Google 内部 pytype 检查后不得不回滚。每次我们将类型注解从级别 1(到处都是 Any
)移动到级别 2 或 3(更严格的注解)时,都会为这种令人沮丧的开发者体验引入更多可能性。
挑战 2:数组鸭子类型#
注解 JAX 代码的一个特殊挑战是其大量使用鸭子类型。标记为 Array
的函数的输入通常可能是多种不同类型之一:JAX DeviceArray
、NumPy np.ndarray
、NumPy 标量、Python 标量、Python 序列、具有 __array__
属性的对象、具有 __jax_array__
属性的对象,或任何类型的 jax.Tracer
。因此,像 def func(x: DeviceArray)
这样的简单注解是不够的,并且会导致许多有效用例的误报。这意味着 JAX 函数的类型注解不会是简短或微不足道的,但我们将不得不有效地开发一组 JAX 特定的类型扩展,类似于 numpy.typing
包中的那些扩展。
挑战 3:转换和装饰器#
JAX 的 Python API 在很大程度上依赖于函数转换(jit()
、vmap()
、grad()
等),这种类型的 API 对静态类型分析提出了特殊的挑战。装饰器的灵活注解一直是 mypy 包中的一个 长期存在的问题,直到最近才通过引入 ParamSpec
解决,这在 PEP 612 中讨论并在 Python 3.10 中添加。由于 JAX 遵循 NEP 29,因此在 2024 年年中之后的某个时间之前,它不能依赖 Python 3.10 功能。与此同时,协议可以用作此问题的部分解决方案(JAX 在 #9950 中为 jit 和其他方法添加了此功能),并且 ParamSpec 可以通过 typing_extensions
包使用(原型在 #9999 中),尽管这目前揭示了 mypy 中的基本错误(参见 python/mypy#12593)。总而言之:目前尚不清楚 JAX 函数转换的 API 是否可以在 Python 类型注解工具的当前约束范围内进行适当的注解。
挑战 4:数组注解缺乏粒度#
这里的另一个挑战是 Python 中所有面向数组的 API 共有的,并且多年来一直是 JAX 讨论的一部分(参见 #943)。类型注解与对象的 Python 类或类型有关,而在基于数组的语言中,通常类的属性更为重要。在 NumPy、JAX 和类似包的情况下,通常我们希望注解特定的数组形状和数据类型。
例如,jnp.linspace
函数的参数必须是标量值,但在 JAX 中,标量由零维数组表示。因此,为了使注解不产生误报,我们必须允许这些参数是任意数组。另一个例子是 jax.random.choice
的第二个参数,当 shape=()
时,它必须具有 dtype=int
。Python 计划通过可变类型泛型来实现具有这种粒度的类型注解(参见 PEP 646,计划在 Python 3.11 中实现),但与 ParamSpec
一样,对该功能的支持还需要一段时间才能稳定。
目前有一些第三方项目可能会有所帮助,特别是 google/jaxtyping,但这使用了非标准注解,可能不适合注解核心 JAX 库本身。总而言之,数组类型粒度挑战不如其他挑战那么重要,因为主要影响是类数组注解将不如它们原本可以的那样具体。
挑战 5:从 NumPy 继承的不精确 API#
JAX 的用户界面 API 的很大一部分继承自 NumPy,位于 jax.numpy
子模块中。NumPy 的 API 是在静态类型检查成为 Python 语言的一部分之前多年开发的,并且遵循 Python 的历史建议,使用 鸭子类型/EAFP 编码风格,其中不鼓励在运行时进行严格的类型检查。作为一个具体的例子,考虑 numpy.tile()
函数,其定义如下:
def tile(A, reps):
try:
tup = tuple(reps)
except TypeError:
tup = (reps,)
d = len(tup)
...
这里的意图是 reps
应该包含一个 int
或一个 int
值的序列,但实现允许 tup
是任何可迭代对象。当向这种鸭子类型代码添加注解时,我们可以采取以下两种途径之一:
我们可以选择注解函数 API 的意图,这里可能类似于
reps: Union[int, Sequence[int]]
。相反,我们可以选择注解函数的实现,这里可能看起来像
reps: Union[ConvertibleToInt, Iterable[ConvertibleToInt]]
,其中ConvertibleToInt
是一个特殊的协议,涵盖了我们的函数将输入转换为整数的确切机制(即通过__int__
、通过__index__
、通过__array__
等)。另请注意,严格来说,Iterable
在这里是不够的,因为 Python 中有一些对象可以作为可迭代对象进行鸭子类型化,但不能满足针对Iterable
的静态类型检查(即,通过__getitem__
而不是__iter__
可迭代的对象)。
注解意图(#1)的优点是,注解对于用户来说更有用,可以传达 API 契约;而对于开发人员来说,灵活性为必要时的重构留出了空间。缺点(特别是对于像 JAX 这样的渐进类型 API)是,很可能存在用户代码可以正确运行,但会被类型检查器标记为不正确的代码。对现有的鸭子类型 API 进行渐进类型化意味着当前的注解隐式为 Any
,因此将其更改为更严格的类型可能会向用户呈现为破坏性更改。
broadly speaking, annotating intent better serves Level 1 type checking, while annotating implementation better serves Level 3, while Level 2 is more of a mixed bag (both intent and implementation are important when it comes to annotations in IDEs).
JAX 类型注解路线图#
有了这个框架(Level 1/2/3)和 JAX 特有的挑战,我们可以开始制定我们的路线图,以便在整个 JAX 项目中实施一致的类型注解。
指导原则#
对于 JAX 类型注解,我们将遵循以下原则:
类型注解的目的#
我们希望尽可能地支持完整的Level 1、2 和 3 类型注解。特别是,这意味着我们应该对公共 API 函数的输入和输出都进行限制性的类型注解。
为意图注解#
JAX 类型注解通常应指示 API 的 意图,而不是实现,以便注解有助于传达 API 的契约。这意味着有时在运行时有效的输入可能不会被静态类型检查器识别为有效(一个例子可能是传递任意迭代器来代替注解为 Shape = Sequence[int]
的形状)。
输入应宽松类型化#
JAX 函数和方法的输入应尽可能宽松地进行类型化:例如,虽然形状通常是元组,但接受形状的函数应接受任意序列。同样,接受 dtype 的函数不需要类 np.dtype
的实例,而是任何可转换为 dtype 的对象。这可能包括字符串、内置标量类型或标量对象构造函数,例如 np.float64
和 jnp.float64
。为了使这在整个包中尽可能统一,我们将添加一个 jax.typing
模块,其中包含常见的类型规范,从广泛的类别开始,例如:
ArrayLike
将是可以隐式转换为数组的任何事物的联合:例如,jax 数组、numpy 数组、JAX 追踪器以及 python 或 numpy 标量。DTypeLike
将是可以隐式转换为 dtype 的任何事物的联合:例如,numpy dtypes、numpy dtype 对象、jax dtype 对象、字符串和内置类型。ShapeLike
将是可以转换为形状的任何事物的联合:例如,整数或类整数对象的序列。等等。
请注意,这些通常比 numpy.typing
中使用的等效协议更简单。例如,在 DTypeLike
的情况下,JAX 不支持结构化 dtypes,因此 JAX 可以使用更简单的实现。类似地,在 ArrayLike
中,JAX 通常不支持列表或元组输入来代替数组,因此类型定义将比 NumPy 类似物更简单。
输出应严格类型化#
相反,函数和方法的输出应尽可能严格地进行类型化:例如,对于返回数组的 JAX 函数,输出应使用类似于 jnp.ndarray
而不是 ArrayLike
的内容进行注解。返回 dtype 的函数应始终注解为 np.dtype
,而返回形状的函数应始终为 Tuple[int]
或严格类型化的 NamedShape 等效项。为此,我们将在 jax.typing
中实现上述宽松类型的几个严格类型化的类似物,即:
Array
或NDArray
(见下文)用于类型注解目的,实际上等同于Union[Tracer, jnp.ndarray]
,应用于注解数组输出。DType
是np.dtype
的别名,可能还能够表示 JAX 中使用的键类型和其他泛化。Shape
本质上是Tuple[int, ...]
,可能具有一些额外的灵活性来考虑动态形状。NamedShape
是Shape
的扩展,允许使用 JAX 内部使用的命名形状。等等。
我们还将探索是否可以删除当前 jax.numpy.ndarray
的实现,而改为使 ndarray
成为 Array
或类似物的别名。
倾向于简单#
除了在 jax.typing
中收集的通用类型协议外,我们应该倾向于简单。我们应该避免为传递给 API 函数的参数构建过于复杂的协议,而是在 API 的完整类型规范无法简洁指定的情况下,使用简单的联合,例如 Union[simple_type, Any]
。这是一种折衷方案,实现了 Level 1 和 Level 2 注解的目标,同时为了避免不必要的复杂性,放弃了 Level 3。
避免不稳定的类型机制#
为了不增加不必要的开发摩擦(由于内部/外部 CI 的差异),我们希望在我们使用的类型注解构造中保持保守:特别是,对于最近引入的机制,如 ParamSpec
(PEP 612) 和 Variadic Type Generics (PEP 646),我们希望等到 mypy 和其他工具中的支持成熟和稳定后再依赖它们。
这的一个影响是,目前,当函数被 JAX 转换(如 jit
、vmap
、grad
等)装饰时,JAX 将有效地剥离所有注解从被装饰的函数中。虽然这很不幸,但在撰写本文时,mypy 与 ParamSpec
提供的潜在解决方案存在大量不兼容之处(请参阅 ParamSpec
mypy 错误跟踪器),因此我们判断它目前尚未准备好在 JAX 中完全采用。一旦对此类功能的支持稳定下来,我们将在未来重新审视这个问题。
同样,目前我们将避免添加 jaxtyping 项目提供的更复杂和细粒度的数组类型注解。这是一个我们可以在未来重新考虑的决定。
Array
类型设计考虑因素#
如上所述,由于 JAX 广泛使用鸭子类型,即在 jax 转换中传递和返回 Tracer
对象来代替实际数组,因此 JAX 中数组的类型注解提出了独特的挑战。这变得越来越令人困惑,因为用于类型注解的对象经常与用于运行时实例检查的对象重叠,并且可能对应也可能不对应于相关对象的实际类型层次结构。对于 JAX,我们需要为两种上下文提供鸭子类型对象:静态类型注解 和 运行时实例检查。
以下讨论将假设 jax.Array
是设备上数组的运行时类型,这目前还不是这种情况,但一旦 #12016 中的工作完成,就会是这种情况。
静态类型注解#
我们需要提供一个对象,该对象可用于鸭子类型的类型注解。暂时假设我们将此对象称为 ArrayAnnotation
,我们需要一个解决方案,该解决方案对于以下情况满足 mypy
和 pytype
:
@jit
def f(x: ArrayAnnotation) -> ArrayAnnotation:
assert isinstance(x, core.Tracer)
return x
这可以通过多种方法实现,例如:
使用类型联合:
ArrayAnnotation = Union[Array, Tracer]
创建一个接口文件,声明
Tracer
和Array
应被视为ArrayAnnotation
的子类。重构
Array
和Tracer
,使ArrayAnnotation
成为两者的真正基类。
运行时实例检查#
我们还必须提供一个对象,该对象可用于鸭子类型的运行时 isinstance
检查。暂时假设我们将此对象称为 ArrayInstance
,我们需要一个解决方案,该解决方案通过以下运行时检查:
def f(x):
return isinstance(x, ArrayInstance)
x = jnp.array([1, 2, 3])
assert f(x) # x will be an array
assert jit(f)(x) # x will be a tracer
同样,有几种机制可以用于此:
重写
type(ArrayInstance).__instancecheck__
以对Array
和Tracer
对象都返回True
;这就是当前jnp.ndarray
的实现方式 (source)。将
ArrayInstance
定义为抽象基类,并将其动态注册到Array
和Tracer
。重构
Array
和Tracer
,使ArrayInstance
成为Array
和Tracer
两者的真正基类。
我们需要做出的决定是 ArrayAnnotation
和 ArrayInstance
应该是相同还是不同的对象。这里有一些先例;例如,在核心 Python 语言规范中,存在 typing.Dict
和 typing.List
用于注解,而内置的 dict
和 list
用于实例检查。然而,Dict
和 List
在较新的 Python 版本中 已弃用,转而使用 dict
和 list
进行注解和实例检查。
遵循 NumPy 的领导#
在 NumPy 的案例中,np.typing.NDArray
用于类型注解,而 np.ndarray
用于实例检查(以及数组类型标识)。鉴于此,遵循 NumPy 的先例并实现以下内容可能是合理的:
jax.Array
是设备上数组的实际类型。jax.typing.NDArray
是用于鸭子类型数组注解的对象。jax.numpy.ndarray
是用于鸭子类型数组实例检查的对象。
对于 NumPy 高级用户来说,这可能感觉很自然,但是这种三分法可能会引起混淆:选择哪个用于实例检查和注解并不立即清楚。
统一实例检查和注解#
另一种方法是通过上述重写机制统一类型检查和注解。
选项 1:部分统一#
部分统一可能如下所示:
jax.Array
是设备上数组的实际类型。jax.typing.Array
是用于鸭子类型数组注解的对象(通过Array
和Tracer
上的.pyi
接口)。jax.typing.Array
也是用于鸭子类型实例检查的对象(通过其元类中的__isinstance__
重写)。
在这种方法中,jax.numpy.ndarray
将成为向后兼容的简单别名 jax.typing.Array
。
选项 2:通过重写完全统一#
或者,我们可以选择通过重写完全统一:
jax.Array
是设备上数组的实际类型。jax.Array
也是用于鸭子类型数组注解的对象(通过Tracer
上的.pyi
接口)。jax.Array
也是用于鸭子类型实例检查的对象(通过其元类中的__isinstance__
重写)。
在这里,jax.numpy.ndarray
将成为向后兼容的简单别名 jax.Array
。
选项 3:通过类层次结构完全统一#
最后,我们可以选择通过重构类层次结构并将鸭子类型替换为 OOP 对象层次结构来实现完全统一:
jax.Array
是设备上数组的实际类型。jax.Array
也是用于数组类型注解的对象,通过确保Tracer
继承自jax.Array
。jax.Array
也是用于实例检查的对象,通过相同的机制。
在这里,jnp.ndarray
可以是 jax.Array
的别名。从某种意义上说,最后一种方法是最纯粹的,但从 OOP 设计的角度来看,它有些牵强(Tracer
是 Array
吗?)。
选项 4:通过类层次结构部分统一#
我们可以通过使 Tracer
和设备上数组的类继承自公共基类来使类层次结构更合理。例如:
jax.Array
是Tracer
以及设备上数组的实际类型的基类,设备上数组的实际类型可能是jax._src.ArrayImpl
或类似物。jax.Array
是用于数组类型注解的对象。jax.Array
也是用于实例检查的对象。
在这里,jnp.ndarray
将是 Array
的别名。从 OOP 的角度来看,这可能更纯粹,但与选项 2 和 3 相比,它放弃了 type(x) is jax.Array
将评估为 True 的概念。
评估#
考虑每种潜在方法的总体优缺点:
从用户的角度来看,统一的方法(选项 2 和 3)可以说是最好的,因为它们消除了记住哪个对象用于实例检查或注解所涉及的认知开销:
jax.Array
是您需要知道的一切。但是,选项 2 和 3 都引入了一些奇怪和/或令人困惑的行为。选项 2 依赖于可能令人困惑的实例检查重写,对于 pybind11 中定义的类,这些重写 支持不佳。选项 3 要求
Tracer
成为数组的子类。这打破了继承模型,因为它将要求Tracer
对象携带Array
对象的所有包袱(数据缓冲区、分片、设备等)。选项 4 在 OOP 意义上更纯粹,并且避免了对典型实例检查或类型注解行为进行任何重写的需要。权衡是设备上数组的实际类型变得独立(这里是
jax._src.ArrayImpl
)。但是绝大多数用户永远不必直接接触这个私有实现。
这里有不同的权衡,但在讨论之后,我们选择了选项 4 作为我们的前进方向。
实施计划#
为了推进类型注解,我们将执行以下操作:
迭代此 JEP 文档,直到开发人员和利益相关者都接受。
创建一个私有的
jax._src.typing
(目前不提供任何公共 API),并在其中放入上述第一级简单类型:暂时别名
Array = Any
,因为这需要更多思考。ArrayLike
:作为正常jax.numpy
函数输入的有效类型的联合。DType
/DTypeLike
(注意:numpy 使用驼峰式DType
;为了易于使用,我们应该遵循此约定)。Shape
/NamedShape
/ShapeLike
。
这方面的开始工作已在 #12300 中完成。
开始在遵循上一节选项 4 的
jax.Array
基类上工作。最初,这将以 Python 定义,并使用当前在jnp.ndarray
实现中找到的动态注册机制,以确保isinstance
检查的正确行为。每个追踪器和类数组类的pyi
重写将确保类型注解的正确行为。jnp.ndarray
然后可以被设为jax.Array
的别名。作为测试,根据上述指南,使用这些新的类型定义来全面注解
jax.lax
中的函数。继续一次添加一个模块的额外注解,重点关注公共 API 函数。
同时,开始在 pybind11 中重新实现
jax.Array
基类,以便ArrayImpl
和Tracer
可以从中继承。使用pyi
定义以确保静态类型检查器识别类的适当属性。一旦
jax.Array
和jax._src.ArrayImpl
完全落地,删除这些临时的 Python 实现。当一切最终确定后,创建一个公共
jax.typing
模块,使上述类型可供用户使用,并提供有关使用 JAX 的代码的注解最佳实践的文档。
我们将在 #12049 中跟踪这项工作,此 JEP 从中获得其编号。