JAX 类型提升语义设计#
Jake VanderPlas,2021 年 12 月
任何数值计算库在设计时都会面临一个挑战:如何处理不同类型值之间的操作。本文档概述了 JAX 所用类型提升语义背后的思考过程,具体内容总结于 JAX 类型提升语义。
JAX 类型提升的目标#
JAX 的数值计算 API 借鉴了 NumPy 的模型,并进行了一些增强,包括支持 GPU 和 TPU 等加速器。这使得采用 NumPy 的类型提升系统对 JAX 用户不利:NumPy 的类型提升规则倾向于生成 64 位输出,这在加速器上进行计算时会带来问题。GPU 和 TPU 等设备在使用 64 位浮点类型时通常会付出显著的性能代价,在某些情况下甚至根本不支持原生的 64 位浮点类型。
这种有问题的类型提升语义的一个简单例子可以在 32 位整数和浮点数之间的二进制操作中看到
import numpy as np
np.dtype(np.int32(1) + np.float32(1))
dtype('float64')
NumPy 倾向于生成 64 位值的特性是使用 NumPy API 进行加速器计算的一个长期存在的问题,目前还没有很好的解决方案。因此,JAX 试图在考虑加速器的情况下重新思考 NumPy 风格的类型提升。
回顾:表格与格#
在我们深入细节之前,让我们花点时间退一步思考一下如何思考类型提升的问题。考虑 Python 内置数值类型之间的算术运算,即 int
、float
和 complex
类型。用几行代码,我们可以生成 Python 用于这些类型值之间加法的类型提升表
import pandas as pd
types = [int, float, complex]
name = lambda t: t.__name__
pd.DataFrame([[name(type(t1(1) + t2(1))) for t1 in types] for t2 in types],
index=[name(t) for t in types], columns=[name(t) for t in types])
int | float | complex | |
---|---|---|---|
int | int | float | complex |
float | float | float | complex |
complex | complex | complex | complex |
这个表格列举了 Python 的数值类型提升行为,但事实证明,存在一种更紧凑的补充表示方法:一种格表示,其中任意两个节点之间的上确界就是它们提升到的类型。Python 提升表的格表示法要简单得多

这个格是上述提升表中信息的紧凑编码。你可以通过沿着图追溯到两个节点的第一个共同子节点(包括节点本身)来找到两个输入的类型提升结果;在数学上,这个共同子节点被称为格中一对元素的上确界,或最小上界,或连接;在这里我们将把这种操作称为 **连接**。
概念上,箭头表示源和目标之间允许隐式类型提升:例如,允许从整数隐式提升到浮点数,但不允许从浮点数隐式提升到整数。
请记住,通常并非每个有向无环图(DAG)都满足格的特性。格要求每对节点都存在唯一的最小上界;因此,例如下面两个 DAG 就不是格

左边的 DAG 不是格,因为它不存在节点 B
和 C
的上界;右边的 DAG 有两处失败:首先,节点 C
和 D
不存在上界,其次,对于节点 A
和 B
,最小上界无法唯一确定:C
和 D
都是候选者,但它们无法排序。
类型提升格的特性#
用格来指定类型提升确保了许多有用的特性。用 \(\vee\) 运算符表示格上的连接,我们有
存在性:根据定义,格要求每对元素都存在唯一的格连接:\(\forall (a, b): \exists !(a \vee b)\)
交换律:格连接是可交换的:\(\forall (a, b): a\vee b = b \vee a\)。
结合律:格连接是可结合的:\(\forall (a, b, c): a \vee (b \vee c) = (a \vee b) \vee c\)。
另一方面,这些特性意味着它们所能表示的类型提升系统存在限制;特别是**并非每个类型提升表都可以用格来表示**。一个现成的例子是 NumPy 的完整类型提升表;这可以通过反例迅速证明:这里有三种标量类型,它们在 NumPy 中的提升行为是非结合的
import numpy as np
a, b, c = np.int8(1), np.uint8(1), np.float16(1)
print(np.dtype((a + b) + c))
print(np.dtype(a + (b + c)))
float32
float16
这样的结果可能会让用户感到惊讶:我们通常期望数学表达式映射到数学概念,例如,a + b + c
应该等价于 c + b + a
;x * (y + z)
应该等价于 x * y + x * z
。如果类型提升是非结合的或非交换的,这些属性就不再适用。
此外,与基于表格的系统相比,基于格的类型提升系统更易于概念化和理解。例如,JAX 识别 18 种不同类型:一个由 18 个节点和它们之间稀疏、有充分理由的连接组成的提升格,比一个包含 324 个条目的表格更容易让人理解和记忆。
因此,我们选择为 JAX 使用基于格的类型提升系统。
类别内的类型提升#
数值计算库通常提供的不止是 int
、float
和 complex
;在每个类别中,都有各种可能的精度,由数值表示中使用的位数表示。我们在这里将考虑的类别有
无符号整数,包括
uint8
、uint16
、uint32
和uint64
(我们简称u8
、u16
、u32
、u64
)有符号整数,包括
int8
、int16
、int32
和int64
(我们简称i8
、i16
、i32
、i64
)浮点数,包括
float16
、float32
和float64
(我们简称f16
、f32
、f64
)复浮点数,包括
complex64
和complex128
(我们简称c64
、c128
)
NumPy 在这四个类别**内部**的类型提升语义相对简单:类型的有序层次结构直接转化为四个独立的格,代表类别内的类型提升规则

