jax.numpy.fft.fftn#

jax.numpy.fft.fftn(a, s=None, axes=None, norm=None)[源代码]#

在给定轴上计算多维离散傅里叶变换。

JAX 对 numpy.fft.fftn() 的实现。

参数:
  • a (ArrayLike) – 输入数组

  • s (形状 | None) – 整数序列。指定结果的形状。如果未指定,则默认沿指定的 axesa 的形状。

  • axes (序列[int] | None) – 整数序列,默认为 None。指定计算变换的轴。

  • norm (str | None) – 字符串。归一化模式。支持 “backward”、“ortho” 和 “forward”。

返回:

包含 a 的多维离散傅里叶变换的数组。

返回类型:

Array

另请参阅

示例

axes 参数为 None 时,jnp.fft.fftn 默认沿所有轴计算变换。

>>> x = jnp.array([[1, 2, 5, 6],
...                [4, 1, 3, 7],
...                [5, 9, 2, 1]])
>>> with jnp.printoptions(precision=2, suppress=True):
...   jnp.fft.fftn(x)
Array([[ 46.  +0.j  ,   0.  +2.j  ,  -6.  +0.j  ,   0.  -2.j  ],
       [ -2.  +1.73j,   6.12+6.73j,   0.  -1.73j, -18.12-3.27j],
       [ -2.  -1.73j, -18.12+3.27j,   0.  +1.73j,   6.12-6.73j]],      dtype=complex64)

s=[2] 时,沿 axis -1 的变换维度将是 2,而沿其他轴的维度将与输入相同。

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jax.numpy.fft.fftn(x, s=[2]))
[[ 3.+0.j -1.+0.j]
 [ 5.+0.j  3.+0.j]
 [14.+0.j -4.+0.j]]

s=[2]axes=[0] 时,沿 axis 0 的变换维度将是 2,而沿其他轴的维度将与输入相同。

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jax.numpy.fft.fftn(x, s=[2], axes=[0]))
[[ 5.+0.j  3.+0.j  8.+0.j 13.+0.j]
 [-3.+0.j  1.+0.j  2.+0.j -1.+0.j]]

s=[2, 3] 时,变换的形状将是 (2, 3)

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jax.numpy.fft.fftn(x, s=[2, 3]))
[[16. +0.j   -0.5+4.33j -0.5-4.33j]
 [ 0. +0.j   -4.5+0.87j -4.5-0.87j]]

可以使用 jnp.fft.ifftnjnp.fft.fftn 的结果中重构 x

>>> x_fftn = jnp.fft.fftn(x)
>>> jnp.allclose(x, jnp.fft.ifftn(x_fftn))
Array(True, dtype=bool)