jax.numpy 模块#

使用 jax.lax 中的原始操作实现 NumPy API。

虽然 JAX 尽可能地遵循 NumPy API,但有时 JAX 无法完全遵循 NumPy。

  • 值得注意的是,由于 JAX 数组是不可变的,因此无法在 JAX 中实现原地修改数组的 NumPy API。然而,JAX 通常能够提供一个纯函数式的替代 API。例如,JAX 提供了一个纯函数式的索引更新函数 x.at[i].set(y) 来替代原地数组更新(x[i] = y)(参见 ndarray.at)。

  • 同样地,一些 NumPy 函数在可能的情况下会返回数组的视图(例如 transpose()reshape())。这些函数的 JAX 版本将返回副本,尽管使用 jax.jit() 编译一系列操作时,这些副本通常会被 XLA 优化掉。

  • NumPy 在将值提升为 float64 类型方面非常积极。JAX 有时在类型提升方面不太积极(参见 类型提升语义)。

  • 一些 NumPy 例程具有依赖数据的输出形状(例如 unique()nonzero())。由于 XLA 编译器要求在编译时已知数组形状,因此此类操作与 JIT 不兼容。因此,JAX 在此类函数中添加了一个可选的 size 参数,可以在静态指定该参数以便与 JIT 一起使用。

几乎所有适用的 NumPy 函数都已在 jax.numpy 命名空间中实现;它们列在下面。

ndarray.at

用于索引更新功能的辅助属性。

abs(x, /)

jax.numpy.absolute() 的别名。

absolute(x, /)

逐元素计算绝对值。

acos(x, /)

jax.numpy.arccos() 的别名。

acosh(x, /)

jax.numpy.arccosh() 的别名。

加法

逐元素相加两个数组。

all(a[, axis, out, keepdims, where])

测试沿给定轴的所有数组元素是否评估为 True。

allclose(a, b[, rtol, atol, equal_nan])

在容差范围内检查两个数组是否逐元素近似相等。

amax(a[, axis, out, keepdims, initial, where])

jax.numpy.max() 的别名。

amin(a[, axis, out, keepdims, initial, where])

jax.numpy.min() 的别名。

angle(z[, deg])

返回复数或复数数组的角度。

any(a[, axis, out, keepdims, where])

测试数组沿给定轴的任何元素是否为 True。

append(arr, values[, axis])

返回一个新数组,将值附加到原始数组的末尾。

apply_along_axis(func1d, axis, arr, *args, ...)

沿轴将函数应用于一维数组切片。

apply_over_axes(func, a, axes)

反复将函数应用于指定的轴。

arange(start[, stop, step, dtype, device, ...])

创建等间距值的数组。

arccos(x, /)

计算输入数组的三角余弦的逐元素反正值。

arccosh(x, /)

计算输入数组的双曲余弦的逐元素反双曲值。

arcsin(x, /)

计算输入数组的三角正弦的逐元素反正值。

arcsinh(x, /)

计算输入数组的双曲正弦的逐元素反双曲值。

arctan(x, /)

计算输入数组的三角正切的逐元素反正值。

arctan2(x1, x2, /)

计算 x1/x2 的反正切,并选择正确的象限。

arctanh(x, /)

计算输入数组的双曲正切的逐元素反双曲值。

argmax(a[, axis, out, keepdims])

返回数组最大值的索引。

argmin(a[, axis, out, keepdims])

返回数组最小值的索引。

argpartition(a, kth[, axis])

返回部分排序数组的索引。

argsort(a[, axis, kind, order, stable, ...])

返回排序数组的索引。

argwhere(a, *[, size, fill_value])

查找非零数组元素的索引。

around(a[, decimals, out])

jax.numpy.round() 的别名。

array(object[, dtype, copy, order, ndmin, ...])

将对象转换为 JAX 数组。

array_equal(a1, a2[, equal_nan])

检查两个数组是否逐元素相等。

array_equiv(a1, a2)

检查两个数组是否逐元素相等。

array_repr(arr[, max_line_width, precision, ...])

返回数组的字符串表示。

array_split(ary, indices_or_sections[, axis])

将数组分割成子数组。

array_str(a[, max_line_width, precision, ...])

返回数组中数据的字符串表示。

asarray(a[, dtype, order, copy, device, ...])

将对象转换为 JAX 数组。

asin(x, /)

jax.numpy.arcsin() 的别名。

asinh(x, /)

jax.numpy.arcsinh() 的别名。

astype(x, dtype, /, *[, copy, device])

将数组转换为指定的数据类型。

atan(x, /)

jax.numpy.arctan() 的别名。

atanh(x, /)

jax.numpy.arctanh() 的别名。

atan2(x1, x2, /)

jax.numpy.arctan2() 的别名。

atleast_1d(*arys)