就 JAX 力求避免的值提升到 64 位而言,这些同类类型类别内的提升语义没有问题:产生 64 位输出的唯一方法是拥有 64 位输入。
引入 Python 标量#
现在我们来思考一下 Python 标量是如何融入其中的。
在 NumPy 中,提升行为因输入是数组还是标量而异。例如,当对两个标量进行操作时,适用正常的提升规则
x = np.int8(0) # int8 scalar
y = 1 # Python int = int64 scalar
(x + y).dtype
dtype('int64')
这里 Python 值 1
被视为 int64
,简单的类别内规则会得到 int64
结果。
然而,在 Python 标量和 NumPy 数组之间的操作中,标量会遵从数组的 dtype。例如
x = np.zeros(1, dtype='int8') # int8 array
y = 1 # Python int = int64 scalar
(x + y).dtype
dtype('int8')
这里 int64
标量的位宽被忽略,而遵从数组的位宽。
这里还有一个细节:当 NumPy 类型提升涉及标量时,输出 dtype 是值相关的:如果 Python 标量对于给定的 dtype 太大,它会被提升到一个兼容的类型
x = np.zeros(1, dtype='int8') # int8 array
y = 1000 # int64 scalar
(x + y).dtype
dtype('int16')
对于 JAX 而言,**值相关的提升是不可接受的**,因为 JIT 编译和其他转换的性质,它们在不参照数据值的情况下对数据的抽象表示进行操作。
忽略值相关的影响,NumPy 类型提升的有符号整数分支可以用以下格表示,其中我们用 *
标记标量 dtype

在 uint
、float
和 complex
格中也存在类似的模式。
为了简单起见,我们将每个类别的标量类型合并为一个节点,分别用 u*
、i*
、f*
和 c*
表示。现在我们的类别内格集合可以这样表示

从某种意义上说,将标量放在左侧是一个奇怪的选择:标量类型可以包含任意宽度的数据,但当与给定类型的数组交互时,提升结果会遵从数组类型。这样做的好处是,当你对数组 x
执行 x + 2
这样的操作时,无论 x
的位宽如何,其类型都会传递给结果。
for dtype in [np.int8, np.int16, np.int32, np.int64]:
x = np.arange(10, dtype=dtype)
assert (x + 2).dtype == dtype
这种行为为我们对标量值使用 *
符号提供了依据:*
让人联想到可以取任何所需值的通配符。
这些语义的好处是,你可以轻松地用简洁的 Python 代码表达一系列操作,而无需显式地将标量转换为适当的类型。想象一下,如果不是这样写
3 * (x + 1) ** 2
你必须这样写
np.int32(3) * (x + np.int32(1)) ** np.int32(2)
尽管它很明确,但数值代码会变得难以阅读或编写。有了上述的标量提升语义,给定一个 int32
类型的数组 x
,第二条语句中的类型在第一条语句中是隐式的。
组合格#
回想一下,我们最初通过引入表示 Python 内类型提升的格开始讨论:int ->
float
-> complex
。让我们将其重写为 i* ->
f*
-> c*
,并进一步允许 i*
包含 u*
(毕竟,Python 中没有无符号整数标量类型)。
将所有这些放在一起,我们得到了以下表示 Python 标量和 NumPy 数组之间类型提升的部分格

