jax.lax
模块#
jax.lax
是一个原始操作库,支持 jax.numpy
等库。转换规则(如 JVP 和批处理规则)通常定义为 jax.lax
原始操作的转换。
许多原始操作是等效 XLA 操作的薄封装,这些操作由 XLA 操作语义 文档描述。在少数情况下,JAX 与 XLA 不同,通常是为了确保操作集在 JVP 和转置规则的操作下是封闭的。
在可能的情况下,优先使用 jax.numpy
等库,而不是直接使用 jax.lax
。jax.numpy
API 遵循 NumPy,因此比 jax.lax
API 更稳定,更不容易改变。
运算符#
|
逐元素的绝对值: \(|x|\)。 |
|
逐元素的反余弦: \(\mathrm{acos}(x)\)。 |
|
逐元素的反双曲余弦: \(\mathrm{acosh}(x)\)。 |
|
逐元素的加法: \(x + y\)。 |
|
合并一个或多个 XLA 令牌值。 |
|
近似返回 |
|
近似返回 |
|
计算沿 |
|
计算沿 |
|
逐元素的反正弦: \(\mathrm{asin}(x)\)。 |
|
逐元素的反双曲正弦: \(\mathrm{asinh}(x)\)。 |
|
逐元素的反正切: \(\mathrm{atan}(x)\)。 |
|
逐元素的双参数反正切: \(\mathrm{atan}({x \over y})\)。 |
|
逐元素的反双曲正切: \(\mathrm{atanh}(x)\)。 |
|
批量矩阵乘法。 |
|
0阶指数缩放修正贝塞尔函数: \(\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)\) |
|
1阶指数缩放修正贝塞尔函数: \(\mathrm{i1e}(x) = e^{-|x|} \mathrm{i1}(x)\) |
|
逐元素的正则不完全贝塔积分。 |
|
逐位转换。 |
|
逐位与: \(x \wedge y\)。 |
|
逐位非: \(\neg x\)。 |
|
逐位或: \(x \vee y\)。 |
|
逐位异或: \(x \oplus y\)。 |
逐元素的位计数,计算每个元素中设置的位数。 |
|
|
广播一个数组,添加新的前导维度 |
|
封装了 XLA 的 BroadcastInDim 运算符。 |
|
返回 NumPy 广播 shapes 后的形状。 |
|
添加前导维度 |
|
|
|
逐元素的立方根: \(\sqrt[3]{x}\)。 |
|
逐元素的上取整: \(\left\lceil x \right\rceil\)。 |
|
逐元素的钳位。 |
|
逐元素的零前缀计数。 |
|
将数组的维度折叠成一个维度。 |
|
逐元素创建复数: \(x + jy\)。 |
|
具有分解函数定义的语义的复合函数。 |
|
沿 dimension 拼接一系列数组。 |
|
逐元素的复共轭函数: \(\overline{x}\)。 |
|
围绕 conv_general_dilated 的便捷包装器。 |
|
逐元素的类型转换。 |
|
将卷积 dimension_numbers 转换为 ConvDimensionNumbers。 |
|
通用 n 维卷积运算符,带可选的膨胀。 |
|
通用 n 维非共享卷积运算符,带可选的膨胀。 |
|
根据 conv_general_dilated 的感受野提取补丁。 |
|
用于计算 N 维卷积“转置”的便捷封装。 |
|
围绕 conv_general_dilated 的便捷包装器。 |
|
逐元素的余弦: \(\mathrm{cos}(x)\)。 |
|
逐元素的双曲余弦: \(\mathrm{cosh}(x)\)。 |
|
计算沿 axis 的累积 logsumexp。 |
|
计算沿 axis 的累积最大值。 |
|
计算沿 axis 的累积最小值。 |
|
计算沿 axis 的累积积。 |
|
计算沿 axis 的累积和。 |
|
逐元素的双伽马函数: \(\psi(x)\)。 |
|
逐元素的除法: \(x \over y\)。 |
|
向量/向量、矩阵/向量和矩阵/矩阵乘法。 |
|
通用点积/收缩运算符。 |
|
围绕 dynamic_slice 执行整数索引的便捷封装。 |
|
封装了 XLA 的 DynamicSlice 运算符。 |
|
应用于单个维度的 |
|
围绕 |
|
封装了 XLA 的 DynamicUpdateSlice 运算符。 |
|
围绕 |
|
逐元素的相等: \(x = y\)。 |
|
逐元素的误差函数: \(\mathrm{erf}(x)\)。 |
|
逐元素的余误差函数: \(\mathrm{erfc}(x) = 1 - \mathrm{erf}(x)\)。 |
|
逐元素的反误差函数: \(\mathrm{erf}^{-1}(x)\)。 |
|
逐元素的指数: \(e^x\)。 |
|
逐元素的以 2 为底的指数: \(2^x\)。 |
|
将任意数量的大小为 1 的维度插入到数组中。 |
|
逐元素的 \(e^{x} - 1\)。 |
|
|
|
逐元素的向下取整: \(\left\lfloor x \right\rfloor\)。 |
|
返回一个填充了 fill_value 的 shape 数组。 |
|
基于示例数组 x 创建一个类似于 np.full 的完整数组。 |
|
收集运算符。 |
|
逐元素的大于或等于: \(x \geq y\)。 |
|
逐元素的大于: \(x > y\)。 |
|
逐元素的正则不完全伽马函数。 |
|
逐元素的余正则不完全伽马函数。 |
|
逐元素提取虚部: \(\mathrm{Im}(x)\)。 |
|
围绕 |
|
|
|
逐元素的幂: \(x^y\),其中 \(y\) 是静态整数。 |
|
封装了 XLA 的 Iota 运算符。 |
|
逐元素的 \(\mathrm{isfinite}\)。 |
|
逐元素的小于或等于: \(x \leq y\)。 |
|
逐元素的对数伽马: \(\mathrm{log}(\Gamma(x))\)。 |
|
逐元素的自然对数: \(\mathrm{log}(x)\)。 |
|
逐元素的 \(\mathrm{log}(1 + x)\)。 |
|
逐元素的逻辑(sigmoid)函数: \(\frac{1}{1 + e^{-x}}\)。 |
|
逐元素的小于: \(x < y\)。 |
|
逐元素的最大值: \(\mathrm{max}(x, y)\)。 |
|
逐元素的最小值: \(\mathrm{min}(x, y)\) |
|
逐元素的乘法: \(x \times y\)。 |
|
逐元素的不等于: \(x \neq y\)。 |
|
逐元素的取反: \(-x\)。 |
|
返回 |
|
阻止编译器在障碍物之间移动操作。 |
|
对数组应用低、高和/或内部填充。 |
|
分阶段地输出平台相关的代码。 |
|
逐多元伽马函数: \(\psi^{(m)}(x)\)。 |
逐元素的位计数,计算每个元素中设置的位数。 |
|
|
逐元素的幂: \(x^y\)。 |
|
|
|
逐元素提取实部: \(\mathrm{Re}(x)\)。 |
|
逐元素的倒数: \(1 \over x\)。 |
|
封装了 XLA 的 Reduce 运算符。 |
|
计算一个或多个数组轴上元素的按位与。 |
|
计算一个或多个数组轴上元素的最大值。 |
|
计算一个或多个数组轴上元素的最小值。 |
|
计算一个或多个数组轴上元素的按位或。 |
|
封装了 XLA 的 ReducePrecision 运算符。 |
|
计算一个或多个数组轴上元素的积。 |
|
计算一个或多个数组轴上元素的和。 |
|
填充窗口上的归约。 |
|
计算一个或多个数组轴上元素的按位异或。 |
|
逐元素的余数: \(x \bmod y\)。 |
|
封装了 XLA 的 Reshape 运算符。 |
|
封装了 XLA 的 Rev 运算符。 |
|
无状态 PRNG 位生成器。 |
|
有状态 PRNG 生成器。 |
|
逐元素的四舍五入。 |
|
逐元素的平方根倒数: \(1 \over \sqrt{x}\)。 |
|
散布更新运算符。 |
|
散布加法运算符。 |
|
散布应用运算符。 |
|
散布最大值运算符。 |
|
散布最小值运算符。 |
|
散布乘法运算符。 |
|
逐元素的左移: \(x \ll y\)。 |
|
逐元素的算术右移: \(x \gg y\)。 |
|
逐元素的逻辑右移: \(x \gg y\)。 |
|
逐元素的符号。 |
|
逐元素的正弦: \(\mathrm{sin}(x)\)。 |
|
逐元素的双曲正弦: \(\mathrm{sinh}(x)\)。 |
|
封装了 XLA 的 Slice 运算符。 |
|
应用于单个维度的 |
|
封装了 XLA 的 Sort 运算符。 |
|
沿 |
|
沿 |
|
逐元素的平方根: \(\sqrt{x}\)。 |
|
逐元素的平方: \(x^2\)。 |
|
从数组中挤压掉任意数量的大小为 1 的维度。 |
|
逐元素的减法: \(x - y\)。 |
|
逐元素的正切: \(\mathrm{tan}(x)\)。 |
|
逐元素的双曲正切: \(\mathrm{tanh}(x)\)。 |
|
返回 |
|
封装了 XLA 的 Transpose 运算符。 |
|
逐元素的赫尔维茨 zeta 函数: \(\zeta(x, q)\) |
控制流运算符#
|
并行地使用结合二元操作执行扫描。 |
|
有条件地应用 |
|
通过归约为 |
|
将函数映射到前导数组轴上。 |
|
在携带状态的同时,在前导数组轴上扫描函数。 |
|
根据布尔谓词在两个分支之间进行选择。 |
|
从多个案例中选择数组值。 |
|
应用由 |
|
当 |
自定义梯度运算符#
停止梯度计算。 |
|
|
执行带有隐式定义梯度的无矩阵线性求解。 |
|
可微分地求解函数的根。 |
并行运算符#
|
在所有副本中收集 x 的值。 |
|
具体化映射轴并映射不同的轴。 |
|
在 pmap 轴 |
|
类似于 |
|
在 pmap 轴 |
|
在 pmap 轴 |
|
在 pmap 轴 |
|
根据置换 |
|
jax.lax.ppermute 的便捷封装,采用替代置换编码。 |
|
将 pmap 轴 |
|
返回映射轴 |
|
返回映射轴 |
|
根据置换 |
|
根据置换 |
线性代数运算符 (jax.lax.linalg)#
|
Cholesky 分解。 |
|
Cholesky 秩 1 更新。 |
|
一般矩阵的特征分解。 |
|
Hermitian 矩阵的特征分解。 |
|
将方阵约化为上 Hessenberg 形式。 |
|
基本 Householder 反射的乘积。 |
|
带部分选主元的 LU 分解。 |
|
将 LU 返回的枢轴(行交换)转换为置换。 |
|
用于极分解的基于 QR 的动态加权 Halley 迭代。 |
|
QR 分解。 |
|
Schur 分解。 |
|
奇异值分解。 |
|
SVD 算法的枚举。 |
|
对称积。 |
|
三角求解。 |
|
将对称/ Hermitian 矩阵约化为三对角形式。 |
|
计算三对角线性系统的解。 |
参数类#
- jax.lax.ConvGeneralDilatedDimensionNumbers#
- class jax.lax.DotAlgorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count=1, rhs_component_count=1, num_primitive_operations=1, allow_imprecise_accumulation=False)[source]#
指定用于计算点积的算法。
当用于指定
precision
输入到dot()
、dot_general()
和其他点积函数时,此数据结构用于控制计算点积所用算法的属性。此 API 控制计算中使用的精度,并允许用户访问硬件特定的加速功能。对这些算法的支持依赖于平台,使用不支持的算法在计算编译时会引发 Python 异常。已知在至少某些平台上支持的算法列在
DotAlgorithmPreset
枚举中,这些是尝试使用此 API 的良好起点。“点算法”由以下参数指定:
lhs_precision_type
和rhs_precision_type
,操作的左侧(LHS)和右侧(RHS)四舍五入到的数据类型。accumulation_type
用于累积的数据类型。lhs_component_count
、rhs_component_count
和num_primitive_operations
适用于将操作的左侧(LHS)和/或右侧(RHS)分解为多个组件,并对这些值执行多次操作的算法,通常是为了模拟更高的精度。对于没有分解的算法,这些值应设置为1
。allow_imprecise_accumulation
用于指定某些步骤是否允许使用较低精度进行累积(例如CUBLASLT_MATMUL_DESC_FAST_ACCUM
)。
点操作的 StableHLO 规范 不要求精度类型与输入或输出的存储类型相同,但某些平台可能要求这些类型匹配。此外,如果指定了输入算法的
accumulation_type
参数,dot_general()
的返回类型总是由该参数定义。示例
使用 32 位浮点累加器累积两个 16 位浮点数
>>> algorithm = DotAlgorithm( ... lhs_precision_type=np.float16, ... rhs_precision_type=np.float16, ... accumulation_type=np.float32, ... ) >>> lhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) >>> rhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) >>> dot(lhs, rhs, precision=algorithm) array([ 1., 4., 9., 16.], dtype=float16)
或者,等效地,使用预设
>>> algorithm = DotAlgorithmPreset.F16_F16_F32 >>> dot(lhs, rhs, precision=algorithm) array([ 1., 4., 9., 16.], dtype=float16)
预设也可以通过名称指定
>>> dot(lhs, rhs, precision="F16_F16_F32") array([ 1., 4., 9., 16.], dtype=float16)
参数
preferred_element_type
可用于在不降级累积类型的情况下返回输出。>>> dot(lhs, rhs, precision="F16_F16_F32", preferred_element_type=np.float32) array([ 1., 4., 9., 16.], dtype=float32)
- class jax.lax.DotAlgorithmPreset(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
一个用于计算点积的已知算法的枚举。
此
Enum
提供了一组命名的DotAlgorithm
对象,已知这些对象至少在某个平台上受支持。有关这些算法行为的更多详细信息,请参阅DotAlgorithm
文档。通过将此
Enum
的成员或其名称字符串作为precision
参数传递,可以在调用dot()
、dot_general()
或大多数其他 JAX 点积函数时,从此列表中选择一个算法。例如,用户可以直接使用此
Enum
指定预设>>> lhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) >>> rhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) >>> algorithm = DotAlgorithmPreset.F16_F16_F32 >>> dot(lhs, rhs, precision=algorithm) array([ 1., 4., 9., 16.], dtype=float16)
或者,等效地,它们也可以通过名称指定
>>> dot(lhs, rhs, precision="F16_F16_F32") array([ 1., 4., 9., 16.], dtype=float16)
预设的名称通常是
LHS_RHS_ACCUM
,其中LHS
和RHS
分别是lhs
和rhs
输入的元素类型,ACCUM
是累加器的元素类型。一些预设有一个额外的后缀,它们的含义在下面有说明。支持的预设包括:- DEFAULT = 1#
将根据输入和输出类型选择算法。
- ANY_F8_ANY_F8_F32 = 2#
接受任意 float8 输入类型并累积到 float32。
- ANY_F8_ANY_F8_F32_FAST_ACCUM = 3#
与
ANY_F8_ANY_F8_F32
类似,但使用更快的累积,代价是降低精度。
- ANY_F8_ANY_F8_ANY = 4#
与
ANY_F8_ANY_F8_F32
类似,但累积类型由preferred_element_type
控制。
- ANY_F8_ANY_F8_ANY_FAST_ACCUM = 5#
与
ANY_F8_ANY_F8_F32_FAST_ACCUM
类似,但累积类型由preferred_element_type
控制。
- F16_F16_F16 = 6#
- F16_F16_F32 = 7#
- BF16_BF16_BF16 = 8#
- BF16_BF16_F32 = 9#
- BF16_BF16_F32_X3 = 10#
后缀
_X3
表示算法使用 3 次操作来模拟更高的精度。
- BF16_BF16_F32_X6 = 11#
与
BF16_BF16_F32_X3
类似,但使用 6 次操作而非 3 次。
- BF16_BF16_F32_X9 = 12#
与
BF16_BF16_F32_X3
类似,但使用 9 次操作而非 3 次。
- TF32_TF32_F32 = 13#
- TF32_TF32_F32_X3 = 14#
后缀
_X3
表示算法使用 3 次操作来模拟更高的精度。
- F32_F32_F32 = 15#
- F64_F64_F64 = 16#
- class jax.lax.FftType(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
描述要执行的 FFT 操作类型。
- FFT = 0#
正向复数到复数 FFT。
- IFFT = 1#
逆向复数到复数 FFT。
- IRFFT = 3#
逆向实数到复数 FFT。
- RFFT = 2#
正向实数到复数 FFT。
- class jax.lax.GatherDimensionNumbers(offset_dims, collapsed_slice_dims, start_index_map, operand_batching_dims=(), start_indices_batching_dims=())[source]#
描述 XLA Gather 算子的维度编号参数。有关维度编号含义的更多详细信息,请参阅 XLA 文档。
- 参数:
offset_dims (tuple[int, ...]) – “gather”输出中偏移到从“operand”切片数组中的维度集合。必须是按升序排列的整数元组,每个整数代表输出的一个维度编号。
collapsed_slice_dims (tuple[int, ...]) – “operand”中维度 i 的集合,其中 slice_sizes[i] == 1 并且在 gather 的输出中不应有对应的维度。必须是按升序排列的整数元组。
start_index_map (tuple[int, ...]) – 对于 start_indices 中的每个维度,给出 operand 中要进行切片的相应维度。必须是一个整数元组,其大小等于 start_indices.shape[-1]。
operand_batching_dims (tuple[int, ...]) – “operand”中批处理维度 i 的集合,其中 slice_sizes[i] == 1,并且在 start_indices(在 start_indices_batching_dims 中的相同索引处)和 gather 的输出中都应具有对应的维度。必须是按升序排列的整数元组。
start_indices_batching_dims (tuple[int, ...]) – “start_indices”中批处理维度 i 的集合,该集合在 operand(在 operand_batching_dims 中的相同索引处)和 gather 的输出中都应具有对应的维度。必须是整数元组(顺序根据与 operand_batching_dims 的对应关系固定)。
与 XLA 的 GatherDimensionNumbers 结构不同,index_vector_dim 是隐式的;总是存在一个索引向量维度,并且它必须始终是最后一个维度。要收集标量索引,请添加一个大小为 1 的末尾维度。
- class jax.lax.GatherScatterMode(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
描述如何在 gather 或 scatter 操作中处理越界索引。
可能的值包括:
- CLIP
索引将被限制到最近的在范围内值,即,使得要收集的整个窗口都在范围内。
- FILL_OR_DROP
如果收集窗口的任何部分越界,则返回的整个窗口(即使是那些原本在范围内的元素)都将用常量填充。如果散射窗口的任何部分越界,则整个窗口将被丢弃。
- PROMISE_IN_BOUNDS
用户承诺索引在界限内。将不会执行额外检查。实际上,在当前的 XLA 实现中,这意味着越界 gather 操作将被钳制,但越界 scatter 操作将被丢弃。如果索引越界,梯度将不正确。
- class jax.lax.Precision(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
lax 矩阵乘法相关函数的精度枚举。
JAX 函数中与设备相关的 precision 参数通常控制加速器后端(即 TPU 和 GPU)上数组计算的速度和精度之间的权衡。对 CPU 后端没有影响。这仅对 float32 计算有效,并且不影响输入/输出数据类型。成员包括:
- DEFAULT
最快模式,但精度最低。在 TPU 上:以 bfloat16 执行 float32 计算。在 GPU 上:如果可用,则使用 tensorfloat32(例如在 A100 和 H100 GPU 上),否则使用标准 float32(例如在 V100 GPU 上)。别名:
'default'
、'fastest'
。- HIGH
较慢但更精确。在 TPU 上:以 3 次 bfloat16 通道执行 float32 计算。在 GPU 上:如果可用,则使用 tensorfloat32,否则使用 float32。别名:
'high'
。- HIGHEST
最慢但最精确。在 TPU 上:以 6 次 bfloat16 执行 float32 计算。别名:
'highest'
。在 GPU 上:使用 float32。
- jax.lax.PrecisionLike#
别名
None
|str
|Precision
|tuple
[str
,str
] |tuple
[Precision
,Precision
] |DotAlgorithm
|DotAlgorithmPreset
- class jax.lax.RandomAlgorithm(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
描述用于 rng_bit_generator 的 PRNG 算法。
- RNG_DEFAULT = 0#
平台的默认算法。
- RNG_THREE_FRY = 1#
Threefry-2x32 PRNG 算法。
- RNG_PHILOX = 2#
Philox-4x32 PRNG 算法。
- class jax.lax.RoundingMethod(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
处理
jax.lax.round()
中半值(例如 0.5)的舍入策略。- AWAY_FROM_ZERO = 0#
将半值远离零舍入(例如,0.5 -> 1,-0.5 -> -1)。
- TO_NEAREST_EVEN = 1#
将半值舍入到最近的偶数整数。这也被称为“银行家舍入法”(例如,0.5 -> 0,1.5 -> 2)。
- class jax.lax.ScatterDimensionNumbers(update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims, operand_batching_dims=(), scatter_indices_batching_dims=())[source]#
描述 XLA Scatter 算子的维度编号参数。有关维度编号含义的更多详细信息,请参阅 XLA 文档。
- 参数:
update_window_dims (序列[int]) – “updates”中作为窗口维度的维度集合。必须是按升序排列的整数元组,每个整数代表一个维度编号。
inserted_window_dims (序列[int]) – 必须插入到 updates 形状中的大小为 1 的窗口维度集合。必须是按升序排列的整数元组,每个整数代表输出的一个维度编号。它们是 gather 操作中 collapsed_slice_dims 的镜像。
scatter_dims_to_operand_dims (序列[int]) – 对于 scatter_indices 中的每个维度,给出 operand 中对应的维度。必须是整数序列,其大小等于 scatter_indices.shape[-1]。
operand_batching_dims (序列[int]) – “operand”中批处理维度 i 的集合,该集合在 scatter_indices(在 scatter_indices_batching_dims 中的相同索引处)和 updates 中都应具有对应的维度。必须是按升序排列的整数元组。它们是 gather 操作中 operand_batching_dims 的镜像。
scatter_indices_batching_dims (序列[int]) – “scatter_indices”中批处理维度 i 的集合,该集合在 operand(在 operand_batching_dims 中的相同索引处)和 gather 的输出中都应具有对应的维度。必须是整数元组(顺序根据与 input_batching_dims 的对应关系固定)。它们是 gather 操作中 start_indices_batching_dims 的镜像。
与 XLA 的 ScatterDimensionNumbers 结构不同,index_vector_dim 是隐式的;总是存在一个索引向量维度,并且它必须始终是最后一个维度。要散射标量索引,请添加一个大小为 1 的末尾维度。