将输入转换为至少具有一维的数组。

atleast_2d(*arys)

将输入转换为至少具有二维的数组。

atleast_3d(*arys)

将输入转换为至少具有三维的数组。

average(a[, axis, weights, returned, keepdims])

计算加权平均值。

bartlett(M)

返回大小为 M 的 Bartlett 窗口。

bincount(x[, weights, minlength, length])

计算整数数组中每个值出现的次数。

bitwise_and

逐元素计算按位 AND 操作。

bitwise_count(x, /)

计算 x 中每个元素的绝对值的二进制表示中 1 的数量。

bitwise_invert(x, /)

jax.numpy.invert() 的别名。

bitwise_left_shift(x, y, /)

jax.numpy.left_shift() 的别名。

bitwise_not(x, /)

jax.numpy.invert() 的别名。

bitwise_or

逐元素计算按位 OR 操作。

bitwise_right_shift(x1, x2, /)

jax.numpy.right_shift() 的别名。

bitwise_xor

逐元素计算按位 XOR 操作。

blackman(M)

返回大小为 M 的 Blackman 窗口。

block(arrays)

从块列表中创建数组。

bool_

bool 的别名。

broadcast_arrays(*args)

将数组广播到公共形状。

broadcast_shapes(*shapes)

将输入形状广播到公共输出形状。

broadcast_to(array, shape, *[, out_sharding])

将数组广播到指定的形状。

c_

沿最后一个轴连接切片、标量和类数组对象。

can_cast(from_, to[, casting])

根据转换规则,返回数据类型之间是否可以发生转换。

cbrt(x, /)

计算输入数组的逐元素立方根。

cdouble

complex128 的别名。

ceil(x, /)

向上取整到最近的整数。

character()

所有字符字符串标量类型的抽象基类。

choose(a, choices[, out, mode])

通过堆叠选择数组的切片来构造数组。

clip([arr, min, max, a, a_min, a_max])

将数组值裁剪到指定的范围。

column_stack(tup)

按列堆叠数组。

complex_

complex128 的别名。

complex128(x)

complex128 类型的 JAX 标量构造函数。

complex64(x)

complex64 类型的 JAX 标量构造函数。

complexfloating()

由浮点数组成的所有复数标量类型的抽象基类。

ComplexWarning

将复数类型转换为实数类型时引发的警告。

compress(condition, a[, axis, size, ...])

使用布尔条件沿给定轴压缩数组。

concat(arrays, /, *[, axis])

沿现有轴连接数组。

concatenate(arrays[, axis, dtype])

沿现有轴连接数组。

conj(x, /)

jax.numpy.conjugate() 的别名。

conjugate(x, /)

返回输入数组的逐元素复共轭。

convolve(a, v[, mode, precision, ...])

两个一维数组的卷积。

copy(a[, order])

返回数组的副本。

copysign(x1, x2, /)

x2 中每个元素的符号复制到 x1 中对应的元素。

corrcoef(x[, y, rowvar, dtype])

计算皮尔逊相关系数。

correlate(a, v[, mode, precision, ...])

两个一维数组的互相关。

cos(x, /)

计算输入数组的逐元素三角余弦。

cosh(x, /)

计算输入数组的双曲余弦的逐元素值。

count_nonzero(a[, axis, keepdims])

返回沿给定轴的非零元素的数量。

cov(m[, y, rowvar, bias, ddof, fweights, ...])

估计加权样本协方差。

cross(a, b[, axisa, axisb, axisc, axis])

计算两个数组的(批量)叉积。

csingle

complex64 的别名。

cumprod(a[, axis, dtype, out])

沿轴的元素的累积乘积。

cumsum(a[, axis, dtype, out])

沿轴的元素的累积和。

cumulative_prod(x, /, *[, axis, dtype, ...])

沿数组轴的累积乘积。

cumulative_sum(x, /, *[, axis, dtype, ...])

沿数组轴的累积和。

deg2rad(x, /)

将角度从度转换为弧度。

degrees(x, /)

jax.numpy.rad2deg() 的别名。

delete(arr, obj[, axis, assume_unique_indices])

从数组中删除条目。

diag(v[, k])

返回指定的对角线或构造对角线数组。

diag_indices(n[, ndim])

返回访问多维数组主对角线的索引。

diag_indices_from(arr)

返回访问给定数组主对角线的索引。

diagflat(v[, k])

返回一个二维数组,其中展平的输入数组沿对角线排列。

diagonal(a[, offset, axis1, axis2])

返回数组的指定对角线。

diff(a[, n, axis, prepend, append])

沿给定轴计算数组元素之间的 n 阶差值。

digitize(x, bins[, right, method])

将数组转换为 bin 索引。

divide(x1, x2, /)

jax.numpy.true_divide() 的别名。

