jax.lax 模块#
jax.lax 是支撑 jax.numpy 等库的基础原始运算符库。转换规则,如 JVP 和批处理规则,通常定义为对 jax.lax 原始运算符的转换。
许多原始运算符只是对等效 XLA 运算符的薄封装,如 XLA 运算符语义 文档所述。在少数情况下,JAX 会偏离 XLA,通常是为了确保运算符集合在 JVP 和转置规则的作用下是闭合的。
在可能的情况下,请优先使用 jax.numpy 等库,而不是直接使用 jax.lax。 jax.numpy API 遵循 NumPy,因此比 jax.lax API 更稳定,更不容易发生变化。
运算符#
|
逐元素绝对值:\(|x|\)。 |
|
逐元素反余弦:\(\mathrm{acos}(x)\)。 |
|
逐元素反双曲余弦:\(\mathrm{acosh}(x)\)。 |
|
逐元素加法:\(x + y\)。 |
|
合并一个或多个 XLA 令牌值。 |
|
以近似方式返回 |
|
以近似方式返回 |
|
计算沿 |
|
计算沿 |
|
逐元素反正弦:\(\mathrm{asin}(x)\)。 |
|
逐元素反双曲正弦:\(\mathrm{asinh}(x)\)。 |
|
逐元素反正切:\(\mathrm{atan}(x)\)。 |
|
逐元素双变量反正切:\(\mathrm{atan}({x \over y})\)。 |
|
逐元素反双曲正切:\(\mathrm{atanh}(x)\)。 |
|
批处理矩阵乘法。 |
|
指数缩放的零阶修正贝塞尔函数:\(\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)\) |
|
指数缩放的一阶修正贝塞尔函数:\(\mathrm{i1e}(x) = e^{-|x|} \mathrm{i1}(x)\) |
|
逐元素正则化不完全 Beta 积分。 |
|
逐元素位转换。 |
|
逐元素 AND:\(x \wedge y\)。 |
|
逐元素 NOT:\(\neg x\)。 |
|
逐元素 OR:\(x \vee y\)。 |
|
逐元素异或:\(x \oplus y\)。 |
逐元素 population count,计算每个元素中置位比特的数量。 |
|
|
广播数组,添加新的前导维度 |
|
封装 XLA 的 BroadcastInDim 运算符。 |
|
返回 NumPy 广播 shapes 的结果形状。 |
|
添加前导维度 |
|
|
|
逐元素立方根:\(\sqrt[3]{x}\)。 |
|
逐元素向上取整:\(\left\lceil x \right\rceil\)。 |
|
逐元素夹紧。 |
|
逐元素计算前导零的数量。 |
|
将数组的维度折叠成一个维度。 |
|
逐元素构造复数:\(x + jy\)。 |
|
具有由分解函数定义的语义的复合体。 |
|
沿 dimension 连接一系列数组。 |
|
逐元素复共轭函数:\(\overline{x}\)。 |
|
围绕 conv_general_dilated 的便捷包装器。 |
|
逐元素类型转换。 |
|
将卷积 dimension_numbers 转换为 ConvDimensionNumbers。 |
|
通用 N 维卷积运算符,带可选的扩张。 |
|
通用 N 维非共享卷积运算符,带可选的扩张。 |
|
提取受 conv_general_dilated 感受野影响的块。 |
|
计算 N 维卷积“转置”的便利包装器。 |
|
围绕 conv_general_dilated 的便捷包装器。 |
|
逐元素余弦:\(\mathrm{cos}(x)\)。 |
|
逐元素双曲余弦:\(\mathrm{cosh}(x)\)。 |
|
沿 axis 计算累积 logsumexp。 |
|
沿 axis 计算累积最大值。 |
|
沿 axis 计算累积最小值。 |
|
沿 axis 计算累积乘积。 |
|
沿 axis 计算累积和。 |
|
逐元素双伽马函数:\(\psi(x)\)。 |
|
逐元素除法:\(x \over y\)。 |
|
通用点积/收缩运算符。 |
|
|
|
封装 |
|
封装 XLA 的 DynamicSlice 运算符。 |
|
封装 |
|
封装 |
|
封装 XLA 的 DynamicUpdateSlice 运算符。 |
|
封装 |
|
逐元素等于:\(x = y\)。 |
|
逐元素误差函数:\(\mathrm{erf}(x)\)。 |
|
逐元素互补误差函数:\(\mathrm{erfc}(x) = 1 - \mathrm{erf}(x)\)。 |
|
逐元素反误差函数:\(\mathrm{erf}^{-1}(x)\)。 |
|
逐元素指数:\(e^x\)。 |
|
逐元素以 2 为底的指数:\(2^x\)。 |
|
向数组插入任意数量大小为 1 的维度。 |
|
逐元素 \(e^{x} - 1\)。 |
|
|
|
逐元素向下取整:\(\left\lfloor x \right\rfloor\)。 |
|
返回一个填充了 fill_value 的 shape 数组。 |
|
根据示例数组 x 创建一个与 np.full 相似的数组。 |
|
Gather 运算符。 |
|
逐元素大于等于:\(x \geq y\)。 |
|
逐元素大于:\(x > y\)。 |
|
逐元素正则化不完全伽马函数。 |
|
逐元素互补正则化不完全伽马函数。 |
|
逐元素提取虚部:\(\mathrm{Im}(x)\)。 |
|
封装 |
|
|
|
逐元素幂:\(x^y\),其中 \(y\) 是静态整数。 |
|
封装 XLA 的 Iota 运算符。 |
|
逐元素 \(\mathrm{isfinite}\)。 |
|
逐元素小于等于:\(x \leq y\)。 |
|
逐元素对数伽马函数:\(\mathrm{log}(\Gamma(x))\)。 |
|
逐元素自然对数:\(\mathrm{log}(x)\)。 |
|
逐元素 \(\mathrm{log}(1 + x)\)。 |
|
逐元素 Logistic(Sigmoid)函数:\(\frac{1}{1 + e^{-x}}\)。 |
|
逐元素小于:\(x < y\)。 |
|
逐元素最大值:\(\mathrm{max}(x, y)\)。 |
|
逐元素最小值:\(\mathrm{min}(x, y)\) |
|
逐元素乘法:\(x \times y\)。 |
|
逐元素不等于:\(x \neq y\)。 |
|
逐元素取反:\(-x\)。 |
|
返回 |
|
阻止编译器将操作移到屏障之外。 |
|
对数组应用低、高和/或内部填充。 |
|
暂存平台特定的代码。 |
|
逐元素多伽马函数:\(\psi^{(m)}(x)\)。 |
逐元素 population count,计算每个元素中置位比特的数量。 |
|
|
逐元素幂:\(x^y\)。 |
|
|
|
逐元素提取实部:\(\mathrm{Re}(x)\)。 |
|
逐元素倒数:\(1 \over x\)。 |
|
封装 XLA 的 Reduce 运算符。 |
|
计算一个或多个数组轴上元素的按位 AND。 |
|
计算一个或多个数组轴上元素的最大值。 |
|
计算一个或多个数组轴上元素的最小值。 |
|
计算一个或多个数组轴上元素的按位 OR。 |
|
封装 XLA 的 ReducePrecision 运算符。 |
|
计算一个或多个数组轴上元素的乘积。 |
|
计算一个或多个数组轴上元素的和。 |
|
在填充的窗口上进行归约。 |
|
计算一个或多个数组轴上元素的按位 XOR。 |
|
逐元素余数:\(x \bmod y\)。 |
|
封装 XLA 的 Reshape 运算符。 |
|
封装 XLA 的 Rev 运算符。 |
|
无状态 PRNG 位生成器。 |
|
有状态 PRNG 生成器。 |
|
逐元素四舍五入。 |
|
逐元素倒数平方根:\(1 \over \sqrt{x}\)。 |
|
Scatter-update 运算符。 |
|
Scatter-add 运算符。 |
|
Scatter-apply 运算符。 |
|
Scatter-max 运算符。 |
|
Scatter-min 运算符。 |
|
Scatter-multiply 运算符。 |
|
逐元素左移:\(x \ll y\)。 |
|
逐元素算术右移:\(x \gg y\)。 |
|
逐元素逻辑右移:\(x \gg y\)。 |
|
逐元素符号。 |
|
逐元素正弦:\(\mathrm{sin}(x)\)。 |
|
逐元素双曲正弦:\(\mathrm{sinh}(x)\)。 |
|
封装 XLA 的 Slice 运算符。 |
|
封装 |
|
封装 XLA 的 Sort 运算符。 |
|
沿 |
|
沿 |
|
逐元素平方根:\(\sqrt{x}\)。 |
|
逐元素平方:\(x^2\)。 |
|
从数组中挤压掉任意数量的大小为 1 的维度。 |
|
逐元素减法:\(x - y\)。 |
|
逐元素正切:\(\mathrm{tan}(x)\)。 |
|
逐元素双曲正切:\(\mathrm{tanh}(x)\)。 |
|
沿 |
|
封装 XLA 的 Transpose 运算符。 |
|
逐元素赫维茨 zeta 函数:\(\zeta(x, q)\) |
控制流运算符#
|
并行执行具有关联二元运算的扫描。 |
|
根据谓词 |
|
通过归约到 |
|
将函数映射到前导数组轴。 |
|
在携带状态的同时将函数扫描到前导数组轴。 |
|
根据布尔谓词在两个分支之间进行选择。 |
|
从多个案例中选择数组值。 |
|
根据 |
|
当 |
自定义梯度运算符#
停止梯度计算。 |
|
|
执行具有隐式定义的梯度的无矩阵线性求解。 |
|
可微分地求解函数的根。 |
并行运算符#
|
跨所有副本收集 x 的值。 |
|
实现映射轴并映射另一个轴。 |
|
在 pmapped 轴 |
|
类似于 |
|
在 pmapped 轴 |
|
在 pmapped 轴 |
|
在 pmapped 轴 |
|
根据置换 |
|
使用替代置换编码的 jax.lax.ppermute 的便利包装器。 |
|
将 pmapped 轴 |
|
返回沿映射轴 |
|
返回映射轴 |
|
根据置换 |
|
根据置换 |
线性代数运算符 (jax.lax.linalg)#
|
乔列斯基分解。 |
|
乔列斯基秩-1 更新。 |
|
一般矩阵的特征值分解。 |
|
厄米特矩阵的特征值分解。 |
|
将方阵还原为上黑塞堡形式。 |
|
初等豪斯霍尔德反射的乘积。 |
|
带部分主元的 LU 分解。 |
|
将 LU 返回的主元(行交换)转换为置换。 |
|
用于极分解的基于 QR 的动态加权 Halley 迭代。 |
|
QR 分解。 |
|
Schur 分解。 |
|
奇异值分解。 |
|
SVD 算法的枚举。 |
|
对称乘积。 |
|
三角求解。 |
|
将对称/厄米特矩阵还原为三对角线形式。 |
|
计算三对角线线性系统的解。 |
参数类#
- class jax.lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)[source]#
描述卷积的批次、空间和特征维度。
- 参数:
lhs_spec (Sequence[int]) – 一个非负整数维度编号元组,包含 (batch dimension, feature dimension, spatial dimensions…)。
rhs_spec (Sequence[int]) – 一个非负整数维度编号元组,包含 (out feature dimension, in feature dimension, spatial dimensions…)。
out_spec (Sequence[int]) – 一个非负整数维度编号元组,包含 (batch dimension, feature dimension, spatial dimensions…)。
- jax.lax.ConvGeneralDilatedDimensionNumbers#
- class jax.lax.DotAlgorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count=1, rhs_component_count=1, num_primitive_operations=1, allow_imprecise_accumulation=False)[source]#
指定用于计算点积的算法。
当用于指定
dot()、dot_general()和其他点积函数的precision输入时,此数据结构用于控制用于计算点积的算法的属性。此 API 控制用于计算的精度,并允许用户访问特定于硬件的加速。这些算法的支持取决于平台,使用不受支持的算法将在编译计算时引发 Python 异常。在至少一些平台上已知支持的算法列在
DotAlgorithmPreset枚举中,这些是尝试使用此 API 的一个很好的起点。“点算法”由以下参数指定
lhs_precision_type和rhs_precision_type,操作的 LHS 和 RHS 的舍入数据类型。accumulation_type,用于累积的数据类型。lhs_component_count、rhs_component_count和num_primitive_operations适用于将 LHS 和/或 RHS 分解为多个组件并对这些值执行多个操作的算法,通常是为了模拟更高的精度。对于没有分解的算法,这些值应设置为1。allow_imprecise_accumulation,用于指定是否允许在某些步骤中使用较低精度的累积(例如CUBLASLT_MATMUL_DESC_FAST_ACCUM)。
dot 操作的 StableHLO 规范 不需要精度类型与输入或输出的存储类型相同,但某些平台可能要求这些类型匹配。此外,
dot_general()的返回类型始终由输入算法的accumulation_type参数定义(如果已指定)。示例
使用 32 位浮点累加器累加两个 16 位浮点数
>>> algorithm = DotAlgorithm( ... lhs_precision_type=np.float16, ... rhs_precision_type=np.float16, ... accumulation_type=np.float32, ... ) >>> lhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) >>> rhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) >>> dot(lhs, rhs, precision=algorithm) array([ 1., 4., 9., 16.], dtype=float16)
或者,等效地,使用预设
>>> algorithm = DotAlgorithmPreset.F16_F16_F32 >>> dot(lhs, rhs, precision=algorithm) array([ 1., 4., 9., 16.], dtype=float16)
还可以通过名称指定预设
>>> dot(lhs, rhs, precision="F16_F16_F32") array([ 1., 4., 9., 16.], dtype=float16)
可以使用
preferred_element_type参数在不向下转换累积类型的情况下返回输出>>> dot(lhs, rhs, precision="F16_F16_F32", preferred_element_type=np.float32) array([ 1., 4., 9., 16.], dtype=float32)
- class jax.lax.DotAlgorithmPreset(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
已知点积计算算法的枚举。
这个
Enum提供了一组已知的、至少在 平台 上支持的DotAlgorithm对象。有关这些算法行为的更多详细信息,请参阅DotAlgorithm文档。在调用
dot()、dot_general()或大多数其他 JAX 点积函数时,可以通过将此Enum的成员或其名称作为字符串传递给precision参数来从此列表中选择算法。例如,用户可以直接通过此
Enum指定预设>>> lhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) >>> rhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) >>> algorithm = DotAlgorithmPreset.F16_F16_F32 >>> dot(lhs, rhs, precision=algorithm) array([ 1., 4., 9., 16.], dtype=float16)
或者,等效地,可以通过名称指定
>>> dot(lhs, rhs, precision="F16_F16_F32") array([ 1., 4., 9., 16.], dtype=float16)
预设的名称通常是
LHS_RHS_ACCUM,其中LHS和RHS分别是lhs和rhs输入的元素类型,而ACCUM是累加器的元素类型。一些预设有一个额外的后缀,这些后缀的含义将在下面文档中说明。支持的预设如下:- DEFAULT = 1#
将根据输入和输出类型选择算法。
- ANY_F8_ANY_F8_F32 = 2#
接受任何 float8 输入类型,并累加到 float32。
- ANY_F8_ANY_F8_F32_FAST_ACCUM = 3#
类似于
ANY_F8_ANY_F8_F32,但使用更快的累加,以牺牲精度为代价。
- ANY_F8_ANY_F8_ANY = 4#
类似于
ANY_F8_ANY_F8_F32,但累加类型由preferred_element_type控制。
- ANY_F8_ANY_F8_ANY_FAST_ACCUM = 5#
类似于
ANY_F8_ANY_F8_F32_FAST_ACCUM,但累加类型由preferred_element_type控制。
- F16_F16_F16 = 6#
- F16_F16_F32 = 7#
- BF16_BF16_BF16 = 8#
- BF16_BF16_F32 = 9#
- BF16_BF16_F32_X3 = 10#
后缀
_X3表示该算法使用 3 次运算来模拟更高的精度。
- BF16_BF16_F32_X6 = 11#
类似于
BF16_BF16_F32_X3,但使用 6 次运算而不是 3 次。
- BF16_BF16_F32_X9 = 12#
类似于
BF16_BF16_F32_X3,但使用 9 次运算而不是 3 次。
- TF32_TF32_F32 = 13#
- TF32_TF32_F32_X3 = 14#
后缀
_X3表示该算法使用 3 次运算来模拟更高的精度。
- F32_F32_F32 = 15#
- F64_F64_F64 = 16#
- class jax.lax.FftType(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
描述要执行的 FFT 操作。
- FFT = 0#
正向复数到复数 FFT。
- IFFT = 1#
反向复数到复数 FFT。
- IRFFT = 3#
反向实数到复数 FFT。
- RFFT = 2#
正向实数到复数 FFT。
- class jax.lax.GatherDimensionNumbers(offset_dims, collapsed_slice_dims, start_index_map, operand_batching_dims=(), start_indices_batching_dims=())[source]#
描述 XLA 的 Gather 运算符的维度编号参数。有关维度编号含义的更多详细信息,请参阅 XLA 文档。
- 参数:
offset_dims (tuple[int, ...]) – gather 输出中用于偏移从 operand 切片出的数组的维度集合。必须是升序整数元组,每个整数代表输出的维度编号。
collapsed_slice_dims (tuple[int, ...]) – operand 中 slice_sizes[i] == 1 且不应在 gather 输出中具有相应维度的维度 i 的集合。必须是升序整数元组。
start_index_map (tuple[int, ...]) – 对于 start_indices 中的每个维度,给出要切片的 operand 中对应的维度。必须是大小等于 start_indices.shape[-1] 的整数元组。
operand_batching_dims (tuple[int, ...]) – operand 中 slice_sizes[i] == 1 且应在 start_indices(在 start_indices_batching_dims 的相同索引处)和 gather 输出中具有相应维度的批处理维度 i 的集合。必须是升序整数元组。
start_indices_batching_dims (tuple[int, ...]) – start_indices 中应在 operand(在 operand_batching_dims 的相同索引处)和 gather 输出中具有相应维度的批处理维度 i 的集合。必须是整数元组(顺序基于与 operand_batching_dims 的对应关系固定)。
与 XLA 的 GatherDimensionNumbers 结构不同,index_vector_dim 是隐式的;始终存在一个索引向量维度,并且它必须始终是最后一个维度。要收集标量索引,请添加大小为 1 的尾随维度。
- class jax.lax.GatherScatterMode(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
描述如何处理 gather 或 scatter 中的越界索引。
可能的值是
- CLIP
索引将被裁剪到最接近的范围内值,即,使得要收集的整个窗口都在范围内。
- FILL_OR_DROP
如果收集窗口的任何部分越界,则返回的整个窗口(即使是那些原本在范围内的元素)都将被填充为常量。如果分散窗口的任何部分越界,则整个窗口将被丢弃。
- PROMISE_IN_BOUNDS
用户保证索引在范围内。将不执行额外的检查。实际上,使用当前的 XLA 实现,这意味着越界收集将被裁剪,但越界分散将被丢弃。如果索引越界,梯度将不正确。
- class jax.lax.Precision(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
用于 lax 矩阵乘法相关函数的精度枚举。
JAX 函数的设备相关 precision 参数通常控制加速器后端(即 TPU 和 GPU)上数组计算的速度和准确性之间的权衡。对 CPU 后端没有影响。这仅对 float32 计算有效,并且不影响输入/输出数据类型。成员是
- DEFAULT
最快的模式,但精度最低。在 TPU 上:以 bfloat16 执行 float32 计算。在 GPU 上:如果可用,则使用 tensorfloat32(例如,在 A100 和 H100 GPU 上),否则使用标准 float32(例如,在 V100 GPU 上)。别名:
'default'、'fastest'。- HIGH
速度较慢但精度更高。在 TPU 上:以 3 次 bfloat16 传递执行 float32 计算。在 GPU 上:在可用时使用 tensorfloat32,否则使用 float32。别名:
'high'。- HIGHEST
最慢但最精确。在 TPU 上:以 6 次 bfloat16 传递执行 float32 计算。别名:
'highest'。在 GPU 上:使用 float32。
- jax.lax.PrecisionLike#
别名:
None|str|Precision|tuple[str,str] |tuple[Precision,Precision] |DotAlgorithm|DotAlgorithmPreset
- class jax.lax.RandomAlgorithm(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
描述 `rng_bit_generator` 使用的 PRNG 算法。
- RNG_DEFAULT = 0#
平台的默认算法。
- RNG_THREE_FRY = 1#
Threefry-2x32 PRNG 算法。
- RNG_PHILOX = 2#
Philox-4x32 PRNG 算法。
- class jax.lax.RoundingMethod(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
处理
jax.lax.round()中一半值(例如 0.5)的舍入策略。- AWAY_FROM_ZERO = 0#
将一半值舍入到远离零的方向(例如,0.5 -> 1,-0.5 -> -1)。
- TO_NEAREST_EVEN = 1#
将一半值舍入到最接近的偶数整数。这也被称为“银行家舍入”(例如,0.5 -> 0,1.5 -> 2)。
- class jax.lax.ScatterDimensionNumbers(update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims, operand_batching_dims=(), scatter_indices_batching_dims=())[source]#
描述 XLA 的 Scatter 运算符的维度编号参数。有关维度编号含义的更多详细信息,请参阅 XLA 文档。
- 参数:
update_window_dims (Sequence[int]) – updates 中作为窗口维度的维度集合。必须是升序整数元组,每个整数代表一个维度编号。
inserted_window_dims (Sequence[int]) – 必须插入到 updates 形状中的大小为 1 的窗口维度集合。必须是升序整数元组,每个整数代表输出的维度编号。这些是 gather 的 collapsed_slice_dims 的镜像。
scatter_dims_to_operand_dims (Sequence[int]) – 对于 scatter_indices 中的每个维度,给出 operand 中对应的维度。必须是大小等于 scatter_indices.shape[-1] 的整数序列。
operand_batching_dims (Sequence[int]) – operand 中应在 scatter_indices(在 scatter_indices_batching_dims 的相同索引处)和 updates 中具有相应维度的批处理维度 i 的集合。必须是升序整数元组。这些是 gather 的 operand_batching_dims 的镜像。
scatter_indices_batching_dims (Sequence[int]) – scatter_indices 中应在 operand(在 operand_batching_dims 的相同索引处)和 gather 输出中具有相应维度的批处理维度 i 的集合。必须是整数元组(顺序基于与 input_batching_dims 的对应关系固定)。这些是 gather 的 start_indices_batching_dims 的镜像。
与 XLA 的 ScatterDimensionNumbers 结构不同,index_vector_dim 是隐式的;始终存在一个索引向量维度,并且它必须始终是最后一个维度。要分散标量索引,请添加大小为 1 的尾随维度。