jax.numpy.convolve#

jax.numpy.convolve(a, v, mode='full', *, precision=None, preferred_element_type=None)[来源]#

两个一维数组的卷积。

numpy.convolve() 的 JAX 实现。

一维数组的卷积定义如下:

\[c_k = \sum_j a_{k - j} v_j\]
参数:
  • a (类数组) – 卷积的左侧输入。必须满足 a.ndim == 1

  • v (类数组) – 卷积的右侧输入。必须满足 v.ndim == 1

  • mode (str) –

    控制输出的大小。可用操作包括

    • "full": (默认)输出输入的完整卷积。

    • "same":返回 "full" 输出的中心部分,其大小与 a 相同。

    • "valid": 返回 "full" 输出中不依赖于数组边缘填充的部分。

  • precision (lax.PrecisionLike) – 指定计算精度。有关可用值的描述,请参阅 jax.lax.Precision

  • preferred_element_type (类数据类型 | None) – 一个数据类型,指示将结果累积到该数据类型并返回该数据类型的结果。默认值为 None,表示使用输入类型的默认累积类型。

返回:

包含卷积结果的数组。

返回类型:

数组

另请参阅

示例

几个一维卷积示例

>>> x = jnp.array([1, 2, 3, 2, 1])
>>> y = jnp.array([4, 1, 2])

jax.numpy.convolve 默认返回完全卷积,并在边缘使用隐式零填充。

>>> jnp.convolve(x, y)
Array([ 4.,  9., 16., 15., 12.,  5.,  2.], dtype=float32)

指定 mode = 'same' 返回与第一个输入大小相同的中心卷积。

>>> jnp.convolve(x, y, mode='same')
Array([ 9., 16., 15., 12.,  5.], dtype=float32)

指定 mode = 'valid' 仅返回两个数组完全重叠的部分。

>>> jnp.convolve(x, y, mode='valid')
Array([16., 15., 12.], dtype=float32)

对于复数值输入

>>> x1 = jnp.array([3+1j, 2, 4-3j])
>>> y1 = jnp.array([1, 2-3j, 4+5j])
>>> jnp.convolve(x1, y1)
Array([ 3. +1.j, 11. -7.j, 15.+10.j,  7. -8.j, 31. +8.j], dtype=complex64)