jax.numpy.polyval#
- jax.numpy.polyval(p, x, *, unroll=16)[源代码]#
在特定值处计算多项式。
JAX 实现
numpy.polyval()。对于长度为
M的一维多项式系数p,该函数返回以下值:\[p_0 x^{M - 1} + p_1 x^{M - 2} + ... + p_{M - 1}\]- 参数:
p (ArrayLike) – 形状为
(M,)的多项式系数数组。x (ArrayLike) – 数字或数字数组。
unroll (int) – 用于控制
lax.scan的展开步数的数字。必须静态指定。
- 返回:
与
x具有相同形状的数组。- 返回类型:
注意
unroll参数是 JAX 特有的。它不影响正确性,但可能对评估高阶多项式的性能产生重大影响。该参数控制jnp.polyval实现中lax.scan的展开步数。为了在加速器上提高运行时性能,可以考虑将unroll设置为128(甚至更高),但会增加编译时间。另请参阅
jax.numpy.polyfit(): 最小二乘多项式拟合。jax.numpy.poly(): 查找具有给定根的多项式的系数。jax.numpy.roots(): 计算给定系数的多项式的根。
示例
>>> p = jnp.array([2, 5, 1]) >>> jnp.polyval(p, 3) Array(34., dtype=float32)
如果
x是二维数组,则polyval返回与x形状相同的二维数组。>>> x = jnp.array([[2, 1, 5], ... [3, 4, 7], ... [1, 3, 5]]) >>> jnp.polyval(p, x) Array([[ 19., 8., 76.], [ 34., 53., 134.], [ 8., 34., 76.]], dtype=float32)