jax.scipy.signal.detrend#

jax.scipy.signal.detrend(data, axis=-1, type='linear', bp=0, overwrite_data=None)[source]#

从数据中移除线性或分段线性趋势。

scipy.signal.detrend() 的 JAX 实现。

参数:
  • data (ArrayLike) – 包含要去除趋势的数据的输入数组。

  • axis (int) – 沿其去除趋势的轴。 默认为 -1 (最后一个轴)。

  • type (str) –

    去除趋势的类型。 可以是

    • 'linear': 为整个数据拟合一个线性趋势。

    • 'constant': 移除数据的平均值。

  • bp (int) – 一系列断点。 如果给定,则在这些断点之间拟合分段线性趋势。

  • overwrite_data (None) – JAX 实现不支持此参数。

返回:

去除趋势的数据数组。

返回类型:

Array

示例

一个维度中的简单趋势消除操作

>>> data = jnp.array([1., 4., 8., 8., 9.])

从数据中删除线性趋势

>>> detrended = jax.scipy.signal.detrend(data)
>>> with jnp.printoptions(precision=3, suppress=True):  # suppress float error
...   print("Detrended:", detrended)
...   print("Underlying trend:", data - detrended)
Detrended: [-1. -0.  2. -0. -1.]
Underlying trend: [ 2.  4.  6.  8. 10.]

从数据中删除常数趋势

>>> detrended = jax.scipy.signal.detrend(data, type='constant')
>>> with jnp.printoptions(precision=3):  # suppress float error
...   print("Detrended:", detrended)
...   print("Underlying trend:", data - detrended)
Detrended: [-5. -2.  2.  2.  3.]
Underlying trend: [6. 6. 6. 6. 6.]