jax.numpy.diff#

jax.numpy.diff(a, n=1, axis=-1, prepend=None, append=None)[源代码]#

计算给定轴上数组元素之间的 n 阶差值。

JAX 对 numpy.diff() 的实现。

一阶差值计算方式为 a[i+1] - a[i],n 阶差值则通过递归计算 n 次得到。

参数:
  • a (ArrayLike) – 输入数组。必须满足 a.ndim >= 1

  • n (int) – int,可选,默认为 1。差值阶数。指定计算差值的次数。如果 n=0,则不计算差值,直接返回输入。

  • axis (int) – int,可选,默认为 -1。指定计算差值的轴。默认沿 axis -1 计算差值。

  • prepend (ArrayLike | None) – 标量或数组,可选,默认为 None。指定在计算差值之前沿 axis 插入的值。

  • append (ArrayLike | None) – 标量或数组,可选,默认为 None。指定在计算差值之前沿 axis 追加的值。

返回:

包含 a 元素之间 n 阶差值的数组。

返回类型:

Array

另请参阅

示例

jnp.diff 默认沿 axis 计算一阶差值。

>>> a = jnp.array([[1, 5, 2, 9],
...                [3, 8, 7, 4]])
>>> jnp.diff(a)
Array([[ 4, -3,  7],
       [ 5, -1, -3]], dtype=int32)

n = 2 时,沿 axis 计算二阶差值。

>>> jnp.diff(a, n=2)
Array([[-7, 10],
       [-6, -2]], dtype=int32)

prepend = 2 时,在计算差值之前,将 2 沿 axis 插入到 a 的前面。

>>> jnp.diff(a, prepend=2)
Array([[-1,  4, -3,  7],
       [ 1,  5, -1, -3]], dtype=int32)

append = jnp.array([[3],[1]]) 时,在计算差值之前,将 `jnp.array([[3],[1]])` 沿 axis 追加到 a 的后面。

>>> jnp.diff(a, append=jnp.array([[3],[1]]))
Array([[ 4, -3,  7, -6],
       [ 5, -1, -3, -3]], dtype=int32)