jax.numpy.power#

jax.numpy.power(x1, x2, /)[源代码]#

计算元素级别的基数 x1x2 次方。

JAX 实现的 numpy.power

参数:
  • x1 (ArrayLike) – 标量或数组。指定底数。

  • x2 (ArrayLike) – 标量或数组。指定指数。x1x2 应该具有相同的形状,或者广播兼容。

返回:

一个包含 x1x2 次方的数组,其 dtype 与输入相同。

返回类型:

数组

注意

  • x2 是一个具体的整数标量时,jnp.power 会降级为 jax.lax.integer_pow()

  • x2 是一个跟踪标量或数组时,jnp.power 会降级为 jax.lax.pow()

  • 对于整数类型升到具体的负整数次幂,jnp.power 会引发 TypeError。 对于非具体的幂,该操作无效,并且返回值是实现定义的。

  • 对于负值升到非整数幂,jnp.power 返回 nan

另请参阅

  • jax.lax.pow():计算元素级别的幂,\(x^y\)

  • jax.lax.integer_pow():计算元素级别的幂 \(x^y\),其中 \(y\) 是一个固定的整数。

  • jax.numpy.float_power():计算第一个数组升到第二个数组的幂,按元素进行,通过提升到非精确的 dtype。

  • jax.numpy.pow():计算第一个数组升到第二个数组的幂,按元素进行。

示例

带有标量整数的输入

>>> jnp.power(4, 3)
Array(64, dtype=int32, weak_type=True)

具有相同形状的输入

>>> x1 = jnp.array([2, 4, 5])
>>> x2 = jnp.array([3, 0.5, 2])
>>> jnp.power(x1, x2)
Array([ 8.,  2., 25.], dtype=float32)

支持广播的输入

>>> x3 = jnp.array([-2, 3, 1])
>>> x4 = jnp.array([[4, 1, 6],
...                 [1.3, 3, 5]])
>>> jnp.power(x3, x4)
Array([[16.,  3.,  1.],
       [nan, 27.,  1.]], dtype=float32)