jax.numpy.power#

jax.numpy.power(x1, x2, /)[源代码]#

计算元素级底数 x1x2 次幂。

JAX 对 numpy.power 的实现。

参数:
  • x1 (ArrayLike) – 标量或数组。指定底数。

  • x2 (ArrayLike) – 标量或数组。指定指数。x1x2 应具有相同的形状或兼容广播。

返回:

一个包含底数 x1x2 次幂的数组,具有与输入相同的 dtype。

返回类型:

Array

注意

  • x2 是一个具体的整数标量时,jnp.power 会被降低为 jax.lax.integer_pow()

  • x2 是一个被跟踪的标量或一个数组时,jnp.power 会被降低为 jax.lax.pow()

  • jnp.power 对于整数类型计算具体负整数次幂时会引发 TypeError。对于非具体的幂,操作无效,返回的值由实现定义。

  • jnp.power 对于负数计算非整数次幂时返回 nan

另请参阅

示例

标量整数的输入

>>> jnp.power(4, 3)
Array(64, dtype=int32, weak_type=True)

具有相同形状的输入

>>> x1 = jnp.array([2, 4, 5])
>>> x2 = jnp.array([3, 0.5, 2])
>>> jnp.power(x1, x2)
Array([ 8.,  2., 25.], dtype=float32)

支持广播的输入

>>> x3 = jnp.array([-2, 3, 1])
>>> x4 = jnp.array([[4, 1, 6],
...                 [1.3, 3, 5]])
>>> jnp.power(x3, x4)
Array([[16.,  3.,  1.],
       [nan, 27.,  1.]], dtype=float32)