默认数据类型和 X64 标志#
JAX 致力于满足各种数值计算从业者的需求,而这些从业者有时会有冲突的偏好。在默认数据类型方面,存在两种不同的阵营。
经典的科学计算从业者(即
numpy
或scipy
等工具的用户)倾向于将计算的准确性放在首位:这些用户希望计算默认使用最宽的可用表示:例如,浮点值默认应为float64
,整数默认应为int64
等。人工智能研究人员(即从事神经网络实现和训练的人员)倾向于将速度置于准确性之上,以至于他们开发了 bfloat16 等特殊数据类型,这些数据类型会故意丢弃最低有效位以加快计算速度。对这些用户而言,计算中仅仅出现一个 float64 值就可能导致程序在最好情况下运行缓慢,在最坏情况下与硬件不兼容!这些用户希望计算默认使用
float32
或int32
。
JAX 提供的这个主要机制是 jax_enable_x64
标志,它控制是否可以创建 64 位值。默认情况下,此标志设置为 False
(以满足人工智能研究人员和从业者的需求),但重视计算速度超过准确性的用户可以将其设置为 True
。
默认设置:全部为 32 位#
默认情况下,jax_enable_x64
设置为 False,因此 jax.numpy
数组创建函数将默认返回 32 位值。
例如
>>> import jax.numpy as jnp
>>> jnp.arange(5)
Array([0, 1, 2, 3, 4], dtype=int32)
>>> jnp.zeros(5)
Array([0., 0., 0., 0., 0.], dtype=float32)
>>> jnp.ones(5, dtype=int)
Array([1, 1, 1, 1, 1], dtype=int32)
除了默认值之外,由于 64 位值对 AI 工作流可能产生严重负面影响,因此将此标志设置为 False 会阻止您创建 64 位数组!例如:
>>> jnp.arange(5, dtype='float64')
UserWarning: Explicitly requested dtype float64 requested in arange is not available, and will be
truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the
JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
Array([0., 1., 2., 3., 4.], dtype=float32)
X64 标志:启用 64 位值#
要使用默认产生 64 位值的“另一种模式”,您可以将 jax_enable_x64
标志设置为 True
。
import jax
import jax.numpy as jnp
jax.config.update('jax_enable_x64', True)
print(repr(jnp.arange(5)))
print(repr(jnp.zeros(5)))
print(repr(jnp.ones(5, dtype=int)))
Array([0, 1, 2, 3, 4], dtype=int64)
Array([0., 0., 0., 0., 0.], dtype=float64)
Array([1, 1, 1, 1, 1], dtype=int64)
X64 配置也可以通过 JAX_ENABLE_X64
shell 环境变量来设置,例如:
$ JAX_ENABLE_X64=1 python main.py
X64 标志旨在作为一项全局设置,应在您的主文件的顶部为整个程序设置一个值。一项常见的需求是允许按上下文配置该标志(例如,仅为长程序的一段启用 X64):事实证明,这在 JAX 的编程模型中很难实现,因为代码执行可能发生在与代码编译不同的上下文中。我们正在积极探索放宽此限制的可行性,敬请期待!