JAX 类型提升语义设计#

Open in Colab Open in Kaggle

Jake VanderPlas,2021 年 12 月

任何数值计算库在设计时面临的挑战之一,是如何处理不同类型值之间的运算。本文概述了 JAX 所使用的提升语义背后的思考过程,并在 JAX 类型提升语义 中进行了总结。

JAX 类型提升的目标#

JAX 的数值计算 API 是参照 NumPy 建模的,并进行了一些增强,包括能够针对 GPU 和 TPU 等加速器进行优化。这使得采用 NumPy 的类型提升系统对 JAX 用户而言存在不利之处:NumPy 的类型提升规则严重偏向于 64 位输出,这对加速器上的计算来说是有问题的。GPU 和 TPU 等设备在使用 64 位浮点类型时往往会付出显著的性能代价,在某些情况下甚至根本不支持原生 64 位浮点类型。

在 32 位整数和浮点数之间的二元运算中,可以看到这种有问题的类型提升语义的一个简单例子

import numpy as np
np.dtype(np.int32(1) + np.float32(1))
dtype('float64')

NumPy 倾向于产生 64 位值的特性,是在使用 NumPy API 进行加速计算时的一个 长期存在的问题,目前尚无完美的解决方案。出于这个原因,JAX 试图从加速器的角度重新思考 NumPy 风格的类型提升。

回顾:表格与格(Lattices)#

在深入细节之前,让我们花点时间回顾并思考一下该如何看待“类型提升”这个问题。考虑 Python 内置数值类型(即 intfloatcomplex)之间的算术运算。通过几行代码,我们可以生成 Python 在这些类型值之间进行加法运算时所使用的类型提升表

import pandas as pd
types = [int, float, complex]
name = lambda t: t.__name__
pd.DataFrame([[name(type(t1(1) + t2(1))) for t1 in types] for t2 in types],
             index=[name(t) for t in types], columns=[name(t) for t in types])
int float complex
int int float complex
float float float complex
complex complex complex complex

该表列举了 Python 的数值类型提升行为,但事实证明,有一种更简洁的补充表示法:格(Lattice)表示法。在格中,任意两个节点的上确界(supremum)即为它们提升后的类型。Python 提升表的格表示法要简单得多

隐藏代码单元源

#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {'int': ['float'], 'float': ['complex']}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {'int': [0, 0], 'float': [1, 0], 'complex': [2, 0]}
fig, ax = plt.subplots(figsize=(8, 2))
nx.draw(graph, with_labels=True, node_size=4000, node_color='lightgray', pos=pos, ax=ax, arrowsize=20)
../_images/818a3cf499d15c3be1d4c116db142da0418c174873f21e1ffcde679c6058f918.png

此格是对上述提升表中所含信息的紧凑编码。你可以通过追踪图表找到两个输入的第一个共同子节点(包括节点本身),从而得出类型提升的结果;在数学上,这个共同子节点被称为这对节点在格上的上确界最小上界并(join);在这里,我们将此操作称为并(join)

从概念上讲,箭头表示源类型和目标类型之间允许隐式类型提升:例如,允许从整数到浮点数的隐式提升,但不允许从浮点数到整数的隐式提升。

请记住,通常并非每个有向无环图(DAG)都能满足格的属性。格要求每一对节点都必须存在唯一的最小上界;因此,例如以下两个 DAG 就不是格

隐藏代码单元源

#@title
import networkx as nx
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 2, figsize=(10, 2))

lattice = {'A': ['B', 'C']}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {'A': [0, 0], 'B': [1, 0.5], 'C': [1, -0.5]}
nx.draw(graph, with_labels=True, node_size=2000, node_color='lightgray', pos=pos, ax=ax[0], arrowsize=20)
ax[0].set(xlim=[-0.5, 1.5], ylim=[-1, 1])

lattice = {'A': ['C', 'D'], 'B': ['C', 'D']}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {'A': [0, 0.5], 'B': [0, -0.5], 'C': [1, 0.5], 'D': [1, -0.5]}
nx.draw(graph, with_labels=True, node_size=2000, node_color='lightgray', pos=pos, ax=ax[1], arrowsize=20)
ax[1].set(xlim=[-0.5, 1.5], ylim=[-1, 1]);
../_images/a0acbd07f9486d95c10a36c11301d528fb7e65d671d622226151c431b3e36c62.png

左侧的 DAG 不是格,因为节点 BC 不存在上界;右侧的 DAG 在两点上不成立:首先,节点 CD 不存在上界,并且对于节点 AB,其最小上界无法唯一确定:CD 都是候选者,但它们之间无法排序。

类型提升格的属性#

以格的形式指定类型提升可确保许多有用的属性。用 \(\vee\) 运算符表示格上的并运算,我们有

存在性: 根据定义,格要求每一对元素都存在唯一的格并运算:\(\forall (a, b): \exists !(a \vee b)\)

交换律: 格并运算满足交换律:\(\forall (a, b): a\vee b = b \vee a\)

结合律: 格并运算满足结合律:\(\forall (a, b, c): a \vee (b \vee c) = (a \vee b) \vee c\)

另一方面,这些属性对它们所能表示的类型提升系统提出了限制;特别是并非每个类型提升表都可以由格来表示。一个现成的例子是 NumPy 的完整类型提升表;这可以通过反例迅速证明:以下是三个标量类型,它们在 NumPy 中的提升行为是非结合的

