jax.numpy 模块#

使用 jax.lax 中的原语实现 NumPy API。

虽然 JAX 努力尽可能紧密地遵循 NumPy API,但有时 JAX 无法完全遵循 NumPy。

  • 值得注意的是,由于 JAX 数组是不可变的,因此 NumPy 中原地修改数组的 API 无法在 JAX 中实现。然而,JAX 通常能够提供一个纯函数式的替代 API。例如,JAX 不提供原地数组更新(x[i] = y),而是提供一个替代的纯索引更新函数 x.at[i].set(y)(参见 ndarray.at)。

  • 此外,一些 NumPy 函数在可能的情况下通常返回数组的视图(例如 transpose()reshape())。这些函数的 JAX 版本将返回副本,不过当使用 jax.jit() 编译操作序列时,XLA 通常会优化掉这些副本。

  • NumPy 在将值提升为 float64 类型方面非常积极。JAX 有时在类型提升方面不那么积极(参见 类型提升语义)。

  • 一些 NumPy 例程的输出形状取决于数据(例如 unique()nonzero())。由于 XLA 编译器要求数组形状在编译时已知,因此此类操作与 JIT 不兼容。为此,JAX 为此类函数添加了一个可选的 size 参数,该参数可以静态指定,以便与 JIT 一起使用。

几乎所有适用的 NumPy 函数都在 jax.numpy 命名空间中实现;它们列在下面。

ndarray.at

索引更新功能的辅助属性。

abs(x, /)

jax.numpy.absolute() 的别名。

absolute(x, /)

计算逐元素的绝对值。

acos(x, /)

jax.numpy.arccos() 的别名

acosh(x, /)

jax.numpy.arccosh() 的别名

加法

逐元素地添加两个数组。

all(a[, axis, out, keepdims, where])

测试沿给定轴的所有数组元素是否评估为 True。

allclose(a, b[, rtol, atol, equal_nan])

检查两个数组是否在容差范围内逐元素近似相等。

amax(a[, axis, out, keepdims, initial, where])

jax.numpy.max() 的别名。

amin(a[, axis, out, keepdims, initial, where])

jax.numpy.min() 的别名。

angle(z[, deg])

返回复数值或数组的角度。

any(a[, axis, out, keepdims, where])

测试给定轴上是否有任何数组元素评估为 True。

append(arr, values[, axis])

返回一个新数组,其中值附加到原始数组的末尾。

apply_along_axis(func1d, axis, arr, *args, ...)

沿轴对一维数组切片应用函数。

apply_over_axes(func, a, axes)

在指定轴上重复应用函数。

arange(start[, stop, step, dtype, device, ...])

创建均匀分布值的数组。

arccos(x, /)

计算输入逐元素的反三角余弦。

arccosh(x, /)

计算输入逐元素的反双曲余弦。

arcsin(x, /)

计算输入逐元素的反正弦。

arcsinh(x, /)

计算输入逐元素的反双曲正弦。

arctan(x, /)

计算输入逐元素的反正切。

arctan2(x1, x2, /)

计算 x1/x2 的反正切,并选择正确的象限。

arctanh(x, /)

计算输入逐元素的反双曲正切。

argmax(a[, axis, out, keepdims])

返回数组中最大值的索引。

argmin(a[, axis, out, keepdims])

返回数组中最小值的索引。

argpartition(a, kth[, axis])

返回部分排序数组的索引。

argsort(a[, axis, kind, order, stable, ...])

返回排序数组的索引。

argwhere(a, *[, size, fill_value])

查找非零数组元素的索引

around(a[, decimals, out])

jax.numpy.round() 的别名

array(object[, dtype, copy, order, ndmin, ...])

将对象转换为 JAX 数组。

array_equal(a1, a2[, equal_nan])

检查两个数组是否逐元素相等。

array_equiv(a1, a2)

检查两个数组是否逐元素相等。

array_repr(arr[, max_line_width, precision, ...])

返回数组的字符串表示。

array_split(ary, indices_or_sections[, axis])

将数组分割成子数组。

array_str(a[, max_line_width, precision, ...])

返回数组中数据的字符串表示。

asarray(a[, dtype, order, copy, device])

将对象转换为 JAX 数组。

asin(x, /)

jax.numpy.arcsin() 的别名

asinh(x, /)

jax.numpy.arcsinh() 的别名

astype(x, dtype, /, *[, copy, device])

将数组转换为指定的 dtype。

atan(x, /)

jax.numpy.arctan() 的别名

atanh(x, /)

jax.numpy.arctanh() 的别名

atan2(x1, x2, /)

jax.numpy.arctan2() 的别名

atleast_1d(*arys)

将输入转换为至少一维的数组。

atleast_2d(*arys)

将输入转换为至少二维的数组。

