jax.scipy.fft.dctn#
- jax.scipy.fft.dctn(x, type=2, s=None, axes=None, norm=None)[源代码]#
计算输入的 N 维离散余弦变换。
JAX 对
scipy.fft.dctn()的实现。- 参数:
- 返回:
包含 x 的离散余弦变换的数组。
- 返回类型:
另请参阅
jax.scipy.fft.dct(): 一维 DCT。jax.scipy.fft.idct(): 一维逆 DCT。jax.scipy.fft.idctn(): 多维逆 DCT
示例
当
axes参数为None时,jax.scipy.fft.dctn默认沿所有轴计算变换。>>> x = jax.random.normal(jax.random.key(0), (3, 3)) >>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.dctn(x)) [[ 12.01 6.2 -10.17] [ 8.84 9.65 -3.54] [ 11.25 -1.54 -0.88]]
当
s=[2]时,沿axis 0的变换维度将为2,沿axis 1的维度将与输入相同。>>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.dctn(x, s=[2])) [[ 9.36 10.22 -8.53] [11.57 2.85 -2.06]]
当
s=[2]且axes=[1]时,沿axis 1的变换维度将为2,沿axis 0的维度将与输入相同。同样,当axes=[1]时,变换将仅沿axis 1计算。>>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.dctn(x, s=[2], axes=[1])) [[ 7.3 -0.57] [ 0.19 -0.36] [-0. -1.4 ]]
当
s=[2, 4]时,变换的形状将为(2, 4)。>>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.dctn(x, s=[2, 4])) [[ 9.36 11.23 2.12 -10.97] [ 11.57 5.86 -1.37 -1.58]]