jax.numpy.diff#

jax.numpy.diff(a, n=1, axis=-1, prepend=None, append=None)[源代码]#

计算沿给定轴的数组元素之间的 n 阶差分。

numpy.diff() 的 JAX 实现。

一阶差分通过 a[i+1] - a[i] 计算,而 n 阶差分递归计算 n 次。

参数:
  • a (ArrayLike) – 输入数组。必须有 a.ndim >= 1

  • n (int) – int,可选,默认值=1。差分的阶数。指定计算差分的次数。如果 n=0,则不计算差分,并按原样返回输入。

  • axis (int) – int,可选,默认值=-1。指定计算差分的轴。默认沿 axis -1 计算差分。

  • prepend (ArrayLike | None) – 标量或数组,可选,默认值=None。指定在计算差分之前沿 axis 预先添加的值。

  • append (ArrayLike | None) – 标量或数组,可选,默认值=None。指定在计算差分之前沿 axis 追加的值。

返回:

一个包含 a 的元素之间 n 阶差分的数组。

返回类型:

数组

另请参阅

示例

jnp.diff 默认沿 axis 计算一阶差分。

>>> a = jnp.array([[1, 5, 2, 9],
...                [3, 8, 7, 4]])
>>> jnp.diff(a)
Array([[ 4, -3,  7],
       [ 5, -1, -3]], dtype=int32)

n = 2 时,沿 axis 计算二阶差分。

>>> jnp.diff(a, n=2)
Array([[-7, 10],
       [-6, -2]], dtype=int32)

prepend = 2 时,在计算差分之前,它会沿 axis 预先添加到 a

>>> jnp.diff(a, prepend=2)
Array([[-1,  4, -3,  7],
       [ 1,  5, -1, -3]], dtype=int32)

append = jnp.array([[3],[1]]) 时,在计算差分之前,它会沿 axis 追加到 a

>>> jnp.diff(a, append=jnp.array([[3],[1]]))
Array([[ 4, -3,  7, -6],
       [ 5, -1, -3, -3]], dtype=int32)