atleast_3d(*arys)

将输入转换为至少三维的数组。

average(a[, axis, weights, returned, keepdims])

计算加权平均值。

bartlett(M)

返回大小为 M 的 Bartlett 窗。

bincount(x[, weights, minlength, length])

计算整数数组中每个值的出现次数。

按位与

逐元素计算按位与操作。

bitwise_count(x, /)

计算 x 中每个元素的绝对值的二进制表示中 1 的位数。

bitwise_invert(x, /)

jax.numpy.invert() 的别名。

bitwise_left_shift(x, y, /)

jax.numpy.left_shift() 的别名。

bitwise_not(x, /)

jax.numpy.invert() 的别名。

按位或

逐元素计算按位或操作。

bitwise_right_shift(x1, x2, /)

jax.numpy.right_shift() 的别名。

按位异或

逐元素计算按位异或操作。

blackman(M)

返回大小为 M 的 Blackman 窗。

block(arrays)

从块列表创建数组。

bool_

bool 的别名

broadcast_arrays(*args)

将数组广播到共同的形状。

broadcast_shapes(*shapes)

将输入形状广播到共同的输出形状。

broadcast_to(array, shape, *[, out_sharding])

将数组广播到指定形状。

c_

沿最后一个轴连接切片、标量和类数组对象。

can_cast(from_, to[, casting])

如果数据类型之间的转换可以根据转换规则发生,则返回 True。

cbrt(x, /)

计算输入数组的逐元素立方根。

cdouble

complex128 的别名

ceil(x, /)

将输入向上舍入到最接近的整数。

字符()

所有字符字符串标量类型的抽象基类。

choose(a, choices[, out, mode])

通过堆叠选择数组的切片来构造数组。

clip([arr, min, max, a, a_min, a_max])

将数组值限制在指定范围。

column_stack(tup)

按列堆叠数组。

complex_

complex128 的别名

complex128(x)

类型为 complex128 的 JAX 标量构造函数。

complex64(x)

类型为 complex64 的 JAX 标量构造函数。

复浮点()

由浮点数组成的所有复数标量类型的抽象基类。

复数警告

将复数 dtype 转换为实数 dtype 时引发的警告。

compress(condition, a[, axis, size, ...])

使用布尔条件沿给定轴压缩数组。

concat(arrays, /, *[, axis])

沿现有轴连接数组。

concatenate(arrays[, axis, dtype])

沿现有轴连接数组。

conj(x, /)

jax.numpy.conjugate() 的别名

conjugate(x, /)

返回输入逐元素的复共轭。

convolve(a, v[, mode, precision, ...])

两个一维数组的卷积。

copy(a[, order])

返回数组的副本。

copysign(x1, x2, /)

x2 中每个元素的符号复制到 x1 中的相应元素。

corrcoef(x[, y, rowvar])

计算 Pearson 相关系数。

correlate(a, v[, mode, precision, ...])

两个一维数组的相关性。

cos(x, /)

计算输入每个元素的三角余弦。

cosh(x, /)

计算输入逐元素的双曲余弦。

count_nonzero(a[, axis, keepdims])

返回给定轴上非零元素的数量。

cov(m[, y, rowvar, bias, ddof, fweights, ...])

估计加权样本协方差。

cross(a, b[, axisa, axisb, axisc, axis])

计算两个数组的(批量)叉积。

csingle

complex64 的别名

cumprod(a[, axis, dtype, out])

沿轴元素的累积积。

cumsum(a[, axis, dtype, out])

沿轴元素的累积和。

cumulative_prod(x, /, *[, axis, dtype, ...])

沿数组轴的累积积。

cumulative_sum(x, /, *[, axis, dtype, ...])

沿数组轴的累积和。

deg2rad(x, /)

将角度从度转换为弧度。

degrees(x, /)

jax.numpy.rad2deg() 的别名

delete(arr, obj[, axis, assume_unique_indices])

从数组中删除一个或多个条目。

diag(v[, k])

返回指定对角线或构造对角数组。

diag_indices(n[, ndim])

返回用于访问多维数组主对角线的索引。

diag_indices_from(arr)

返回用于访问给定数组主对角线的索引。

diagflat(v[, k])

返回一个二维数组,其中展平的输入数组沿对角线排列。

diagonal(a[, offset, axis1, axis2])

返回数组的指定对角线。

diff(a[, n, axis, prepend, append])

计算给定轴上数组元素之间的 n 阶差分。

digitize(x, bins[, right, method])

将数组转换为 bin 索引。

divide(x1, x2, /)

jax.numpy.true_divide() 的别名。

divmod(x1, x2, /)

逐元素计算 x1 除以 x2 的整数商和余数

dot(a, b, *[, precision, ...])

计算两个数组的点积。

double

float64 的别名

