jax.lax
模块#
jax.lax
是一个原语操作库,它是 jax.numpy
等库的基础。转换规则,例如 JVP 和批处理规则,通常被定义为对 jax.lax
原语的转换。
许多原语是对等效 XLA 操作的薄封装,这些操作在 XLA 操作语义文档中描述。在少数情况下,JAX 会偏离 XLA,通常是为了确保操作集在 JVP 和转置规则的操作下是闭合的。
在可能的情况下,请优先使用 jax.numpy
等库,而不是直接使用 jax.lax
。 jax.numpy
API 遵循 NumPy,因此比 jax.lax
API 更稳定且更不易更改。
运算符#
|
逐元素绝对值:\(|x|\)。 |
|
逐元素反余弦:\(\mathrm{acos}(x)\)。 |
|
逐元素反双曲余弦:\(\mathrm{acosh}(x)\)。 |
|
逐元素加法:\(x + y\)。 |
|
合并一个或多个 XLA 令牌值。 |
|
以近似方式返回 |
|
以近似方式返回 |
|
计算沿 |
|
计算沿 |
|
逐元素反正弦:\(\mathrm{asin}(x)\)。 |
|
逐元素反双曲正弦:\(\mathrm{asinh}(x)\)。 |
|
逐元素反正切:\(\mathrm{atan}(x)\)。 |
|
逐元素双参数反正切:\(\mathrm{atan}({x \over y})\)。 |
|
逐元素反双曲正切:\(\mathrm{atanh}(x)\)。 |
|
批量矩阵乘法。 |
|
指数缩放的 0 阶修正贝塞尔函数:\(\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)\) |
|
指数缩放的 1 阶修正贝塞尔函数:\(\mathrm{i1e}(x) = e^{-|x|} \mathrm{i1}(x)\) |
|
逐元素正则不完全贝塔积分。 |
|
逐元素位播。 |
|
逐元素与:\(x \wedge y\)。 |
|
逐元素非:\(\neg x\)。 |
|
逐元素或:\(x \vee y\)。 |
|
逐元素异或:\(x \oplus y\)。 |
逐元素 population count,计算每个元素中设置的位数。 |
|
|
广播数组,添加新的前导维度 |
|
包装 XLA 的 BroadcastInDim 运算符。 |
|
返回 NumPy 广播 shapes 产生的形状。 |
|
添加 |
|
|
|
逐元素立方根:\(\sqrt[3]{x}\)。 |
|
逐元素向上取整:\(\left\lceil x \right\rceil\)。 |
|
逐元素夹紧。 |
|
逐元素前导零计数。 |
|
将数组的维度折叠成单个维度。 |
|
逐元素创建复数:\(x + jy\)。 |
|
语义由分解函数定义的 Composite。 |
|
沿 dimension 连接一系列数组。 |
|
逐元素复共轭函数:\(\overline{x}\)。 |
|
conv_general_dilated 的便捷包装器。 |
|
逐元素类型转换。 |
|
将卷积 dimension_numbers 转换为 ConvDimensionNumbers。 |
|
通用 n 维卷积运算符,具有可选的空洞。 |
|
通用 n 维非共享卷积运算符,具有可选的空洞。 |
|
提取受 conv_general_dilated 感受野约束的补丁。 |
|
用于计算 N 维卷积“转置”的便捷包装器。 |
|
conv_general_dilated 的便捷包装器。 |
|
逐元素余弦:\(\mathrm{cos}(x)\)。 |
|
逐元素双曲余弦:\(\mathrm{cosh}(x)\)。 |
|
计算沿 axis 的累积 logsumexp。 |
|
计算沿 axis 的累积最大值。 |
|
计算沿 axis 的累积最小值。 |
|
计算沿 axis 的累积积。 |
|
计算沿 axis 的累积和。 |
|
逐元素 digamma:\(\psi(x)\)。 |
|
逐元素除法:\(x \over y\)。 |
|
向量/向量、矩阵/向量和矩阵/矩阵乘法。 |
|
通用点积/缩并运算符。 |
|
dynamic_slice 的便捷包装器,用于执行整数索引。 |
|
包装 XLA 的 DynamicSlice 运算符。 |
|
|
|
|
|
包装 XLA 的 DynamicUpdateSlice 运算符。 |
|
|
|
逐元素等于:\(x = y\)。 |
|
逐元素误差函数:\(\mathrm{erf}(x)\)。 |
|
逐元素互补误差函数:\(\mathrm{erfc}(x) = 1 - \mathrm{erf}(x)\)。 |
|
逐元素逆误差函数:\(\mathrm{erf}^{-1}(x)\)。 |
|
逐元素指数:\(e^x\)。 |
|
逐元素以 2 为底的指数:\(2^x\)。 |
|
在数组中插入任意数量的大小为 1 的维度。 |
|
逐元素 \(e^{x} - 1\)。 |
|
|
|
逐元素向下取整:\(\left\lfloor x \right\rfloor\)。 |
|
返回填充了 fill_value 的 shape 数组。 |
|
基于示例数组 x 创建类似 np.full 的完整数组。 |
|
Gather 操作符。 |
|
逐元素大于等于:\(x \geq y\)。 |
|
逐元素大于:\(x > y\)。 |
|
逐元素正则化不完全伽玛函数。 |
|
逐元素互补正则化不完全伽玛函数。 |
|
逐元素提取虚部:\(\mathrm{Im}(x)\)。 |
|
|
|
|
|
逐元素幂运算:\(x^y\),其中 \(y\) 是一个静态整数。 |
|
包装了 XLA 的 Iota 操作符。 |
|
逐元素 \(\mathrm{isfinite}\)。 |
|
逐元素小于等于:\(x \leq y\)。 |
|
逐元素 log 伽玛函数:\(\mathrm{log}(\Gamma(x))\)。 |
|
逐元素自然对数:\(\mathrm{log}(x)\)。 |
|
逐元素 \(\mathrm{log}(1 + x)\)。 |
|
逐元素 logistic(sigmoid)函数:\(\frac{1}{1 + e^{-x}}\)。 |
|
逐元素小于:\(x < y\)。 |
|
逐元素最大值:\(\mathrm{max}(x, y)\)。 |
|
逐元素最小值:\(\mathrm{min}(x, y)\) |
|
逐元素乘法:\(x \times y\)。 |
|
逐元素不等于:\(x \neq y\)。 |
|
逐元素取反:\(-x\)。 |
|
返回 |
|
阻止编译器跨越 barrier 移动操作。 |
|
对数组应用低、高和/或内部填充。 |
|
分阶段输出特定于平台代码。 |
|
逐元素多伽玛函数:\(\psi^{(m)}(x)\)。 |
逐元素 population count,计算每个元素中设置的位数。 |
|
|
逐元素幂运算:\(x^y\)。 |
|
|
|
逐元素提取实部:\(\mathrm{Re}(x)\)。 |
|
逐元素倒数:\(1 \over x\)。 |
|
包装了 XLA 的 Reduce 操作符。 |
|
计算一个或多个数组轴上元素的按位 AND。 |
|
计算一个或多个数组轴上元素的最大值。 |
|
计算一个或多个数组轴上元素的最小值。 |
|
计算一个或多个数组轴上元素的按位 OR。 |
|
包装了 XLA 的 ReducePrecision 操作符。 |
|
计算一个或多个数组轴上元素的乘积。 |
|
计算一个或多个数组轴上元素的总和。 |
|
包装了 XLA 的 ReduceWindowWithGeneralPadding 操作符。 |
|
计算一个或多个数组轴上元素的按位 XOR。 |
|
逐元素余数:\(x \bmod y\)。 |
|
包装了 XLA 的 Reshape 操作符。 |
|
包装了 XLA 的 Rev 操作符。 |
|
无状态 PRNG 位生成器。 |
|
有状态 PRNG 生成器。 |
|
逐元素四舍五入。 |
|
逐元素倒数平方根:\(1 \over \sqrt{x}\)。 |
|
Scatter-update 操作符。 |
|
Scatter-add 操作符。 |
|
Scatter-apply 操作符。 |
|
Scatter-max 操作符。 |
|
Scatter-min 操作符。 |
|
Scatter-multiply 操作符。 |
|
逐元素左移:\(x \ll y\)。 |
|
逐元素算术右移:\(x \gg y\)。 |
|
逐元素逻辑右移:\(x \gg y\)。 |
|
逐元素符号函数。 |
|
逐元素正弦:\(\mathrm{sin}(x)\)。 |
|
逐元素双曲正弦:\(\mathrm{sinh}(x)\)。 |
|
包装了 XLA 的 Slice 操作符。 |
|
|
|
包装了 XLA 的 Sort 操作符。 |
|
沿 |
|
沿 |
|
逐元素平方根:\(\sqrt{x}\)。 |
|
逐元素平方:\(x^2\)。 |
|
从数组中挤压任意数量的大小为 1 的维度。 |
|
逐元素减法:\(x - y\)。 |
|
逐元素正切:\(\mathrm{tan}(x)\)。 |
|
逐元素双曲正切:\(\mathrm{tanh}(x)\)。 |
|
返回 |
|
包装了 XLA 的 Transpose 操作符。 |
|
逐元素 Hurwitz zeta 函数:\(\zeta(x, q)\) |
控制流操作符#
|
使用结合二元运算并行执行扫描。 |
|
有条件地应用 |
|
通过归约到 |
|
在先导数组轴上映射函数。 |
|
在先导数组轴上扫描函数,同时携带状态。 |
|
根据布尔谓词在两个分支之间进行选择。 |
|
从多个 cases 中选择数组值。 |
|
应用由 |
|
当 |
自定义梯度操作符#
停止梯度计算。 |
|
|
执行具有隐式定义梯度的无矩阵线性求解。 |
|
可微地求解函数的根。 |
并行操作符#
|
跨所有副本收集 x 的值。 |
|
物化映射轴并映射不同的轴。 |
|
在 pmapped 轴 |
|
类似于 |
|
在 pmapped 轴 |
|
在 pmapped 轴 |
|
在 pmapped 轴 |
|
根据排列 |
|
jax.lax.ppermute 的便捷包装器,具有备用排列编码 |
|
将 pmapped 轴 |
|
返回沿映射轴 |
|
返回映射轴 |
线性代数操作符 (jax.lax.linalg)#
|
Cholesky 分解。 |
|
Cholesky 秩 1 更新。 |
|
一般矩阵的特征分解。 |
|
Hermitian 矩阵的特征分解。 |
|
将方阵简化为上 Hessenberg 形式。 |
|
初等 Householder 反射器的乘积。 |
|
具有部分旋转的 LU 分解。 |
|
将 LU 返回的 pivots(行交换)转换为排列。 |
|
基于 QR 的动态加权 Halley 迭代,用于极分解。 |
|
QR 分解。 |
|
Schur 分解。 |
|
奇异值分解。 |
|
SVD 算法的枚举。 |
|
对称积。 |
|
三角求解。 |
|
将对称/Hermitian 矩阵简化为三对角形式。 |
|
计算三对角线性系统的解。 |
参数类#
- class jax.lax.DotAlgorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count=1, rhs_component_count=1, num_primitive_operations=1, allow_imprecise_accumulation=False)[source]#
指定用于计算点积的算法。
当用于指定
precision
输入到dot()
,dot_general()
和其他点积函数时,此数据结构用于控制用于计算点积的算法的属性。此 API 控制用于计算的精度,并允许用户访问硬件特定的加速。对这些算法的支持取决于平台,并且使用不受支持的算法将在编译计算时引发 Python 异常。至少在某些平台上已知支持的算法在
DotAlgorithmPreset
枚举中列出,这些是试验此 API 的良好起点。“点积算法”由以下参数指定
lhs_precision_type
和rhs_precision_type
,操作的 LHS 和 RHS 四舍五入到的数据类型。accumulation_type
用于累加的数据类型。lhs_component_count
,rhs_component_count
,和num_primitive_operations
应用于将 LHS 和/或 RHS 分解为多个组件并在这些值上执行多个操作的算法,通常是为了模拟更高的精度。对于没有分解的算法,这些值应设置为1
。allow_imprecise_accumulation
用于指定是否允许在某些步骤中进行较低精度的累加 (例如CUBLASLT_MATMUL_DESC_FAST_ACCUM
)。
StableHLO 规范 对于点积操作,不要求精度类型与输入或输出的存储类型相同,但某些平台可能要求这些类型匹配。此外,
dot_general()
的返回类型始终由输入算法的accumulation_type
参数定义(如果已指定)。示例
使用 32 位浮点累加器累加两个 16 位浮点数
>>> algorithm = DotAlgorithm( ... lhs_precision_type=np.float16, ... rhs_precision_type=np.float16, ... accumulation_type=np.float32, ... ) >>> lhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) >>> rhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) >>> dot(lhs, rhs, precision=algorithm) array([ 1., 4., 9., 16.], dtype=float16)
或者,等效地,使用预设
>>> algorithm = DotAlgorithmPreset.F16_F16_F32 >>> dot(lhs, rhs, precision=algorithm) array([ 1., 4., 9., 16.], dtype=float16)
预设也可以按名称指定
>>> dot(lhs, rhs, precision="F16_F16_F32") array([ 1., 4., 9., 16.], dtype=float16)
preferred_element_type
参数可用于返回输出,而无需向下转换累加类型>>> dot(lhs, rhs, precision="F16_F16_F32", preferred_element_type=np.float32) array([ 1., 4., 9., 16.], dtype=float32)
- class jax.lax.DotAlgorithmPreset(value)[source]#
用于计算点积的已知算法的枚举。
此
Enum
提供了一组命名的DotAlgorithm
对象,这些对象已知在至少一个平台上受支持。有关这些算法行为的更多详细信息,请参阅DotAlgorithm
文档。在调用
dot()
,dot_general()
或大多数其他 JAX 点积函数时,可以从此列表中选择一个算法,方法是使用precision
参数传递此Enum
的成员或其名称作为字符串。例如,用户可以直接使用此
Enum
指定预设>>> lhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) >>> rhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) >>> algorithm = DotAlgorithmPreset.F16_F16_F32 >>> dot(lhs, rhs, precision=algorithm) array([ 1., 4., 9., 16.], dtype=float16)
或者,等效地,可以通过名称指定它们
>>> dot(lhs, rhs, precision="F16_F16_F32") array([ 1., 4., 9., 16.], dtype=float16)
预设的名称通常为
LHS_RHS_ACCUM
,其中LHS
和RHS
分别是lhs
和rhs
输入的元素类型,而ACCUM
是累加器的元素类型。一些预设有一个额外的后缀,每个后缀的含义在下面文档中说明。支持的预设是- DEFAULT = 1#
将根据输入和输出类型选择算法。
- ANY_F8_ANY_F8_F32 = 2#
接受任何 float8 输入类型并累加到 float32。
- ANY_F8_ANY_F8_F32_FAST_ACCUM = 3#
类似于
ANY_F8_ANY_F8_F32
,但使用更快的累加,但以较低的精度为代价。
- ANY_F8_ANY_F8_ANY = 4#
类似于
ANY_F8_ANY_F8_F32
,但累加类型由preferred_element_type
控制。
- ANY_F8_ANY_F8_ANY_FAST_ACCUM = 5#
类似于
ANY_F8_ANY_F8_F32_FAST_ACCUM
,但累加类型由preferred_element_type
控制。
- F16_F16_F16 = 6#
- F16_F16_F32 = 7#
- BF16_BF16_BF16 = 8#
- BF16_BF16_F32 = 9#
- BF16_BF16_F32_X3 = 10#
_X3
后缀表示该算法使用 3 个操作来模拟更高的精度。
- BF16_BF16_F32_X6 = 11#
类似于
BF16_BF16_F32_X3
,但使用 6 个操作而不是 3 个。
- BF16_BF16_F32_X9 = 12#
类似于
BF16_BF16_F32_X3
,但使用 9 个操作而不是 3 个。
- TF32_TF32_F32 = 13#
- TF32_TF32_F32_X3 = 14#
_X3
后缀表示该算法使用 3 个操作来模拟更高的精度。
- F32_F32_F32 = 15#
- F64_F64_F64 = 16#
- class jax.lax.FftType(value)[source]#
描述要执行的 FFT 操作的类型。
- FFT = 0#
正向复数到复数 FFT。
- IFFT = 1#
逆向复数到复数 FFT。
- IRFFT = 3#
逆向实数到复数 FFT。
- RFFT = 2#
正向实数到复数 FFT。
- class jax.lax.GatherDimensionNumbers(offset_dims, collapsed_slice_dims, start_index_map, operand_batching_dims=(), start_indices_batching_dims=())[source]#
描述 XLA 的 Gather 运算符 的维度编号参数。有关维度编号含义的更多详细信息,请参阅 XLA 文档。
- 参数:
offset_dims (tuple[int, ...]) – gather 输出中偏移到从 operand 切片的数组中的维度集合。 必须是一个升序排列的整数元组,每个整数代表输出的维度编号。
collapsed_slice_dims (tuple[int, ...]) – operand 中 i 维度的集合,这些维度的 slice_sizes[i] == 1,并且在 gather 的输出中不应具有相应的维度。 必须是一个升序排列的整数元组。
start_index_map (tuple[int, ...]) – 对于 start_indices 中的每个维度,给出 operand 中要切片的相应维度。 必须是一个整数元组,大小等于 start_indices.shape[-1]。
operand_batching_dims (tuple[int, ...]) – operand 中的批处理维度 i 的集合,这些维度的 slice_sizes[i] == 1,并且应该在 start_indices (在 start_indices_batching_dims 中的相同索引处) 和 gather 的输出中都具有相应的维度。 必须是一个升序排列的整数元组。
start_indices_batching_dims (tuple[int, ...]) – start_indices 中的批处理维度 i 的集合,这些维度应该在 operand (在 operand_batching_dims 中的相同索引处) 和 gather 的输出中都具有相应的维度。 必须是一个整数元组(顺序根据与 operand_batching_dims 的对应关系固定)。
与 XLA 的 GatherDimensionNumbers 结构不同,index_vector_dim 是隐式的;始终存在索引向量维度,并且它必须始终是最后一个维度。 要收集标量索引,请添加大小为 1 的尾部维度。
- class jax.lax.GatherScatterMode(value)[source]#
描述如何在 gather 或 scatter 中处理越界索引。
可能的值为
- CLIP
索引将被钳制到最近的范围内值,即,使得要 gather 的整个窗口都在范围内。
- FILL_OR_DROP
如果 gather 的窗口的任何部分超出范围,则返回的整个窗口(即使是那些原本在范围内的元素)都将填充一个常量。 如果 scatter 的窗口的任何部分超出范围,则整个窗口将被丢弃。
- PROMISE_IN_BOUNDS
用户承诺索引在范围内。 将不会执行额外的检查。 实际上,对于当前的 XLA 实现,这意味着越界 gather 将被钳制,但越界 scatter 将被丢弃。 如果索引超出范围,则梯度将不正确。
- class jax.lax.Precision(value)[source]#
用于 lax 矩阵乘法相关函数的精度枚举。
JAX 函数的设备相关的 precision 参数通常控制加速器后端(即 TPU 和 GPU)上数组计算的速度和精度之间的权衡。 对 CPU 后端没有影响。 这仅对 float32 计算有效,并且不影响输入/输出数据类型。 成员包括
- DEFAULT
最快模式,但精度最低。 在 TPU 上:以 bfloat16 执行 float32 计算。 在 GPU 上:如果可用(例如在 A100 和 H100 GPU 上),则使用 tensorfloat32,否则使用标准 float32(例如在 V100 GPU 上)。 别名:
'default'
,'fastest'
。- HIGH
较慢但更精确。 在 TPU 上:以 3 个 bfloat16 过程执行 float32 计算。 在 GPU 上:在可用时使用 tensorfloat32,否则使用 float32。 别名:
'high'
。- HIGHEST
最慢但最精确。 在 TPU 上:以 6 个 bfloat16 过程执行 float32 计算。 别名:
'highest'
。 在 GPU 上:使用 float32。
- jax.lax.PrecisionLike#
别名:
None
|str
|Precision
|tuple
[str
,str
] |tuple
[Precision
,Precision
] |DotAlgorithm
|DotAlgorithmPreset
- class jax.lax.RandomAlgorithm(value)[source]#
描述用于 rng_bit_generator 的 PRNG 算法。
- RNG_DEFAULT = 0#
平台的默认算法。
- RNG_THREE_FRY = 1#
Threefry-2x32 PRNG 算法。
- RNG_PHILOX = 2#
Philox-4x32 PRNG 算法。
- class jax.lax.RoundingMethod(value)[source]#
jax.lax.round()
中用于处理中间值(例如,0.5)的舍入策略。- AWAY_FROM_ZERO = 0#
将中间值四舍五入到远离零的方向(例如,0.5 -> 1,-0.5 -> -1)。
- TO_NEAREST_EVEN = 1#
将中间值四舍五入到最接近的偶数整数。 这也称为“银行家舍入法”(例如,0.5 -> 0,1.5 -> 2)。
- class jax.lax.ScatterDimensionNumbers(update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims, operand_batching_dims=(), scatter_indices_batching_dims=())[source]#
描述 XLA 的 Scatter 运算符 的维度编号参数。 有关维度编号含义的更多详细信息,请参阅 XLA 文档。
- 参数:
update_window_dims (Sequence[int]) – updates 中作为窗口维度的维度集合。 必须是一个升序排列的整数元组,每个整数代表一个维度编号。
inserted_window_dims (Sequence[int]) – 必须插入到 updates 形状中的大小为 1 的窗口维度集合。必须是整数元组,并按升序排列,每个整数代表输出的维度编号。在 gather 的情况下,这些是 collapsed_slice_dims 的镜像。
scatter_dims_to_operand_dims (Sequence[int]) – 对于 scatter_indices 中的每个维度,给出 operand 中对应的维度。必须是整数序列,其大小等于 scatter_indices.shape[-1]。
operand_batching_dims (Sequence[int]) – operand 中批处理维度 i 的集合,这些维度应该在 scatter_indices (在 scatter_indices_batching_dims 中的相同索引处) 和 updates 中都有对应的维度。必须是整数元组,并按升序排列。在 gather 的情况下,这些是 operand_batching_dims 的镜像。
scatter_indices_batching_dims (Sequence[int]) – scatter_indices 中批处理维度 i 的集合,这些维度应该在 operand (在 operand_batching_dims 中的相同索引处) 和 gather 的输出中都有对应的维度。必须是整数元组 (顺序根据与 input_batching_dims 的对应关系固定)。在 gather 的情况下,这些是 start_indices_batching_dims 的镜像。
与 XLA 的 ScatterDimensionNumbers 结构不同,index_vector_dim 是隐式的;始终存在索引向量维度,并且它必须始终是最后一个维度。要分散标量索引,请添加大小为 1 的尾部维度。