divmod(x1, x2, /)

逐元素计算 x1 除以 x2 的整数商和余数。

dot(a, b, *[, precision, ...])

计算两个数组的点积。

double

float64 的别名。

dsplit(ary, indices_or_sections)

按深度分割数组。

dstack(tup[, dtype])

按深度堆叠数组。

dtype(dtype[, align, copy])

创建数据类型对象。

ediff1d(ary[, to_end, to_begin])

计算展平数组元素的差值。

einsum(subscripts, /, *operands[, out, ...])

爱因斯坦求和。

einsum_path(subscripts, /, *operands[, optimize])

在不计算 einsum 的情况下评估最优收缩路径。

empty(shape[, dtype, device, out_sharding])

创建空数组。

empty_like(prototype[, dtype, shape, device])

创建具有与数组相同形状和数据类型的空数组。

equal(x, y, /)

返回 x == y 的逐元素真值。

exp(x, /)

计算输入数组的逐元素指数。

exp2(x, /)

计算输入数组的逐元素以 2 为底的指数。

expand_dims(a, axis)

在数组中插入长度为 1 的维度。

expm1(x, /)

计算输入数组的逐元素 exp(x)-1

extract(condition, arr, *[, size, fill_value])

返回满足条件的数组元素。

eye(N[, M, k, dtype, device])

创建方阵或长方形单位矩阵。

fabs(x, /)

计算实值输入数组的逐元素绝对值。

fill_diagonal(a, val[, wrap, inplace])

返回一个对角线被覆盖的数组的副本。

finfo(dtype)

浮点类型的机器限制。

fix(x[, out])

将输入向零舍入到最接近的整数。

flatnonzero(a, *[, size, fill_value])

返回展平数组中非零元素的索引。

flexible()

没有预定义长度的所有标量类型的抽象基类。

flip(m[, axis])

沿给定轴反转数组元素的顺序。

fliplr(m)

沿轴 1 反转数组元素的顺序。

flipud(m)

沿轴 0 反转数组元素的顺序。

float_

float64 的别名。

float_power(x, y, /)

计算以 x 为底,y 为指数的逐元素幂。

float16(x)

float16 类型的 JAX 标量构造函数。

float32(x)

float32 类型的 JAX 标量构造函数。

float64(x)

float64 类型的 JAX 标量构造函数。

floating()

所有浮点数标量类型的抽象基类。

floor(x, /)

向下取整到最近的整数。

floor_divide(x1, x2, /)

逐元素计算 x1 除以 x2 的地板除法。

fmax(x1, x2)

返回输入数组的逐元素最大值。

fmin(x1, x2)

返回输入数组的逐元素最小值。

fmod(x1, x2, /)

逐元素计算浮点数取模操作。

frexp(x, /)

将浮点数值分解为尾数和 2 的指数。

frombuffer(buffer[, dtype, count, offset])

将缓冲区转换为一维 JAX 数组。

fromfile(*args, **kwargs)

jnp.fromfile 的未实现 JAX 包装器。

fromfunction(function, shape, *[, dtype])

通过对索引应用函数来创建数组。

fromiter(*args, **kwargs)

jnp.fromiter 的未实现 JAX 包装器。

frompyfunc(func, /, nin, nout, *[, identity])

从任意 JAX 兼容的标量函数创建 JAX ufunc。

fromstring(string[, dtype, count])

将文本字符串转换为一维 JAX 数组。

from_dlpack(x, /, *[, device, copy])

通过 DLPack 构建 JAX 数组。

full(shape, fill_value[, dtype, device])

创建充满指定值的数组。

full_like(a, fill_value[, dtype, shape, device])

创建具有与数组相同形状和数据类型的、充满指定值的数组。

gcd(x1, x2)

计算两个数组的最大公约数。

generic()

NumPy 标量类型的基类。

geomspace(start, stop[, num, endpoint, ...])

生成几何间隔的值。

get_printoptions()

numpy.get_printoptions() 的别名。

gradient(f, *varargs[, axis, edge_order])

计算采样函数的数值梯度。

greater(x, y, /)

返回 x > y 的逐元素真值。

greater_equal(x, y, /)

返回 x >= y 的逐元素真值。

hamming(M)

返回大小为 M 的 Hamming 窗口。

hanning(M)

返回大小为 M 的 Hanning 窗口。

heaviside(x1, x2, /)

计算阶跃函数(Heaviside 函数)。

histogram(a[, bins, range, weights, density])

计算一维直方图。

histogram_bin_edges(a[, bins, range, weights])

计算直方图的 bin 边缘。

histogram2d(x, y[, bins, range, weights, ...])

计算二维直方图。

histogramdd(sample[, bins, range, weights, ...])

计算 N 维直方图。

hsplit(ary, indices_or_sections)

