jax.numpy
模块#
使用 jax.lax
中的原语实现 NumPy API。
虽然 JAX 努力尽可能紧密地遵循 NumPy API,但有时 JAX 无法完全遵循 NumPy。
值得注意的是,由于 JAX 数组是不可变的,因此 NumPy 中原地修改数组的 API 无法在 JAX 中实现。然而,JAX 通常能够提供一个纯函数式的替代 API。例如,JAX 不提供原地数组更新(
x[i] = y
),而是提供一个替代的纯索引更新函数x.at[i].set(y)
(参见ndarray.at
)。此外,一些 NumPy 函数在可能的情况下通常返回数组的视图(例如
transpose()
和reshape()
)。这些函数的 JAX 版本将返回副本,不过当使用jax.jit()
编译操作序列时,XLA 通常会优化掉这些副本。NumPy 在将值提升为
float64
类型方面非常积极。JAX 有时在类型提升方面不那么积极(参见 类型提升语义)。一些 NumPy 例程的输出形状取决于数据(例如
unique()
和nonzero()
)。由于 XLA 编译器要求数组形状在编译时已知,因此此类操作与 JIT 不兼容。为此,JAX 为此类函数添加了一个可选的size
参数,该参数可以静态指定,以便与 JIT 一起使用。
几乎所有适用的 NumPy 函数都在 jax.numpy
命名空间中实现;它们列在下面。
索引更新功能的辅助属性。 |
|
|
是 |
|
计算逐元素的绝对值。 |
|
是 |
|
是 |
逐元素地添加两个数组。 |
|
|
测试沿给定轴的所有数组元素是否评估为 True。 |
|
检查两个数组是否在容差范围内逐元素近似相等。 |
|
是 |
|
是 |
|
返回复数值或数组的角度。 |
|
测试给定轴上是否有任何数组元素评估为 True。 |
|
返回一个新数组,其中值附加到原始数组的末尾。 |
|
沿轴对一维数组切片应用函数。 |
|
在指定轴上重复应用函数。 |
|
创建均匀分布值的数组。 |
|
计算输入逐元素的反三角余弦。 |
|
计算输入逐元素的反双曲余弦。 |
|
计算输入逐元素的反正弦。 |
|
计算输入逐元素的反双曲正弦。 |
|
计算输入逐元素的反正切。 |
|
计算 x1/x2 的反正切,并选择正确的象限。 |
|
计算输入逐元素的反双曲正切。 |
|
返回数组中最大值的索引。 |
|
返回数组中最小值的索引。 |
|
返回部分排序数组的索引。 |
|
返回排序数组的索引。 |
|
查找非零数组元素的索引 |
|
是 |
|
将对象转换为 JAX 数组。 |
|
检查两个数组是否逐元素相等。 |
|
检查两个数组是否逐元素相等。 |
|
返回数组的字符串表示。 |
|
将数组分割成子数组。 |
|
返回数组中数据的字符串表示。 |
|
将对象转换为 JAX 数组。 |
|
是 |
|
是 |
|
将数组转换为指定的 dtype。 |
|
是 |
|
是 |
|
是 |
|
将输入转换为至少一维的数组。 |
|
将输入转换为至少二维的数组。 |
|
将输入转换为至少三维的数组。 |
|
计算加权平均值。 |
|
返回大小为 M 的 Bartlett 窗。 |
|
计算整数数组中每个值的出现次数。 |
逐元素计算按位与操作。 |
|
|
计算 |
|
是 |
|
是 |
|
是 |
逐元素计算按位或操作。 |
|
|
是 |
逐元素计算按位异或操作。 |
|
|
返回大小为 M 的 Blackman 窗。 |
|
从块列表创建数组。 |
是 |
|
|
将数组广播到共同的形状。 |
|
将输入形状广播到共同的输出形状。 |
|
将数组广播到指定形状。 |
沿最后一个轴连接切片、标量和类数组对象。 |
|
|
如果数据类型之间的转换可以根据转换规则发生,则返回 True。 |
|
计算输入数组的逐元素立方根。 |
是 |
|
|
将输入向上舍入到最接近的整数。 |
|
所有字符字符串标量类型的抽象基类。 |
|
通过堆叠选择数组的切片来构造数组。 |
|
将数组值限制在指定范围。 |
|
按列堆叠数组。 |
是 |
|
|
类型为 complex128 的 JAX 标量构造函数。 |
|
类型为 complex64 的 JAX 标量构造函数。 |
|
由浮点数组成的所有复数标量类型的抽象基类。 |
将复数 dtype 转换为实数 dtype 时引发的警告。 |
|
|
使用布尔条件沿给定轴压缩数组。 |
|
沿现有轴连接数组。 |
|
沿现有轴连接数组。 |
|
是 |
|
返回输入逐元素的复共轭。 |
|
两个一维数组的卷积。 |
|
返回数组的副本。 |
|
将 |
|
计算 Pearson 相关系数。 |
|
两个一维数组的相关性。 |
|
计算输入每个元素的三角余弦。 |
|
计算输入逐元素的双曲余弦。 |
|
返回给定轴上非零元素的数量。 |
|
估计加权样本协方差。 |
|
计算两个数组的(批量)叉积。 |
是 |
|
|
沿轴元素的累积积。 |
|
沿轴元素的累积和。 |
|
沿数组轴的累积积。 |
|
沿数组轴的累积和。 |
|
将角度从度转换为弧度。 |
|
是 |
|
从数组中删除一个或多个条目。 |
|
返回指定对角线或构造对角数组。 |
|
返回用于访问多维数组主对角线的索引。 |
|
返回用于访问给定数组主对角线的索引。 |
|
返回一个二维数组,其中展平的输入数组沿对角线排列。 |
|
返回数组的指定对角线。 |
|
计算给定轴上数组元素之间的 n 阶差分。 |
|
将数组转换为 bin 索引。 |
|
是 |
|
逐元素计算 x1 除以 x2 的整数商和余数 |
|
计算两个数组的点积。 |
是 |
|
|
将数组深度拆分为子数组。 |
|
深度堆叠数组。 |
|
创建数据类型对象。 |
|
计算展平数组元素的差异。 |
|
爱因斯坦求和 |
|
评估最优收缩路径而不评估爱因斯坦求和。 |
|
创建一个空数组。 |
|
创建一个与现有数组形状和 dtype 相同的空数组。 |
|
返回 |
|
计算输入逐元素的指数。 |
|
计算输入逐元素的以 2 为底的指数。 |
|
在数组中插入长度为 1 的维度 |
|
计算输入每个元素的 |
|
返回满足条件的数组元素。 |
|
创建方阵或矩形单位矩阵 |
|
计算实值输入逐元素的绝对值。 |
|
返回对角线被覆盖的数组副本。 |
|
浮点类型的机器限制。 |
|
将输入向零舍入到最接近的整数。 |
|
返回展平数组中非零元素的索引 |
|
所有没有预定义长度的标量类型的抽象基类。 |
|
沿给定轴反转数组元素的顺序。 |
|
沿轴 1 反转数组元素的顺序。 |
|
沿轴 0 反转数组元素的顺序。 |
是 |
|
|
计算逐元素的以 |
|
类型为 float16 的 JAX 标量构造函数。 |
|
类型为 float32 的 JAX 标量构造函数。 |
|
类型为 float64 的 JAX 标量构造函数。 |
|
所有浮点标量类型的抽象基类。 |
|
将输入向下舍入到最接近的整数。 |
|
逐元素计算 x1 除以 x2 的向下取整除法 |
|
返回输入数组的逐元素最大值。 |
|
返回输入数组的逐元素最小值。 |
|
计算逐元素浮点模运算。 |
|
将浮点值拆分为尾数和二次幂。 |
|
将缓冲区转换为一维 JAX 数组。 |
|
jnp.fromfile 的未实现 JAX 封装。 |
|
从应用于索引的函数创建数组。 |
|
jnp.fromiter 的未实现 JAX 封装。 |
|
从任意 JAX 兼容的标量函数创建 JAX ufunc。 |
|
将文本字符串转换为一维 JAX 数组。 |
|
通过 DLPack 构造 JAX 数组。 |
|
创建填充指定值的数组。 |
|
创建一个与现有数组形状和 dtype 相同且填充指定值的数组。 |
|
计算两个数组的最大公约数。 |
|
NumPy 标量类型的基类。 |
|
生成几何间隔值。 |
是 |
|
|
计算采样函数的数值梯度。 |
|
返回 |
|
返回 |
|
返回大小为 M 的 Hamming 窗。 |
|
返回大小为 M 的 Hanning 窗。 |
|
计算 Heaviside 阶跃函数。 |
|
计算一维直方图。 |
|
计算直方图的 bin 边缘。 |
|
计算二维直方图。 |
|
计算 N 维直方图。 |
|
将数组水平拆分为子数组。 |
|
水平堆叠数组。 |
|
返回给定直角三角形边长的逐元素斜边。 |
|
计算第一类零阶修正贝塞尔函数。 |
|
创建方单位矩阵 |
|
|
|
返回复数参数逐元素的虚部。 |
一种更方便地为数组构建索引元组的方法。 |
|
|
生成网格索引数组。 |
|
所有数值标量类型的抽象基类,其范围内值的表示可能不精确,例如浮点数。 |
|
计算两个数组的内积。 |
|
在指定索引处将条目插入数组。 |
是 |
|
|
类型为 int16 的 JAX 标量构造函数。 |
|
类型为 int32 的 JAX 标量构造函数。 |
|
类型为 int64 的 JAX 标量构造函数。 |
|
类型为 int8 的 JAX 标量构造函数。 |
|
所有整数标量类型的抽象基类。 |
|
一维线性插值。 |
|
计算两个一维数组的集合交集。 |
|
计算输入的按位反转。 |
|
检查两个数组的元素是否在容差范围内近似相等。 |
|
返回布尔数组,指示输入何处为复数。 |
|
检查输入是否为复数或包含复数元素的数组。 |
|
返回一个布尔值,指示所提供的 dtype 是否为指定类型。 |
|
返回一个布尔数组,指示输入中的每个元素是否为有限值。 |
|
确定 |
|
返回一个布尔数组,指示输入中的每个元素是否为无穷大。 |
|
返回一个布尔数组,指示输入中的每个元素是否为 |
|
返回一个布尔数组,指示输入中的每个元素是否为负无穷大。 |
|
返回一个布尔数组,指示输入中的每个元素是否为正无穷大。 |
|
返回布尔数组,指示输入何处为实数。 |
|
检查输入是否不是复数或不包含复数元素的数组。 |
|
如果输入是标量,则返回 True。 |
|
如果在类型层次结构中 arg1 等于或低于 arg2,则返回 True。 |
|
检查对象是否可以迭代。 |
|
从 N 个一维序列返回多维网格(开放网格)。 |
|
返回大小为 M 的 Kaiser 窗。 |
|
计算两个输入数组的克罗内克积。 |
|
计算两个数组的最小公倍数。 |
|
计算 x1 * 2 ** x2 |
|
逐元素将 |
|
返回 |
|
返回 |
|
按字典顺序对键序列进行排序。 |
|
返回一个区间内均匀分布的数字。 |
|
从 npy 文件加载 JAX 数组。 |
|
计算输入逐元素的自然对数。 |
|
逐元素计算 x 的以 10 为底的对数 |
|
逐元素计算一加输入的对数,即 |
|
逐元素计算 |
计算 |
|
避免溢出的输入指数和的以 2 为底的对数。 |
|
逐元素计算逻辑与操作。 |
|
|
逐元素计算 NOT bool(x)。 |
逐元素计算逻辑或操作。 |
|
逐元素计算逻辑异或操作。 |
|
|
生成对数间隔值。 |
|
返回 (n, n) 数组掩码的索引。 |
|
执行矩阵乘法。 |
|
转置数组的最后两个维度。 |
|
批量矩阵-向量乘积。 |
|
返回给定轴上数组元素的最大值。 |
返回输入数组的逐元素最大值。 |
|
|
返回给定轴上数组元素的平均值。 |
|
返回沿给定轴的数组元素的中间值。 |
|
从 N 个一维向量构造 N 维网格数组。 |
返回密集多维“网格”。 |
|
|
返回沿给定轴的数组元素的最小值。 |
返回输入数组的逐元素最小值。 |
|
|
|
|
返回输入数组的逐元素小数部分和整数部分。 |
|
将数组轴移动到新位置 |
逐元素相乘两个数组。 |
|
|
替换数组中的 NaN 和无限值条目。 |
|
返回数组中最大值的索引,忽略 NaN 值。 |
|
返回数组中最小值的索引,忽略 NaN 值。 |
|
沿轴的元素累积乘积,忽略 NaN 值。 |
|
沿轴的元素累积和,忽略 NaN 值。 |
|
返回沿给定轴的数组元素的最大值,忽略 NaN 值。 |
|
返回沿给定轴的数组元素的平均值,忽略 NaN 值。 |
|
返回沿给定轴的数组元素的中间值,忽略 NaN 值。 |
|
返回沿给定轴的数组元素的最小值,忽略 NaN 值。 |
|
计算沿指定轴的数据百分位数,忽略 NaN 值。 |
|
返回沿给定轴的数组元素的乘积,忽略 NaN 值。 |
|
计算沿指定轴的数据分位数,忽略 NaN 值。 |
|
计算沿给定轴的标准差,忽略 NaN 值。 |
|
返回沿给定轴的数组元素的总和,忽略 NaN 值。 |
|
计算沿给定轴的数组元素的方差,忽略 NaN 值。 |
|
|
|
返回数组的维数。 |
返回输入的逐元素负值。 |
|
|
返回 |
|
返回数组中非零元素的索引。 |
|
返回 |
|
所有数字标量类型的抽象基类。 |
任何 Python 对象。 |
|
返回开放多维“网格”。 |
|
|
创建全为一的数组。 |
|
创建与数组具有相同形状和数据类型,且全为一的数组。 |
|
计算两个数组的外积。 |
|
将位数组打包成 uint8 数组。 |
|
向数组添加填充。 |
|
返回数组的部分排序副本。 |
|
计算沿指定轴的数据百分位数。 |
|
置换数组的轴/维度。 |
|
在域上分段计算函数。 |
|
根据掩码更新数组元素。 |
|
返回给定根序列的多项式系数。 |
|
返回两个多项式的和。 |
|
返回多项式指定阶导数的系数。 |
|
返回多项式除法的商和余数。 |
|
对数据进行最小二乘多项式拟合。 |
|
返回多项式指定阶积分的系数。 |
|
返回两个多项式的乘积。 |
|
返回两个多项式的差。 |
|
计算多项式在特定值处的值。 |
|
返回输入的逐元素正值。 |
|
|
|
计算逐元素以 |
|
|
|
返回数组元素沿给定轴的乘积。 |
|
返回二进制操作应将其参数转换为的类型。 |
|
返回沿给定轴的峰峰值范围。 |
|
将元素放入数组的给定索引处。 |
|
通过匹配一维索引和数据切片,将值放入目标数组。 |
|
计算沿指定轴的数据分位数。 |
沿第一个轴连接切片、标量和类数组对象。 |
|
|
将角度从弧度转换为度。 |
|
|
|
将数组展平为一维形状。 |
|
将多维索引转换为扁平索引。 |
|
返回复数参数的逐元素实部。 |
|
计算输入的逐元素倒数。 |
|
返回逐元素除法的余数。 |
|
从重复元素构造数组。 |
|
返回数组的重塑副本。 |
|
返回具有指定形状的新数组。 |
|
返回将 JAX 提升规则应用于输入的结果。 |
|
将 |
|
将 x 的元素四舍五入到最近的整数。 |
|
沿指定轴滚动数组的元素。 |
|
将指定轴滚动到给定位置。 |
|
返回给定系数 |
|
在由轴指定的平面中,将数组逆时针旋转 90 度。 |
|
将输入四舍五入到给定的小数位数。 |
一种更方便地为数组构建索引元组的方法。 |
|
|
将数组保存到 NumPy |
|
将多个数组保存到未压缩的 |
|
在排序数组中执行二分查找。 |
|
根据一系列条件选择值。 |
|
|
|
计算两个一维数组的集合差。 |
|
计算两个数组中元素的集合异或。 |
|
返回数组的形状。 |
|
返回输入的逐元素符号指示。 |
|
返回数组元素的符号位。 |
所有有符号整数标量类型的抽象基类。 |
|
|
计算输入的每个元素的三角正弦。 |
|
计算归一化的 sinc 函数。 |
|
|
|
计算输入的逐元素双曲正弦。 |
|
返回沿给定轴的元素数量。 |
|
返回数组的排序副本。 |
|
返回复杂数组的排序副本。 |
|
返回 |
|
将数组分割成子数组。 |
|
计算输入数组的逐元素非负平方根。 |
|
计算输入数组的逐元素平方。 |
|
从数组中移除一个或多个长度为 1 的轴。 |
|
沿新轴连接数组。 |
|
计算沿给定轴的标准差。 |
逐元素相减两个数组。 |
|
|
在给定轴上对数组元素求和。 |
|
交换数组的两个轴。 |
|
从数组中取出元素。 |
|
从数组中取出元素。 |
|
计算输入的每个元素的三角正切。 |
|
计算输入的逐元素双曲正切。 |
|
计算两个 N 维数组的张量点积。 |
|
通过沿指定维度重复 |
|
计算输入沿给定轴的对角线和。 |
|
使用复合梯形法则沿给定轴积分。 |
|
返回 N 维数组的转置版本。 |
|
返回一个对角线及以下为 1,其他为 0 的数组。 |
|
返回数组的下三角。 |
|
返回大小为 |
|
返回给定数组的下三角索引。 |
|
修剪输入数组的前导和/或尾随零。 |
|
返回数组的上三角。 |
|
返回大小为 |
|
返回给定数组的上三角索引。 |
|
逐元素计算 x1 除以 x2 的结果。 |
|
将输入向零舍入到最接近的整数。 |
|
对数组逐元素操作的通用函数。 |
|
|
|
类型为 uint16 的 JAX 标量构造函数。 |
|
类型为 uint32 的 JAX 标量构造函数。 |
|
类型为 uint64 的 JAX 标量构造函数。 |
|
类型为 uint8 的 JAX 标量构造函数。 |
|
计算两个一维数组的集合并集。 |
|
返回数组中的唯一值。 |
|
从 x 返回唯一值,以及索引、逆索引和计数。 |
|
返回 x 中的唯一值及其计数。 |
|
从 x 返回唯一值,以及索引、逆索引和计数。 |
|
从 x 返回唯一值,以及索引、逆索引和计数。 |
|
解包 uint8 数组中的位。 |
|
将扁平索引转换为多维索引。 |
|
沿轴解堆数组。 |
所有无符号整数标量类型的抽象基类。 |
|
|
解卷周期信号。 |
|
生成范德蒙矩阵。 |
|
计算沿给定轴的方差。 |
|
执行两个一维向量的共轭乘法。 |
|
执行两个批处理向量的共轭乘法。 |
|
批处理共轭向量-矩阵乘积。 |
|
定义具有广播功能的向量化函数。 |
|
将数组垂直拆分为子数组。 |
|
垂直堆叠数组。 |
|
根据条件从两个数组中选择元素。 |
|
创建全为零的数组。 |
|
创建与数组具有相同形状和数据类型,且全为零的数组。 |
jax.numpy.fft#
|
计算沿给定轴的一维离散傅里叶变换。 |
|
计算沿给定轴的二维离散傅里叶变换。 |
|
返回离散傅里叶变换的采样频率。 |
|
计算沿给定轴的多维离散傅里叶变换。 |
|
将零频率 FFT 分量移到频谱中心。 |
|
计算频谱具有 Hermitian 对称性的数组的一维 FFT。 |
|
计算一维逆离散傅里叶变换。 |
|
计算二维逆离散傅里叶变换。 |
|
计算多维逆离散傅里叶变换。 |
|
|
|
计算频谱具有 Hermitian 对称性数组的一维逆 FFT。 |
|
计算实值一维逆离散傅里叶变换。 |
|
计算实值二维逆离散傅里叶变换。 |
|
计算实值多维逆离散傅里叶变换。 |
|
计算实值数组的一维离散傅里叶变换。 |
|
计算实值数组的二维离散傅里叶变换。 |
|
返回离散傅里叶变换的采样频率。 |
|
计算实值数组的多维离散傅里叶变换。 |
jax.numpy.linalg#
|
计算矩阵的 Cholesky 分解。 |
|
计算矩阵的条件数。 |
|
计算两个三维向量的叉积。 |
|
计算数组的行列式。 |
|
提取矩阵或矩阵堆栈的对角线。 |
|
计算方阵的特征值和特征向量。 |
|
计算 Hermitian 矩阵的特征值和特征向量。 |
|
计算一般矩阵的特征值。 |
|
计算 Hermitian 矩阵的特征值。 |
|
返回方阵的逆。 |
|
返回线性方程的最小二乘解。 |
|
执行矩阵乘法。 |
|
计算矩阵或矩阵堆栈的范数。 |
|
将方阵提升到整数幂。 |
|
计算矩阵的秩。 |
|
转置矩阵或矩阵堆栈。 |
|
高效计算数组序列之间的矩阵乘积。 |
|
计算矩阵或向量的范数。 |
|
计算两个一维数组的外积。 |
|
计算矩阵的(Moore-Penrose)伪逆。 |
|
计算数组的 QR 分解 |
|
计算数组行列式的符号和(自然)对数。 |
|
求解线性方程组。 |
|
计算奇异值分解。 |
|
计算矩阵的奇异值。 |
|
计算两个 N 维数组的张量点积。 |
|
计算数组的张量逆。 |
|
求解张量方程 a x = b 中的 x。 |
|
计算矩阵的迹。 |
|
计算向量或向量批次的向量范数。 |
|
计算两个数组的(批处理)向量共轭点积。 |
JAX Array#
JAX Array
对象(及其别名 jax.numpy.ndarray
)是 JAX 中的核心数组对象:您可以将其视为 JAX 中 numpy.ndarray
的等效项。与 numpy.ndarray
类似,大多数用户无需手动实例化 Array
对象,而是通过 jax.numpy
函数(如 array()
、arange()
、linspace()
以及上面列出的其他函数)来创建它们。
复制和序列化#
JAX Array
对象旨在在适当的情况下与 Python 标准库工具无缝协作。
使用内置的 copy
模块,当 copy.copy()
或 copy.deepcopy()
遇到 Array
时,它等同于调用 copy()
方法,该方法将在与原始数组相同的设备上创建缓冲区副本。这在追踪/JIT 编译的代码中将正常工作,尽管在此上下文中复制操作可能会被编译器省略。
当内置的 pickle
模块遇到 Array
时,它将通过紧凑的位表示进行序列化,方式类似于 pickled numpy.ndarray
对象。解压后,结果将是一个新 Array
对象,位于**默认设备上**。这是因为通常情况下,pickling 和 unpickling 可能会在不同的运行时环境中进行,并且没有通用的方法将一个运行时的设备 ID 映射到另一个运行时的设备 ID。如果在追踪/JIT 编译的代码中使用 pickle
,则将导致 ConcretizationTypeError
。
Python 数组 API 标准#
注意
在 JAX v0.4.32 之前,您必须 import jax.experimental.array_api
才能为 JAX 数组启用数组 API。JAX v0.4.32 之后,不再需要导入此模块,并且会引发弃用警告。JAX v0.5.0 之后,此导入将引发错误。
从 JAX v0.4.32 开始,jax.Array
和 jax.numpy
与 Python 数组 API 标准兼容。您可以通过 jax.Array.__array_namespace__()
访问数组 API 命名空间。
>>> def f(x):
... nx = x.__array_namespace__()
... return nx.sin(x) ** 2 + nx.cos(x) ** 2
>>> import jax.numpy as jnp
>>> x = jnp.arange(5)
>>> f(x).round()
Array([1., 1., 1., 1., 1.], dtype=float32)
JAX 在某些方面偏离了标准,主要原因是 JAX 数组是不可变的,不支持就地更新。其中一些不兼容性正在通过 array-api-compat 模块得到解决。
欲了解更多信息,请参阅 Python 数组 API 标准文档。