dsplit(ary, indices_or_sections)

将数组深度拆分为子数组。

dstack(tup[, dtype])

深度堆叠数组。

dtype(dtype[, align, copy])

创建数据类型对象。

ediff1d(ary[, to_end, to_begin])

计算展平数组元素的差异。

einsum(subscripts, /, *operands[, out, ...])

爱因斯坦求和

einsum_path(subscripts, /, *operands[, optimize])

评估最优收缩路径而不评估爱因斯坦求和。

empty(shape[, dtype, device, out_sharding])

创建一个空数组。

empty_like(prototype[, dtype, shape, device])

创建一个与现有数组形状和 dtype 相同的空数组。

equal(x, y, /)

返回 x == y 的逐元素真值。

exp(x, /)

计算输入逐元素的指数。

exp2(x, /)

计算输入逐元素的以 2 为底的指数。

expand_dims(a, axis)

在数组中插入长度为 1 的维度

expm1(x, /)

计算输入每个元素的 exp(x)-1

extract(condition, arr, *[, size, fill_value])

返回满足条件的数组元素。

eye(N[, M, k, dtype, device])

创建方阵或矩形单位矩阵

fabs(x, /)

计算实值输入逐元素的绝对值。

fill_diagonal(a, val[, wrap, inplace])

返回对角线被覆盖的数组副本。

finfo(dtype)

浮点类型的机器限制。

fix(x[, out])

将输入向零舍入到最接近的整数。

flatnonzero(a, *[, size, fill_value])

返回展平数组中非零元素的索引

柔性()

所有没有预定义长度的标量类型的抽象基类。

flip(m[, axis])

沿给定轴反转数组元素的顺序。

fliplr(m)

沿轴 1 反转数组元素的顺序。

flipud(m)

沿轴 0 反转数组元素的顺序。

float_

float64 的别名

float_power(x, y, /)

计算逐元素的以 x 为底的 y 的指数。

float16(x)

类型为 float16 的 JAX 标量构造函数。

float32(x)

类型为 float32 的 JAX 标量构造函数。

float64(x)

类型为 float64 的 JAX 标量构造函数。

浮点()

所有浮点标量类型的抽象基类。

floor(x, /)

将输入向下舍入到最接近的整数。

floor_divide(x1, x2, /)

逐元素计算 x1 除以 x2 的向下取整除法

fmax(x1, x2)

返回输入数组的逐元素最大值。

fmin(x1, x2)

返回输入数组的逐元素最小值。

fmod(x1, x2, /)

计算逐元素浮点模运算。

frexp(x, /)

将浮点值拆分为尾数和二次幂。

frombuffer(buffer[, dtype, count, offset])

将缓冲区转换为一维 JAX 数组。

fromfile(*args, **kwargs)

jnp.fromfile 的未实现 JAX 封装。

fromfunction(function, shape, *[, dtype])

从应用于索引的函数创建数组。

fromiter(*args, **kwargs)

jnp.fromiter 的未实现 JAX 封装。

frompyfunc(func, /, nin, nout, *[, identity])

从任意 JAX 兼容的标量函数创建 JAX ufunc。

fromstring(string[, dtype, count])

将文本字符串转换为一维 JAX 数组。

from_dlpack(x, /, *[, device, copy])

通过 DLPack 构造 JAX 数组。

full(shape, fill_value[, dtype, device])

创建填充指定值的数组。

full_like(a, fill_value[, dtype, shape, device])

创建一个与现有数组形状和 dtype 相同且填充指定值的数组。

gcd(x1, x2)

计算两个数组的最大公约数。

通用()

NumPy 标量类型的基类。

geomspace(start, stop[, num, endpoint, ...])

生成几何间隔值。

get_printoptions()

numpy.get_printoptions() 的别名。

gradient(f, *varargs[, axis, edge_order])

计算采样函数的数值梯度。

greater(x, y, /)

返回 x > y 的逐元素真值。

greater_equal(x, y, /)

返回 x >= y 的逐元素真值。

hamming(M)

返回大小为 M 的 Hamming 窗。

hanning(M)

返回大小为 M 的 Hanning 窗。

heaviside(x1, x2, /)

计算 Heaviside 阶跃函数。

histogram(a[, bins, range, weights, density])

计算一维直方图。

histogram_bin_edges(a[, bins, range, weights])

计算直方图的 bin 边缘。

histogram2d(x, y[, bins, range, weights, ...])

计算二维直方图。

histogramdd(sample[, bins, range, weights, ...])

计算 N 维直方图。

hsplit(ary, indices_or_sections)

将数组水平拆分为子数组。

hstack(tup[, dtype])

水平堆叠数组。

hypot(x1, x2, /)

返回给定直角三角形边长的逐元素斜边。

i0(x)