将数组水平分割成子数组。

hstack(tup[, dtype])

水平堆叠数组。

hypot(x1, x2, /)

对于给定直角三角形的两个直角边,返回逐元素的斜边长。

i0(x)

计算第一类零阶修正贝塞尔函数。

identity(n[, dtype])

创建方阵单位矩阵。

iinfo(int_type)

imag(val, /)

返回复数参数的逐元素虚部。

index_exp

一种更方便地为数组构建索引元组的方法。

indices(dimensions[, dtype, sparse])

生成网格索引数组。

inexact()

在其范围内可能包含不精确表示值的数字标量类型的抽象基类,例如浮点数。

inner(a, b, *[, precision, ...])

计算两个数组的内积。

insert(arr, obj, values[, axis])

在指定索引处将条目插入数组。

int_

int64 的别名。

int16(x)

int16 类型的 JAX 标量构造函数。

int32(x)

int32 类型的 JAX 标量构造函数。

int64(x)

int64 类型的 JAX 标量构造函数。

int8(x)

int8 类型的 JAX 标量构造函数。

integer()

所有整数标量类型的抽象基类。

interp(x, xp, fp[, left, right, period])

一维线性插值。

intersect1d(ar1, ar2[, assume_unique, ...])

计算两个一维数组的集合交集。

invert(x, /)

计算输入数组的按位取反。

isclose(a, b[, rtol, atol, equal_nan])

在容差范围内检查两个数组的元素是否近似相等。

iscomplex(x)

返回一个布尔数组,显示输入是复数的位置。

iscomplexobj(x)

检查输入是否为复数或包含复数元素的数组。

isdtype(dtype, kind)

返回一个布尔值,指示提供的 dtype 是否属于指定种类。

isfinite(x, /)

返回一个布尔数组,指示输入中的每个元素是否为有限值。

isin(element, test_elements[, ...])

确定 element 中的元素是否出现在 test_elements 中。

isinf(x, /)

返回一个布尔数组,指示输入中的每个元素是否为无穷大。

isnan(x, /)

返回一个布尔数组,指示输入中的每个元素是否为 NaN

isneginf(x, /[, out])

返回一个布尔数组,指示输入中的每个元素是否为负无穷大。

isposinf(x, /[, out])

返回一个布尔数组,指示输入中的每个元素是否为正无穷大。

isreal(x)

返回一个布尔数组,显示输入是实数的位置。

isrealobj(x)

检查输入是否不是复数或包含复数元素的数组。

isscalar(element)

如果输入是标量,则返回 True。

issubdtype(arg1, arg2)

如果 arg1 在类型层次结构中等于或低于 arg2,则返回 True。

iterable(y)

检查一个对象是否可以迭代。

ix_(*args)

从 N 个一维序列返回多维网格(开放网格)。

kaiser(M, beta)

返回大小为 M 的 Kaiser 窗口。

kron(a, b)

计算两个输入数组的 Kronecker 乘积。

lcm(x1, x2)

计算两个数组的最小公倍数。

ldexp(x1, x2, /)

计算 x1 * 2 ** x2

left_shift(x, y, /)

x 的位向左移动 y 指定的数量,逐元素进行。

less(x, y, /)

返回 x < y 的逐元素真值。

less_equal(x, y, /)

返回 x <= y 的逐元素真值。

lexsort(keys[, axis])

按字典序对键序列进行排序。

linspace(start, stop[, num, endpoint, ...])

在区间内返回等间距的数字。

load(file, *args, **kwargs)

从 npy 文件加载 JAX 数组。

log(x, /)

计算输入数组的逐元素自然对数。

log10(x, /)

逐元素计算 x 的以 10 为底的对数。

log1p(x, /)

逐元素计算 log(x+1)

log2(x, /)

逐元素计算 x 的以 2 为底的对数。

logaddexp

计算 log(exp(x1) + exp(x2)),避免溢出。

logaddexp2

以 2 为底,计算 log(exp(x1) + exp(x2)),避免溢出。

logical_and

逐元素计算逻辑 AND 操作。

logical_not(x, /)

逐元素计算逻辑 NOT bool(x)。

logical_or

逐元素计算逻辑 OR 操作。

logical_xor

逐元素计算逻辑 XOR 操作。

logspace(start, stop[, num, endpoint, base, ...])

生成对数间隔的值。

mask_indices(n, mask_func[, k, size])

返回 (n, n) 数组掩码的索引。

matmul(a, b, *[, precision, ...])

执行矩阵乘法。

matrix_transpose(x, /)

转置数组的最后两个维度。

matvec(x1, x2, /)

批量矩阵-向量乘积。

max(a[, axis, out, keepdims, initial, where])

返回数组元素沿给定轴的最大值。

maximum

返回输入数组的逐元素最大值。

