jax.numpy 模块#
使用 jax.lax 中的原始操作实现 NumPy API。
虽然 JAX 尽可能地遵循 NumPy API,但有时 JAX 无法完全遵循 NumPy。
值得注意的是,由于 JAX 数组是不可变的,因此无法在 JAX 中实现原地修改数组的 NumPy API。然而,JAX 通常能够提供一个纯函数式的替代 API。例如,JAX 提供了一个纯函数式的索引更新函数
x.at[i].set(y)来替代原地数组更新(x[i] = y)(参见ndarray.at)。同样地,一些 NumPy 函数在可能的情况下会返回数组的视图(例如
transpose()和reshape())。这些函数的 JAX 版本将返回副本,尽管使用jax.jit()编译一系列操作时,这些副本通常会被 XLA 优化掉。NumPy 在将值提升为
float64类型方面非常积极。JAX 有时在类型提升方面不太积极(参见 类型提升语义)。一些 NumPy 例程具有依赖数据的输出形状(例如
unique()和nonzero())。由于 XLA 编译器要求在编译时已知数组形状,因此此类操作与 JIT 不兼容。因此,JAX 在此类函数中添加了一个可选的size参数,可以在静态指定该参数以便与 JIT 一起使用。
几乎所有适用的 NumPy 函数都已在 jax.numpy 命名空间中实现;它们列在下面。
用于索引更新功能的辅助属性。 |
|
|
是 |
|
逐元素计算绝对值。 |
|
是 |
|
是 |
逐元素相加两个数组。 |
|
|
测试沿给定轴的所有数组元素是否评估为 True。 |
|
在容差范围内检查两个数组是否逐元素近似相等。 |
|
是 |
|
是 |
|
返回复数或复数数组的角度。 |
|
测试数组沿给定轴的任何元素是否为 True。 |
|
返回一个新数组,将值附加到原始数组的末尾。 |
|
沿轴将函数应用于一维数组切片。 |
|
反复将函数应用于指定的轴。 |
|
创建等间距值的数组。 |
|
计算输入数组的三角余弦的逐元素反正值。 |
|
计算输入数组的双曲余弦的逐元素反双曲值。 |
|
计算输入数组的三角正弦的逐元素反正值。 |
|
计算输入数组的双曲正弦的逐元素反双曲值。 |
|
计算输入数组的三角正切的逐元素反正值。 |
|
计算 x1/x2 的反正切,并选择正确的象限。 |
|
计算输入数组的双曲正切的逐元素反双曲值。 |
|
返回数组最大值的索引。 |
|
返回数组最小值的索引。 |
|
返回部分排序数组的索引。 |
|
返回排序数组的索引。 |
|
查找非零数组元素的索引。 |
|
是 |
|
将对象转换为 JAX 数组。 |
|
检查两个数组是否逐元素相等。 |
|
检查两个数组是否逐元素相等。 |
|
返回数组的字符串表示。 |
|
将数组分割成子数组。 |
|
返回数组中数据的字符串表示。 |
|
将对象转换为 JAX 数组。 |
|
是 |
|
是 |
|
将数组转换为指定的数据类型。 |
|
是 |
|
是 |
|
是 |
|
将输入转换为至少具有一维的数组。 |
|
将输入转换为至少具有二维的数组。 |
|
将输入转换为至少具有三维的数组。 |
|
计算加权平均值。 |
|
返回大小为 M 的 Bartlett 窗口。 |
|
计算整数数组中每个值出现的次数。 |
逐元素计算按位 AND 操作。 |
|
|
计算 |
|
是 |
|
是 |
|
是 |
逐元素计算按位 OR 操作。 |
|
|
是 |
逐元素计算按位 XOR 操作。 |
|
|
返回大小为 M 的 Blackman 窗口。 |
|
从块列表中创建数组。 |
是 |
|
|
将数组广播到公共形状。 |
|
将输入形状广播到公共输出形状。 |
|
将数组广播到指定的形状。 |
沿最后一个轴连接切片、标量和类数组对象。 |
|
|
根据转换规则,返回数据类型之间是否可以发生转换。 |
|
计算输入数组的逐元素立方根。 |
是 |
|
|
向上取整到最近的整数。 |
所有字符字符串标量类型的抽象基类。 |
|
|
通过堆叠选择数组的切片来构造数组。 |
|
将数组值裁剪到指定的范围。 |
|
按列堆叠数组。 |
是 |
|
|
complex128 类型的 JAX 标量构造函数。 |
|
complex64 类型的 JAX 标量构造函数。 |
由浮点数组成的所有复数标量类型的抽象基类。 |
|
将复数类型转换为实数类型时引发的警告。 |
|
|
使用布尔条件沿给定轴压缩数组。 |
|
沿现有轴连接数组。 |
|
沿现有轴连接数组。 |
|
是 |
|
返回输入数组的逐元素复共轭。 |
|
两个一维数组的卷积。 |
|
返回数组的副本。 |
|
将 |
|
计算皮尔逊相关系数。 |
|
两个一维数组的互相关。 |
|
计算输入数组的逐元素三角余弦。 |
|
计算输入数组的双曲余弦的逐元素值。 |
|
返回沿给定轴的非零元素的数量。 |
|
估计加权样本协方差。 |
|
计算两个数组的(批量)叉积。 |
是 |
|
|
沿轴的元素的累积乘积。 |
|
沿轴的元素的累积和。 |
|
沿数组轴的累积乘积。 |
|
沿数组轴的累积和。 |
|
将角度从度转换为弧度。 |
|
是 |
|
从数组中删除条目。 |
|
返回指定的对角线或构造对角线数组。 |
|
返回访问多维数组主对角线的索引。 |
|
返回访问给定数组主对角线的索引。 |
|
返回一个二维数组,其中展平的输入数组沿对角线排列。 |
|
返回数组的指定对角线。 |
|
沿给定轴计算数组元素之间的 n 阶差值。 |
|
将数组转换为 bin 索引。 |
|
是 |
|
逐元素计算 x1 除以 x2 的整数商和余数。 |
|
计算两个数组的点积。 |
是 |
|
|
按深度分割数组。 |
|
按深度堆叠数组。 |
|
创建数据类型对象。 |
|
计算展平数组元素的差值。 |
|
爱因斯坦求和。 |
|
在不计算 einsum 的情况下评估最优收缩路径。 |
|
创建空数组。 |
|
创建具有与数组相同形状和数据类型的空数组。 |
|
返回 |
|
计算输入数组的逐元素指数。 |
|
计算输入数组的逐元素以 2 为底的指数。 |
|
在数组中插入长度为 1 的维度。 |
|
计算输入数组的逐元素 |
|
返回满足条件的数组元素。 |
|
创建方阵或长方形单位矩阵。 |
|
计算实值输入数组的逐元素绝对值。 |
|
返回一个对角线被覆盖的数组的副本。 |
|
浮点类型的机器限制。 |
|
将输入向零舍入到最接近的整数。 |
|
返回展平数组中非零元素的索引。 |
|
没有预定义长度的所有标量类型的抽象基类。 |
|
沿给定轴反转数组元素的顺序。 |
|
沿轴 1 反转数组元素的顺序。 |
|
沿轴 0 反转数组元素的顺序。 |
是 |
|
|
计算以 |
|
float16 类型的 JAX 标量构造函数。 |
|
float32 类型的 JAX 标量构造函数。 |
|
float64 类型的 JAX 标量构造函数。 |
|
所有浮点数标量类型的抽象基类。 |
|
向下取整到最近的整数。 |
|
逐元素计算 x1 除以 x2 的地板除法。 |
|
返回输入数组的逐元素最大值。 |
|
返回输入数组的逐元素最小值。 |
|
逐元素计算浮点数取模操作。 |
|
将浮点数值分解为尾数和 2 的指数。 |
|
将缓冲区转换为一维 JAX 数组。 |
|
jnp.fromfile 的未实现 JAX 包装器。 |
|
通过对索引应用函数来创建数组。 |
|
jnp.fromiter 的未实现 JAX 包装器。 |
|
从任意 JAX 兼容的标量函数创建 JAX ufunc。 |
|
将文本字符串转换为一维 JAX 数组。 |
|
通过 DLPack 构建 JAX 数组。 |
|
创建充满指定值的数组。 |
|
创建具有与数组相同形状和数据类型的、充满指定值的数组。 |
|
计算两个数组的最大公约数。 |
|
NumPy 标量类型的基类。 |
|
生成几何间隔的值。 |
是 |
|
|
计算采样函数的数值梯度。 |
|
返回 |
|
返回 |
|
返回大小为 M 的 Hamming 窗口。 |
|
返回大小为 M 的 Hanning 窗口。 |
|
计算阶跃函数(Heaviside 函数)。 |
|
计算一维直方图。 |
|
计算直方图的 bin 边缘。 |
|
计算二维直方图。 |
|
计算 N 维直方图。 |
|
将数组水平分割成子数组。 |
|
水平堆叠数组。 |
|
对于给定直角三角形的两个直角边,返回逐元素的斜边长。 |
|
计算第一类零阶修正贝塞尔函数。 |
|
创建方阵单位矩阵。 |
|
|
|
返回复数参数的逐元素虚部。 |
一种更方便地为数组构建索引元组的方法。 |
|
|
生成网格索引数组。 |
|
在其范围内可能包含不精确表示值的数字标量类型的抽象基类,例如浮点数。 |
|
计算两个数组的内积。 |
|
在指定索引处将条目插入数组。 |
是 |
|
|
int16 类型的 JAX 标量构造函数。 |
|
int32 类型的 JAX 标量构造函数。 |
|
int64 类型的 JAX 标量构造函数。 |
|
int8 类型的 JAX 标量构造函数。 |
|
所有整数标量类型的抽象基类。 |
|
一维线性插值。 |
|
计算两个一维数组的集合交集。 |
|
计算输入数组的按位取反。 |
|
在容差范围内检查两个数组的元素是否近似相等。 |
|
返回一个布尔数组,显示输入是复数的位置。 |
|
检查输入是否为复数或包含复数元素的数组。 |
|
返回一个布尔值,指示提供的 dtype 是否属于指定种类。 |
|
返回一个布尔数组,指示输入中的每个元素是否为有限值。 |
|
确定 |
|
返回一个布尔数组,指示输入中的每个元素是否为无穷大。 |
|
返回一个布尔数组,指示输入中的每个元素是否为 |
|
返回一个布尔数组,指示输入中的每个元素是否为负无穷大。 |
|
返回一个布尔数组,指示输入中的每个元素是否为正无穷大。 |
|
返回一个布尔数组,显示输入是实数的位置。 |
|
检查输入是否不是复数或包含复数元素的数组。 |
|
如果输入是标量,则返回 True。 |
|
如果 arg1 在类型层次结构中等于或低于 arg2,则返回 True。 |
|
检查一个对象是否可以迭代。 |
|
从 N 个一维序列返回多维网格(开放网格)。 |
|
返回大小为 M 的 Kaiser 窗口。 |
|
计算两个输入数组的 Kronecker 乘积。 |
|
计算两个数组的最小公倍数。 |
|
计算 |
|
将 |
|
返回 |
|
返回 |
|
按字典序对键序列进行排序。 |
|
在区间内返回等间距的数字。 |
|
从 npy 文件加载 JAX 数组。 |
|
计算输入数组的逐元素自然对数。 |
|
逐元素计算 x 的以 10 为底的对数。 |
|
逐元素计算 |
|
逐元素计算 |
计算 |
|
以 2 为底,计算 |
|
逐元素计算逻辑 AND 操作。 |
|
|
逐元素计算逻辑 NOT bool(x)。 |
逐元素计算逻辑 OR 操作。 |
|
逐元素计算逻辑 XOR 操作。 |
|
|
生成对数间隔的值。 |
|
返回 (n, n) 数组掩码的索引。 |
|
执行矩阵乘法。 |
|
转置数组的最后两个维度。 |
|
批量矩阵-向量乘积。 |
|
返回数组元素沿给定轴的最大值。 |
返回输入数组的逐元素最大值。 |
|
|
返回给定轴上数组元素的平均值。 |
|
返回数组元素沿给定轴的中位数。 |
|
从 N 个一维向量构建 N 维网格数组。 |
返回密集的多维“网格”。 |
|
|
返回给定轴上数组元素的最小值。 |
返回输入数组的逐元素最小值。 |
|
|
是 |
|
返回输入数组的逐元素小数部分和整数部分。 |
|
将数组轴移动到新位置。 |
逐元素相乘两个数组。 |
|
|
替换数组中的 NaN 和无穷大条目。 |
|
返回数组最大值的索引,忽略 NaN。 |
|
返回数组最小值的索引,忽略 NaN。 |
|
沿轴的累积乘积,忽略 NaN 值。 |
|
沿轴的累积和,忽略 NaN 值。 |
|
返回数组元素沿给定轴的最大值,忽略 NaN。 |
|
返回数组元素沿给定轴的平均值,忽略 NaN。 |
|
返回数组元素沿给定轴的中位数,忽略 NaN。 |
|
返回数组元素沿给定轴的最小值,忽略 NaN。 |
|
沿指定轴计算数据的百分位数,忽略 NaN 值。 |
|
返回数组元素沿给定轴的乘积,忽略 NaN。 |
|
沿指定轴计算数据的分位数,忽略 NaN 值。 |
|
沿给定轴计算标准差,忽略 NaN。 |
|
返回数组元素沿给定轴的和,忽略 NaN。 |
|
沿给定轴计算方差,忽略 NaN。 |
是 |
|
|
返回数组的维度数。 |
返回输入数组的逐元素负值。 |
|
|
返回 |
|
返回数组中非零元素的索引。 |
|
返回 |
|
所有数字标量类型的抽象基类。 |
任何 Python 对象。 |
|
返回开放的多维“网格”。 |
|
|
创建充满 1 的数组。 |
|
创建具有与数组相同形状和数据类型的、充满 1 的数组。 |
|
计算两个数组的外积。 |
|
将位数组打包成 uint8 数组。 |
|
向数组添加填充。 |
|
返回数组的部分排序副本。 |
|
沿指定轴计算数据的百分位数。 |
|
置换数组的轴/维度。 |
|
在域上分段求值函数。 |
|
根据掩码更新数组元素。 |
|
返回给定根序列的多项式系数。 |
|
返回两个多项式的和。 |
|
返回指定阶数多项式导数的系数。 |
|
返回多项式除法的商和余数。 |
|
最小二乘多项式拟合数据。 |
|
返回指定阶数多项式积分的系数。 |
|
返回两个多项式的乘积。 |
|
返回两个多项式的差。 |
|
在特定值处评估多项式。 |
|
返回输入的逐元素正值。 |
|
是 |
|
计算 |
|
是 |
|
返回给定轴上数组元素的乘积。 |
|
返回二进制运算应将参数强制转换为的类型。 |
|
返回沿给定轴的峰峰值范围。 |
|
将元素按指定索引放入数组。 |
|
通过匹配一维索引和数据切片将值放入目标数组。 |
|
计算指定轴上数据的分位数。 |
沿第一个轴连接切片、标量和类数组对象。 |
|
|
将弧度转换为度。 |
|
是 |
|
将数组展平为一维形状。 |
|
将多维索引转换为扁平索引。 |
|
返回复数参数的逐元素实部。 |
|
计算输入的逐元素倒数。 |
|
返回除法的逐元素余数。 |
|
从重复元素构造数组。 |
|
返回数组的重塑副本。 |
|
返回具有指定形状的新数组。 |
|
返回将 JAX 提升规则应用于输入的类型。 |
|
将 |
|
将 x 的元素四舍五入到最接近的整数 |
|
沿指定轴滚动数组的元素。 |
|
将指定轴滚动到给定位置。 |
|
返回给定系数 |
|
在指定的坐标轴所定义的平面内,将数组逆时针旋转 90 度。 |
|
将输入四舍五入到指定的小数位数。 |
一种更方便地为数组构建索引元组的方法。 |
|
|
将数组保存到 NumPy |
|
将多个数组保存到未压缩的 |
|
在已排序数组中执行二分查找。 |
|
根据一系列条件选择值。 |
|
是 |
|
计算两个一维数组的集合差集。 |
|
计算两个数组中元素的集合异或。 |
|
返回数组的形状。 |
|
返回输入的逐元素符号指示。 |
|
返回数组元素的符号位。 |
所有有符号整数标量类型的抽象基类。 |
|
|
计算输入的每个元素的三角正弦。 |
|
计算归一化 sinc 函数。 |
alias of |
|
|
计算输入的逐元素双曲正弦。 |
|
返回沿给定轴的元素数量。 |
|
返回数组的排序副本。 |
|
返回复数数组的排序副本。 |
|
返回 |
|
将数组分割成子数组。 |
|
计算数组的逐元素非负平方根。 |
|
计算数组的逐元素平方。 |
|
移除数组的一个或多个长度为 1 的轴 |
|
沿新轴连接数组。 |
|
计算沿给定轴的标准差。 |
逐元素减去两个数组。 |
|
|
在给定轴上对数组元素求和。 |
|
交换数组的两个轴。 |
|
从数组中选取元素。 |
|
从数组中选取元素。 |
|
计算输入的每个元素的三角正切。 |
|
计算输入的逐元素双曲正切。 |
|
计算两个 N 维数组的张量点积。 |
|
沿指定维度重复 |
|
计算沿给定轴的输入的对角线之和。 |
|
使用复合梯形法则沿给定轴积分。 |
|
返回 N 维数组的转置版本。 |
|
返回对角线及其以下为 1,其他地方为 0 的数组。 |
|
返回数组的下三角部分。 |
|
返回大小为 |
|
返回给定数组的下三角索引。 |
|
修剪输入数组的前导和/或尾随零。 |
|
返回数组的上三角部分。 |
|
返回大小为 |
|
返回给定数组的上三角索引。 |
|
逐元素计算 x1 除以 x2 的结果 |
|
将输入向零舍入到最接近的整数。 |
|
对数组进行逐元素操作的通用函数。 |
alias of |
|
|
JAX uint16 类型的标量构造函数。 |
|
JAX uint32 类型的标量构造函数。 |
|
JAX uint64 类型的标量构造函数。 |
|
JAX uint8 类型的标量构造函数。 |
|
计算两个一维数组的集合并集。 |
|
返回数组中的唯一值。 |
|
从 x 返回唯一值,以及索引、逆索引和计数。 |
|
返回 x 中的唯一值以及它们的计数。 |
|
从 x 返回唯一值,以及索引、逆索引和计数。 |
|
从 x 返回唯一值,以及索引、逆索引和计数。 |
|
解包 uint8 数组中的位。 |
|
将扁平索引转换为多维索引。 |
|
沿轴解堆栈数组。 |
所有无符号整数标量类型的抽象基类。 |
|
|
解开周期性信号。 |
|
生成范德蒙德矩阵。 |
|
计算给定轴上的方差。 |
|
执行两个一维向量的共轭乘积。 |
|
执行两个批量向量的共轭乘积。 |
|
批量共轭向量-矩阵乘积。 |
|
定义一个支持广播的向量化函数。 |
|
垂直分割数组为子数组。 |
|
垂直堆叠数组。 |
|
根据条件从两个数组中选择元素。 |
|
创建全零数组。 |
|
创建与数组形状和 dtype 相同的全零数组。 |
jax.numpy.fft#
|
计算给定轴上的一个一维离散傅里叶变换。 |
|
计算给定轴上的二维离散傅里叶变换。 |
|
返回离散傅里叶变换的采样频率。 |
|
计算给定轴上的多维离散傅里叶变换。 |
|
将零频率 FFT 分量移动到频谱中心。 |
|
计算具有厄米对称频谱的数组的一维 FFT。 |
|
计算一维离散傅里叶逆变换。 |
|
计算二维离散傅里叶逆变换。 |
|
计算多维离散傅里叶逆变换。 |
|
是 |
|
计算具有厄米对称频谱的数组的一维逆 FFT。 |
|
计算实值的一维离散傅里叶逆变换。 |
|
计算实值的二维离散傅里叶逆变换。 |
|
计算实值多维离散傅里叶逆变换。 |
|
计算实值数组的一维离散傅里叶变换。 |
|
计算实值数组的二维离散傅里叶变换。 |
|
返回离散傅里叶变换的采样频率。 |
|
计算实值数组的多维离散傅里叶变换。 |
jax.numpy.linalg#
|
计算矩阵的 Cholesky 分解。 |
|
计算矩阵的条件数。 |
|
计算两个三维向量的叉积 |
|
计算数组的行列式。 |
|
提取矩阵或矩阵堆的对角线。 |
|
计算方阵的特征值和特征向量。 |
|
计算厄米矩阵的特征值和特征向量。 |
|
计算一般矩阵的特征值。 |
|
计算厄米矩阵的特征值。 |
|
返回方阵的逆。 |
|
返回线性方程的最小二乘解。 |
|
执行矩阵乘法。 |
|
计算矩阵或矩阵堆的范数。 |
|
将方阵提升到整数幂。 |
|
计算矩阵的秩。 |
|
转置矩阵或矩阵堆。 |
|
高效计算一系列数组之间的矩阵乘积。 |
|
计算矩阵或向量的范数。 |
|
计算两个一维数组的外积。 |
|
计算矩阵的(摩尔-彭罗斯)伪逆。 |
|
计算数组的 QR 分解 |
|
计算数组的行列式的符号和(自然)对数。 |
|
求解线性方程组。 |
|
计算奇异值分解。 |
|
计算矩阵的奇异值。 |
|
计算两个 N 维数组的张量点积。 |
|
计算张量的逆。 |
|
求解张量方程 a x = b 中的 x。 |
|
计算矩阵的迹。 |
|
计算向量或向量批的向量范数。 |
|
计算两个数组的(批量)向量共轭点积。 |
JAX Array#
JAX 的 Array (及其别名 jax.numpy.ndarray) 是 JAX 中的核心数组对象:您可以将其视为 JAX 中等同于 numpy.ndarray 的对象。与 numpy.ndarray 类似,大多数用户不需要手动实例化 Array 对象,而是通过 jax.numpy 函数(如 array()、arange()、linspace() 以及上面列出的其他函数)来创建它们。
复制和序列化#
JAX Array 对象旨在在适当的情况下与 Python 标准库工具无缝集成。
使用内置的 copy 模块时,当 copy.copy() 或 copy.deepcopy() 遇到 Array 时,它等同于调用 copy() 方法,该方法将会在与原始数组相同的设备上创建缓冲区副本。这在经过跟踪/JIT 编译的代码中也能正常工作,尽管编译器在此上下文中可能会省略复制操作。
当内置的 pickle 模块遇到 Array 时,它将以类似于被 pickle 的 numpy.ndarray 对象的方式,通过紧凑的位表示来序列化。反序列化时,结果将是一个 *在默认设备上的* 新的 Array 对象。这是因为通常情况下,序列化和反序列化可能在不同的运行时环境中进行,并且无法将一个运行时的设备 ID 映射到另一个运行时的设备 ID。如果 pickle 用于经过跟踪/JIT 编译的代码,将会导致 ConcretizationTypeError。
Python Array API 标准#
注意
在 JAX v0.4.32 之前,您必须 import jax.experimental.array_api 才能为 JAX 数组启用 array API。在 JAX v0.4.32 之后,导入此模块不再需要,并且会引发弃用警告。在 JAX v0.5.0 之后,此导入将引发错误。
从 JAX v0.4.32 开始,jax.Array 和 jax.numpy 与 Python Array API Standard 兼容。您可以通过 jax.Array.__array_namespace__() 访问 Array API 命名空间。
>>> def f(x):
... nx = x.__array_namespace__()
... return nx.sin(x) ** 2 + nx.cos(x) ** 2
>>> import jax.numpy as jnp
>>> x = jnp.arange(5)
>>> f(x).round()
Array([1., 1., 1., 1., 1.], dtype=float32)
JAX 在某些地方偏离了标准,主要是因为 JAX 数组是不可变的,不支持原地更新。其中一些不兼容性正通过 array-api-compat 模块得到解决。
有关更多信息,请参阅 Python Array API Standard 文档。