jax.numpy.roots#

jax.numpy.roots(p, *, strip_zeros=True)[源代码]#

返回给定系数 p 的多项式的根。

numpy.roots()的 JAX 实现。

参数:
  • p (ArrayLike) – 秩为 1 的多项式系数数组。

  • strip_zeros (bool) – bool, 默认值=True。如果为 True,则将去除系数中的前导零,类似于numpy.roots()。如果设置为 False,则不会去除前导零,并且未定义的根将在函数输出中表示为 NaN 值。strip_zeros必须设置为False,该函数才能与jax.jit()和其他 JAX 转换兼容。

返回:

一个包含多项式根的数组。

返回类型:

数组

注意

与此函数的np.roots不同,jnp.roots返回复数数组中的根,而与根的值无关。

另请参阅

示例

>>> coeffs = jnp.array([0, 1, 2])

默认行为与 numpy 匹配并去除前导零

>>> jnp.roots(coeffs)
Array([-2.+0.j], dtype=complex64)

使用strip_zeros=False,额外的根设置为 NaN

>>> jnp.roots(coeffs, strip_zeros=False)
Array([-2. +0.j, nan+nanj], dtype=complex64)