mean(a[, axis, dtype, out, keepdims, where])

返回给定轴上数组元素的平均值。

median(a[, axis, out, overwrite_input, keepdims])

返回数组元素沿给定轴的中位数。

meshgrid(*xi[, copy, sparse, indexing])

从 N 个一维向量构建 N 维网格数组。

mgrid

返回密集的多维“网格”。

min(a[, axis, out, keepdims, initial, where])

返回给定轴上数组元素的最小值。

minimum

返回输入数组的逐元素最小值。

mod(x1, x2, /)

jax.numpy.remainder() 的别名。

modf(x, /[, out])

返回输入数组的逐元素小数部分和整数部分。

moveaxis(a, source, destination)

将数组轴移动到新位置。

multiply

逐元素相乘两个数组。

nan_to_num(x[, copy, nan, posinf, neginf])

替换数组中的 NaN 和无穷大条目。

nanargmax(a[, axis, out, keepdims])

返回数组最大值的索引,忽略 NaN。

nanargmin(a[, axis, out, keepdims])

返回数组最小值的索引,忽略 NaN。

nancumprod(a[, axis, dtype, out])

沿轴的累积乘积,忽略 NaN 值。

nancumsum(a[, axis, dtype, out])

沿轴的累积和,忽略 NaN 值。

nanmax(a[, axis, out, keepdims, initial, where])

返回数组元素沿给定轴的最大值,忽略 NaN。

nanmean(a[, axis, dtype, out, keepdims, where])

返回数组元素沿给定轴的平均值,忽略 NaN。

nanmedian(a[, axis, out, overwrite_input, ...])

返回数组元素沿给定轴的中位数,忽略 NaN。

nanmin(a[, axis, out, keepdims, initial, where])

返回数组元素沿给定轴的最小值,忽略 NaN。

nanpercentile(a, q[, axis, out, ...])

沿指定轴计算数据的百分位数,忽略 NaN 值。

nanprod(a[, axis, dtype, out, keepdims, ...])

返回数组元素沿给定轴的乘积,忽略 NaN。

nanquantile(a, q[, axis, out, ...])

沿指定轴计算数据的分位数,忽略 NaN 值。

nanstd(a[, axis, dtype, out, ddof, ...])

沿给定轴计算标准差,忽略 NaN。

nansum(a[, axis, dtype, out, keepdims, ...])

返回数组元素沿给定轴的和,忽略 NaN。

nanvar(a[, axis, dtype, out, ddof, ...])

沿给定轴计算方差,忽略 NaN。

ndarray

Array 的别名。

ndim(a)

返回数组的维度数。

negative

返回输入数组的逐元素负值。

nextafter(x, y, /)

返回 x 沿 y 方向的下一个浮点数,逐元素进行。

nonzero(a, *[, size, fill_value])

返回数组中非零元素的索引。

not_equal(x, y, /)

返回 x != y 的逐元素真值。

number()

所有数字标量类型的抽象基类。

object_

任何 Python 对象。

ogrid

返回开放的多维“网格”。

ones(shape[, dtype, device, out_sharding])

创建充满 1 的数组。

ones_like(a[, dtype, shape, device, ...])

创建具有与数组相同形状和数据类型的、充满 1 的数组。

outer(a, b[, out])

计算两个数组的外积。

packbits(a[, axis, bitorder])

将位数组打包成 uint8 数组。

pad(array, pad_width[, mode])

向数组添加填充。

partition(a, kth[, axis])

返回数组的部分排序副本。

percentile(a, q[, axis, out, ...])

沿指定轴计算数据的百分位数。

permute_dims(a, /, axes)

置换数组的轴/维度。

piecewise(x, condlist, funclist, *args, **kw)

在域上分段求值函数。

place(arr, mask, vals, *[, inplace])

根据掩码更新数组元素。

poly(seq_of_zeros)

返回给定根序列的多项式系数。

polyadd(a1, a2)

返回两个多项式的和。

polyder(p[, m])

返回指定阶数多项式导数的系数。

polydiv(u, v, *[, trim_leading_zeros])

返回多项式除法的商和余数。

polyfit(x, y, deg[, rcond, full, w, cov])

最小二乘多项式拟合数据。

polyint(p[, m, k])

返回指定阶数多项式积分的系数。

polymul(a1, a2, *[, trim_leading_zeros])

返回两个多项式的乘积。

polysub(a1, a2)

返回两个多项式的差。

polyval(p, x, *[, unroll])

在特定值处评估多项式。

positive(x, /)

返回输入的逐元素正值。

pow(x1, x2, /)

jax.numpy.power() 的别名

power(x1, x2, /)

计算 x1 的逐元素指数 x2

printoptions(*args, **kwargs)

numpy.printoptions() 的别名。

prod(a[, axis, dtype, out, keepdims, ...])