请注意,这(还)不是一个真正的格:有许多节点对没有连接。但是,我们可以将其视为一个部分格,其中某些节点对没有定义的提升行为,而这个部分格的已定义部分确实正确描述了 NumPy 的数组提升行为(暂且不提上述值相关的语义)。
这建立了一个很好的框架,我们可以通过在此图中添加连接来填充这些未定义的提升规则。但是要添加哪些连接呢?广义地说,我们希望任何额外的连接都能满足以下几个特性
提升应满足交换律和结合律:换句话说,图应该保持一个(部分)格。
提升绝不应该允许丢失数据的完整组成部分:例如,我们绝不应该将
complex
提升为float
,因为它会丢弃任何虚部。提升绝不应该导致未处理的溢出。例如,
uint32
的最大值是int32
最大值的两倍,因此我们不应该隐式地将uint32
提升到int32
。在可能的情况下,提升应避免精度损失。例如,一个
int64
值可能有 64 位尾数,因此将int64
提升到float64
可能导致精度损失。然而,float64
的最大可表示值大于int64
的最大可表示值,所以在这种情况下,标准 #3 仍然得到满足。在可能的情况下,二进制提升应避免产生比输入更宽的类型。这是为了确保 JAX 的隐式提升对基于加速器的工作流保持友好,在这种工作流中,用户通常希望将类型限制为 32 位(或在某些情况下为 16 位)值。
格上的每个新连接都为用户带来了一定程度的便利(一组无需显式类型转换即可交互的新类型),但如果违反了上述任何一个标准,这种便利可能会变得代价过高。开发一个完整的提升格涉及在便利性和成本之间取得平衡。
混合提升:浮点数与复数#
让我们从最简单的情况开始,即浮点数和复数值之间的提升。
复数由成对的浮点数组成,因此它们之间存在一个自然的提升路径:将浮点数转换为复数,同时保持实部的宽度。就我们的部分格表示法而言,它将是这样的

这结果正好代表了 NumPy 在混合浮点数/复数类型提升中使用的语义。
混合提升:有符号与无符号整数#
对于下一个案例,让我们考虑一个稍微困难一点的问题:有符号整数和无符号整数之间的提升。例如,当将 uint8
提升为有符号整数时,我们需要多少位?
乍一看,你可能会认为将 uint8
提升为 int8
是很自然的;但是最大的 uint8
数字无法在 int8
中表示。因此,将无符号整数提升为位数为其两倍的整数更有意义;这种提升行为可以通过在提升格中添加以下连接来表示

同样,这里添加的连接正是 NumPy 为混合整数提升所实现的提升语义。
如何处理 uint64
?#
混合有符号/无符号整数提升的方法遗漏了一种类型:uint64
。按照上述模式,涉及 uint64
的混合整数操作的输出应该是 int128
,但这并不是一个标准的可用 dtype。
NumPy 在此的选择是提升到 float64
(np.uint64(1) + np.int64(1)).dtype
dtype('float64')
然而,这可能是一个令人惊讶的惯例:这是整数类型提升不会产生整数的唯一情况。目前,我们将 uint64
的提升保留为未定义,稍后再讨论。
混合提升:整数与浮点数#
当将整数提升为浮点数时,我们可能会从与有符号和无符号整数之间的混合提升相同的思考过程开始。一个 16 位有符号或无符号整数无法通过 16 位浮点数以完整精度表示,因为 16 位浮点数只有 10 位尾数。因此,将整数提升为位数为其两倍的浮点数可能更有意义

这实际上是 NumPy 类型提升所做的事情,但这样做破坏了图的格特性:例如,{i8, u8} 这对不再有唯一的最小上界:可能的选项是 i16 和 f16,它们在图中无法排序。这正是上述 NumPy 非结合类型提升的根源。
我们能否修改 NumPy 的提升规则,使其满足格的特性,同时为混合类型提升提供合理的结果?这里有几种方法可以尝试。
选项 0:将整数/浮点数混合精度提升保留为未定义#
为了使行为完全可预测(但会牺牲一些用户便利性),一个可取的选择是将除 Python 标量之外的任何混合整数/浮点提升保留为未定义,并停止在上一节中的部分格。缺点是用户在整数和浮点数之间操作时需要显式地进行类型转换。
选项 1:避免所有精度损失#
如果我们的重点是无论如何都要避免精度损失,我们可以通过将无符号整数通过其现有的有符号整数路径提升为浮点数来恢复格的特性

