JAX 类型注释路线图#

  • 作者:jakevdp

  • 日期:2022 年 8 月

背景#

Python 3.0 引入了可选的函数注释(PEP 3107),后来在 Python 3.5 发布前后(PEP 484)被规范化用于静态类型检查。在某种程度上,类型注释和静态类型检查已成为许多 Python 开发工作流程不可或缺的一部分,为此,我们在 JAX API 的许多地方添加了注释。JAX 中类型注释的当前状态有些零散,并且由于更基本的设计问题,添加更多注释的努力受到了阻碍。本文档试图总结这些问题,并为 JAX 中类型注释的目标和非目标制定路线图。

为什么我们需要这样的路线图?更好/更全面的类型注释是用户(无论是内部还是外部)经常提出的要求。此外,我们经常收到外部用户的拉取请求(例如,PR #9917PR #10322),旨在改进 JAX 的类型注释:JAX 团队成员在审查代码时并不总是清楚这些贡献是否有益,特别是当它们引入复杂的协议来解决完全注释 JAX 对 Python 使用的固有挑战时。本文档详细说明了 JAX 在包内类型注释方面的目标和建议。

为什么需要类型注释?#

Python 项目希望对其代码库进行注释的原因有很多;我们将在本文档中将其总结为级别 1、级别 2 和级别 3。

级别 1:将注释作为文档#

当最初在PEP 3107中引入时,类型注释的部分动机是能够将其用作函数参数类型和返回类型的简洁内联文档。JAX 长期以来一直以这种方式利用注释;一个例子是创建别名为Any的类型名称的常见模式。一个例子可以在lax/slicing.py中找到[source]

Array = Any
Shape = core.Shape

def slice(operand: Array, start_indices: Sequence[int],
          limit_indices: Sequence[int],
          strides: Optional[Sequence[int]] = None) -> Array:
  ...

为了静态类型检查的目的,这种将 Array = Any 用于数组类型注释的做法对参数值不施加任何约束(Any 等同于不进行任何注释),但它确实为开发者提供了有用的代码内文档形式。

对于生成的文档而言,别名会被丢失(jax.lax.sliceHTML 文档报告操作数类型为Any),因此文档的好处不会超出源代码(尽管我们可以启用一些sphinx-autodoc选项来改善这一点:参见autodoc_type_aliases)。

这种级别的类型注释的一个好处是,用Any注释一个值永远不会出错,因此它将以文档的形式为开发者和用户提供具体的益处,而无需增加满足任何特定静态类型检查器更严格需求的复杂性。

级别 2:用于智能自动完成的注释#

许多现代 IDE 利用类型注释作为智能代码完成系统的输入。其中一个例子是 VSCode 的Pylance扩展,它使用微软的pyright静态类型检查器作为 VSCode IntelliSense 完成的信息源。

这种类型检查的使用要求比上面使用的简单别名更进一步;例如,知道slice函数返回一个名为ArrayAny的别名,并不会给代码完成引擎添加任何有用的信息。然而,如果我们将函数注释为DeviceArray返回类型,自动完成将知道如何填充结果的命名空间,从而能够在开发过程中建议更相关的自动完成。

JAX 已开始在一些地方添加此级别的类型注释;一个例子是jax.random包内的jnp.ndarray返回类型 [source]

def shuffle(key: KeyArray, x: Array, axis: int = 0) -> jnp.ndarray:
  ...

在这种情况下,jnp.ndarray是一个抽象基类,它预先声明了 JAX 数组的属性和方法(见源代码),因此 VSCode 中的 Pylance 可以提供该函数结果的完整自动完成集。以下是显示结果的屏幕截图

VSCode Intellisense Screenshot

自动完成字段中列出了抽象ndarray类声明的所有方法和属性。我们将在下面进一步讨论为什么有必要创建这个抽象类而不是直接使用DeviceArray进行注释。

级别 3:用于静态类型检查的注释#

如今,当考虑 Python 代码中类型注释的目的时,静态类型检查通常是人们首先想到的。虽然 Python 不进行任何运行时类型检查,但存在几种成熟的静态类型检查工具,可以在 CI 测试套件中完成此操作。对于 JAX,最重要的工具如下:

  • python/mypy或多或少是开放 Python 世界的标准。JAX 目前在 Github Actions CI 检查中对部分源文件运行 mypy。

  • google/pytype是 Google 的静态类型检查器,Google 内部依赖 JAX 的项目经常使用它。

  • microsoft/pyright很重要,它是 VSCode 中用于前面提到的 Pylance 完成的静态类型检查器。

全面的静态类型检查是所有类型注释应用中最严格的,因为它会在您的类型注释不完全正确时立即报错。一方面,这很好,因为您的静态类型分析可能会捕获错误的类型注释(例如,jnp.ndarray抽象类中缺少DeviceArray方法的情况)。

另一方面,这种严格性会使那些经常依赖鸭子类型而非严格类型安全 API 的软件包的类型检查过程变得非常脆弱。您目前会在 JAX 代码库中数百个地方找到类似#type: ignore(针对 mypy)或#pytype: disable(针对 pytype)的代码注释。这些通常表示出现了类型问题;它们可能是 JAX 类型注释中的不准确之处,或者是静态类型检查器无法正确遵循代码中控制流的不准确之处。偶尔,它们是由于 pytype 或 mypy 行为中真实而微妙的错误造成的。在极少数情况下,它们可能是由于 JAX 使用了 Python 模式,这些模式难以甚至不可能用 Python 的静态类型注释语法来表达。

JAX 的类型注释挑战#

JAX 当前的类型注释混合了不同的风格,并且旨在满足上述所有三个级别的类型注释。部分原因在于 JAX 的源代码对 Python 的类型注释系统提出了一些独特的挑战。我们在此概述它们。

挑战 1:pytype、mypy 和开发者摩擦#

JAX 目前面临的一个挑战是,包开发必须满足两种不同静态类型检查系统(pytype(内部 CI 和内部 Google 项目使用)和mypy(外部 CI 和外部依赖项使用))的约束。尽管这两种类型检查器在行为上存在广泛重叠,但每种都呈现其独特的边缘情况,JAX 代码库中大量的#type: ignore#pytype: disable语句就证明了这一点。

这在开发中造成了摩擦:内部贡献者可能会迭代直到测试通过,却发现导出的经过 pytype 批准的代码与 mypy 冲突。对于外部贡献者来说,情况往往相反:最近的一个例子是#9596,它在未能通过内部 Google pytype 检查后不得不回滚。每次我们将类型注释从级别 1(Any 无处不在)移动到级别 2 或 3(更严格的注释),就会增加这种令人沮丧的开发者体验的可能性。

挑战 2:数组鸭子类型化#

注释 JAX 代码的一个特殊挑战是其大量使用鸭子类型。函数中标记为Array的输入通常可以是许多不同类型之一:JAX DeviceArray、NumPy np.ndarray、NumPy 标量、Python 标量、Python 序列、具有__array__属性的对象、具有__jax_array__属性的对象,或任何类型的jax.Tracer。因此,简单的注释,如def func(x: DeviceArray)将不足够,并且会导致许多有效用法出现误报。这意味着 JAX 函数的类型注释不会简短或微不足道,但我们必须有效地开发一套 JAX 特定的类型扩展,类似于numpy.typing中的那些。

挑战 3:转换和装饰器#

JAX 的 Python API 严重依赖于函数转换(jit()vmap()grad() 等),这种类型的 API 对静态类型分析构成了特殊挑战。装饰器的灵活注释一直是 mypy 包中的一个长期存在的问题,直到最近才通过引入ParamSpec得到解决,该问题在PEP 612中讨论并添加到 Python 3.10 中。由于 JAX 遵循NEP 29,它在 2024 年年中之前无法依赖 Python 3.10 功能。在此期间,协议(Protocols)可以作为部分解决方案(JAX 在#9950中为 jit 和其他方法添加了此功能),并且 ParamSpec 可以通过typing_extensions包使用(一个原型在#9999中),尽管这目前揭示了 mypy 中的基本错误(参见python/mypy#12593)。所有这些都表明:目前尚不清楚 JAX 函数转换的 API 是否可以在当前 Python 类型注释工具的约束下进行适当注释。

挑战 4:数组注释缺乏粒度#

另一个挑战是 Python 中所有面向数组的 API 普遍存在的问题,并且在 JAX 的讨论中已持续数年(参见#943)。类型注释与对象的 Python 类或类型有关,而在基于数组的语言中,通常类的属性更为重要。对于 NumPy、JAX 和类似包,我们通常希望注释特定的数组形状和数据类型。

例如,jnp.linspace函数的参数必须是标量值,但在 JAX 中,标量由零维数组表示。因此,为了使注释不引发误报,我们必须允许这些参数是任意数组。另一个例子是jax.random.choice的第二个参数,当shape=()时,它必须具有dtype=int。Python 计划通过可变类型泛型(参见PEP 646,计划用于 Python 3.11)实现这种粒度的类型注释,但与ParamSpec一样,对此功能的支持需要一段时间才能稳定。

在此期间,一些第三方项目可能会有所帮助,特别是google/jaxtyping,但这使用了非标准注释,可能不适合注释 JAX 核心库本身。总而言之,数组类型粒度挑战不如其他挑战那么重要,因为主要影响是数组类注释的特异性将低于其本应有的水平。

挑战 5:从 NumPy 继承的不精确 API#

JAX 面向用户的 API 大部分是从jax.numpy子模块中的 NumPy 继承的。NumPy 的 API 在静态类型检查成为 Python 语言一部分之前多年就已经开发出来了,它遵循 Python 历史上推荐的鸭子类型/EAFP编码风格,其中不鼓励在运行时进行严格的类型检查。作为一个具体的例子,考虑numpy.tile()函数,它的定义如下:

def tile(A, reps):
  try:
    tup = tuple(reps)
  except TypeError:
    tup = (reps,)
  d = len(tup)
  ...

这里,意图reps应包含一个int或一个int值的序列,但实现允许tup是任何可迭代对象。在为这种鸭子类型代码添加注释时,我们可以采取两种途径之一:

  1. 我们可以选择注释函数 API 的意图,这里可能是reps: Union[int, Sequence[int]]之类的。

  2. 相反,我们可以选择注释函数的实现,这里可能看起来像reps: Union[ConvertibleToInt, Iterable[ConvertibleToInt]],其中ConvertibleToInt是一个特殊协议,涵盖了函数将输入转换为整数的确切机制(即通过__int__,通过__index__,通过__array__等)。这里还需要注意的是,严格来说,Iterable在这里是不够的,因为 Python 中有些对象在鸭子类型上是可迭代的,但不能通过Iterable进行静态类型检查(即,通过__getitem__而不是__iter__进行迭代的对象)。

选项 1(注释意图)的优点是,注释在传达 API 契约方面对用户更有用;而对于开发者来说,灵活性在必要时为重构留下了空间。缺点(特别是对于像 JAX 这样渐进类型化的 API)是,很可能存在用户代码运行正确,但会被类型检查器标记为不正确的情况。对现有鸭子类型 API 进行渐进类型化意味着当前注释隐式为Any,因此将其更改为更严格的类型可能会对用户造成破坏性更改。

广义上讲,注释意图更好地服务于级别 1 类型检查,而注释实现更好地服务于级别 3,而级别 2 则是一个混合体(意图和实现在 IDE 中的注释方面都很重要)。

JAX 类型注释路线图#

有了这个框架(级别 1/2/3)并考虑到 JAX 特有的挑战,我们就可以开始制定在 JAX 项目中实现一致类型注释的路线图。

指导原则#

对于 JAX 类型注释,我们将遵循以下原则:

类型注释的目的#

我们希望尽可能支持完整的级别 1、2 和 3 类型注释。特别是,这意味着我们应该对公共 API 函数的输入和输出都进行限制性类型注释。

为意图添加注释#

JAX 类型注释通常应表明 API 的意图,而不是实现,以便注释对于传达 API 契约变得有用。这意味着有时在运行时有效的输入可能不会被静态类型检查器识别为有效(一个例子可能是任意迭代器传递给被注释为Shape = Sequence[int]的形状)。

输入应采用宽松类型#

JAX 函数和方法的输入应尽可能宽松地类型化:例如,虽然形状通常是元组,但接受形状的函数应接受任意序列。同样,接受 dtype 的函数无需要求np.dtype类的实例,而是任何可转换为 dtype 的对象。这可能包括字符串、内置标量类型或标量对象构造函数,例如np.float64jnp.float64。为了使整个包尽可能统一,我们将添加一个jax.typing模块,其中包含常见的类型规范,从以下宽泛类别开始:

  • ArrayLike将是任何可以隐式转换为数组的类型的联合:例如,jax 数组、numpy 数组、JAX tracer,以及 python 或 numpy 标量

  • DTypeLike将是任何可以隐式转换为 dtype 的类型的联合:例如,numpy dtypes、numpy dtype 对象、jax dtype 对象、字符串和内置类型。

  • ShapeLike将是任何可以转换为形状的类型的联合:例如,整数或类整数对象的序列。

  • 等等

请注意,这些通常会比numpy.typing中使用的等效协议更简单。例如,在DTypeLike的情况下,JAX 不支持结构化 dtype,因此 JAX 可以使用更简单的实现。同样,在ArrayLike中,JAX 通常不支持列表或元组输入代替数组,因此类型定义将比 NumPy 对应的类型更简单。

输出应采用严格类型#

相反,函数和方法的输出应尽可能严格类型化:例如,对于返回数组的 JAX 函数,输出应使用类似于jnp.ndarray而不是ArrayLike进行注释。返回 dtype 的函数应始终注释为np.dtype,返回形状的函数应始终为Tuple[int]或严格类型化的 NamedShape 等效项。为此,我们将在jax.typing中实现几个上述宽松类型的严格类型化模拟,即:

  • ArrayNDArray (见下文) 用于类型注释目的,实际上等同于 Union[Tracer, jnp.ndarray],应用于注释数组输出。

  • DTypenp.dtype的别名,可能还具有表示 JAX 内部使用的键类型和其他泛化的能力。

  • Shape本质上是Tuple[int, ...],可能带有一些额外的灵活性以适应动态形状。

  • NamedShapeShape的扩展,允许使用 JAX 内部使用的命名形状。

  • 等等

我们还将探讨是否可以放弃jax.numpy.ndarray的当前实现,转而将ndarray作为Array或类似对象的别名。

倾向于简单性#

除了jax.typing中收集的常见类型协议之外,我们应该倾向于简单性。我们应该避免为传递给 API 函数的参数构建过于复杂的协议,而应使用简单的联合类型,例如在无法简洁指定 API 的完整类型规范的情况下使用Union[simple_type, Any]。这是一种折衷方案,它实现了级别 1 和级别 2 注释的目标,同时为了避免不必要的复杂性而放弃了级别 3。

避免不稳定的类型机制#

为了避免增加不必要的开发摩擦(由于内部/外部 CI 差异),我们希望在使用的类型注释构造方面保持保守:特别是对于最近引入的机制,如ParamSpecPEP 612)和可变类型泛型(PEP 646),我们希望等到 mypy 和其他工具中的支持成熟并稳定后再依赖它们。

其中一个影响是,暂时而言,当函数被 JAX 转换(如jitvmapgrad等)装饰时,JAX 将有效地剥离被装饰函数的所有注释。虽然这很不幸,但在撰写本文时,mypy 与ParamSpec提供的潜在解决方案存在一系列不兼容问题(参见ParamSpec mypy bug tracker),因此我们认为目前尚未准备好在 JAX 中全面采用。我们将在未来支持此类功能稳定后重新审视这个问题。

同样,目前我们将避免添加jaxtyping项目提供的更复杂和更细粒度的数组类型注释。这个决定我们可以在未来重新审视。

Array 类型设计考量#

如上所述,JAX 中数组的类型注释带来了独特的挑战,因为 JAX 大量使用鸭子类型,即在 JAX 转换中传递和返回Tracer对象来代替实际数组。这变得越来越令人困惑,因为用于类型注释的对象通常与用于运行时实例检查的对象重叠,并且可能与所讨论对象的实际类型层次结构对应或不对应。对于 JAX,我们需要在两种情况下提供鸭子类型对象:静态类型注释运行时实例检查

以下讨论将假设jax.Array是设备上数组的运行时类型,目前情况并非如此,但一旦#12016的工作完成,情况就会如此。

静态类型注释#

我们需要提供一个可用于鸭子类型注释的对象。假设我们暂时称此对象为ArrayAnnotation,我们需要一个满足mypypytype的解决方案,例如以下情况:

@jit
def f(x: ArrayAnnotation) -> ArrayAnnotation:
  assert isinstance(x, core.Tracer)
  return x

这可以通过多种方法实现,例如:

  • 使用类型联合:ArrayAnnotation = Union[Array, Tracer]

  • 创建一个接口文件,声明TracerArray应被视为ArrayAnnotation的子类。

  • 重构ArrayTracer,使ArrayAnnotation成为两者的真正基类。

运行时实例检查#

我们还必须提供一个可用于鸭子类型运行时isinstance检查的对象。假设我们暂时称此对象为ArrayInstance,我们需要一个通过以下运行时检查的解决方案:

def f(x):
  return isinstance(x, ArrayInstance)
x = jnp.array([1, 2, 3])
assert f(x)       # x will be an array
assert jit(f)(x)  # x will be a tracer

同样,有几种机制可以用于此:

  • 覆盖type(ArrayInstance).__instancecheck__,使其对ArrayTracer对象返回True;这是jnp.ndarray当前实现的方式(source)。

  • ArrayInstance定义为一个抽象基类,并将其动态注册到ArrayTracer

  • 重构ArrayTracer,使ArrayInstance成为ArrayTracer两者的真实基类。

我们需要决定ArrayAnnotationArrayInstance是否应该相同或不同。这里有一些先例;例如,在核心 Python 语言规范中,typing.Dicttyping.List是为了注释而存在,而内置的dictlist则用于实例检查。然而,DictList在较新的 Python 版本中已被弃用,转而使用dictlist进行注释和实例检查。

遵循 NumPy 的领先#

在 NumPy 的案例中,np.typing.NDArray用于类型注释,而np.ndarray用于实例检查(以及数组类型标识)。鉴于此,遵循 NumPy 的先例并实施以下方案可能是合理的:

  • jax.Array是设备上数组的实际类型。

  • jax.typing.NDArray是用于鸭子类型数组注释的对象。

  • jax.numpy.ndarray是用于鸭子类型数组实例检查的对象。

这可能让 NumPy 资深用户感到有些自然,但这种三分法很可能会造成混淆:选择哪一个用于实例检查和注释并不立即清楚。

统一实例检查和注释#

另一种方法是通过上述覆盖机制统一类型检查和注释。

选项 1:部分统一#

部分统一可能如下所示:

  • jax.Array是设备上数组的实际类型。

  • jax.typing.Array是用于鸭子类型数组注释的对象(通过ArrayTracer上的.pyi接口)。

  • jax.typing.Array也是用于鸭子类型实例检查的对象(通过其元类中的__isinstance__覆盖)。

在这种方法中,jax.numpy.ndarray将成为jax.typing.Array的简单别名,以实现向后兼容性。

选项 2:通过覆盖实现完全统一#

或者,我们可以选择通过覆盖实现完全统一:

  • jax.Array是设备上数组的实际类型。

  • jax.Array也用于鸭子类型数组注释(通过Tracer上的.pyi接口)

  • jax.Array也用于鸭子类型实例检查(通过其元类中的__isinstance__覆盖)

这里,jax.numpy.ndarray将成为jax.Array的简单别名,以实现向后兼容性。

选项 3:通过类层次结构实现完全统一#

最后,我们可以选择通过重构类层次结构并用面向对象编程(OOP)对象层次结构替换鸭子类型来实现完全统一:

  • jax.Array是设备上数组的实际类型。

  • jax.Array也用于数组类型注释,通过确保Tracer继承自jax.Array

  • jax.Array也通过相同的机制用于实例检查。

这里jnp.ndarray可以是jax.Array的别名。这种最终方法在某些意义上是最纯粹的,但从 OOP 设计的角度来看有些牵强(Tracer一个Array吗?)。

选项 4:通过类层次结构实现部分统一#

我们可以通过让Tracer和设备上数组的类继承自一个共同的基类,从而使类层次结构更加合理。例如:

  • jax.ArrayTracer以及设备上数组的实际类型(可能是jax._src.ArrayImpl或类似)的基类。

  • jax.Array是用于数组类型注释的对象。

  • jax.Array也用于实例检查。

这里jnp.ndarray将是Array的别名。这在 OOP 角度可能更纯粹,但与选项 2 和 3 相比,它放弃了type(x) is jax.Array将评估为 True 的概念。

评估#

考虑到每种潜在方法的总体优缺点:

  • 从用户的角度来看,统一方法(选项 2 和 3)可以说是最好的,因为它们消除了记住哪些对象用于实例检查或注释的认知开销:您只需要知道jax.Array

  • 然而,选项 2 和 3 都引入了一些奇怪和/或令人困惑的行为。选项 2 依赖于可能令人困惑的实例检查覆盖,这对于在 pybind11 中定义的类来说支持不佳。选项 3 要求Tracer成为子类数组。这打破了继承模型,因为它将要求Tracer对象携带Array对象的所有“负担”(数据缓冲区、分片、设备等)。

  • 选项 4 在 OOP 意义上更纯粹,并且避免了对典型实例检查或类型注释行为进行任何覆盖。权衡是设备上数组的实际类型变得独立(此处为jax._src.ArrayImpl)。但绝大多数用户永远不必直接接触这种私有实现。

这里有不同的权衡,但在讨论之后,我们选择了选项 4 作为我们的前进方向。

实施计划#

为了推进类型注释,我们将执行以下操作:

  1. 迭代此 JEP 文档,直到开发者和利益相关者达成共识。

  2. 创建私有的jax._src.typing(暂时不提供任何公共 API),并在其中放入上述第一层简单类型:

    • 暂时将Array = Any别名,因为这还需要一些思考。

    • ArrayLike:作为正常jax.numpy函数输入有效类型的联合。

    • DType / DTypeLike(注意:numpy 使用驼峰式DType;我们应该遵循此约定以方便使用)。

    • Shape / NamedShape / ShapeLike

    这项工作的开端已在#12300中完成。

  3. 开始着手创建jax.Array基类,该类将遵循前一节中的选项 4。最初这将用 Python 定义,并使用jnp.ndarray实现中当前发现的动态注册机制,以确保isinstance检查的正确行为。pyi为每个 tracer 和类数组类进行的覆盖将确保类型注释的正确行为。jnp.ndarray随后可以成为jax.Array的别名。

  4. 作为测试,使用这些新的类型定义,根据上述指南全面注释jax.lax中的函数。

  5. 继续逐个模块添加额外的注释,重点关注公共 API 函数。

  6. 同时,开始在 pybind11 中重新实现一个jax.Array基类,以便ArrayImplTracer可以继承它。使用pyi定义来确保静态类型检查器识别类的适当属性。

  7. 一旦jax.Arrayjax._src.ArrayImpl完全实现,移除这些临时的 Python 实现。

  8. 所有工作完成后,创建一个公共的jax.typing模块,将上述类型提供给用户,并附带使用 JAX 代码的注释最佳实践文档。

我们将在#12049中跟踪这项工作,本 JEP 的编号也由此而来。