返回给定轴上数组元素的乘积。

promote_types(a, b)

返回二进制运算应将参数强制转换为的类型。

ptp(a[, axis, out, keepdims])

返回沿给定轴的峰峰值范围。

put(a, ind, v[, mode, inplace])

将元素按指定索引放入数组。

put_along_axis(arr, indices, values, axis[, ...])

通过匹配一维索引和数据切片将值放入目标数组。

quantile(a, q[, axis, out, overwrite_input, ...])

计算指定轴上数据的分位数。

r_

沿第一个轴连接切片、标量和类数组对象。

rad2deg(x, /)

将弧度转换为度。

radians(x, /)

jax.numpy.deg2rad() 的别名

ravel(a[, order, out_sharding])

将数组展平为一维形状。

ravel_multi_index(multi_index, dims[, mode, ...])

将多维索引转换为扁平索引。

real(val, /)

返回复数参数的逐元素实部。

reciprocal(x, /)

计算输入的逐元素倒数。

remainder(x1, x2, /)

返回除法的逐元素余数。

repeat(a, repeats[, axis, ...])

从重复元素构造数组。

reshape(a, shape[, order, copy, out_sharding])

返回数组的重塑副本。

resize(a, new_shape)

返回具有指定形状的新数组。

result_type(*args)

返回将 JAX 提升规则应用于输入的类型。

right_shift(x1, x2, /)

x1 的位右移指定的量 x2

rint(x, /)

将 x 的元素四舍五入到最接近的整数

roll(a, shift[, axis])

沿指定轴滚动数组的元素。

rollaxis(a, axis[, start])

将指定轴滚动到给定位置。

roots(p[, strip_zeros])

返回给定系数 p 的多项式的根。

rot90(m[, k, axes])

在指定的坐标轴所定义的平面内,将数组逆时针旋转 90 度。

round(a[, decimals, out])

将输入四舍五入到指定的小数位数。

s_

一种更方便地为数组构建索引元组的方法。

save(file, arr[, allow_pickle, fix_imports])

将数组保存到 NumPy .npy 格式的二进制文件中。

savez(file, *args[, allow_pickle])

将多个数组保存到未压缩的 .npz 格式的单个文件中。

searchsorted(a, v[, side, sorter, method])

在已排序数组中执行二分查找。

select(condlist, choicelist[, default])

根据一系列条件选择值。

set_printoptions(*args, **kwargs)

numpy.set_printoptions() 的别名。

setdiff1d(ar1, ar2[, assume_unique, size, ...])

计算两个一维数组的集合差集。

setxor1d(ar1, ar2[, assume_unique, size, ...])

计算两个数组中元素的集合异或。

shape(a)

返回数组的形状。

sign(x, /)

返回输入的逐元素符号指示。

signbit(x, /)

返回数组元素的符号位。

signedinteger()

所有有符号整数标量类型的抽象基类。

sin(x, /)

计算输入的每个元素的三角正弦。

sinc(x, /)

计算归一化 sinc 函数。

single

alias of float32

sinh(x, /)

计算输入的逐元素双曲正弦。

size(a[, axis])

返回沿给定轴的元素数量。

sort(a[, axis, kind, order, stable, descending])

返回数组的排序副本。

sort_complex(a)

返回复数数组的排序副本。

spacing(x, /)

返回 x 与下一个相邻数字之间的间距。

split(ary, indices_or_sections[, axis])

将数组分割成子数组。

sqrt(x, /)

计算数组的逐元素非负平方根。

square(x, /)

计算数组的逐元素平方。

squeeze(a[, axis])

移除数组的一个或多个长度为 1 的轴

stack(arrays[, axis, out, dtype])

沿新轴连接数组。

std(a[, axis, dtype, out, ddof, keepdims, ...])

计算沿给定轴的标准差。

subtract

逐元素减去两个数组。

sum(a[, axis, dtype, out, keepdims, ...])

在给定轴上对数组元素求和。

swapaxes(a, axis1, axis2)

交换数组的两个轴。

take(a, indices[, axis, out, mode, ...])

从数组中选取元素。

take_along_axis(arr, indices[, axis, mode, ...])

从数组中选取元素。

tan(x, /)

计算输入的每个元素的三角正切。

tanh(x, /)

计算输入的逐元素双曲正切。

tensordot(a, b[, axes, precision, ...])

计算两个 N 维数组的张量点积。

tile(A, reps)

沿指定维度重复 A 来构造数组。

trace(a[, offset, axis1, axis2, dtype, out])

计算沿给定轴的输入的对角线之和。

trapezoid(y[, x, dx, axis])

使用复合梯形法则沿给定轴积分。

transpose(a[, axes])

返回 N 维数组的转置版本。

tri(N[, M, k, dtype])