计算第一类零阶修正贝塞尔函数。

identity(n[, dtype])

创建方单位矩阵

iinfo(int_type)

imag(val, /)

返回复数参数逐元素的虚部。

索引表达式

一种更方便地为数组构建索引元组的方法。

indices(dimensions[, dtype, sparse])

生成网格索引数组。

不精确()

所有数值标量类型的抽象基类,其范围内值的表示可能不精确,例如浮点数。

inner(a, b, *[, precision, ...])

计算两个数组的内积。

insert(arr, obj, values[, axis])

在指定索引处将条目插入数组。

int_

int64 的别名

int16(x)

类型为 int16 的 JAX 标量构造函数。

int32(x)

类型为 int32 的 JAX 标量构造函数。

int64(x)

类型为 int64 的 JAX 标量构造函数。

int8(x)

类型为 int8 的 JAX 标量构造函数。

整数()

所有整数标量类型的抽象基类。

interp(x, xp, fp[, left, right, period])

一维线性插值。

intersect1d(ar1, ar2[, assume_unique, ...])

计算两个一维数组的集合交集。

invert(x, /)

计算输入的按位反转。

isclose(a, b[, rtol, atol, equal_nan])

检查两个数组的元素是否在容差范围内近似相等。

iscomplex(x)

返回布尔数组,指示输入何处为复数。

iscomplexobj(x)

检查输入是否为复数或包含复数元素的数组。

isdtype(dtype, kind)

返回一个布尔值,指示所提供的 dtype 是否为指定类型。

isfinite(x, /)

返回一个布尔数组,指示输入中的每个元素是否为有限值。

isin(element, test_elements[, ...])

确定 element 中的元素是否出现在 test_elements 中。

isinf(x, /)

返回一个布尔数组,指示输入中的每个元素是否为无穷大。

isnan(x, /)

返回一个布尔数组,指示输入中的每个元素是否为 NaN

isneginf(x, /[, out])

返回一个布尔数组,指示输入中的每个元素是否为负无穷大。

isposinf(x, /[, out])

返回一个布尔数组,指示输入中的每个元素是否为正无穷大。

isreal(x)

返回布尔数组,指示输入何处为实数。

isrealobj(x)

检查输入是否不是复数或不包含复数元素的数组。

isscalar(element)

如果输入是标量,则返回 True。

issubdtype(arg1, arg2)

如果在类型层次结构中 arg1 等于或低于 arg2,则返回 True。

iterable(y)

检查对象是否可以迭代。

ix_(*args)

从 N 个一维序列返回多维网格(开放网格)。

kaiser(M, beta)

返回大小为 M 的 Kaiser 窗。

kron(a, b)

计算两个输入数组的克罗内克积。

lcm(x1, x2)

计算两个数组的最小公倍数。

ldexp(x1, x2, /)

计算 x1 * 2 ** x2

left_shift(x, y, /)

逐元素将 x 的位左移 y 中指定的量。

less(x, y, /)

返回 x < y 的逐元素真值。

less_equal(x, y, /)

返回 x <= y 的逐元素真值。

lexsort(keys[, axis])

按字典顺序对键序列进行排序。

linspace(start, stop[, num, endpoint, ...])

返回一个区间内均匀分布的数字。

load(file, *args, **kwargs)

从 npy 文件加载 JAX 数组。

log(x, /)

计算输入逐元素的自然对数。

log10(x, /)

逐元素计算 x 的以 10 为底的对数

log1p(x, /)

逐元素计算一加输入的对数,即 log(x+1)

log2(x, /)

逐元素计算 x 的以 2 为底的对数。

logaddexp

计算 log(exp(x1) + exp(x2)) 以避免溢出。

logaddexp2

避免溢出的输入指数和的以 2 为底的对数。

逻辑与

逐元素计算逻辑与操作。

logical_not(x, /)

逐元素计算 NOT bool(x)。

逻辑或

逐元素计算逻辑或操作。

逻辑异或

逐元素计算逻辑异或操作。

logspace(start, stop[, num, endpoint, base, ...])

生成对数间隔值。

mask_indices(n, mask_func[, k, size])

返回 (n, n) 数组掩码的索引。

matmul(a, b, *[, precision, ...])

执行矩阵乘法。

matrix_transpose(x, /)

转置数组的最后两个维度。

matvec(x1, x2, /)

批量矩阵-向量乘积。

max(a[, axis, out, keepdims, initial, where])

返回给定轴上数组元素的最大值。

最大值

返回输入数组的逐元素最大值。

mean(a[, axis, dtype, out, keepdims, where])

返回给定轴上数组元素的平均值。

median(a[, axis, out, overwrite_input, keepdims])

返回沿给定轴的数组元素的中间值。

meshgrid(*xi[, copy, sparse, indexing])

