jax.numpy.sinc#
- jax.numpy.sinc(x, /)[源代码]#
计算归一化的 sinc 函数。
JAX 对
numpy.sinc()的实现。归一化的 sinc 函数定义为
\[\mathrm{sinc}(x) = \frac{\sin({\pi x})}{\pi x}\]其中
sinc(0)返回其极限值1。sinc 函数是光滑且无限可微的。- 参数:
x (ArrayLike) – 输入数组;将被提升为非精确类型。
- 返回:
一个与
x形状相同的数组,包含结果。- 返回类型:
示例
>>> x = jnp.array([-1, -0.5, 0, 0.5, 1]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.sinc(x) Array([-0. , 0.637, 1. , 0.637, -0. ], dtype=float32)
与计算该函数(在零点未定义)的朴素方法进行比较。
>>> with jnp.printoptions(precision=3, suppress=True): ... jnp.sin(jnp.pi * x) / (jnp.pi * x) Array([-0. , 0.637, nan, 0.637, -0. ], dtype=float32)
JAX 为 sinc 定义了自定义的梯度规则,以便即使对于高阶导数也能在零点进行精确的梯度求值。
>>> f = jnp.sinc >>> for i in range(1, 6): ... f = jax.grad(f) ... print(f"(d/dx)^{i} f(0.0) = {f(0.0):.2f}") ... (d/dx)^1 f(0.0) = 0.00 (d/dx)^2 f(0.0) = -3.29 (d/dx)^3 f(0.0) = 0.00 (d/dx)^4 f(0.0) = 19.48 (d/dx)^5 f(0.0) = 0.00