默认 dtype 和 X64 标志#

JAX 努力满足各种数值计算从业者的需求,他们有时会有冲突的偏好。在默认 dtype 方面,有两个不同的阵营

  • 经典的科学计算从业者(即,numpyscipy 等工具的用户)往往最看重计算的准确性:这些用户更希望计算默认使用**最宽可用的表示**:例如,浮点值应默认为 float64,整数默认为 int64 等。

  • AI 研究人员(即,实施和训练神经网络的人员)往往更看重速度而不是准确性,以至于他们开发了像 bfloat16 和其他特殊数据类型,这些数据类型故意丢弃最低有效位以加快计算速度。对于这些用户来说,即使在他们的计算中存在一个 float64 值也可能导致程序充其量运行缓慢,最坏的情况是与他们的硬件不兼容!这些用户更希望计算默认使用 float32int32

JAX 为此提供的主要机制是 jax_enable_x64 标志,它控制是否可以创建 64 位值。默认情况下,此标志设置为 False(满足 AI 研究人员和从业者的需求),但可以由重视准确性而非计算速度的用户设置为 True

默认设置:所有地方都是 32 位#

默认情况下,jax_enable_x64 设置为 False,因此 jax.numpy 数组创建函数将默认返回 32 位值。

例如

>>> import jax.numpy as jnp

>>> jnp.arange(5)
Array([0, 1, 2, 3, 4], dtype=int32)

>>> jnp.zeros(5)
Array([0., 0., 0., 0., 0.], dtype=float32)

>>> jnp.ones(5, dtype=int)
Array([1, 1, 1, 1, 1], dtype=int32)

除了默认值之外,由于 64 位值可能对 AI 工作流程有害,因此将此标志设置为 False 会阻止您创建任何 64 位数组!例如

>>> jnp.arange(5, dtype='float64')  
UserWarning: Explicitly requested dtype float64 requested in arange is not available, and will be 
truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the 
JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
Array([0., 1., 2., 3., 4.], dtype=float32)

X64 标志:启用 64 位值#

要在函数默认生成 64 位值的“另一种模式”下工作,您可以将 jax_enable_x64 标志设置为 True

import jax
import jax.numpy as jnp

jax.config.update('jax_enable_x64', True)

print(repr(jnp.arange(5)))
print(repr(jnp.zeros(5)))
print(repr(jnp.ones(5, dtype=int)))
Array([0, 1, 2, 3, 4], dtype=int64)
Array([0., 0., 0., 0., 0.], dtype=float64)
Array([1, 1, 1, 1, 1], dtype=int64)

X64 配置也可以通过 JAX_ENABLE_X64 shell 环境变量设置,例如

$ JAX_ENABLE_X64=1 python main.py

X64 标志旨在作为**全局设置**,对于您的整个程序应具有一个值,在主文件的顶部设置。一个常见的特性请求是标志可以上下文配置(例如,仅为一个长程序的某个部分启用 X64):事实证明,这在 JAX 的编程模型中很难实现,因为代码执行可能发生在与代码编译不同的上下文中。目前正在进行探索放宽此约束可行性的工作,敬请期待!