返回对角线及其以下为 1,其他地方为 0 的数组。

tril(m[, k])

返回数组的下三角部分。

tril_indices(n[, k, m])

返回大小为 (n, m) 的数组的下三角索引。

tril_indices_from(arr[, k])

返回给定数组的下三角索引。

trim_zeros(filt[, trim, axis])

修剪输入数组的前导和/或尾随零。

triu(m[, k])

返回数组的上三角部分。

triu_indices(n[, k, m])

返回大小为 (n, m) 的数组的上三角索引。

triu_indices_from(arr[, k])

返回给定数组的上三角索引。

true_divide(x1, x2, /)

逐元素计算 x1 除以 x2 的结果

trunc(x)

将输入向零舍入到最接近的整数。

ufunc(func, /, nin, nout, *[, name, nargs, ...])

对数组进行逐元素操作的通用函数。

uint

alias of uint64

uint16(x)

JAX uint16 类型的标量构造函数。

uint32(x)

JAX uint32 类型的标量构造函数。

uint64(x)

JAX uint64 类型的标量构造函数。

uint8(x)

JAX uint8 类型的标量构造函数。

union1d(ar1, ar2, *[, size, fill_value])

计算两个一维数组的集合并集。

unique(ar[, return_index, return_inverse, ...])

返回数组中的唯一值。

unique_all(x, /, *[, size, fill_value])

从 x 返回唯一值,以及索引、逆索引和计数。

unique_counts(x, /, *[, size, fill_value])

返回 x 中的唯一值以及它们的计数。

unique_inverse(x, /, *[, size, fill_value])

从 x 返回唯一值,以及索引、逆索引和计数。

unique_values(x, /, *[, size, fill_value])

从 x 返回唯一值,以及索引、逆索引和计数。

unpackbits(a[, axis, count, bitorder])

解包 uint8 数组中的位。

unravel_index(indices, shape)

将扁平索引转换为多维索引。

unstack(x, /, *[, axis])

沿轴解堆栈数组。

unsignedinteger()

所有无符号整数标量类型的抽象基类。

unwrap(p[, discont, axis, period])

解开周期性信号。

vander(x[, N, increasing])

生成范德蒙德矩阵。

var(a[, axis, dtype, out, ddof, keepdims, ...])

计算给定轴上的方差。

vdot(a, b, *[, precision, ...])

执行两个一维向量的共轭乘积。

vecdot(x1, x2, /, *[, axis, precision, ...])

执行两个批量向量的共轭乘积。

vecmat(x1, x2, /)

批量共轭向量-矩阵乘积。

vectorize(pyfunc, *[, excluded, signature])

定义一个支持广播的向量化函数。

vsplit(ary, indices_or_sections)

垂直分割数组为子数组。

vstack(tup[, dtype])

垂直堆叠数组。

where(condition[, x, y, size, fill_value])

根据条件从两个数组中选择元素。

zeros(shape[, dtype, device, out_sharding])

创建全零数组。

zeros_like(a[, dtype, shape, device, ...])

创建与数组形状和 dtype 相同的全零数组。

jax.numpy.fft#

fft(a[, n, axis, norm])

计算给定轴上的一个一维离散傅里叶变换。

fft2(a[, s, axes, norm])

计算给定轴上的二维离散傅里叶变换。

fftfreq(n[, d, dtype, device])

返回离散傅里叶变换的采样频率。

fftn(a[, s, axes, norm])

计算给定轴上的多维离散傅里叶变换。

fftshift(x[, axes])

将零频率 FFT 分量移动到频谱中心。

hfft(a[, n, axis, norm])

计算具有厄米对称频谱的数组的一维 FFT。

ifft(a[, n, axis, norm])

计算一维离散傅里叶逆变换。

ifft2(a[, s, axes, norm])

计算二维离散傅里叶逆变换。

ifftn(a[, s, axes, norm])

计算多维离散傅里叶逆变换。

ifftshift(x[, axes])

jax.numpy.fft.fftshift() 的逆运算。

ihfft(a[, n, axis, norm])

计算具有厄米对称频谱的数组的一维逆 FFT。

irfft(a[, n, axis, norm])

计算实值的一维离散傅里叶逆变换。

irfft2(a[, s, axes, norm])

计算实值的二维离散傅里叶逆变换。

irfftn(a[, s, axes, norm])

计算实值多维离散傅里叶逆变换。

rfft(a[, n, axis, norm])

计算实值数组的一维离散傅里叶变换。

rfft2(a[, s, axes, norm])

计算实值数组的二维离散傅里叶变换。

rfftfreq(n[, d, dtype, device])

返回离散傅里叶变换的采样频率。

rfftn(a[, s, axes, norm])

计算实值数组的多维离散傅里叶变换。

jax.numpy.linalg#

