jax.numpy.ediff1d#

jax.numpy.ediff1d(ary, to_end=None, to_begin=None)[源代码]#

计算扁平化数组中元素的差值。

JAX 中 numpy.ediff1d() 的实现。

参数:
  • ary (ArrayLike) – 输入数组或标量。

  • to_end (ArrayLike | None) – 标量或数组,可选,默认值为 None。指定要追加到结果数组的数字。

  • to_begin (ArrayLike | None) – 标量或数组,可选,默认值为 None。指定要添加到结果数组开头的数字。

返回:

包含输入数组元素之间差值的数组。

返回类型:

数组

注意

与 NumPy 的 ediff1d 实现不同,如果将 to_endto_begin 转换为 ary 的类型时丢失精度,jax.numpy.ediff1d() 不会报错。

另请参阅

示例

>>> a = jnp.array([2, 3, 5, 9, 1, 4])
>>> jnp.ediff1d(a)
Array([ 1,  2,  4, -8,  3], dtype=int32)
>>> jnp.ediff1d(a, to_begin=-10)
Array([-10,   1,   2,   4,  -8,   3], dtype=int32)
>>> jnp.ediff1d(a, to_end=jnp.array([20, 30]))
Array([ 1,  2,  4, -8,  3, 20, 30], dtype=int32)
>>> jnp.ediff1d(a, to_begin=-10, to_end=jnp.array([20, 30]))
Array([-10,   1,   2,   4,  -8,   3,  20,  30], dtype=int32)

对于 ndim > 1 的数组,在展平输入数组后计算差值。

>>> a1 = jnp.array([[2, -1, 4, 7],
...                 [3, 5, -6, 9]])
>>> jnp.ediff1d(a1)
Array([ -3,   5,   3,  -4,   2, -11,  15], dtype=int32)
>>> a2 = jnp.array([2, -1, 4, 7, 3, 5, -6, 9])
>>> jnp.ediff1d(a2)
Array([ -3,   5,   3,  -4,   2, -11,  15], dtype=int32)