从 N 个一维向量构造 N 维网格数组。

mgrid

返回密集多维“网格”。

min(a[, axis, out, keepdims, initial, where])

返回沿给定轴的数组元素的最小值。

minimum

返回输入数组的逐元素最小值。

mod(x1, x2, /)

jax.numpy.remainder() 的别名

modf(x, /[, out])

返回输入数组的逐元素小数部分和整数部分。

moveaxis(a, source, destination)

将数组轴移动到新位置

multiply

逐元素相乘两个数组。

nan_to_num(x[, copy, nan, posinf, neginf])

替换数组中的 NaN 和无限值条目。

nanargmax(a[, axis, out, keepdims])

返回数组中最大值的索引,忽略 NaN 值。

nanargmin(a[, axis, out, keepdims])

返回数组中最小值的索引,忽略 NaN 值。

nancumprod(a[, axis, dtype, out])

沿轴的元素累积乘积,忽略 NaN 值。

nancumsum(a[, axis, dtype, out])

沿轴的元素累积和,忽略 NaN 值。

nanmax(a[, axis, out, keepdims, initial, where])

返回沿给定轴的数组元素的最大值,忽略 NaN 值。

nanmean(a[, axis, dtype, out, keepdims, where])

返回沿给定轴的数组元素的平均值,忽略 NaN 值。

nanmedian(a[, axis, out, overwrite_input, ...])

返回沿给定轴的数组元素的中间值,忽略 NaN 值。

nanmin(a[, axis, out, keepdims, initial, where])

返回沿给定轴的数组元素的最小值,忽略 NaN 值。

nanpercentile(a, q[, axis, out, ...])

计算沿指定轴的数据百分位数,忽略 NaN 值。

nanprod(a[, axis, dtype, out, keepdims, ...])

返回沿给定轴的数组元素的乘积,忽略 NaN 值。

nanquantile(a, q[, axis, out, ...])

计算沿指定轴的数据分位数,忽略 NaN 值。

nanstd(a[, axis, dtype, out, ddof, ...])

计算沿给定轴的标准差,忽略 NaN 值。

nansum(a[, axis, dtype, out, keepdims, ...])

返回沿给定轴的数组元素的总和,忽略 NaN 值。

nanvar(a[, axis, dtype, out, ddof, ...])

计算沿给定轴的数组元素的方差,忽略 NaN 值。

ndarray

Array 的别名

ndim(a)

返回数组的维数。

negative

返回输入的逐元素负值。

nextafter(x, y, /)

返回 x 朝向 y 方向的逐元素下一个浮点值。

nonzero(a, *[, size, fill_value])

返回数组中非零元素的索引。

not_equal(x, y, /)

返回 x != y 的逐元素真值。

number()

所有数字标量类型的抽象基类。

object_

任何 Python 对象。

ogrid

返回开放多维“网格”。

ones(shape[, dtype, device, out_sharding])

创建全为一的数组。

ones_like(a[, dtype, shape, device])

创建与数组具有相同形状和数据类型,且全为一的数组。

outer(a, b[, out])

计算两个数组的外积。

packbits(a[, axis, bitorder])

将位数组打包成 uint8 数组。

pad(array, pad_width[, mode])

向数组添加填充。

partition(a, kth[, axis])

返回数组的部分排序副本。

percentile(a, q[, axis, out, ...])

计算沿指定轴的数据百分位数。

permute_dims(a, /, axes)

置换数组的轴/维度。

piecewise(x, condlist, funclist, *args, **kw)

在域上分段计算函数。

place(arr, mask, vals, *[, inplace])

根据掩码更新数组元素。

poly(seq_of_zeros)

返回给定根序列的多项式系数。

polyadd(a1, a2)

返回两个多项式的和。

polyder(p[, m])

返回多项式指定阶导数的系数。

polydiv(u, v, *[, trim_leading_zeros])

返回多项式除法的商和余数。

polyfit(x, y, deg[, rcond, full, w, cov])

对数据进行最小二乘多项式拟合。

polyint(p[, m, k])

返回多项式指定阶积分的系数。

polymul(a1, a2, *[, trim_leading_zeros])

返回两个多项式的乘积。

polysub(a1, a2)

返回两个多项式的差。

polyval(p, x, *[, unroll])

计算多项式在特定值处的值。

positive(x, /)

返回输入的逐元素正值。

pow(x1, x2, /)

jax.numpy.power() 的别名

power(x1, x2, /)

计算逐元素以 x1 为底,x2 为指数的幂。

printoptions(*args, **kwargs)

numpy.printoptions() 的别名。

prod(a[, axis, dtype, out, keepdims, ...])

返回数组元素沿给定轴的乘积。

promote_types(a, b)

返回二进制操作应将其参数转换为的类型。

