默认数据类型和X64标志#
JAX 致力于满足各种数值计算从业者的需求,他们有时有相互冲突的偏好。当谈到默认数据类型时,存在两种不同的阵营
传统的科学计算从业者(即像
numpy
或scipy
这样的工具的用户)倾向于将计算精度放在首位:这类用户希望计算默认使用**最宽的可用表示**:例如,浮点值应默认为float64
,整数应默认为int64
等。AI 研究人员(即实现和训练神经网络的人员)倾向于将速度置于精度之上,甚至为此开发了像 bfloat16 等特殊数据类型,这些数据类型有意舍弃最低有效位以加速计算。对于这些用户来说,计算中仅存在一个 float64 值就可能导致程序运行缓慢(往好了说),或者与他们的硬件不兼容(往坏了说)!这些用户希望计算默认使用
float32
或int32
。
JAX 为此提供的主要机制是 jax_enable_x64
标志,它控制是否可以创建64位值。默认情况下,此标志设置为 False
(以满足AI研究人员和从业者的需求),但看重计算精度而非速度的用户可以将其设置为 True
。
默认设置:全局32位#
默认情况下 jax_enable_x64
设置为 False,因此 jax.numpy
数组创建函数将默认返回32位值。
例如
>>> import jax.numpy as jnp
>>> jnp.arange(5)
Array([0, 1, 2, 3, 4], dtype=int32)
>>> jnp.zeros(5)
Array([0., 0., 0., 0., 0.], dtype=float32)
>>> jnp.ones(5, dtype=int)
Array([1, 1, 1, 1, 1], dtype=int32)
除默认设置外,因为64位值对AI工作流可能有害,将此标志设置为 False 会阻止你创建任何64位数组!例如
>>> jnp.arange(5, dtype='float64')
UserWarning: Explicitly requested dtype float64 requested in arange is not available, and will be
truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the
JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
Array([0., 1., 2., 3., 4.], dtype=float32)
X64标志:启用64位值#
要在一个函数默认生成64位值的“其他模式”下工作,你可以将 jax_enable_x64
标志设置为 True
import jax
import jax.numpy as jnp
jax.config.update('jax_enable_x64', True)
print(repr(jnp.arange(5)))
print(repr(jnp.zeros(5)))
print(repr(jnp.ones(5, dtype=int)))
Array([0, 1, 2, 3, 4], dtype=int64)
Array([0., 0., 0., 0., 0.], dtype=float64)
Array([1, 1, 1, 1, 1], dtype=int64)
X64配置也可以通过 JAX_ENABLE_X64
shell 环境变量设置,例如
$ JAX_ENABLE_X64=1 python main.py
X64标志旨在作为**全局设置**,它应该在整个程序中只有一个值,并在主文件的顶部设置。一个常见的特性请求是使该标志能够上下文配置(例如,仅在长程序的一个部分启用X64):这在JAX的编程模型中实现起来很困难,因为代码执行可能发生在与代码编译不同的上下文中。目前正在进行探索放松此限制可行性的工作,敬请关注!