这种方法的一个缺点是它仍然将 int64
和 uint64
的提升保留为未定义,因为没有足够尾数位的标准浮点类型来表示它们的完整值范围。我们可以放宽精度限制,通过从 i64->f64
和 u64->f64
建立连接来完成格,但这些连接将与此提升方案的初衷背道而驰。
第二个缺点是,这个格使得在保持格特性的同时,很难找到一个合理的位置来插入 bfloat16
(见下文)。
这种方法的第三个缺点,对于 JAX 的加速器后端来说更为重要,就是某些操作会导致比必要更宽的类型;例如,uint16
和 float16
之间的混合操作会将类型一直提升到 float64
,这并不理想。
选项 2:避免大部分不必要的宽度提升#
为了解决不必要的向更宽类型的提升,我们可以接受整数/浮点数提升中可能出现的精度损失,将有符号整数提升为相同宽度的浮点数

虽然这确实允许整数和浮点数之间进行损失精度的提升,但这些提升不会错误地表示结果的大小:尽管浮点数的尾数不足以表示所有值,但指数足够宽以近似它们。
这种方法也提供了一条从 int64
到 float64
的自然提升路径,尽管在此方案中 uint64
仍然无法提升。话虽如此,在这里从 u64
到 f64
的连接比以前更容易被证明是合理的。
这种提升方案仍然导致一些不必要的宽度提升路径;例如,float32
和 uint32
之间的操作会产生 float64
。此外,这个格使得在保持格特性的同时,很难找到一个合理的位置来插入 bfloat16
(见下文)。
选项 3:避免所有不必要的宽度提升#
如果我们愿意从根本上改变对整数和浮点数提升的思考方式,我们就可以避免所有非理想的 64 位提升。正如标量总是遵从数组类型的宽度一样,我们也可以让整数总是遵从浮点数类型的宽度

这需要一点巧妙的处理:之前我们用 f*
来指代标量类型。在这个格中,f*
可能会应用于混合计算的数组输出。与其将 f*
视为标量,不如将其视为一种具有独特提升规则的特殊 float
值:在 JAX 中,我们将其称为弱浮点数;详见下文。
这种方法的优点是,除了无符号整数外,它避免了所有不必要的宽度提升:没有 64 位输入,你永远不会得到 f64 输出;没有 32 位输入,你永远不会得到 f32 输出:这为在加速器上工作提供了便利的语义,同时避免了意外的 64 位值。
这种将浮点类型置于优先地位的特性类似于 PyTorch 的类型提升行为。这个格也恰好生成了一个与 JAX 原始的特殊类型提升方案非常相似的提升表,该方案并非基于格,但具有将浮点类型置于优先地位的特性。
此外,这个格提供了一个自然的位置来插入 bfloat16
,而无需在 bf16
和 f16
之间强制规定顺序

这很重要,因为 f16
和 bf16
不可比较,因为它们利用其位的方式不同:bf16
以较低的精度表示更大的范围,而 f16
以较高的精度表示较小的范围。
然而,这些优点也伴随着一些权衡
混合浮点数/整数提升非常容易导致精度损失:例如,
int64
(最大值为 \(9.2 \times 10^{18}\))可以提升到float16
(最大值为 \(6.5 \times 10^4\)),这意味着大多数可表示的值将变为inf
。如上所述,
f*
不能再被视为“标量类型”,而是 float64 的一种不同“风味”。在 JAX 的术语中,这被称为弱类型,因为它表示为 64 位,但在与其他值提升时,这种位宽只被弱持有。
另请注意,这种方法仍然没有回答 uint64
的提升问题,尽管通过连接 u64
到 f*
来封闭格或许是合理的。
JAX 中的类型提升#
在设计 JAX 的类型提升语义时,我们牢记了许多这些想法,并着重依赖于以下几点
我们选择将 JAX 的类型提升语义限制在满足格特性的图上:这既是为了确保结合律和交换律,也是为了让语义可以紧凑地在 DAG 中描述,而不是需要一个大型表格。
我们倾向于采用避免无意中提升到更宽类型的语义,尤其是在浮点值方面,以利于加速器上的计算。
如果需要保持(1)和(2),我们愿意接受混合类型提升中可能出现的精度损失(但不是量级损失)
考虑到这一点,JAX 采用了选项 3。或者更确切地说,是选项 3 的一个略微修改版本,它在 u64
和 f*
之间建立了连接,以创建一个真正的格。为了清晰起见,重新排列节点后,JAX 的类型提升格如下所示