import numpy as np
a, b, c = np.int8(1), np.uint8(1), np.float16(1)
print(np.dtype((a + b) + c))
print(np.dtype(a + (b + c)))
float32
float16

这样的结果可能会让用户感到惊讶:我们通常期望数学表达式映射到数学概念,因此,例如 a + b + c 应该等同于 c + b + ax * (y + z) 应该等同于 x * y + x * z。如果类型提升不满足结合律或交换律,这些属性将不再适用。

此外,与基于表格的系统相比,基于格的类型提升系统在概念上更简单且更易于理解。例如,JAX 识别 18 种不同的类型:一个包含 18 个节点且节点间连接稀疏、理由充分的提升格,远比一个包含 324 个条目的表格更容易记忆。

因此,我们选择为 JAX 使用基于格的类型提升系统。

类别内的类型提升#

数值计算库通常不仅提供 intfloatcomplex;在这些类别中的每一个内部,都有各种可能的精度,用数值表示中所使用的位数来表示。我们在此考虑的类别包括

  • 无符号整数,包括 uint8uint16uint32uint64(简称为 u8, u16, u32, u64

  • 有符号整数,包括 int8int16int32int64(简称为 i8, i16, i32, i64

  • 浮点数,包括 float16float32float64(简称为 f16, f32, f64

  • 复数浮点数,包括 complex64complex128(简称为 c64, c128

Numpy 在这四个类别内部的类型提升语义相对简单:有序的类型层级直接转化为四个独立的格,分别表示类别内的类型提升规则

隐藏代码单元源

#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'u8': ['u16'], 'u16': ['u32'], 'u32': ['u64'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
  'f16': ['f32'], 'f32': ['f64'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'u8': [0, 0], 'u16': [1, 0], 'u32': [2, 0], 'u64': [3, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [1, 2], 'f32': [2, 2], 'f64': [3, 2],
  'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 4))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/2d8495bcb006c34b42eeb4f3e0c6530fdef0bd7364c56184993925f0cf157abc.png

就 JAX 试图避免的值提升至 64 位的情况而言,这些类型类别内部的同类提升语义是没有问题的:产生 64 位输出的唯一途径是拥有 64 位输入。

Python 标量登场#

现在让我们思考一下 Python 标量在其中处于什么位置。

在 NumPy 中,提升行为取决于输入是数组还是标量。例如,当对两个标量进行运算时,适用普通的提升规则

x = np.int8(0)  # int8 scalar
y = 1  # Python int = int64 scalar
(x + y).dtype
dtype('int64')

这里 Python 值 1 被视为 int64,简单的类别内规则导致了 int64 的结果。

然而,在 Python 标量与 NumPy 数组之间的运算中,标量会遵循数组的数据类型(dtype)。例如

x = np.zeros(1, dtype='int8')  # int8 array
y = 1  # Python int = int64 scalar
(x + y).dtype
dtype('int8')

这里 int64 标量的位宽被忽略,转而遵循数组的位宽。

这里还有一个细节:当 NumPy 类型提升涉及标量时,输出 dtype 取决于值:如果 Python 标量对于给定的 dtype 而言过大,它会被提升为兼容的类型

x = np.zeros(1, dtype='int8')  # int8 array
y = 1000  # int64 scalar
(x + y).dtype
dtype('int16')

对于 JAX 而言,值依赖的提升是不可接受的,因为 JIT 编译及其他转换的性质决定了它们作用于数据的抽象表示,而不引用其具体值。

忽略值依赖的影响,NumPy 类型提升的有符号整数分支可以用以下格表示,其中我们使用 * 来标记标量 dtype

隐藏代码单元源

#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i8*': ['i16*'], 'i16*': ['i32*'], 'i32*': ['i64*'], 'i64*': ['i8'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i8*': [0, 1], 'i16*': [2, 1], 'i32*': [4, 1], 'i64*': [6, 1],
  'i8': [9, 1], 'i16': [11, 1], 'i32': [13, 1], 'i64': [15, 1],
}
fig, ax = plt.subplots(figsize=(12, 4))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
ax.text(3, 1.6, "Scalar Types", ha='center', fontsize=14)
ax.text(12, 1.6, "Array Types", ha='center', fontsize=14)
ax.set_ylim(-1, 3);
../_images/7e8c3295e403209560d8e142c5c830d79456a4e6d207dd1a7e4d15b55c56006b.png

类似的模式也存在于 uintfloatcomplex 格中。

为了简单起见,我们将每一类标量类型折叠为一个单一节点,分别记为 u*i*f*c*。我们现在的类别内格集合可以表示为

隐藏代码单元源

#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'u*': ['u8'], 'u8': ['u16'], 'u16': ['u32'], 'u32': ['u64'],
  'i*': ['i8'], 'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
  'f*': ['f16'], 'f16': ['f32'], 'f32': ['f64'],
  'c*': ['c64'], 'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'u*': [0, 0], 'u8': [3, 0], 'u16': [5, 0], 'u32': [7, 0], 'u64': [9, 0],
  'i*': [0, 1], 'i8': [3, 1], 'i16': [5, 1], 'i32': [7, 1], 'i64': [9, 1],
  'f*': [0, 2], 'f16': [5, 2], 'f32': [7, 2], 'f64': [9, 2],
  'c*': [0, 3], 'c64': [7, 3], 'c128': [9, 3],
}
fig, ax = plt.subplots(figsize=(6, 4))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/0fbe0c20cd350821e64f3742aa7864ec729565572b136950042095881672fdb9.png

从某种意义上说,将标量放在左侧是一个奇怪的选择:标量类型可能包含任何宽度的值,但当与特定类型的数组交互时,提升结果会遵循数组类型。这样做的好处是,当你对数组 x 执行类似 x + 2 的操作时,无论 x 的宽度如何,其类型都会延续到结果中

for dtype in [np.int8, np.int16, np.int32, np.int64]:
  x = np.arange(10, dtype=dtype)
  assert (x + 2).dtype == dtype

这种行为激励了我们对标量值使用 * 符号:* 会让人联想到可以取任何所需值的通配符。

这种语义的好处在于,你可以用简洁的 Python 代码轻松表达一系列操作,而无需显式地将标量转换为适当的类型。试想一下,如果不用写成这样

3 * (x + 1) ** 2

而是不得不写成这样

np.int32(3) * (x + np.int32(1)) ** np.int32(2)

尽管这很明确,但数值代码读写起来会变得非常冗长。有了上述标量提升语义,对于类型为 int32 的数组 x,第二条语句中的类型在第一条语句中就是隐含的。

组合格#

回想一下,我们通过引入表示 Python 内部类型提升的格 int -> float -> complex 开始了我们的讨论。让我们将其重写为 i* -> f* -> c*,并进一步允许 i* 包含 u*(毕竟 Python 中没有无符号整数标量类型)。

将它们整合在一起,我们得到以下部分格,它表示 Python 标量与 numpy 数组之间的类型提升

隐藏代码单元源

#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16'], 'u16': ['u32'], 'u32': ['u64'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
  'f16': ['f32'], 'f32': ['f64'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],
  'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/796586be87180b0de3171d39763f2d33a80a641b72d82c00f0c0e352f754f201.png

请注意,这(目前)还不是一个真正的格:有许多节点对不存在并运算(join)。然而,我们可以将其视为一个部分格(partial lattice),其中某些节点对没有定义提升行为,而这个部分格的已定义部分确实正确描述了 NumPy 的数组提升行为(撇开上述提到的值依赖语义不谈)。

这建立了一个很好的框架,我们可以通过在该图上增加连接来思考如何完善这些未定义的提升规则。但应该增加哪些连接呢?广义上讲,我们希望任何额外的连接都能满足几个属性

  1. 提升应满足交换律和结合律:换句话说,该图应保持为一个(部分)格。

  2. 提升绝不应导致丢失数据的整体组件:例如,我们绝不应将 complex 提升为 float,因为它会丢弃任何虚部。

  3. 提升绝不应导致未处理的溢出。例如,最大的 uint32 是最大 int32 的两倍,因此我们不应隐式地将 uint32 提升为 int32

  4. 尽可能地,提升应避免精度损失。例如,int64 值可能拥有 64 位的尾数,因此将 int64 提升为 float64 代表了潜在的精度损失。然而,可表示的最大 float64 大于可表示的最大 int64,因此在这种情况下,准则 #3 仍然满足。

  5. 尽可能地,二元提升应避免产生比输入更宽的类型。这是为了确保 JAX 的隐式提升对基于加速器的工作流保持友好,用户通常希望将类型限制为 32 位(或在某些情况下为 16 位)值。

格上的每个新连接都会给用户带来一定程度的便利(无需显式转换即可交互的一组新类型),但如果违反了上述任何准则,这种便利可能会变得代价高昂。开发一个完整的提升格需要在这种便利性和代价之间取得平衡。

混合提升:浮点数与复数#

让我们从也许是最简单的情况开始,即浮点数和复数之间的提升。

复数由一对浮点数组成,因此我们拥有它们之间自然的提升路径:将浮点数转换为复数,同时保持实部的宽度。在我们的部分格表示法中,它看起来像这样

隐藏代码单元源

#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16'], 'u16': ['u32'], 'u32': ['u64'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
  'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],
  'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/bf87909b2344aed80590d1c6d91585a02b25898ac217526cb49948d91205318f.png

事实证明,这准确地代表了 Numpy 在混合浮点数/复数类型提升中所使用的语义。

混合提升:有符号与无符号整数#

对于下一个情况,让我们考虑一些更困难的问题:有符号和无符号整数之间的提升。例如,当将 uint8 提升为有符号整数时,我们需要多少位?

乍一看,你可能觉得将 uint8 提升为 int8 很自然;但最大的 uint8 数字无法在 int8 中表示。因此,将无符号整数提升为具有两倍位数的整数更有意义;这种提升行为可以通过在提升格中添加以下连接来表示

隐藏代码单元源

#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
  'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],
  'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/3be7e17889458ac823bb5dacf31525c0d96578c6854962f45dcc60ec987a30bd.png

同样,这里添加的连接正是 Numpy 为混合整数提升所实现的提升语义。

如何处理 uint64#

混合有符号/无符号整数提升的方法遗漏了一种类型:uint64。遵循上述模式,涉及 uint64 的混合整数运算的输出应导致 int128,但这并不是一种标准的可用 dtype。

Numpy 在此处的选择是提升为 float64

(np.uint64(1) + np.int64(1)).dtype
dtype('float64')

然而,这可能是一个令人惊讶的惯例:这是唯一一种整数类型的提升不产生整数的情况。目前,我们将 uint64 的提升保持为未定义,稍后再回到这个问题。

混合提升:整数与浮点数#

在将整数提升至浮点数时,我们可能会从与有符号和无符号整数混合提升相同的思考过程开始。16 位有符号或无符号整数无法以全精度由 16 位浮点数(仅有 10 位尾数)表示。因此,将整数提升为具有两倍位数表示的浮点数可能是有意义的

隐藏代码单元源

#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16', 'i16', 'f16'], 'u16': ['u32', 'i32', 'f32'], 'u32': ['u64', 'i64', 'f64'],
  'i8': ['i16', 'f16'], 'i16': ['i32', 'f32'], 'i32': ['i64', 'f64'],
  'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],
  'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/8b3247e8189fbfad46a7e5583b636866fc45576e07c9bfd904457926306299d1.png

这实际上就是 Numpy 类型提升所做的,但在这样做时,它打破了图的格属性:例如,对 {i8, u8} 而言不再存在唯一的最小上界:可能性是 i16f16,它们在图上无法排序。事实证明,这就是上述 NumPy 非结合类型提升的来源。

我们能否对 NumPy 的提升规则进行修改,使其在满足格属性的同时,又能为混合类型提升给出合理的结果?我们可以在这里采取几种方法。

方案 0:保持整数/浮点数混合精度未定义#

为了使行为完全可预测(以牺牲一定的用户便利性为代价),一个可辩护的选择是,将除 Python 标量之外的任何混合整数/浮点数提升保持为未定义,止步于上一节的部分格。缺点是用户在操作整数和浮点数时需要显式地进行类型转换。

方案 1:避免所有精度损失#

如果我们专注于不惜一切代价避免精度损失,我们可以通过将无符号整数通过其现有的有符号整数路径提升为浮点数来恢复格属性

隐藏代码单元源

#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],
  'i8': ['i16', 'f16'], 'i16': ['i32', 'f32'], 'i32': ['i64', 'f64'],
  'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],
  'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/1eda89d008a8c6dadf926229bf9f2245722006c5bc1c42961c555a2595c95117.png

这种方法的一个缺点是它仍然将 int64uint64 的提升保持为未定义,因为没有标准浮点类型具有足够的尾数位来表示它们的全值范围。我们可以放宽精度约束,通过从 i64->f64u64->f64 绘制连接来补全格,但这些链接将与该提升方案的动机背道而驰。

第二个缺点是,这个格使得在保持格属性的同时,难以找到一个合理的位置来插入 bfloat16(见下文)。

这种方法的第三个缺点(对 JAX 的加速器后端更为重要)是,某些操作产生的结果类型比必要的宽得多;例如 uint16float16 之间的混合操作将一直提升到 float64,这并不理想。

方案 2:避免大多数不必要的更宽类型提升#

为了解决向更宽类型不必要提升的问题,我们可以接受在整数/浮点数提升中可能出现的一些精度损失,将有符号整数提升为相同宽度的浮点数

隐藏代码单元源

#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],
  'i8': ['i16'], 'i16': ['f16', 'i32'], 'i32': ['f32', 'i64'], 'i64': ['f64'],
  'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [1.5, 2], 'f32': [2.5, 2], 'f64': [3.5, 2],
  'c64': [3, 3], 'c128': [4, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/f41cee38a476bf636be901e7f64a5dc3687002f9d12532ab706b9077d602b175.png

虽然这确实允许在整数和浮点数之间进行精度损失型的提升,但这些提升不会误导结果的量级(magnitude):尽管浮点尾数不够宽以表示所有值,但指数足够宽以对它们进行近似。

这种方法还允许从 int64float64 的自然提升路径,尽管 uint64 在此方案中仍然不可提升。话虽如此,与之前相比,从 u64f64 的连接在这里更容易得到证明。

这种提升方案仍然导致了一些不必要宽的提升路径;例如 float32uint32 之间的运算结果为 float64。此外,这个格使得在保持格属性的同时,难以找到一个合理的位置来插入 bfloat16(见下文)。

方案 3:避免所有不必要的更宽类型提升#

如果我们愿意彻底改变对整数和浮点数提升的看法,我们就可以避免所有不理想的 64 位提升。正如标量总是遵循数组类型的宽度一样,我们可以使整数总是遵循浮点类型的宽度

隐藏代码单元源

#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'], 'i64': ['f*'],
  'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [1.5, 2], 'f32': [2.5, 2], 'f64': [3.5, 2],
  'c64': [3, 3], 'c128': [4, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/d3f5e5be4354238a60698cb4f228d4e1f75a665577343c36b2c1ade1207783a0.png

这涉及一个小小的技巧:以前我们使用 f* 来指代标量类型。在这个格中,f* 可能被应用于混合计算的数组输出。我们不是将 f* 看作一个标量,而是将其看作一种具有独特提升规则的特殊 float 值:在 JAX 中,我们将其称为弱浮点数(weak float);见下文。

这种方法的好处是,除了无符号整数外,它避免了所有不必要更宽的提升:你不可能在没有 64 位输入的情况下获得 f64 输出,也不可能在没有 32 位输入的情况下获得 f32 输出:这为在加速器上工作提供了便利的语义,同时避免了意外的 64 位值。

这种赋予浮点类型主要地位的特性类似于 PyTorch 的类型提升行为。这个格碰巧也生成了一个非常接近 JAX 最初临时(ad hoc)类型提升方案的提升表,该方案虽然不是基于格的,但具有赋予浮点类型主要地位的属性。

这个格还提供了一个自然的插入 bfloat16 的位置,而无需在 bf16f16 之间强制排序

隐藏代码单元源

#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['u8', 'i8'], 'f*': ['c*', 'f16', 'bf16'], 'c*': ['c64'],
  'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'], 'i64': ['f*'],
  'f16': ['f32'], 'bf16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [1.8, 1.7], 'bf16': [1.8, 2.3], 'f32': [3.0, 2], 'f64': [4.0, 2],
  'c64': [3.5, 3], 'c128': [4.5, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/aa73688b580b02776fce218d6efe58792ae3b0976160a4b0c130b797780578af.png

这一点很重要,因为 f16bf16 不可比较,因为它们利用位的方式不同:bf16 以较低的精度表示较大的范围,而 f16 以较高的精度表示较小的范围。

然而,这些优点带来了一些权衡

  • 混合浮点数/整数提升非常容易导致精度损失:例如,int64(最大值为 \(9.2 \times 10^{18}\))可以提升为 float16(最大值为 \(6.5 \times 10^4\)),这意味着大多数可表示的值将变为 inf

  • 如上所述,f* 不能再被视为“标量类型”,而是 float64 的不同变体。在 JAX 的用语中,这被称为弱类型(weak type),因为它被表示为 64 位,但在与其他值进行提升时,仅微弱地保持此位宽。

请注意,这种方法仍然没有回答 uint64 的提升问题,尽管通过将 u64 连接到 f* 来闭合该格也许是合理的。

JAX 中的类型提升#

在设计 JAX 的类型提升语义时,我们牢记了其中许多想法,并很大程度上依赖于几点

  1. 我们选择将 JAX 的类型提升语义限制在满足格属性的图上:这是为了确保结合律和交换律,同时也为了允许语义在 DAG 中紧凑地描述,而不是需要一个巨大的表格。

  2. 我们倾向于避免意外提升到更宽类型的语义,特别是在涉及浮点值时,以便有利于加速器上的计算。

  3. 我们愿意接受混合类型提升中的潜在精度损失(但不接受量级损失),如果这是满足 (1) 和 (2) 所必需的话

考虑到这一点,JAX 采用了方案 3。或者更准确地说,是一个稍微修改过的方案 3 版本,它在 u64f* 之间绘制了连接,以创建一个真正的格。为了清晰起见,重新排列节点后,JAX 的类型提升格看起来像这样

隐藏代码单元源

#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['u8', 'i8'], 'f*': ['c*', 'f16', 'bf16'], 'c*': ['c64'],
  'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'], 'u64': ['f*'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'], 'i64': ['f*'],
  'f16': ['f32'], 'bf16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [4.5, 0.5], 'c*': [5, 1.5],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [5.75, 0.8], 'bf16': [5.75, 0.2], 'f32': [7, 0.5], 'f64': [8, 0.5],
  'c64': [7.5, 1.5], 'c128': [8.5, 1.5],
}
fig, ax = plt.subplots(figsize=(10, 4))
ax.set_ylim(-0.5, 2)
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
# ax.patches[12].set_linestyle((0, (2, 4)))
../_images/d261add493a579484d9772634ce146f1240af3966d0845839c354417a3de2e53.png

由此选择产生的行为总结在 JAX 类型提升语义 中。值得注意的是,除了包含较大的无符号类型(u16, u32, u64)以及关于标量/弱类型(i*, f*, c*)的一些行为细节外,该类型提升方案最终与 PyTorch 所选择的方案非常接近。

对于有兴趣的人,下方的附录打印了 NumPy、Tensorflow、PyTorch 和 JAX 使用的完整提升表。

附录:类型提升表示例#

以下是各种 Python 数组计算库实现的一些隐式类型提升表示例。

NumPy 类型提升#

请注意,NumPy 不包含 bfloat16 dtype,下表忽略了值依赖的影响。

隐藏代码单元源

# @title

import numpy as np
import pandas as pd
from IPython import display

np_dtypes = {
  'b': np.bool_,
  'u8': np.uint8, 'u16': np.uint16, 'u32': np.uint32, 'u64': np.uint64,
  'i8': np.int8, 'i16': np.int16, 'i32': np.int32, 'i64': np.int64,
  'bf16': 'invalid', 'f16': np.float16, 'f32': np.float32, 'f64': np.float64,
  'c64': np.complex64, 'c128': np.complex128,
  'i*': int, 'f*': float, 'c*': complex}

np_dtype_to_code = {val: key for key, val in np_dtypes.items()}

def make_np_zero(dtype):
  if dtype in {int, float, complex}:
    return dtype(0)
  else:
    return np.zeros(1, dtype=dtype)

def np_result_code(dtype1, dtype2):
  try:
    out = np.add(make_np_zero(dtype1), make_np_zero(dtype2))
  except TypeError:
    return '-'
  else:
    if type(out) in {int, float, complex}:
      return np_dtype_to_code[type(out)]
    else:
      return np_dtype_to_code[out.dtype.type]


grid = [[np_result_code(dtype1, dtype2)
         for dtype2 in np_dtypes.values()]
        for dtype1 in np_dtypes.values()]
table = pd.DataFrame(grid, index=np_dtypes.keys(), columns=np_dtypes.keys())
display.HTML(table.to_html())
b u8 u16 u32 u64 i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i* f* c*
b b u8 u16 u32 u64 i8 i16 i32 i64 - f16 f32 f64 c64 c128 i64 f64 c128
u8 u8 u8 u16 u32 u64 i16 i16 i32 i64 - f16 f32 f64 c64 c128 u8 f64 c128
u16 u16 u16 u16 u32 u64 i32 i32 i32 i64 - f32 f32 f64 c64 c128 u16 f64 c128
u32 u32 u32 u32 u32 u64 i64 i64 i64 i64 - f64 f64 f64 c128 c128 u32 f64 c128
u64 u64 u64 u64 u64 u64 f64 f64 f64 f64 - f64 f64 f64 c128 c128 u64 f64 c128
i8 i8 i16 i32 i64 f64 i8 i16 i32 i64 - f16 f32 f64 c64 c128 i8 f64 c128
i16 i16 i16 i32 i64 f64 i16 i16 i32 i64 - f32 f32 f64 c64 c128 i16 f64 c128
i32 i32 i32 i32 i64 f64 i32 i32 i32 i64 - f64 f64 f64 c128 c128 i32 f64 c128
i64 i64 i64 i64 i64 f64 i64 i64 i64 i64 - f64 f64 f64 c128 c128 i64 f64 c128
bf16 - - - - - - - - - - - - - - - - - -
f16 f16 f16 f32 f64 f64 f16 f32 f64 f64 - f16 f32 f64 c64 c128 f16 f16 c64
f32 f32 f32 f32 f64 f64 f32 f32 f64 f64 - f32 f32 f64 c64 c128 f32 f32 c64
f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 - f64 f64 f64 c128 c128 f64 f64 c128
c64 c64 c64 c64 c128 c128 c64 c64 c128 c128 - c64 c64 c128 c64 c128 c64 c64 c64
c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 - c128 c128 c128 c128 c128 c128 c128 c128
i* i64 u8 u16 u32 u64 i8 i16 i32 i64 - f16 f32 f64 c64 c128 i64 f64 c128
f* f64 f64 f64 f64 f64 f64 f64 f64 f64 - f16 f32 f64 c64 c128 f64 f64 c128
c* c128 c128 c128 c128 c128 c128 c128 c128 c128 - c64 c64 c128 c64 c128 c128 c128 c128

Tensorflow 类型提升#

除了在有限情况下的 Python 标量外,Tensorflow 避免定义隐式类型提升。该表是不对称的,因为在 tf.add(x, y) 中,y 的类型必须能够强制转换为 x 的类型。

隐藏代码单元源

# @title

import tensorflow as tf
import pandas as pd
from IPython import display

tf_dtypes = {
  'b': tf.bool,
  'u8': tf.uint8, 'u16': tf.uint16, 'u32': tf.uint32, 'u64': tf.uint64,
  'i8': tf.int8, 'i16': tf.int16, 'i32': tf.int32, 'i64': tf.int64,
  'bf16': tf.bfloat16, 'f16': tf.float16, 'f32': tf.float32, 'f64': tf.float64,
  'c64': tf.complex64, 'c128': tf.complex128,
  'i*': int, 'f*': float, 'c*': complex}

tf_dtype_to_code = {val: key for key, val in tf_dtypes.items()}

def make_tf_zero(dtype):
  if dtype in {int, float, complex}:
    return dtype(0)
  else:
    return tf.zeros(1, dtype=dtype)

def result_code(dtype1, dtype2):
  try:
    out = tf.add(make_tf_zero(dtype1), make_tf_zero(dtype2))
  except (TypeError, tf.errors.InvalidArgumentError):
    return '-'
  else:
    if type(out) in {int, float, complex}:
      return tf_dtype_to_code[type(out)]
    else:
      return tf_dtype_to_code[out.dtype]


grid = [[result_code(dtype1, dtype2)
         for dtype2 in tf_dtypes.values()]
        for dtype1 in tf_dtypes.values()]
table = pd.DataFrame(grid, index=tf_dtypes.keys(), columns=tf_dtypes.keys())
display.HTML(table.to_html())
b u8 u16 u32 u64 i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i* f* c*
b - - - - - - - - - - - - - - - - - -
u8 - u8 - - - - - - - - - - - - - u8 - -
u16 - - u16 - - - - - - - - - - - - u16 - -
u32 - - - u32 - - - - - - - - - - - u32 - -
u64 - - - - u64 - - - - - - - - - - u64 - -
i8 - - - - - i8 - - - - - - - - - i8 - -
i16 - - - - - - i16 - - - - - - - - i16 - -
i32 - - - - - - - i32 - - - - - - - i32 - -
i64 - - - - - - - - i64 - - - - - - i64 - -
bf16 - - - - - - - - - bf16 - - - - - bf16 bf16 -
f16 - - - - - - - - - - f16 - - - - f16 f16 -
f32 - - - - - - - - - - - f32 - - - f32 f32 -
f64 - - - - - - - - - - - - f64 - - f64 f64 -
c64 - - - - - - - - - - - - - c64 - c64 c64 c64
c128 - - - - - - - - - - - - - - c128 c128 c128 c128
i* - - - - - - - i32 - - - - - - - i32 - -
f* - - - - - - - - - - - f32 - - - f32 f32 -
c* - - - - - - - - - - - - - - c128 c128 c128 c128

PyTorch 类型提升#

注意,torch 不包含大于 uint8 的无符号整数类型。除了这一点以及关于与标量/弱类型提升的一些细节外,该表与 jax.numpy 使用的非常接近。

隐藏代码单元源

# @title
import torch
import pandas as pd
from IPython import display

torch_dtypes = {
  'b': torch.bool,
  'u8': torch.uint8, 'u16': 'invalid', 'u32': 'invalid', 'u64': 'invalid',
  'i8': torch.int8, 'i16': torch.int16, 'i32': torch.int32, 'i64': torch.int64,
  'bf16': torch.bfloat16, 'f16': torch.float16, 'f32': torch.float32, 'f64': torch.float64,
  'c64': torch.complex64, 'c128': torch.complex128,
  'i*': int, 'f*': float, 'c*': complex}

torch_dtype_to_code = {val: key for key, val in torch_dtypes.items()}

def make_torch_zero(dtype):
  if dtype in {int, float, complex}:
    return dtype(0)
  else:
    return torch.zeros(1, dtype=dtype)

def torch_result_code(dtype1, dtype2):
  try:
    out = torch.add(make_torch_zero(dtype1), make_torch_zero(dtype2))
  except TypeError:
    return '-'
  else:
    if type(out) in {int, float, complex}:
      return torch_dtype_to_code[type(out)]
    else:
      return torch_dtype_to_code[out.dtype]


grid = [[torch_result_code(dtype1, dtype2)
         for dtype2 in torch_dtypes.values()]
        for dtype1 in torch_dtypes.values()]
table = pd.DataFrame(grid, index=torch_dtypes.keys(), columns=torch_dtypes.keys())
display.HTML(table.to_html())
b u8 u16 u32 u64 i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i* f* c*
b b u8 - - - i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i64 f32 c64
u8 u8 u8 - - - i16 i16 i32 i64 bf16 f16 f32 f64 c64 c128 u8 f32 c64
u16 - - - - - - - - - - - - - - - - - -
u32 - - - - - - - - - - - - - - - - - -
u64 - - - - - - - - - - - - - - - - - -
i8 i8 i16 - - - i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i8 f32 c64
i16 i16 i16 - - - i16 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i16 f32 c64
i32 i32 i32 - - - i32 i32 i32 i64 bf16 f16 f32 f64 c64 c128 i32 f32 c64
i64 i64 i64 - - - i64 i64 i64 i64 bf16 f16 f32 f64 c64 c128 i64 f32 c64
bf16 bf16 bf16 - - - bf16 bf16 bf16 bf16 bf16 f32 f32 f64 c64 c128 bf16 bf16 c64
f16 f16 f16 - - - f16 f16 f16 f16 f32 f16 f32 f64 c64 c128 f16 f16 c64
f32 f32 f32 - - - f32 f32 f32 f32 f32 f32 f32 f64 c64 c128 f32 f32 c64
f64 f64 f64 - - - f64 f64 f64 f64 f64 f64 f64 f64 c128 c128 f64 f64 c128
c64 c64 c64 - - - c64 c64 c64 c64 c64 c64 c64 c128 c64 c128 c64 c64 c64
c128 c128 c128 - - - c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128
i* i64 u8 - - - i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i64 f32 c64
f* f32 f32 - - - f32 f32 f32 f32 bf16 f16 f32 f64 c64 c128 f32 f64 c64
c* c64 c64 - - - c64 c64 c64 c64 c64 c64 c64 c128 c64 c128 c64 c64 c128

JAX 类型提升:jax.numpy#

jax.numpy 遵循 https://jax.net.cn/en/latest/type_promotion.html 中制定的类型提升规则。此处我们使用 i*, f*, c* 来表示 Python 标量和弱类型数组。

隐藏代码单元源

# @title
import jax
import jax.numpy as jnp
import pandas as pd
from IPython import display
jax.config.update('jax_enable_x64', True)

jnp_dtypes = {
  'b': jnp.bool_.dtype,
  'u8': jnp.uint8.dtype, 'u16': jnp.uint16.dtype, 'u32': jnp.uint32.dtype, 'u64': jnp.uint64.dtype,
  'i8': jnp.int8.dtype, 'i16': jnp.int16.dtype, 'i32': jnp.int32.dtype, 'i64': jnp.int64.dtype,
  'bf16': jnp.bfloat16.dtype, 'f16': jnp.float16.dtype, 'f32': jnp.float32.dtype, 'f64': jnp.float64.dtype,
  'c64': jnp.complex64.dtype, 'c128': jnp.complex128.dtype,
  'i*': int, 'f*': float, 'c*': complex}


jnp_dtype_to_code = {val: key for key, val in jnp_dtypes.items()}

def make_jnp_zero(dtype):
  if dtype in {int, float, complex}:
    return dtype(0)
  else:
    return jnp.zeros((), dtype=dtype)

def jnp_result_code(dtype1, dtype2):
  try:
    out = jnp.add(make_jnp_zero(dtype1), make_jnp_zero(dtype2))
  except TypeError:
    return '-'
  else:
    if hasattr(out, 'aval') and out.aval.weak_type:
      return out.dtype.kind + '*'
    elif type(out) in {int, float, complex}:
      return jnp_dtype_to_code[type(out)]
    else:
      return jnp_dtype_to_code[out.dtype]

grid = [[jnp_result_code(dtype1, dtype2)
         for dtype2 in jnp_dtypes.values()]
        for dtype1 in jnp_dtypes.values()]
table = pd.DataFrame(grid, index=jnp_dtypes.keys(), columns=jnp_dtypes.keys())
display.HTML(table.to_html())
b u8 u16 u32 u64 i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i* f* c*
b b u8 u16 u32 u64 i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i* f* c*
u8 u8 u8 u16 u32 u64 i16 i16 i32 i64 bf16 f16 f32 f64 c64 c128 u8 f* c*
u16 u16 u16 u16 u32 u64 i32 i32 i32 i64 bf16 f16 f32 f64 c64 c128 u16 f* c*
u32 u32 u32 u32 u32 u64 i64 i64 i64 i64 bf16 f16 f32 f64 c64 c128 u32 f* c*
u64 u64 u64 u64 u64 u64 f* f* f* f* bf16 f16 f32 f64 c64 c128 u64 f* c*
i8 i8 i16 i32 i64 f* i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i8 f* c*
i16 i16 i16 i32 i64 f* i16 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i16 f* c*
i32 i32 i32 i32 i64 f* i32 i32 i32 i64 bf16 f16 f32 f64 c64 c128 i32 f* c*
i64 i64 i64 i64 i64 f* i64 i64 i64 i64 bf16 f16 f32 f64 c64 c128 i64 f* c*
bf16 bf16 bf16 bf16 bf16 bf16 bf16 bf16 bf16 bf16 bf16 f32 f32 f64 c64 c128 bf16 bf16 c64
f16 f16 f16 f16 f16 f16 f16 f16 f16 f16 f32 f16 f32 f64 c64 c128 f16 f16 c64
f32 f32 f32 f32 f32 f32 f32 f32 f32 f32 f32 f32 f32 f64 c64 c128 f32 f32 c64
f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 c128 c128 f64 f64 c128
c64 c64 c64 c64 c64 c64 c64 c64 c64 c64 c64 c64 c64 c128 c64 c128 c64 c64 c64
c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128
i* i* u8 u16 u32 u64 i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i* f* c*
f* f* f* f* f* f* f* f* f* f* bf16 f16 f32 f64 c64 c128 f* f* c*
c* c* c* c* c* c* c* c* c* c* c64 c64 c64 c128 c64 c128 c* c* c*

JAX 类型提升:jax.lax#

jax.lax 是更底层的,不做任何隐式提升。此处我们使用 i*, f*, c* 来表示 Python 标量和弱类型数组。

隐藏代码单元源

# @title
import jax
import jax.numpy as jnp
import pandas as pd
from IPython import display
jax.config.update('jax_enable_x64', True)

jnp_dtypes = {
  'b': jnp.bool_.dtype,
  'u8': jnp.uint8.dtype, 'u16': jnp.uint16.dtype, 'u32': jnp.uint32.dtype, 'u64': jnp.uint64.dtype,
  'i8': jnp.int8.dtype, 'i16': jnp.int16.dtype, 'i32': jnp.int32.dtype, 'i64': jnp.int64.dtype,
  'bf16': jnp.bfloat16.dtype, 'f16': jnp.float16.dtype, 'f32': jnp.float32.dtype, 'f64': jnp.float64.dtype,
  'c64': jnp.complex64.dtype, 'c128': jnp.complex128.dtype,
  'i*': int, 'f*': float, 'c*': complex}


jnp_dtype_to_code = {val: key for key, val in jnp_dtypes.items()}

def make_jnp_zero(dtype):
  if dtype in {int, float, complex}:
    return dtype(0)
  else:
    return jnp.zeros((), dtype=dtype)

def jnp_result_code(dtype1, dtype2):
  try:
    out = jax.lax.add(make_jnp_zero(dtype1), make_jnp_zero(dtype2))
  except TypeError:
    return '-'
  else:
    if hasattr(out, 'aval') and out.aval.weak_type:
      return out.dtype.kind + '*'
    elif type(out) in {int, float, complex}:
      return jnp_dtype_to_code[type(out)]
    else:
      return jnp_dtype_to_code[out.dtype]

grid = [[jnp_result_code(dtype1, dtype2)
         for dtype2 in jnp_dtypes.values()]
        for dtype1 in jnp_dtypes.values()]
table = pd.DataFrame(grid, index=jnp_dtypes.keys(), columns=jnp_dtypes.keys())
display.HTML(table.to_html())
b u8 u16 u32 u64 i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i* f* c*
b - - - - - - - - - - - - - - - - - -
u8 - u8 - - - - - - - - - - - - - - - -
u16 - - u16 - - - - - - - - - - - - - - -
u32 - - - u32 - - - - - - - - - - - - - -
u64 - - - - u64 - - - - - - - - - - - - -
i8 - - - - - i8 - - - - - - - - - - - -
i16 - - - - - - i16 - - - - - - - - - - -
i32 - - - - - - - i32 - - - - - - - - - -
i64 - - - - - - - - i64 - - - - - - i64 - -
bf16 - - - - - - - - - bf16 - - - - - - - -
f16 - - - - - - - - - - f16 - - - - - - -
f32 - - - - - - - - - - - f32 - - - - - -
f64 - - - - - - - - - - - - f64 - - - f64 -
c64 - - - - - - - - - - - - - c64 - - - -
c128 - - - - - - - - - - - - - - c128 - - c128
i* - - - - - - - - i64 - - - - - - i* - -
f* - - - - - - - - - - - - f64 - - - f* -
c* - - - - - - - - - - - - - - c128 - - c*