ptp(a[, axis, out, keepdims])

返回沿给定轴的峰峰值范围。

put(a, ind, v[, mode, inplace])

将元素放入数组的给定索引处。

put_along_axis(arr, indices, values, axis[, ...])

通过匹配一维索引和数据切片,将值放入目标数组。

quantile(a, q[, axis, out, overwrite_input, ...])

计算沿指定轴的数据分位数。

r_

沿第一个轴连接切片、标量和类数组对象。

rad2deg(x, /)

将角度从弧度转换为度。

radians(x, /)

jax.numpy.deg2rad() 的别名

ravel(a[, order, out_sharding])

将数组展平为一维形状。

ravel_multi_index(multi_index, dims[, mode, ...])

将多维索引转换为扁平索引。

real(val, /)

返回复数参数的逐元素实部。

reciprocal(x, /)

计算输入的逐元素倒数。

remainder(x1, x2, /)

返回逐元素除法的余数。

repeat(a, repeats[, axis, ...])

从重复元素构造数组。

reshape(a, shape[, order, copy, out_sharding])

返回数组的重塑副本。

resize(a, new_shape)

返回具有指定形状的新数组。

result_type(*args)

返回将 JAX 提升规则应用于输入的结果。

right_shift(x1, x2, /)

x1 的位向右移 x2 指定的量。

rint(x, /)

将 x 的元素四舍五入到最近的整数。

roll(a, shift[, axis])

沿指定轴滚动数组的元素。

rollaxis(a, axis[, start])

将指定轴滚动到给定位置。

roots(p, *[, strip_zeros])

返回给定系数 p 的多项式的根。

rot90(m[, k, axes])

在由轴指定的平面中,将数组逆时针旋转 90 度。

round(a[, decimals, out])

将输入四舍五入到给定的小数位数。

s_

一种更方便地为数组构建索引元组的方法。

save(file, arr[, allow_pickle, fix_imports])

将数组保存到 NumPy .npy 格式的二进制文件。

savez(file, *args[, allow_pickle])

将多个数组保存到未压缩的 .npz 格式的单个文件中。

searchsorted(a, v[, side, sorter, method])

在排序数组中执行二分查找。

select(condlist, choicelist[, default])

根据一系列条件选择值。

set_printoptions(*args, **kwargs)

numpy.set_printoptions() 的别名。

setdiff1d(ar1, ar2[, assume_unique, size, ...])

计算两个一维数组的集合差。

setxor1d(ar1, ar2[, assume_unique, size, ...])

计算两个数组中元素的集合异或。

shape(a)

返回数组的形状。

sign(x, /)

返回输入的逐元素符号指示。

signbit(x, /)

返回数组元素的符号位。

signedinteger()

所有有符号整数标量类型的抽象基类。

sin(x, /)

计算输入的每个元素的三角正弦。

sinc(x, /)

计算归一化的 sinc 函数。

single

float32 的别名

sinh(x, /)

计算输入的逐元素双曲正弦。

size(a[, axis])

返回沿给定轴的元素数量。

sort(a[, axis, kind, order, stable, descending])

返回数组的排序副本。

sort_complex(a)

返回复杂数组的排序副本。

spacing(x, /)

返回 x 与下一个相邻数字之间的间距。

split(ary, indices_or_sections[, axis])

将数组分割成子数组。

sqrt(x, /)

计算输入数组的逐元素非负平方根。

square(x, /)

计算输入数组的逐元素平方。

squeeze(a[, axis])

从数组中移除一个或多个长度为 1 的轴。

stack(arrays[, axis, out, dtype])

沿新轴连接数组。

std(a[, axis, dtype, out, ddof, keepdims, ...])

计算沿给定轴的标准差。

subtract

逐元素相减两个数组。

sum(a[, axis, dtype, out, keepdims, ...])

在给定轴上对数组元素求和。

swapaxes(a, axis1, axis2)

交换数组的两个轴。

take(a, indices[, axis, out, mode, ...])

从数组中取出元素。

take_along_axis(arr, indices, axis[, mode, ...])

从数组中取出元素。

tan(x, /)

计算输入的每个元素的三角正切。

tanh(x, /)

计算输入的逐元素双曲正切。

tensordot(a, b[, axes, precision, ...])

计算两个 N 维数组的张量点积。

tile(A, reps)

通过沿指定维度重复 A 来构造数组。

trace(a[, offset, axis1, axis2, dtype, out])

计算输入沿给定轴的对角线和。

trapezoid(y[, x, dx, axis])

使用复合梯形法则沿给定轴积分。

transpose(a[, axes])

返回 N 维数组的转置版本。

tri(N[, M, k, dtype])

返回一个对角线及以下为 1,其他为 0 的数组。

tril(m[, k])

返回数组的下三角。

