JAX 类型注解路线图#

  • 作者:jakevdp

  • 日期:2022 年 8 月

背景#

Python 3.0 引入了可选的函数注解(PEP 3107),后来在 Python 3.5 发布时将其规范化用于静态类型检查(PEP 484)。在某种程度上,类型注解和静态类型检查已成为许多 Python 开发工作流程的组成部分,为此,我们在 JAX API 的许多地方都添加了注解。JAX 中类型注解的当前状态有些零散,而添加更多注解的努力则因更基本的设计问题而受阻。本文档试图总结这些问题,并为 JAX 中类型注解的目标和非目标生成一个路线图。

为什么我们需要这样的路线图?改进/更全面的类型注解是用户(内部和外部)频繁提出的请求。此外,我们经常收到外部用户的拉取请求(例如,PR #9917PR #10322),希望改进 JAX 的类型注解:JAX 团队成员在审查代码时,并不总是清楚此类贡献是否会有益,尤其是在它们引入复杂的协议来解决 JAX 使用 Python 的完全注解所固有的挑战时。本文档详细介绍了 JAX 在该包内类型注解的目标和建议。

为什么需要类型注解?#

Python 项目可能希望注解其代码库的原因有很多;我们将在本文档中将其总结为级别 1、级别 2 和级别 3。

级别 1:注解作为文档#

最初在 PEP 3107 中引入时,类型注解的动机部分是能够将它们用作函数参数类型和返回类型的简洁、内联文档。JAX 长期以来一直以这种方式利用注解;一个例子是创建别名为 Any 的类型名称的常见模式。一个例子可以在 lax/slicing.py 中找到 [来源]

Array = Any
Shape = core.Shape

def slice(operand: Array, start_indices: Sequence[int],
          limit_indices: Sequence[int],
          strides: Optional[Sequence[int]] = None) -> Array:
  ...

对于静态类型检查的目的,将 Array = Any 用于数组类型注解不会对参数值施加任何约束(Any 等同于没有注解),但它确实为开发人员提供了有用的代码内文档形式。

为了生成的文档,别名的名称会丢失(jax.lax.sliceHTML 文档将操作数报告为 Any 类型),因此文档的好处仅限于源代码(尽管我们可以启用一些 sphinx-autodoc 选项来改进这一点:请参阅 autodoc_type_aliases)。

这种级别的类型注解的一个好处是,用 Any 注解一个值永远不会出错,因此它将以文档的形式为开发人员和用户提供切实的益处,而不会增加满足任何特定静态类型检查器更严格需求的复杂性。

级别 2:注解用于智能自动补全#

许多现代 IDE 利用类型注解作为 智能代码补全 系统的输入。其中一个例子是 VSCode 的 Pylance 扩展,它使用 Microsoft 的 pyright 静态类型检查器作为 VSCode IntelliSense 补全的信息来源。

这种类型的检查使用需要比上面简单的别名更进一步;例如,知道 slice 函数返回一个名为 ArrayAny 的别名,并不会为代码补全引擎增加任何有用信息。然而,如果我们使用 DeviceArray 返回类型注解该函数,自动补全将知道如何填充结果的命名空间,从而能够在开发过程中提供更相关的自动补全。

JAX 已开始在一些地方添加这种级别的类型注解;一个例子是 jax.random 包内的 jnp.ndarray 返回类型 [来源]

def shuffle(key: KeyArray, x: Array, axis: int = 0) -> jnp.ndarray:
  ...

在这种情况下,jnp.ndarray 是一个抽象基类,它前向声明了 JAX 数组的属性和方法(参见来源),因此 VSCode 中的 Pylance 可以为该函数的返回结果提供完整的自动补全集。下面是显示结果的屏幕截图

VSCode Intellisense Screenshot

在自动补全字段中列出了抽象 ndarray 类声明的所有方法和属性。我们将在下文进一步讨论为什么有必要创建这个抽象类而不是直接用 DeviceArray 进行注解。

级别 3:注解用于静态类型检查#

如今,静态类型检查通常是人们在考虑 Python 代码中类型注解目的时首先想到的。虽然 Python 不会进行任何运行时类型检查,但存在几种成熟的静态类型检查工具,它们可以在 CI 测试套件中进行此操作。对 JAX 最重要的工具如下:

  • python/mypy 在开放的 Python 世界中或多或少是标准。JAX 目前在 Github Actions CI 检查中使用 mypy 来检查部分源代码文件。

  • google/pytype 是 Google 的静态类型检查器,依赖 JAX 的 Google 内部项目经常使用它。

  • microsoft/pyright 很重要,因为它是前面提到的 Pylance 补全中使用的 VSCode 的静态类型检查器。

完整的静态类型检查是所有类型注解应用中最严格的,因为它会在类型注解不完全正确时报告错误。一方面,这很好,因为静态类型分析可以捕获错误的类型注解(例如,jnp.ndarray 抽象类中缺少 DeviceArray 方法的情况)。

另一方面,这种严格性可能会使类型检查过程在经常依赖“鸭子类型”而不是严格类型安全 API 的包中变得非常脆弱。您目前会在 JAX 代码库的数百个地方找到类似 #type: ignore(针对 mypy)或 #pytype: disable(针对 pytype)的代码注释。这些通常代表了出现类型问题的 cases;它们可能是 JAX 类型注解中的不准确之处,或者是静态类型检查器正确跟踪代码控制流能力的不准确之处。偶尔,它们是由于 pytype 或 mypy 行为中的实际且微妙的错误。在极少数情况下,它们可能是由于 JAX 使用了难以甚至不可能用 Python 的静态类型注解语法表达的 Python 模式。

JAX 类型注解面临的挑战#

JAX 目前的类型注解混合了不同的样式,并针对上述所有三个级别的类型注解。部分原因是 JAX 的源代码对 Python 的类型注解系统带来了一些独特的挑战。我们在此概述它们。

挑战 1:pytype、mypy 和开发者摩擦#

JAX 目前面临的一个挑战是,包开发必须满足两个不同静态类型检查系统的约束,pytype(由内部 CI 和 Google 内部项目使用)和 mypy(由外部 CI 和外部依赖项使用)。尽管这两个类型检查器在行为上广泛重叠,但每个都有其独特的边界情况,正如 JAX 代码库中大量的 #type: ignore#pytype: disable 语句所证明的。

这会在开发中造成摩擦:内部贡献者可能会一直迭代直到测试通过,却发现在导出后,他们 pytype 批准的代码会触犯 mypy。对于外部贡献者,情况通常相反:最近的一个例子是 #9596,在内部 Google pytype 检查失败后不得不回滚。每次我们将一个类型注解从级别 1(到处是 Any)移到级别 2 或 3(更严格的注解)时,都会增加发生这种令人沮丧的开发体验的潜在可能性。

挑战 2:数组的“鸭子类型”#

注解 JAX 代码的一个特定挑战是其大量使用“鸭子类型”。函数(标记为 Array)的输入通常可以是多种类型之一:JAX 的 DeviceArray、NumPy 的 np.ndarray、NumPy 标量、Python 标量、Python 序列、带有 __array__ 属性的对象、带有 __jax_array__ 属性的对象,或者任何一种 jax.Tracer。因此,像 def func(x: DeviceArray) 这样的简单注解将不足够,并且会导致许多有效用法出现误报。这意味着 JAX 函数的类型注解将不会简短或琐碎,但我们实际上需要开发一套 JAX 特定的类型扩展,类似于 numpy.typing 中的。

挑战 3:转换和装饰器#

JAX 的 Python API 严重依赖函数转换(jit()vmap()grad() 等),而这种 API 类型对静态类型分析构成了特殊挑战。装饰器的灵活注解一直是 mypy 包中一个长期存在的问题,直到最近通过引入 ParamSpec 得到了解决,该问题在 PEP 612 中进行了讨论,并在 Python 3.10 中添加。由于 JAX 遵循 NEP 29,它不能依赖 Python 3.10 的功能,直到 2024 年中旬之后。在此期间,协议可以作为一种部分解决方案(JAX 在 #9950 中为此类方法和 jit 添加了此功能),并且可以通过 typing_extensions 包使用 ParamSpec(一个原型在 #9999 中),尽管这目前暴露了 mypy 中的根本性错误(参见 python/mypy#12593)。所有这些都说明:目前尚不清楚 JAX 函数转换的 API 是否能在当前 Python 类型注解工具的约束下得到充分注解。

挑战 4:数组注解的粒度不足#

这里的另一个挑战是 Python 中所有面向数组的 API 的常见挑战,并且已成为 JAX 讨论的一部分多年(参见 #943)。类型注解与对象的 Python 类或类型有关,而在面向数组的语言中,类的属性通常更重要。在 NumPy、JAX 和类似包的情况下,我们经常希望注解特定的数组形状和数据类型。

例如,jnp.linspace 函数的参数必须是标量值,但在 JAX 中标量由零维数组表示。因此,为了使注解不引发误报,我们必须允许这些参数为任意数组。另一个例子是 jax.random.choice 的第二个参数,当 shape=() 时,它必须具有 dtype=int。Python 有一个计划通过可变类型泛型(参见 PEP 646,计划在 Python 3.11 中发布)来实现具有这种粒度的类型注解,但与 ParamSpec 一样,对该功能的支持需要一段时间才能稳定。

在此期间,有一些第三方项目可能会有所帮助,特别是 google/jaxtyping,但它使用非标准注解,并且可能不适合注解核心 JAX 库本身。总而言之,数组类型粒度挑战不如其他挑战那么严重,因为主要影响是类数组注解的特异性不如它们原本可能的那样。

挑战 5:继承自 NumPy 的不精确 API#

JAX 的大部分用户可见 API 是在 jax.numpy 子模块中从 NumPy 继承的。NumPy 的 API 在静态类型检查成为 Python 语言一部分之前多年就已开发,并且遵循 Python 早期关于使用“鸭子类型”/EAFP 编码风格的建议,其中不鼓励运行时严格类型检查。以 numpy.tile() 函数为例,它的定义如下:

def tile(A, reps):
  try:
    tup = tuple(reps)
  except TypeError:
    tup = (reps,)
  d = len(tup)
  ...

这里,意图reps 将包含一个 int 或一个 int 值序列,但实现允许 tup 是任何可迭代对象。在为这种“鸭子类型”代码添加注解时,我们可以选择以下两种方式之一:

  1. 我们可以选择注解 API 的意图,这里可能是 reps: Union[int, Sequence[int]]

  2. 相反,我们可以选择注解函数的实现,这里可能是 reps: Union[ConvertibleToInt, Iterable[ConvertibleToInt]],其中 ConvertibleToInt 是一个特殊的协议,涵盖了我们的函数将输入转换为整数的精确机制(即通过 __int__、通过 __index__、通过 __array__ 等)。另请注意,严格来说,Iterable 在这里是不够的,因为 Python 中存在一些对象,它们在“鸭子类型”上是可迭代的,但在与 Iterable 的静态类型检查中不满足(即,一个对象可以通过 __getitem__ 而不是 __iter__ 来迭代)。

选择 #1(注解意图)的优点是注解对于用户来说更能传达 API 合约;而对于开发人员,这种灵活性在必要时为重构留下了空间。缺点(特别是对于 JAX 这样的渐进式类型 API)是,很可能存在实际上能运行但会被类型检查器标记为不正确的用户代码。对现有“鸭子类型” API 进行渐进式类型化意味着当前的注解是隐式 Any,因此将其更改为更严格的类型可能会被用户视为破坏性更改。

广义而言,注解意图更好地服务于级别 1 类型检查,而注解实现则更好地服务于级别 3,而级别 2 则更加混合(在 IDE 注解方面,意图和实现都很重要)。

JAX 类型注解路线图#

有了这个框架(级别 1/2/3)以及 JAX 特有的挑战,我们可以开始制定在整个 JAX 项目中实现一致类型注解的路线图。

指导原则#

对于 JAX 类型注解,我们将遵循以下原则:

类型注解的目的#

我们希望尽可能支持完整的级别 1、2 和 3 类型注解。特别是,这意味着我们应该对公共 API 函数的输入和输出都有严格的类型注解。

为意图注解#

JAX 类型注解通常应指示 API 的**意图**,而不是实现,以便注解能够用于传达 API 的合约。这意味着有时在运行时有效的输入可能不会被静态类型检查器识别为有效(一个例子可能是将任意迭代器作为形状的替代,该形状注解为 Shape = Sequence[int])。

输入应具有宽松的类型#

JAX 函数和方法的输入应具有合理宽松的类型:例如,虽然形状通常是元组,但接受形状的函数应接受任意序列。类似地,接受 dtype 的函数不一定需要 np.dtype 类的实例,而是任何可转换为 dtype 的对象。这可能包括字符串、内置标量类型或标量对象构造函数,如 np.float64jnp.float64。为了使整个包的这种一致性尽可能高,我们将添加一个 jax.typing 模块,其中包含常见的类型规范,首先是广泛的类别,例如:

  • ArrayLike 将是任何可以隐式转换为数组的内容的联合:例如,jax 数组、numpy 数组、JAX 追踪器以及 Python 或 numpy 标量。

  • DTypeLike 将是任何可以隐式转换为 dtype 的内容(例如,numpy dtypes、numpy dtype 对象、jax dtype 对象、字符串和内置类型)的联合。

  • ShapeLike 将是任何可以转换为形状的内容的联合:例如,整数或类整数对象的序列。

  • 等等。

请注意,这些通常会比 numpy.typing 中使用的等效协议更简单。例如,对于 DTypeLike,JAX 不支持结构化 dtype,因此 JAX 可以使用更简单的实现。同样,在 ArrayLike 中,JAX 通常不支持列表或元组输入代替数组,因此类型定义将比 NumPy 类似物更简单。

输出应具有严格的类型#

相反,函数和方法的输出应尽可能严格地进行类型注解:例如,对于返回数组的 JAX 函数,输出应注解为类似于 jnp.ndarray 而不是 ArrayLike。返回 dtype 的函数应始终注解为 np.dtype,返回形状的函数应始终为 Tuple[int] 或严格类型的 NamedShape 等效项。为此,我们将在 jax.typing 中实现上述宽松类型的几个严格类型等效项,即:

  • ArrayNDArray(见下文)用于类型注解,实际上等同于 Union[Tracer, jnp.ndarray],并应用于注解数组输出。

  • DTypenp.dtype 的别名,可能还可以表示 JAX 中使用的键类型和其他泛化。

  • Shape 基本上是 Tuple[int, ...],可能具有一些额外的灵活性来适应动态形状。

  • NamedShapeShape 的扩展,允许使用 JAX 内部使用的命名形状。

  • 等等。

我们将探讨 jax.numpy.ndarray 的当前实现是否可以被删除,以便使 ndarray 成为 Array 或类似的别名。

倾向于简洁#

除了 jax.typing 中收集的常见类型协议外,我们应倾向于简洁。我们应避免为传递给 API 函数的参数构建过于复杂的协议,而应使用简单的联合,如 Union[simple_type, Any],在 API 的完整类型规范无法简洁指定的情况下。这是一种折衷,可以实现级别 1 和 2 注解的目标,同时通过避免不必要的复杂性来推迟级别 3。

避免不稳定的类型机制#

为了不增加不当的开发摩擦(由于内部/外部 CI 的差异),我们希望在使用的类型注解构造方面保持保守:特别是,对于最近引入的机制,如 ParamSpecPEP 612)和可变类型泛型(PEP 646),我们希望等到 mypy 和其他工具的支持成熟并稳定后再依赖它们。

这其中一个影响是,目前,当函数被 jitvmapgrad 等 JAX 转换装饰时,JAX 将有效地**剥离被装饰函数的所有注解**。虽然这很遗憾,但在撰写本文时,mypy 对 ParamSpec 提供的潜在解决方案有一长串不兼容问题(参见 ParamSpec mypy bug 跟踪器),因此我们认为目前不适合在 JAX 中完全采用它。一旦对这些功能的支持稳定下来,我们将重新审视这个问题。

同样,目前我们将避免添加 jaxtyping 项目提供的更复杂和更精细的数组类型注解。这是一个我们可以以后重新考虑的决定。

Array 类型设计考虑因素#

如上所述,JAX 中数组的类型注解由于 JAX 广泛使用“鸭子类型”而带来独特的挑战,即在 jax 转换中,在实际数组的位置传递和返回 Tracer 对象。这变得越来越令人困惑,因为用于类型注解的对象经常与用于运行时实例检查的对象重叠,并且可能与对象的实际类型层次结构不对应,也可能对应。对于 JAX,我们需要为两种上下文提供“鸭子类型”对象:**静态类型注解**和**运行时实例检查**。

接下来的讨论将假定 jax.Array 是设备上数组的运行时类型,目前尚未实现,但一旦 #12016 中的工作完成,它就会是。

静态类型注解#

我们需要提供一个可用于“鸭子类型”的类型注解的对象。假设我们暂时称该对象为 ArrayAnnotation,我们需要一个解决方案来满足 mypypytype 在以下情况下的需求:

@jit
def f(x: ArrayAnnotation) -> ArrayAnnotation:
  assert isinstance(x, core.Tracer)
  return x

这可以通过多种方法实现,例如:

  • 使用类型联合:ArrayAnnotation = Union[Array, Tracer]

  • 创建一个接口文件,声明 TracerArray 应被视为 ArrayAnnotation 的子类。

  • 重构 ArrayTracer,使 ArrayAnnotation 成为两者的真正基类。

运行时实例检查#

我们还必须提供一个可用于“鸭子类型”的运行时 isinstance 检查的对象。假设我们暂时称该对象为 ArrayInstance,我们需要一个解决方案来通过以下运行时检查:

def f(x):
  return isinstance(x, ArrayInstance)
x = jnp.array([1, 2, 3])
assert f(x)       # x will be an array
assert jit(f)(x)  # x will be a tracer

同样,有几种机制可以用于此:

  • 重写 type(ArrayInstance).__instancecheck__ 以返回 True,用于 ArrayTracer 对象;jnp.ndarray 目前就是这样实现的(来源)。

  • ArrayInstance 定义为抽象基类,并动态将其注册到 ArrayTracer

  • 重构 ArrayTracer,使 ArrayInstance 成为 ArrayTracer 的真正基类。

我们需要做出一个决定,即 ArrayAnnotationArrayInstance 应该是相同的对象还是不同的对象。这里有一些先例;例如,在核心 Python 语言规范中,typing.Dicttyping.List 是为了注解而存在的,而内置的 dictlist 则用于实例检查。然而,DictList 在较新版本的 Python 中已被弃用,转而使用 dictlist 进行注解和实例检查。

遵循 NumPy 的做法#

在 NumPy 的情况下,np.typing.NDArray 用于类型注解,而 np.ndarray 用于实例检查(以及数组类型标识)。鉴于此,遵循 NumPy 的先例并实现以下内容可能是合理的:

  • jax.Array 是设备上数组的实际类型。

  • jax.typing.NDArray 是用于“鸭子类型”数组注解的对象。

  • jax.numpy.ndarray 是用于“鸭子类型”数组实例检查的对象。

这可能对 NumPy 的重度用户来说感觉有些自然,但这种三叉分叉很可能会引起混淆:选择哪一个用于实例检查和注解并不立即清楚。

统一实例检查和注解#

另一种方法是通过上述重写机制来统一类型检查和注解。

选项 1:部分统一#

部分统一可能看起来像这样:

  • jax.Array 是设备上数组的实际类型。

  • jax.typing.Array 是用于“鸭子类型”数组注解的对象(通过 ArrayTracer 上的 .pyi 接口)。

  • jax.typing.Array 也是用于“鸭子类型”实例检查的对象(通过其元类中的 __isinstance__ 重写)。

在此方法中,jax.numpy.ndarray 将成为 jax.typing.Array 的简单别名,以兼容以前的版本。

选项 2:通过重写完全统一#

或者,我们可以选择通过重写完全统一:

  • jax.Array 是设备上数组的实际类型。

  • jax.Array 也是用于“鸭子类型”数组注解的对象(通过 Tracer 上的 .pyi 接口)。

  • jax.Array 也是用于“鸭子类型”实例检查的对象(通过其元类中的 __isinstance__ 重写)。

在这里,jax.numpy.ndarray 将成为 jax.Array 的简单别名,以兼容以前的版本。

选项 3:通过类层次结构完全统一#

最后,我们可以选择通过重构类层次结构并用 OOP 对象层次结构替换“鸭子类型”来完全统一:

  • jax.Array 是设备上数组的实际类型。

  • jax.Array 也是用于数组类型注解的对象,通过确保 Tracer 继承自 jax.Array

  • jax.Array 也是用于实例检查的对象,通过相同的机制。

在这里,jnp.ndarray 可以是 jax.Array 的别名。这个最终方法在某种意义上是最纯粹的,但它在 OOP 设计方面有些牵强(Tracer *是*一个 Array?)。

选项 4:通过类层次结构部分统一#

我们可以通过让 Tracer 和设备上数组的类继承自一个公共基类来使类层次结构更合理。例如:

  • jax.ArrayTracer 以及设备上数组的实际类型的基类,该类可能是 jax._src.ArrayImpl 或类似的。

  • jax.Array 是用于数组类型注解的对象。

  • jax.Array 也是用于实例检查的对象。

在这里,jnp.ndarray 将是 Array 的别名。这在 OOP 角度可能更纯粹,但与选项 2 和 3 相比,它放弃了 type(x) is jax.Array 评估为 True 的概念。

评估#

考虑到每种潜在方法的整体优缺点:

  • 从用户的角度来看,统一的方法(选项 2 和 3)可能是最好的,因为它们消除了记住使用哪个对象进行实例检查或注解的认知负担:jax.Array 是您需要知道的一切。

  • 然而,选项 2 和 3 都引入了一些奇怪和/或令人困惑的行为。选项 2 依赖于可能令人困惑的实例检查重写,而这些重写对于 pybind11 定义的类支持不佳。选项 3 要求 Tracer 成为数组的子类。这破坏了继承模型,因为它要求 Tracer 对象携带 Array 对象的所有负担(数据缓冲区、分片、设备等)。

  • 选项 4 在 OOP 方面更纯粹,并且避免了对典型实例检查或类型注解行为进行任何重写的需要。权衡是设备上数组的实际类型成为一个独立的东西(这里是 jax._src.ArrayImpl)。但绝大多数用户永远不需要直接接触这个私有实现。

这里有不同的权衡,但在讨论后,我们已决定 Option 4 是我们前进的方向。

实施计划#

为了推进类型注解,我们将执行以下操作:

  1. 迭代此 JEP 文档,直到开发人员和利益相关者达成一致。

  2. 创建一个私有的 jax._src.typing(目前不提供任何公共 API),并将上面提到的第一组简单类型放入其中。

    • 暂时将 Array = Any 作为别名,这需要更多思考。

    • ArrayLike:普通 jax.numpy 函数有效输入的类型联合。

    • DType / DTypeLike(注意:numpy 使用驼峰式 DType;为了方便使用,我们应遵循此约定)。

    • Shape / NamedShape / ShapeLike

    这方面的一些基础工作已在 #12300 中完成。

  3. 开始构建一个 jax.Array 基类,该基类遵循上一节的选项 4。最初将在 Python 中定义,并使用当前 jnp.ndarray 实现中的动态注册机制,以确保 isinstance 检查的正确行为。一个 pyi 覆盖(针对每个 tracer 和类数组类)将确保类型注解的正确行为。然后,jnp.ndarray 可以成为 jax.Array 的别名。

  4. 作为测试,使用这些新的类型定义,根据上述指南全面注解 jax.lax 中的函数。

  5. 一次一个模块地继续添加更多注解,重点关注公共 API 函数。

  6. 同时,开始用 pybind11 重写 jax.Array 基类,以便 ArrayImplTracer 可以继承它。使用 pyi 定义来确保静态类型检查器识别类的适当属性。

  7. 一旦 jax.Arrayjax._src.ArrayImpl 完全落地,则移除这些临时的 Python 实现。

  8. 所有内容最终确定后,创建一个公共 jax.typing 模块,使上述类型可供用户使用,并提供有关使用 JAX 的代码的注解最佳实践的文档。

我们将在 #12049 中跟踪此工作,JEP 号码也由此而来。