jax.numpy.diff#
- jax.numpy.diff(a, n=1, axis=-1, prepend=None, append=None)[源代码]#
计算给定轴上数组元素之间的 n 阶差值。
JAX 对
numpy.diff()的实现。一阶差值计算方式为
a[i+1] - a[i],n 阶差值则通过递归计算n次得到。- 参数:
- 返回:
包含
a元素之间 n 阶差值的数组。- 返回类型:
另请参阅
jax.numpy.ediff1d():计算数组相邻元素之间的差值。jax.numpy.cumsum():计算数组沿给定轴的累积和。jax.numpy.gradient():计算 N 维数组的梯度。
示例
jnp.diff默认沿axis计算一阶差值。>>> a = jnp.array([[1, 5, 2, 9], ... [3, 8, 7, 4]]) >>> jnp.diff(a) Array([[ 4, -3, 7], [ 5, -1, -3]], dtype=int32)
当
n = 2时,沿axis计算二阶差值。>>> jnp.diff(a, n=2) Array([[-7, 10], [-6, -2]], dtype=int32)
当
prepend = 2时,在计算差值之前,将 2 沿axis插入到a的前面。>>> jnp.diff(a, prepend=2) Array([[-1, 4, -3, 7], [ 1, 5, -1, -3]], dtype=int32)
当
append = jnp.array([[3],[1]])时,在计算差值之前,将 `jnp.array([[3],[1]])` 沿axis追加到a的后面。>>> jnp.diff(a, append=jnp.array([[3],[1]])) Array([[ 4, -3, 7, -6], [ 5, -1, -3, -3]], dtype=int32)