这种选择所产生的行为总结在 JAX 类型提升语义中。值得注意的是,除了包含更大的无符号类型(u16
、u32
、u64
)以及一些关于标量/弱类型(i*
、f*
、c*
)行为的细节外,这种类型提升方案与 PyTorch 所选择的方案非常接近。
感兴趣的读者,下方的附录中列出了 NumPy、TensorFlow、PyTorch 和 JAX 所使用的完整提升表。
附录:类型提升示例表#
以下是各种 Python 数组计算库实现的隐式类型提升表的一些示例。
NumPy 类型提升#
请注意,NumPy 不包含 bfloat16
dtype,并且下表忽略了值相关的影响。
b | u8 | u16 | u32 | u64 | i8 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i* | f* | c* | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
b | b | u8 | u16 | u32 | u64 | i8 | i16 | i32 | i64 | - | f16 | f32 | f64 | c64 | c128 | i64 | f64 | c128 |
u8 | u8 | u8 | u16 | u32 | u64 | i16 | i16 | i32 | i64 | - | f16 | f32 | f64 | c64 | c128 | u8 | f64 | c128 |
u16 | u16 | u16 | u16 | u32 | u64 | i32 | i32 | i32 | i64 | - | f32 | f32 | f64 | c64 | c128 | u16 | f64 | c128 |
u32 | u32 | u32 | u32 | u32 | u64 | i64 | i64 | i64 | i64 | - | f64 | f64 | f64 | c128 | c128 | u32 | f64 | c128 |
u64 | u64 | u64 | u64 | u64 | u64 | f64 | f64 | f64 | f64 | - | f64 | f64 | f64 | c128 | c128 | u64 | f64 | c128 |
i8 | i8 | i16 | i32 | i64 | f64 | i8 | i16 | i32 | i64 | - | f16 | f32 | f64 | c64 | c128 | i8 | f64 | c128 |
i16 | i16 | i16 | i32 | i64 | f64 | i16 | i16 | i32 | i64 | - | f32 | f32 | f64 | c64 | c128 | i16 | f64 | c128 |
i32 | i32 | i32 | i32 | i64 | f64 | i32 | i32 | i32 | i64 | - | f64 | f64 | f64 | c128 | c128 | i32 | f64 | c128 |
i64 | i64 | i64 | i64 | i64 | f64 | i64 | i64 | i64 | i64 | - | f64 | f64 | f64 | c128 | c128 | i64 | f64 | c128 |
bf16 | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - |
f16 | f16 | f16 | f32 | f64 | f64 | f16 | f32 | f64 | f64 | - | f16 | f32 | f64 | c64 | c128 | f16 | f16 | c64 |
f32 | f32 | f32 | f32 | f64 | f64 | f32 | f32 | f64 | f64 | - | f32 | f32 | f64 | c64 | c128 | f32 | f32 | c64 |
f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | - | f64 | f64 | f64 | c128 | c128 | f64 | f64 | c128 |
c64 | c64 | c64 | c64 | c128 | c128 | c64 | c64 | c128 | c128 | - | c64 | c64 | c128 | c64 | c128 | c64 | c64 | c64 |
c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | - | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 |
i* | i64 | u8 | u16 | u32 | u64 | i8 | i16 | i32 | i64 | - | f16 | f32 | f64 | c64 | c128 | i64 | f64 | c128 |
f* | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | - | f16 | f32 | f64 | c64 | c128 | f64 | f64 | c128 |
c* | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | - | c64 | c64 | c128 | c64 | c128 | c128 | c128 | c128 |
TensorFlow 类型提升#
TensorFlow 避免定义隐式类型提升,除了在有限情况下对 Python 标量进行提升。该表是不对称的,因为在 tf.add(x, y)
中,y
的类型必须能够强制转换为 x
的类型。
b | u8 | u16 | u32 | u64 | i8 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i* | f* | c* | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
b | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - |
u8 | - | u8 | - | - | - | - | - | - | - | - | - | - | - | - | - | u8 | - | - |
u16 | - | - | u16 | - | - | - | - | - | - | - | - | - | - | - | - | u16 | - | - |
u32 | - | - | - | u32 | - | - | - | - | - | - | - | - | - | - | - | u32 | - | - |
u64 | - | - | - | - | u64 | - | - | - | - | - | - | - | - | - | - | u64 | - | - |
i8 | - | - | - | - | - | i8 | - | - | - | - | - | - | - | - | - | i8 | - | - |
i16 | - | - | - | - | - | - | i16 | - | - | - | - | - | - | - | - | i16 | - | - |
i32 | - | - | - | - | - | - | - | i32 | - | - | - | - | - | - | - | i32 | - | - |
i64 | - | - | - | - | - | - | - | - | i64 | - | - | - | - | - | - | i64 | - | - |
bf16 | - | - | - | - | - | - | - | - | - | bf16 | - | - | - | - | - | bf16 | bf16 | - |
f16 | - | - | - | - | - | - | - | - | - | - | f16 | - | - | - | - | f16 | f16 | - |
f32 | - | - | - | - | - | - | - | - | - | - | - | f32 | - | - | - | f32 | f32 | - |
f64 | - | - | - | - | - | - | - | - | - | - | - | - | f64 | - | - | f64 | f64 | - |
c64 | - | - | - | - | - | - | - | - | - | - | - | - | - | c64 | - | c64 | c64 | c64 |
c128 | - | - | - | - | - | - | - | - | - | - | - | - | - | - | c128 | c128 | c128 | c128 |
i* | - | - | - | - | - | - | - | i32 | - | - | - | - | - | - | - | i32 | - | - |
f* | - | - | - | - | - | - | - | - | - | - | - | f32 | - | - | - | f32 | f32 | - |
c* | - | - | - | - | - | - | - | - | - | - | - | - | - | - | c128 | c128 | c128 | c128 |
PyTorch 类型提升#
请注意,PyTorch 不包含大于 uint8
的无符号整数类型。除了这一点以及一些关于与标量/弱类型提升的细节之外,该表与 jax.numpy
使用的表非常接近。
b | u8 | u16 | u32 | u64 | i8 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i* | f* | c* | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
b | b | u8 | - | - | - | i8 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i64 | f32 | c64 |
u8 | u8 | u8 | - | - | - | i16 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | u8 | f32 | c64 |
u16 | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - |
u32 | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - |
u64 | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - |
i8 | i8 | i16 | - | - | - | i8 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i8 | f32 | c64 |
i16 | i16 | i16 | - | - | - | i16 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i16 | f32 | c64 |
i32 | i32 | i32 | - | - | - | i32 | i32 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i32 | f32 | c64 |
i64 | i64 | i64 | - | - | - | i64 | i64 | i64 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i64 | f32 | c64 |
bf16 | bf16 | bf16 | - | - | - | bf16 | bf16 | bf16 | bf16 | bf16 | f32 | f32 | f64 | c64 | c128 | bf16 | bf16 | c64 |
f16 | f16 | f16 | - | - | - | f16 | f16 | f16 | f16 | f32 | f16 | f32 | f64 | c64 | c128 | f16 | f16 | c64 |
f32 | f32 | f32 | - | - | - | f32 | f32 | f32 | f32 | f32 | f32 | f32 | f64 | c64 | c128 | f32 | f32 | c64 |
f64 | f64 | f64 | - | - | - | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | c128 | c128 | f64 | f64 | c128 |
c64 | c64 | c64 | - | - | - | c64 | c64 | c64 | c64 | c64 | c64 | c64 | c128 | c64 | c128 | c64 | c64 | c64 |
c128 | c128 | c128 | - | - | - | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 |
i* | i64 | u8 | - | - | - | i8 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i64 | f32 | c64 |
f* | f32 | f32 | - | - | - | f32 | f32 | f32 | f32 | bf16 | f16 | f32 | f64 | c64 | c128 | f32 | f64 | c64 |
c* | c64 | c64 | - | - | - | c64 | c64 | c64 | c64 | c64 | c64 | c64 | c128 | c64 | c128 | c64 | c64 | c128 |
JAX 类型提升:jax.numpy
#
jax.numpy
遵循 https://jax.net.cn/en/latest/type_promotion.html 中列出的类型提升规则。这里我们使用 i*
、f*
、c*
来表示 Python 标量和弱类型数组。
b | u8 | u16 | u32 | u64 | i8 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i* | f* | c* | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
b | b | u8 | u16 | u32 | u64 | i8 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i* | f* | c* |
u8 | u8 | u8 | u16 | u32 | u64 | i16 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | u8 | f* | c* |
u16 | u16 | u16 | u16 | u32 | u64 | i32 | i32 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | u16 | f* | c* |
u32 | u32 | u32 | u32 | u32 | u64 | i64 | i64 | i64 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | u32 | f* | c* |
u64 | u64 | u64 | u64 | u64 | u64 | f* | f* | f* | f* | bf16 | f16 | f32 | f64 | c64 | c128 | u64 | f* | c* |
i8 | i8 | i16 | i32 | i64 | f* | i8 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i8 | f* | c* |
i16 | i16 | i16 | i32 | i64 | f* | i16 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i16 | f* | c* |
i32 | i32 | i32 | i32 | i64 | f* | i32 | i32 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i32 | f* | c* |
i64 | i64 | i64 | i64 | i64 | f* | i64 | i64 | i64 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i64 | f* | c* |
bf16 | bf16 | bf16 | bf16 | bf16 | bf16 | bf16 | bf16 | bf16 | bf16 | bf16 | f32 | f32 | f64 | c64 | c128 | bf16 | bf16 | c64 |
f16 | f16 | f16 | f16 | f16 | f16 | f16 | f16 | f16 | f16 | f32 | f16 | f32 | f64 | c64 | c128 | f16 | f16 | c64 |
f32 | f32 | f32 | f32 | f32 | f32 | f32 | f32 | f32 | f32 | f32 | f32 | f32 | f64 | c64 | c128 | f32 | f32 | c64 |
f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | c128 | c128 | f64 | f64 | c128 |
c64 | c64 | c64 | c64 | c64 | c64 | c64 | c64 | c64 | c64 | c64 | c64 | c64 | c128 | c64 | c128 | c64 | c64 | c64 |
c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 | c128 |
i* | i* | u8 | u16 | u32 | u64 | i8 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i* | f* | c* |
f* | f* | f* | f* | f* | f* | f* | f* | f* | f* | bf16 | f16 | f32 | f64 | c64 | c128 | f* | f* | c* |
c* | c* | c* | c* | c* | c* | c* | c* | c* | c* | c64 | c64 | c64 | c128 | c64 | c128 | c* | c* | c* |
JAX 类型提升:jax.lax
#
jax.lax
是更底层的,不进行任何隐式提升。这里我们使用 i*
、f*
、c*
来表示 Python 标量和弱类型数组。
b | u8 | u16 | u32 | u64 | i8 | i16 | i32 | i64 | bf16 | f16 | f32 | f64 | c64 | c128 | i* | f* | c* | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
b | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - |
u8 | - | u8 | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - |
u16 | - | - | u16 | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - |
u32 | - | - | - | u32 | - | - | - | - | - | - | - | - | - | - | - | - | - | - |
u64 | - | - | - | - | u64 | - | - | - | - | - | - | - | - | - | - | - | - | - |
i8 | - | - | - | - | - | i8 | - | - | - | - | - | - | - | - | - | - | - | - |
i16 | - | - | - | - | - | - | i16 | - | - | - | - | - | - | - | - | - | - | - |
i32 | - | - | - | - | - | - | - | i32 | - | - | - | - | - | - | - | - | - | - |
i64 | - | - | - | - | - | - | - | - | i64 | - | - | - | - | - | - | i64 | - | - |
bf16 | - | - | - | - | - | - | - | - | - | bf16 | - | - | - | - | - | - | - | - |
f16 | - | - | - | - | - | - | - | - | - | - | f16 | - | - | - | - | - | - | - |
f32 | - | - | - | - | - | - | - | - | - | - | - | f32 | - | - | - | - | - | - |
f64 | - | - | - | - | - | - | - | - | - | - | - | - | f64 | - | - | - | f64 | - |
c64 | - | - | - | - | - | - | - | - | - | - | - | - | - | c64 | - | - | - | - |
c128 | - | - | - | - | - | - | - | - | - | - | - | - | - | - | c128 | - | - | c128 |
i* | - | - | - | - | - | - | - | - | i64 | - | - | - | - | - | - | i* | - | - |
f* | - | - | - | - | - | - | - | - | - | - | - | - | f64 | - | - | - | f* | - |
c* | - | - | - | - | - | - | - | - | - | - | - | - | - | - | c128 | - | - | c* |