JAX 类型注释路线图#
作者:jakevdp
日期:2022 年 8 月
背景#
Python 3.0 引入了可选的函数注释(PEP 3107),后来在 Python 3.5 发布前后(PEP 484)被规范化用于静态类型检查。在某种程度上,类型注释和静态类型检查已成为许多 Python 开发工作流程不可或缺的一部分,为此,我们在 JAX API 的许多地方添加了注释。JAX 中类型注释的当前状态有些零散,并且由于更基本的设计问题,添加更多注释的努力受到了阻碍。本文档试图总结这些问题,并为 JAX 中类型注释的目标和非目标制定路线图。
为什么我们需要这样的路线图?更好/更全面的类型注释是用户(无论是内部还是外部)经常提出的要求。此外,我们经常收到外部用户的拉取请求(例如,PR #9917,PR #10322),旨在改进 JAX 的类型注释:JAX 团队成员在审查代码时并不总是清楚这些贡献是否有益,特别是当它们引入复杂的协议来解决完全注释 JAX 对 Python 使用的固有挑战时。本文档详细说明了 JAX 在包内类型注释方面的目标和建议。
为什么需要类型注释?#
Python 项目希望对其代码库进行注释的原因有很多;我们将在本文档中将其总结为级别 1、级别 2 和级别 3。
级别 1:将注释作为文档#
当最初在PEP 3107中引入时,类型注释的部分动机是能够将其用作函数参数类型和返回类型的简洁内联文档。JAX 长期以来一直以这种方式利用注释;一个例子是创建别名为Any
的类型名称的常见模式。一个例子可以在lax/slicing.py
中找到[source]
Array = Any
Shape = core.Shape
def slice(operand: Array, start_indices: Sequence[int],
limit_indices: Sequence[int],
strides: Optional[Sequence[int]] = None) -> Array:
...
为了静态类型检查的目的,这种将 Array = Any
用于数组类型注释的做法对参数值不施加任何约束(Any
等同于不进行任何注释),但它确实为开发者提供了有用的代码内文档形式。
对于生成的文档而言,别名会被丢失(jax.lax.slice
的HTML 文档报告操作数类型为Any
),因此文档的好处不会超出源代码(尽管我们可以启用一些sphinx-autodoc
选项来改善这一点:参见autodoc_type_aliases)。
这种级别的类型注释的一个好处是,用Any
注释一个值永远不会出错,因此它将以文档的形式为开发者和用户提供具体的益处,而无需增加满足任何特定静态类型检查器更严格需求的复杂性。
级别 2:用于智能自动完成的注释#
许多现代 IDE 利用类型注释作为智能代码完成系统的输入。其中一个例子是 VSCode 的Pylance扩展,它使用微软的pyright静态类型检查器作为 VSCode IntelliSense 完成的信息源。
这种类型检查的使用要求比上面使用的简单别名更进一步;例如,知道slice
函数返回一个名为Array
的Any
的别名,并不会给代码完成引擎添加任何有用的信息。然而,如果我们将函数注释为DeviceArray
返回类型,自动完成将知道如何填充结果的命名空间,从而能够在开发过程中建议更相关的自动完成。
JAX 已开始在一些地方添加此级别的类型注释;一个例子是jax.random
包内的jnp.ndarray
返回类型 [source]
def shuffle(key: KeyArray, x: Array, axis: int = 0) -> jnp.ndarray:
...
在这种情况下,jnp.ndarray
是一个抽象基类,它预先声明了 JAX 数组的属性和方法(见源代码),因此 VSCode 中的 Pylance 可以提供该函数结果的完整自动完成集。以下是显示结果的屏幕截图
自动完成字段中列出了抽象ndarray
类声明的所有方法和属性。我们将在下面进一步讨论为什么有必要创建这个抽象类而不是直接使用DeviceArray
进行注释。
级别 3:用于静态类型检查的注释#
如今,当考虑 Python 代码中类型注释的目的时,静态类型检查通常是人们首先想到的。虽然 Python 不进行任何运行时类型检查,但存在几种成熟的静态类型检查工具,可以在 CI 测试套件中完成此操作。对于 JAX,最重要的工具如下:
python/mypy或多或少是开放 Python 世界的标准。JAX 目前在 Github Actions CI 检查中对部分源文件运行 mypy。
google/pytype是 Google 的静态类型检查器,Google 内部依赖 JAX 的项目经常使用它。
microsoft/pyright很重要,它是 VSCode 中用于前面提到的 Pylance 完成的静态类型检查器。
全面的静态类型检查是所有类型注释应用中最严格的,因为它会在您的类型注释不完全正确时立即报错。一方面,这很好,因为您的静态类型分析可能会捕获错误的类型注释(例如,jnp.ndarray
抽象类中缺少DeviceArray
方法的情况)。
另一方面,这种严格性会使那些经常依赖鸭子类型而非严格类型安全 API 的软件包的类型检查过程变得非常脆弱。您目前会在 JAX 代码库中数百个地方找到类似#type: ignore
(针对 mypy)或#pytype: disable
(针对 pytype)的代码注释。这些通常表示出现了类型问题;它们可能是 JAX 类型注释中的不准确之处,或者是静态类型检查器无法正确遵循代码中控制流的不准确之处。偶尔,它们是由于 pytype 或 mypy 行为中真实而微妙的错误造成的。在极少数情况下,它们可能是由于 JAX 使用了 Python 模式,这些模式难以甚至不可能用 Python 的静态类型注释语法来表达。
JAX 的类型注释挑战#
JAX 当前的类型注释混合了不同的风格,并且旨在满足上述所有三个级别的类型注释。部分原因在于 JAX 的源代码对 Python 的类型注释系统提出了一些独特的挑战。我们在此概述它们。
挑战 1:pytype、mypy 和开发者摩擦#
JAX 目前面临的一个挑战是,包开发必须满足两种不同静态类型检查系统(pytype
(内部 CI 和内部 Google 项目使用)和mypy
(外部 CI 和外部依赖项使用))的约束。尽管这两种类型检查器在行为上存在广泛重叠,但每种都呈现其独特的边缘情况,JAX 代码库中大量的#type: ignore
和#pytype: disable
语句就证明了这一点。
这在开发中造成了摩擦:内部贡献者可能会迭代直到测试通过,却发现导出的经过 pytype 批准的代码与 mypy 冲突。对于外部贡献者来说,情况往往相反:最近的一个例子是#9596,它在未能通过内部 Google pytype 检查后不得不回滚。每次我们将类型注释从级别 1(Any
无处不在)移动到级别 2 或 3(更严格的注释),就会增加这种令人沮丧的开发者体验的可能性。
挑战 2:数组鸭子类型化#
注释 JAX 代码的一个特殊挑战是其大量使用鸭子类型。函数中标记为Array
的输入通常可以是许多不同类型之一:JAX DeviceArray
、NumPy np.ndarray
、NumPy 标量、Python 标量、Python 序列、具有__array__
属性的对象、具有__jax_array__
属性的对象,或任何类型的jax.Tracer
。因此,简单的注释,如def func(x: DeviceArray)
将不足够,并且会导致许多有效用法出现误报。这意味着 JAX 函数的类型注释不会简短或微不足道,但我们必须有效地开发一套 JAX 特定的类型扩展,类似于numpy.typing
包中的那些。
挑战 3:转换和装饰器#
JAX 的 Python API 严重依赖于函数转换(jit()
、vmap()
、grad()
等),这种类型的 API 对静态类型分析构成了特殊挑战。装饰器的灵活注释一直是 mypy 包中的一个长期存在的问题,直到最近才通过引入ParamSpec
得到解决,该问题在PEP 612中讨论并添加到 Python 3.10 中。由于 JAX 遵循NEP 29,它在 2024 年年中之前无法依赖 Python 3.10 功能。在此期间,协议(Protocols)可以作为部分解决方案(JAX 在#9950中为 jit 和其他方法添加了此功能),并且 ParamSpec 可以通过typing_extensions
包使用(一个原型在#9999中),尽管这目前揭示了 mypy 中的基本错误(参见python/mypy#12593)。所有这些都表明:目前尚不清楚 JAX 函数转换的 API 是否可以在当前 Python 类型注释工具的约束下进行适当注释。
挑战 4:数组注释缺乏粒度#
另一个挑战是 Python 中所有面向数组的 API 普遍存在的问题,并且在 JAX 的讨论中已持续数年(参见#943)。类型注释与对象的 Python 类或类型有关,而在基于数组的语言中,通常类的属性更为重要。对于 NumPy、JAX 和类似包,我们通常希望注释特定的数组形状和数据类型。
例如,jnp.linspace
函数的参数必须是标量值,但在 JAX 中,标量由零维数组表示。因此,为了使注释不引发误报,我们必须允许这些参数是任意数组。另一个例子是jax.random.choice
的第二个参数,当shape=()
时,它必须具有dtype=int
。Python 计划通过可变类型泛型(参见PEP 646,计划用于 Python 3.11)实现这种粒度的类型注释,但与ParamSpec
一样,对此功能的支持需要一段时间才能稳定。
在此期间,一些第三方项目可能会有所帮助,特别是google/jaxtyping,但这使用了非标准注释,可能不适合注释 JAX 核心库本身。总而言之,数组类型粒度挑战不如其他挑战那么重要,因为主要影响是数组类注释的特异性将低于其本应有的水平。
挑战 5:从 NumPy 继承的不精确 API#
JAX 面向用户的 API 大部分是从jax.numpy
子模块中的 NumPy 继承的。NumPy 的 API 在静态类型检查成为 Python 语言一部分之前多年就已经开发出来了,它遵循 Python 历史上推荐的鸭子类型/EAFP编码风格,其中不鼓励在运行时进行严格的类型检查。作为一个具体的例子,考虑numpy.tile()
函数,它的定义如下:
def tile(A, reps):
try:
tup = tuple(reps)
except TypeError:
tup = (reps,)
d = len(tup)
...
这里,意图是reps
应包含一个int
或一个int
值的序列,但实现允许tup
是任何可迭代对象。在为这种鸭子类型代码添加注释时,我们可以采取两种途径之一:
我们可以选择注释函数 API 的意图,这里可能是
reps: Union[int, Sequence[int]]
之类的。相反,我们可以选择注释函数的实现,这里可能看起来像
reps: Union[ConvertibleToInt, Iterable[ConvertibleToInt]]
,其中ConvertibleToInt
是一个特殊协议,涵盖了函数将输入转换为整数的确切机制(即通过__int__
,通过__index__
,通过__array__
等)。这里还需要注意的是,严格来说,Iterable
在这里是不够的,因为 Python 中有些对象在鸭子类型上是可迭代的,但不能通过Iterable
进行静态类型检查(即,通过__getitem__
而不是__iter__
进行迭代的对象)。
选项 1(注释意图)的优点是,注释在传达 API 契约方面对用户更有用;而对于开发者来说,灵活性在必要时为重构留下了空间。缺点(特别是对于像 JAX 这样渐进类型化的 API)是,很可能存在用户代码运行正确,但会被类型检查器标记为不正确的情况。对现有鸭子类型 API 进行渐进类型化意味着当前注释隐式为Any
,因此将其更改为更严格的类型可能会对用户造成破坏性更改。
广义上讲,注释意图更好地服务于级别 1 类型检查,而注释实现更好地服务于级别 3,而级别 2 则是一个混合体(意图和实现在 IDE 中的注释方面都很重要)。
JAX 类型注释路线图#
有了这个框架(级别 1/2/3)并考虑到 JAX 特有的挑战,我们就可以开始制定在 JAX 项目中实现一致类型注释的路线图。
指导原则#
对于 JAX 类型注释,我们将遵循以下原则:
类型注释的目的#
我们希望尽可能支持完整的级别 1、2 和 3 类型注释。特别是,这意味着我们应该对公共 API 函数的输入和输出都进行限制性类型注释。
为意图添加注释#
JAX 类型注释通常应表明 API 的意图,而不是实现,以便注释对于传达 API 契约变得有用。这意味着有时在运行时有效的输入可能不会被静态类型检查器识别为有效(一个例子可能是任意迭代器传递给被注释为Shape = Sequence[int]
的形状)。
输入应采用宽松类型#
JAX 函数和方法的输入应尽可能宽松地类型化:例如,虽然形状通常是元组,但接受形状的函数应接受任意序列。同样,接受 dtype 的函数无需要求np.dtype
类的实例,而是任何可转换为 dtype 的对象。这可能包括字符串、内置标量类型或标量对象构造函数,例如np.float64
和jnp.float64
。为了使整个包尽可能统一,我们将添加一个jax.typing
模块,其中包含常见的类型规范,从以下宽泛类别开始:
ArrayLike
将是任何可以隐式转换为数组的类型的联合:例如,jax 数组、numpy 数组、JAX tracer,以及 python 或 numpy 标量DTypeLike
将是任何可以隐式转换为 dtype 的类型的联合:例如,numpy dtypes、numpy dtype 对象、jax dtype 对象、字符串和内置类型。ShapeLike
将是任何可以转换为形状的类型的联合:例如,整数或类整数对象的序列。等等
请注意,这些通常会比numpy.typing
中使用的等效协议更简单。例如,在DTypeLike
的情况下,JAX 不支持结构化 dtype,因此 JAX 可以使用更简单的实现。同样,在ArrayLike
中,JAX 通常不支持列表或元组输入代替数组,因此类型定义将比 NumPy 对应的类型更简单。
输出应采用严格类型#
相反,函数和方法的输出应尽可能严格类型化:例如,对于返回数组的 JAX 函数,输出应使用类似于jnp.ndarray
而不是ArrayLike
进行注释。返回 dtype 的函数应始终注释为np.dtype
,返回形状的函数应始终为Tuple[int]
或严格类型化的 NamedShape 等效项。为此,我们将在jax.typing
中实现几个上述宽松类型的严格类型化模拟,即:
Array
或NDArray
(见下文) 用于类型注释目的,实际上等同于Union[Tracer, jnp.ndarray]
,应用于注释数组输出。DType
是np.dtype
的别名,可能还具有表示 JAX 内部使用的键类型和其他泛化的能力。Shape
本质上是Tuple[int, ...]
,可能带有一些额外的灵活性以适应动态形状。NamedShape
是Shape
的扩展,允许使用 JAX 内部使用的命名形状。等等
我们还将探讨是否可以放弃jax.numpy.ndarray
的当前实现,转而将ndarray
作为Array
或类似对象的别名。
倾向于简单性#
除了jax.typing
中收集的常见类型协议之外,我们应该倾向于简单性。我们应该避免为传递给 API 函数的参数构建过于复杂的协议,而应使用简单的联合类型,例如在无法简洁指定 API 的完整类型规范的情况下使用Union[simple_type, Any]
。这是一种折衷方案,它实现了级别 1 和级别 2 注释的目标,同时为了避免不必要的复杂性而放弃了级别 3。
避免不稳定的类型机制#
为了避免增加不必要的开发摩擦(由于内部/外部 CI 差异),我们希望在使用的类型注释构造方面保持保守:特别是对于最近引入的机制,如ParamSpec
(PEP 612)和可变类型泛型(PEP 646),我们希望等到 mypy 和其他工具中的支持成熟并稳定后再依赖它们。
其中一个影响是,暂时而言,当函数被 JAX 转换(如jit
、vmap
、grad
等)装饰时,JAX 将有效地剥离被装饰函数的所有注释。虽然这很不幸,但在撰写本文时,mypy 与ParamSpec
提供的潜在解决方案存在一系列不兼容问题(参见ParamSpec
mypy bug tracker),因此我们认为目前尚未准备好在 JAX 中全面采用。我们将在未来支持此类功能稳定后重新审视这个问题。
同样,目前我们将避免添加jaxtyping项目提供的更复杂和更细粒度的数组类型注释。这个决定我们可以在未来重新审视。
Array
类型设计考量#
如上所述,JAX 中数组的类型注释带来了独特的挑战,因为 JAX 大量使用鸭子类型,即在 JAX 转换中传递和返回Tracer
对象来代替实际数组。这变得越来越令人困惑,因为用于类型注释的对象通常与用于运行时实例检查的对象重叠,并且可能与所讨论对象的实际类型层次结构对应或不对应。对于 JAX,我们需要在两种情况下提供鸭子类型对象:静态类型注释和运行时实例检查。
以下讨论将假设jax.Array
是设备上数组的运行时类型,目前情况并非如此,但一旦#12016的工作完成,情况就会如此。
静态类型注释#
我们需要提供一个可用于鸭子类型注释的对象。假设我们暂时称此对象为ArrayAnnotation
,我们需要一个满足mypy
和pytype
的解决方案,例如以下情况:
@jit
def f(x: ArrayAnnotation) -> ArrayAnnotation:
assert isinstance(x, core.Tracer)
return x
这可以通过多种方法实现,例如:
使用类型联合:
ArrayAnnotation = Union[Array, Tracer]
创建一个接口文件,声明
Tracer
和Array
应被视为ArrayAnnotation
的子类。重构
Array
和Tracer
,使ArrayAnnotation
成为两者的真正基类。
运行时实例检查#
我们还必须提供一个可用于鸭子类型运行时isinstance
检查的对象。假设我们暂时称此对象为ArrayInstance
,我们需要一个通过以下运行时检查的解决方案:
def f(x):
return isinstance(x, ArrayInstance)
x = jnp.array([1, 2, 3])
assert f(x) # x will be an array
assert jit(f)(x) # x will be a tracer
同样,有几种机制可以用于此:
覆盖
type(ArrayInstance).__instancecheck__
,使其对Array
和Tracer
对象返回True
;这是jnp.ndarray
当前实现的方式(source)。将
ArrayInstance
定义为一个抽象基类,并将其动态注册到Array
和Tracer
。重构
Array
和Tracer
,使ArrayInstance
成为Array
和Tracer
两者的真实基类。
我们需要决定ArrayAnnotation
和ArrayInstance
是否应该相同或不同。这里有一些先例;例如,在核心 Python 语言规范中,typing.Dict
和typing.List
是为了注释而存在,而内置的dict
和list
则用于实例检查。然而,Dict
和List
在较新的 Python 版本中已被弃用,转而使用dict
和list
进行注释和实例检查。
遵循 NumPy 的领先#
在 NumPy 的案例中,np.typing.NDArray
用于类型注释,而np.ndarray
用于实例检查(以及数组类型标识)。鉴于此,遵循 NumPy 的先例并实施以下方案可能是合理的:
jax.Array
是设备上数组的实际类型。jax.typing.NDArray
是用于鸭子类型数组注释的对象。jax.numpy.ndarray
是用于鸭子类型数组实例检查的对象。
这可能让 NumPy 资深用户感到有些自然,但这种三分法很可能会造成混淆:选择哪一个用于实例检查和注释并不立即清楚。
统一实例检查和注释#
另一种方法是通过上述覆盖机制统一类型检查和注释。
选项 1:部分统一#
部分统一可能如下所示:
jax.Array
是设备上数组的实际类型。jax.typing.Array
是用于鸭子类型数组注释的对象(通过Array
和Tracer
上的.pyi
接口)。jax.typing.Array
也是用于鸭子类型实例检查的对象(通过其元类中的__isinstance__
覆盖)。
在这种方法中,jax.numpy.ndarray
将成为jax.typing.Array
的简单别名,以实现向后兼容性。
选项 2:通过覆盖实现完全统一#
或者,我们可以选择通过覆盖实现完全统一:
jax.Array
是设备上数组的实际类型。jax.Array
也用于鸭子类型数组注释(通过Tracer
上的.pyi
接口)jax.Array
也用于鸭子类型实例检查(通过其元类中的__isinstance__
覆盖)
这里,jax.numpy.ndarray
将成为jax.Array
的简单别名,以实现向后兼容性。
选项 3:通过类层次结构实现完全统一#
最后,我们可以选择通过重构类层次结构并用面向对象编程(OOP)对象层次结构替换鸭子类型来实现完全统一:
jax.Array
是设备上数组的实际类型。jax.Array
也用于数组类型注释,通过确保Tracer
继承自jax.Array
。jax.Array
也通过相同的机制用于实例检查。
这里jnp.ndarray
可以是jax.Array
的别名。这种最终方法在某些意义上是最纯粹的,但从 OOP 设计的角度来看有些牵强(Tracer
是一个Array
吗?)。
选项 4:通过类层次结构实现部分统一#
我们可以通过让Tracer
和设备上数组的类继承自一个共同的基类,从而使类层次结构更加合理。例如:
jax.Array
是Tracer
以及设备上数组的实际类型(可能是jax._src.ArrayImpl
或类似)的基类。jax.Array
是用于数组类型注释的对象。jax.Array
也用于实例检查。
这里jnp.ndarray
将是Array
的别名。这在 OOP 角度可能更纯粹,但与选项 2 和 3 相比,它放弃了type(x) is jax.Array
将评估为 True 的概念。
评估#
考虑到每种潜在方法的总体优缺点:
从用户的角度来看,统一方法(选项 2 和 3)可以说是最好的,因为它们消除了记住哪些对象用于实例检查或注释的认知开销:您只需要知道
jax.Array
。然而,选项 2 和 3 都引入了一些奇怪和/或令人困惑的行为。选项 2 依赖于可能令人困惑的实例检查覆盖,这对于在 pybind11 中定义的类来说支持不佳。选项 3 要求
Tracer
成为子类数组。这打破了继承模型,因为它将要求Tracer
对象携带Array
对象的所有“负担”(数据缓冲区、分片、设备等)。选项 4 在 OOP 意义上更纯粹,并且避免了对典型实例检查或类型注释行为进行任何覆盖。权衡是设备上数组的实际类型变得独立(此处为
jax._src.ArrayImpl
)。但绝大多数用户永远不必直接接触这种私有实现。
这里有不同的权衡,但在讨论之后,我们选择了选项 4 作为我们的前进方向。
实施计划#
为了推进类型注释,我们将执行以下操作:
迭代此 JEP 文档,直到开发者和利益相关者达成共识。
创建私有的
jax._src.typing
(暂时不提供任何公共 API),并在其中放入上述第一层简单类型:暂时将
Array = Any
别名,因为这还需要一些思考。ArrayLike
:作为正常jax.numpy
函数输入有效类型的联合。DType
/DTypeLike
(注意:numpy 使用驼峰式DType
;我们应该遵循此约定以方便使用)。Shape
/NamedShape
/ShapeLike
这项工作的开端已在#12300中完成。
开始着手创建
jax.Array
基类,该类将遵循前一节中的选项 4。最初这将用 Python 定义,并使用jnp.ndarray
实现中当前发现的动态注册机制,以确保isinstance
检查的正确行为。pyi
为每个 tracer 和类数组类进行的覆盖将确保类型注释的正确行为。jnp.ndarray
随后可以成为jax.Array
的别名。作为测试,使用这些新的类型定义,根据上述指南全面注释
jax.lax
中的函数。继续逐个模块添加额外的注释,重点关注公共 API 函数。
同时,开始在 pybind11 中重新实现一个
jax.Array
基类,以便ArrayImpl
和Tracer
可以继承它。使用pyi
定义来确保静态类型检查器识别类的适当属性。一旦
jax.Array
和jax._src.ArrayImpl
完全实现,移除这些临时的 Python 实现。所有工作完成后,创建一个公共的
jax.typing
模块,将上述类型提供给用户,并附带使用 JAX 代码的注释最佳实践文档。
我们将在#12049中跟踪这项工作,本 JEP 的编号也由此而来。