tril_indices(n[, k, m])

返回大小为 (n, m) 的数组的下三角索引。

tril_indices_from(arr[, k])

返回给定数组的下三角索引。

trim_zeros(filt[, trim])

修剪输入数组的前导和/或尾随零。

triu(m[, k])

返回数组的上三角。

triu_indices(n[, k, m])

返回大小为 (n, m) 的数组的上三角索引。

triu_indices_from(arr[, k])

返回给定数组的上三角索引。

true_divide(x1, x2, /)

逐元素计算 x1 除以 x2 的结果。

trunc(x)

将输入向零舍入到最接近的整数。

ufunc(func, /, nin, nout, *[, name, nargs, ...])

对数组逐元素操作的通用函数。

uint

uint64 的别名

uint16(x)

类型为 uint16 的 JAX 标量构造函数。

uint32(x)

类型为 uint32 的 JAX 标量构造函数。

uint64(x)

类型为 uint64 的 JAX 标量构造函数。

uint8(x)

类型为 uint8 的 JAX 标量构造函数。

union1d(ar1, ar2, *[, size, fill_value])

计算两个一维数组的集合并集。

unique(ar[, return_index, return_inverse, ...])

返回数组中的唯一值。

unique_all(x, /, *[, size, fill_value])

从 x 返回唯一值,以及索引、逆索引和计数。

unique_counts(x, /, *[, size, fill_value])

返回 x 中的唯一值及其计数。

unique_inverse(x, /, *[, size, fill_value])

从 x 返回唯一值,以及索引、逆索引和计数。

unique_values(x, /, *[, size, fill_value])

从 x 返回唯一值,以及索引、逆索引和计数。

unpackbits(a[, axis, count, bitorder])

解包 uint8 数组中的位。

unravel_index(indices, shape)

将扁平索引转换为多维索引。

unstack(x, /, *[, axis])

沿轴解堆数组。

unsignedinteger()

所有无符号整数标量类型的抽象基类。

unwrap(p[, discont, axis, period])

解卷周期信号。

vander(x[, N, increasing])

生成范德蒙矩阵。

var(a[, axis, dtype, out, ddof, keepdims, ...])

计算沿给定轴的方差。

vdot(a, b, *[, precision, ...])

执行两个一维向量的共轭乘法。

vecdot(x1, x2, /, *[, axis, precision, ...])

执行两个批处理向量的共轭乘法。

vecmat(x1, x2, /)

批处理共轭向量-矩阵乘积。

vectorize(pyfunc, *[, excluded, signature])

定义具有广播功能的向量化函数。

vsplit(ary, indices_or_sections)

将数组垂直拆分为子数组。

vstack(tup[, dtype])

垂直堆叠数组。

where(condition[, x, y, size, fill_value])

根据条件从两个数组中选择元素。

zeros(shape[, dtype, device, out_sharding])

创建全为零的数组。

zeros_like(a[, dtype, shape, device])

创建与数组具有相同形状和数据类型,且全为零的数组。

jax.numpy.fft#

fft(a[, n, axis, norm])

计算沿给定轴的一维离散傅里叶变换。

fft2(a[, s, axes, norm])

计算沿给定轴的二维离散傅里叶变换。

fftfreq(n[, d, dtype, device])

返回离散傅里叶变换的采样频率。

fftn(a[, s, axes, norm])

计算沿给定轴的多维离散傅里叶变换。

fftshift(x[, axes])

将零频率 FFT 分量移到频谱中心。

hfft(a[, n, axis, norm])

计算频谱具有 Hermitian 对称性的数组的一维 FFT。

ifft(a[, n, axis, norm])

计算一维逆离散傅里叶变换。

ifft2(a[, s, axes, norm])

计算二维逆离散傅里叶变换。

ifftn(a[, s, axes, norm])

计算多维逆离散傅里叶变换。

ifftshift(x[, axes])

jax.numpy.fft.fftshift() 的逆运算。

ihfft(a[, n, axis, norm])

计算频谱具有 Hermitian 对称性数组的一维逆 FFT。

irfft(a[, n, axis, norm])

计算实值一维逆离散傅里叶变换。

irfft2(a[, s, axes, norm])

计算实值二维逆离散傅里叶变换。

irfftn(a[, s, axes, norm])

计算实值多维逆离散傅里叶变换。

rfft(a[, n, axis, norm])

计算实值数组的一维离散傅里叶变换。

rfft2(a[, s, axes, norm])

计算实值数组的二维离散傅里叶变换。

rfftfreq(n[, d, dtype, device])

返回离散傅里叶变换的采样频率。

rfftn(a[, s, axes, norm])

计算实值数组的多维离散傅里叶变换。

jax.numpy.linalg#

cholesky(a, *[, upper, symmetrize_input])

