jax.numpy.polymul#
- jax.numpy.polymul(a1, a2, *, trim_leading_zeros=False)[源代码]#
返回两个多项式的乘积。
JAX 对 (NumPy v2.3) 中
numpy.polymul()的实现。- 参数:
a1 (ArrayLike) – 多项式系数的一维数组。
a2 (ArrayLike) – 多项式系数的一维数组。
trim_leading_zeros ((Python v3.14)bool) – 默认为
False。如果为True,则会移除返回值中的前导零,以匹配 numpy 的结果。但这会阻止函数在编译的代码中使用。由于浮点算术误差累积的差异,被视为零的值的截止点可能导致 NumPy 和 JAX 之间,甚至不同 JAX 后端之间产生不一致的结果。trim_leading_zeros=True时,结果可能导致输出形状不一致。
- 返回:
包含两个多项式乘积系数的数组。输出的 dtype 始终提升为不精确(inexact)。
- 返回类型:
注意
jax.numpy.polymul()与接受标量输入(也接受)的 (NumPy v2.3) 中numpy.polymul()不同,它只接受数组作为输入。另请参阅
jax.numpy.polyadd(): 计算两个多项式的和。jax.numpy.polysub(): 计算两个多项式的差。jax.numpy.polydiv(): 计算多项式除法的商和余数。
示例
>>> x1 = np.array([2, 1, 0]) >>> x2 = np.array([0, 5, 0, 3]) >>> np.polymul(x1, x2) array([10, 5, 6, 3, 0]) >>> jnp.polymul(x1, x2) Array([ 0., 10., 5., 6., 3., 0.], dtype=float32)
如果
trim_leading_zeros=True,则结果与np.polymul的结果匹配。>>> jnp.polymul(x1, x2, trim_leading_zeros=True) Array([10., 5., 6., 3., 0.], dtype=float32)
对于 dtype 为
complex的输入数组>>> x3 = np.array([2., 1+2j, 1-2j]) >>> x4 = np.array([0, 5, 0, 3]) >>> np.polymul(x3, x4) array([10. +0.j, 5.+10.j, 11.-10.j, 3. +6.j, 3. -6.j]) >>> jnp.polymul(x3, x4) Array([ 0. +0.j, 10. +0.j, 5.+10.j, 11.-10.j, 3. +6.j, 3. -6.j], dtype=complex64) >>> jnp.polymul(x3, x4, trim_leading_zeros=True) Array([10. +0.j, 5.+10.j, 11.-10.j, 3. +6.j, 3. -6.j], dtype=complex64)