类型提升语义#

本文档描述了 JAX 的类型提升规则 — 即对于每对类型,jax.numpy.promote_types() 的结果。有关下面描述的设计考虑因素的背景信息,请参见JAX 类型提升语义设计

JAX 的类型提升行为由以下类型提升格确定

_images/type_lattice.svg

其中,例如

  • b1 表示 np.bool_

  • i2 表示 np.int16

  • u4 表示 np.uint32

  • bf 表示 np.bfloat16

  • f2 表示 np.float16

  • c8 表示 np.complex64

  • i* 表示 Python int 或弱类型 int

  • f* 表示 Python float 或弱类型 float,以及

  • c* 表示 Python complex 或弱类型 complex

(有关弱类型的更多信息,请参见下文JAX 中的弱类型值)。

任意两种类型之间的提升由它们在该格上的join(并)决定,这会生成以下二进制提升表

b1u1u2u4u8i1i2i4i8bff2f4f8c8c16i*f*c*
b1b1u1u2u4u8i1i2i4i8bff2f4f8c8c16i*f*c*
u1u1u1u2u4u8i2i2i4i8bff2f4f8c8c16u1f*c*
u2u2u2u2u4u8i4i4i4i8bff2f4f8c8c16u2f*c*
u4u4u4u4u4u8i8i8i8i8bff2f4f8c8c16u4f*c*
u8u8u8u8u8u8f*f*f*f*bff2f4f8c8c16u8f*c*
i1i1i2i4i8f*i1i2i4i8bff2f4f8c8c16i1f*c*
i2i2i2i4i8f*i2i2i4i8bff2f4f8c8c16i2f*c*
i4i4i4i4i8f*i4i4i4i8bff2f4f8c8c16i4f*c*
i8i8i8i8i8f*i8i8i8i8bff2f4f8c8c16i8f*c*
bfbfbfbfbfbfbfbfbfbfbff4f4f8c8c16bfbfc8
f2f2f2f2f2f2f2f2f2f2f4f2f4f8c8c16f2f2c8
f4f4f4f4f4f4f4f4f4f4f4f4f4f8c8c16f4f4c8
f8f8f8f8f8f8f8f8f8f8f8f8f8f8c16c16f8f8c16
c8c8c8c8c8c8c8c8c8c8c8c8c8c16c8c16c8c8c8
c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16
i*i*u1u2u4u8i1i2i4i8bff2f4f8c8c16i*f*c*
f*f*f*f*f*f*f*f*f*f*bff2f4f8c8c16f*f*c*
c*c*c*c*c*c*c*c*c*c*c8c8c8c16c8c16c*c*c*

JAX 的类型提升规则与 NumPy 的规则不同,后者由numpy.promote_types() 给出,在上表中以绿色背景突出显示的单元格中有所体现。主要有三类区别:

  • 当将一个弱类型值与一个相同类别的有类型 JAX 值进行提升时,JAX 总是优先使用 JAX 值的精度。例如,jnp.int16(1) + 1 将返回 int16 而不是像 NumPy 中那样提升到 int64。请注意,这仅适用于 Python 标量值;如果常量是 NumPy 数组,则使用上述格进行类型提升。例如,jnp.int16(1) + np.array(1) 将返回 int64

  • 当将整数或布尔类型与浮点或复数类型进行提升时,JAX 总是优先使用浮点或复数类型。

  • JAX 支持bfloat16非标准 16 位浮点类型(jax.numpy.bfloat16),这对于神经网络训练很有用。唯一值得注意的提升行为是与 IEEE-754 float16 相关的,`bfloat16` 会提升到 float32

NumPy 和 JAX 之间的差异是由于以下事实:加速器设备(例如 GPU 和 TPU)在使用 64 位浮点类型时会付出显著的性能代价(GPU),或者根本不支持 64 位浮点类型(TPU)。经典的 NumPy 提升规则过于倾向于提升到 64 位类型,这对于旨在加速器上运行的系统来说是个问题。

JAX 使用的浮点提升规则更适合现代加速器设备,并且在浮点类型提升方面不那么激进。JAX 用于浮点类型的提升规则类似于 PyTorch 使用的规则。

Python 运算符分派的影响#

请记住,像 + 这样的 Python 运算符会根据两个相加值的 Python 类型进行分派。这意味着,例如,np.int16(1) + 1 将使用 NumPy 规则进行提升,而 jnp.int16(1) + 1 将使用 JAX 规则进行提升。当两种类型的提升结合使用时,这可能导致潜在的令人困惑的非关联性提升语义;例如 np.int16(1) + 1 + jnp.int16(1)

JAX 中的弱类型值#

在大多数情况下,JAX 中的弱类型值可以被认为是具有与 Python 标量等效的提升行为,例如以下示例中的整数标量 2

>>> x = jnp.arange(5, dtype='int8')
>>> 2 * x
Array([0, 2, 4, 6, 8], dtype=int8)

JAX 的弱类型框架旨在防止 JAX 值与没有显式用户指定类型的值(例如 Python 标量字面量)之间的二元操作中发生不必要的类型提升。例如,如果 2 未被视为弱类型,则上述表达式将导致隐式类型提升

>>> jnp.int32(2) * x
Array([0, 2, 4, 6, 8], dtype=int32)

在 JAX 中使用时,Python 标量有时会被提升为 DeviceArray 对象,例如在 JIT 编译期间。为了在这种情况下保持所需的提升语义,DeviceArray 对象带有一个 weak_type 标志,可以在数组的字符串表示中看到它

>>> jnp.asarray(2)
Array(2, dtype=int32, weak_type=True)

如果 dtype 被显式指定,它将改为生成一个标准的强类型数组值

>>> jnp.asarray(2, dtype='int32')
Array(2, dtype=int32)

严格的 dtype 提升#

在某些情况下,禁用隐式类型提升行为并要求所有提升都是显式的可能很有用。这可以通过将 jax_numpy_dtype_promotion 标志设置为 'strict' 在 JAX 中完成。在局部范围内,可以使用上下文管理器来完成

>>> x = jnp.float32(1)
>>> y = jnp.int32(1)
>>> with jax.numpy_dtype_promotion('strict'):
...   z = x + y  
...
Traceback (most recent call last):
TypePromotionError: Input dtypes ('float32', 'int32') have no available implicit
dtype promotion path when jax_numpy_dtype_promotion=strict. Try explicitly casting
inputs to the desired output type, or set jax_numpy_dtype_promotion=standard.

为了方便起见,严格提升模式仍将允许安全的弱类型提升,因此您仍然可以编写混合 JAX 数组和 Python 标量的代码

>>> with jax.numpy_dtype_promotion('strict'):
...   z = x + 1
>>> print(z)
2.0

如果您希望全局设置配置,可以使用标准配置更新来完成

jax.config.update('jax_numpy_dtype_promotion', 'strict')

要恢复默认的标准类型提升,请将此配置设置为 'standard'

jax.config.update('jax_numpy_dtype_promotion', 'standard')