计算矩阵的 Cholesky 分解。

cond(x[, p])

计算矩阵的条件数。

cross(x1, x2, /, *[, axis])

计算两个三维向量的叉积。

det(a)

计算数组的行列式。

diagonal(x, /, *[, offset])

提取矩阵或矩阵堆栈的对角线。

eig(a)

计算方阵的特征值和特征向量。

eigh(a[, UPLO, symmetrize_input])

计算 Hermitian 矩阵的特征值和特征向量。

eigvals(a)

计算一般矩阵的特征值。

eigvalsh(a[, UPLO, symmetrize_input])

计算 Hermitian 矩阵的特征值。

inv(a)

返回方阵的逆。

lstsq(a, b[, rcond, numpy_resid])

返回线性方程的最小二乘解。

matmul(x1, x2, /, *[, precision, ...])

执行矩阵乘法。

matrix_norm(x, /, *[, keepdims, ord])

计算矩阵或矩阵堆栈的范数。

matrix_power(a, n)

将方阵提升到整数幂。

matrix_rank(M[, rtol, tol])

计算矩阵的秩。

matrix_transpose(x, /)

转置矩阵或矩阵堆栈。

multi_dot(arrays, *[, precision])

高效计算数组序列之间的矩阵乘积。

norm(x[, ord, axis, keepdims])

计算矩阵或向量的范数。

outer(x1, x2, /)

计算两个一维数组的外积。

pinv(a[, rtol, hermitian, rcond])

计算矩阵的(Moore-Penrose)伪逆。

qr(a[, mode])

计算数组的 QR 分解

slogdet(a, *[, method])

计算数组行列式的符号和(自然)对数。

solve(a, b)

求解线性方程组。

svd(a[, full_matrices, compute_uv, ...])

计算奇异值分解。

svdvals(x, /)

计算矩阵的奇异值。

tensordot(x1, x2, /, *[, axes, precision, ...])

计算两个 N 维数组的张量点积。

tensorinv(a[, ind])

计算数组的张量逆。

tensorsolve(a, b[, axes])

求解张量方程 a x = b 中的 x。

trace(x, /, *[, offset, dtype])

计算矩阵的迹。

vector_norm(x, /, *[, axis, keepdims, ord])

计算向量或向量批次的向量范数。

vecdot(x1, x2, /, *[, axis, precision, ...])

计算两个数组的(批处理)向量共轭点积。

JAX Array#

JAX Array 对象(及其别名 jax.numpy.ndarray)是 JAX 中的核心数组对象:您可以将其视为 JAX 中 numpy.ndarray 的等效项。与 numpy.ndarray 类似,大多数用户无需手动实例化 Array 对象,而是通过 jax.numpy 函数(如 array()arange()linspace() 以及上面列出的其他函数)来创建它们。

复制和序列化#

JAX Array 对象旨在在适当的情况下与 Python 标准库工具无缝协作。

使用内置的 copy 模块,当 copy.copy()copy.deepcopy() 遇到 Array 时,它等同于调用 copy() 方法,该方法将在与原始数组相同的设备上创建缓冲区副本。这在追踪/JIT 编译的代码中将正常工作,尽管在此上下文中复制操作可能会被编译器省略。

当内置的 pickle 模块遇到 Array 时,它将通过紧凑的位表示进行序列化,方式类似于 pickled numpy.ndarray 对象。解压后,结果将是一个新 Array 对象,位于**默认设备上**。这是因为通常情况下,pickling 和 unpickling 可能会在不同的运行时环境中进行,并且没有通用的方法将一个运行时的设备 ID 映射到另一个运行时的设备 ID。如果在追踪/JIT 编译的代码中使用 pickle,则将导致 ConcretizationTypeError

Python 数组 API 标准#

注意

在 JAX v0.4.32 之前,您必须 import jax.experimental.array_api 才能为 JAX 数组启用数组 API。JAX v0.4.32 之后,不再需要导入此模块,并且会引发弃用警告。JAX v0.5.0 之后,此导入将引发错误。

从 JAX v0.4.32 开始,jax.Arrayjax.numpyPython 数组 API 标准兼容。您可以通过 jax.Array.__array_namespace__() 访问数组 API 命名空间。

>>> def f(x):
...   nx = x.__array_namespace__()
...   return nx.sin(x) ** 2 + nx.cos(x) ** 2

>>> import jax.numpy as jnp
>>> x = jnp.arange(5)
>>> f(x).round()
Array([1., 1., 1., 1., 1.], dtype=float32)

JAX 在某些方面偏离了标准,主要原因是 JAX 数组是不可变的,不支持就地更新。其中一些不兼容性正在通过 array-api-compat 模块得到解决。

欲了解更多信息,请参阅 Python 数组 API 标准文档。