默认 dtype 和 X64 标志#
JAX 努力满足各种数值计算从业者的需求,他们有时会有冲突的偏好。在默认 dtype 方面,有两个不同的阵营
经典的科学计算从业者(即,
numpy
或scipy
等工具的用户)往往最看重计算的准确性:这些用户更希望计算默认使用**最宽可用的表示**:例如,浮点值应默认为float64
,整数默认为int64
等。AI 研究人员(即,实施和训练神经网络的人员)往往更看重速度而不是准确性,以至于他们开发了像 bfloat16 和其他特殊数据类型,这些数据类型故意丢弃最低有效位以加快计算速度。对于这些用户来说,即使在他们的计算中存在一个 float64 值也可能导致程序充其量运行缓慢,最坏的情况是与他们的硬件不兼容!这些用户更希望计算默认使用
float32
或int32
。
JAX 为此提供的主要机制是 jax_enable_x64
标志,它控制是否可以创建 64 位值。默认情况下,此标志设置为 False
(满足 AI 研究人员和从业者的需求),但可以由重视准确性而非计算速度的用户设置为 True
。
默认设置:所有地方都是 32 位#
默认情况下,jax_enable_x64
设置为 False,因此 jax.numpy
数组创建函数将默认返回 32 位值。
例如
>>> import jax.numpy as jnp
>>> jnp.arange(5)
Array([0, 1, 2, 3, 4], dtype=int32)
>>> jnp.zeros(5)
Array([0., 0., 0., 0., 0.], dtype=float32)
>>> jnp.ones(5, dtype=int)
Array([1, 1, 1, 1, 1], dtype=int32)
除了默认值之外,由于 64 位值可能对 AI 工作流程有害,因此将此标志设置为 False 会阻止您创建任何 64 位数组!例如
>>> jnp.arange(5, dtype='float64')
UserWarning: Explicitly requested dtype float64 requested in arange is not available, and will be
truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the
JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
Array([0., 1., 2., 3., 4.], dtype=float32)
X64 标志:启用 64 位值#
要在函数默认生成 64 位值的“另一种模式”下工作,您可以将 jax_enable_x64
标志设置为 True
import jax
import jax.numpy as jnp
jax.config.update('jax_enable_x64', True)
print(repr(jnp.arange(5)))
print(repr(jnp.zeros(5)))
print(repr(jnp.ones(5, dtype=int)))
Array([0, 1, 2, 3, 4], dtype=int64)
Array([0., 0., 0., 0., 0.], dtype=float64)
Array([1, 1, 1, 1, 1], dtype=int64)
X64 配置也可以通过 JAX_ENABLE_X64
shell 环境变量设置,例如
$ JAX_ENABLE_X64=1 python main.py
X64 标志旨在作为**全局设置**,对于您的整个程序应具有一个值,在主文件的顶部设置。一个常见的特性请求是标志可以上下文配置(例如,仅为一个长程序的某个部分启用 X64):事实证明,这在 JAX 的编程模型中很难实现,因为代码执行可能发生在与代码编译不同的上下文中。目前正在进行探索放宽此约束可行性的工作,敬请期待!