cholesky(a, *[, upper, symmetrize_input])

计算矩阵的 Cholesky 分解。

cond(x[, p])

计算矩阵的条件数。

cross(x1, x2, /, *[, axis])

计算两个三维向量的叉积

det(a)

计算数组的行列式。

diagonal(x, /, *[, offset])

提取矩阵或矩阵堆的对角线。

eig(a)

计算方阵的特征值和特征向量。

eigh(a[, UPLO, symmetrize_input])

计算厄米矩阵的特征值和特征向量。

eigvals(a)

计算一般矩阵的特征值。

eigvalsh(a[, UPLO, symmetrize_input])

计算厄米矩阵的特征值。

inv(a)

返回方阵的逆。

lstsq(a, b[, rcond, numpy_resid])

返回线性方程的最小二乘解。

matmul(x1, x2, /, *[, precision, ...])

执行矩阵乘法。

matrix_norm(x, /, *[, keepdims, ord])

计算矩阵或矩阵堆的范数。

matrix_power(a, n)

将方阵提升到整数幂。

matrix_rank(M[, rtol, tol])

计算矩阵的秩。

matrix_transpose(x, /)

转置矩阵或矩阵堆。

multi_dot(arrays, *[, precision])

高效计算一系列数组之间的矩阵乘积。

norm(x[, ord, axis, keepdims])

计算矩阵或向量的范数。

outer(x1, x2, /)

计算两个一维数组的外积。

pinv(a[, rtol, hermitian, rcond])

计算矩阵的(摩尔-彭罗斯)伪逆。

qr(a[, mode])

计算数组的 QR 分解

slogdet(a, *[, method])

计算数组的行列式的符号和(自然)对数。

solve(a, b)

求解线性方程组。

svd(a[, full_matrices, compute_uv, ...])

计算奇异值分解。

svdvals(x, /)

计算矩阵的奇异值。

tensordot(x1, x2, /, *[, axes, precision, ...])

计算两个 N 维数组的张量点积。

tensorinv(a[, ind])

计算张量的逆。

tensorsolve(a, b[, axes])

求解张量方程 a x = b 中的 x。

trace(x, /, *[, offset, dtype])

计算矩阵的迹。

vector_norm(x, /, *[, axis, keepdims, ord])

计算向量或向量批的向量范数。

vecdot(x1, x2, /, *[, axis, precision, ...])

计算两个数组的(批量)向量共轭点积。

JAX Array#

JAX 的 Array (及其别名 jax.numpy.ndarray) 是 JAX 中的核心数组对象:您可以将其视为 JAX 中等同于 numpy.ndarray 的对象。与 numpy.ndarray 类似,大多数用户不需要手动实例化 Array 对象,而是通过 jax.numpy 函数(如 array()arange()linspace() 以及上面列出的其他函数)来创建它们。

复制和序列化#

JAX Array 对象旨在在适当的情况下与 Python 标准库工具无缝集成。

使用内置的 copy 模块时,当 copy.copy()copy.deepcopy() 遇到 Array 时,它等同于调用 copy() 方法,该方法将会在与原始数组相同的设备上创建缓冲区副本。这在经过跟踪/JIT 编译的代码中也能正常工作,尽管编译器在此上下文中可能会省略复制操作。

当内置的 pickle 模块遇到 Array 时,它将以类似于被 pickle 的 numpy.ndarray 对象的方式,通过紧凑的位表示来序列化。反序列化时,结果将是一个 *在默认设备上的* 新的 Array 对象。这是因为通常情况下,序列化和反序列化可能在不同的运行时环境中进行,并且无法将一个运行时的设备 ID 映射到另一个运行时的设备 ID。如果 pickle 用于经过跟踪/JIT 编译的代码,将会导致 ConcretizationTypeError

Python Array API 标准#

注意

在 JAX v0.4.32 之前,您必须 import jax.experimental.array_api 才能为 JAX 数组启用 array API。在 JAX v0.4.32 之后,导入此模块不再需要,并且会引发弃用警告。在 JAX v0.5.0 之后,此导入将引发错误。

从 JAX v0.4.32 开始,jax.Arrayjax.numpyPython Array API Standard 兼容。您可以通过 jax.Array.__array_namespace__() 访问 Array API 命名空间。

>>> def f(x):
...   nx = x.__array_namespace__()
...   return nx.sin(x) ** 2 + nx.cos(x) ** 2

>>> import jax.numpy as jnp
>>> x = jnp.arange(5)
>>> f(x).round()
Array([1., 1., 1., 1., 1.], dtype=float32)

JAX 在某些地方偏离了标准,主要是因为 JAX 数组是不可变的,不支持原地更新。其中一些不兼容性正通过 array-api-compat 模块得到解决。

有关更多信息,请参阅 Python Array API Standard 文档。