jax.numpy.power#

jax.numpy.power(x1, x2, /)[source]#

计算元素级别的基数 x1x2 次方。

JAX 对 numpy.power 的实现。

参数:
  • x1 (ArrayLike) – 标量或数组。指定底数。

  • x2 (ArrayLike) – 标量或数组。指定指数。x1x2 应该具有相同的形状或广播兼容。

返回:

一个数组,包含基数 x1x2 次方,数据类型与输入相同。

返回类型:

数组

注意

  • x2 是一个具体的整数标量时,jnp.power 会降低为 jax.lax.integer_pow()

  • x2 是一个 traced 标量或数组时,jnp.power 会降低为 jax.lax.pow()

  • jnp.power 对于整数类型提升为具体的负整数次方会引发 TypeError。对于非具体次方,该操作无效,并且返回的值是实现定义的。

  • jnp.power 对于负值提升为非整数次方会返回 nan

另请参阅

  • jax.lax.pow(): 计算元素级别的幂,\(x^y\)

  • jax.lax.integer_pow(): 计算元素级别的幂 \(x^y\),其中 \(y\) 是一个固定的整数。

  • jax.numpy.float_power(): 通过提升到非精确数据类型,计算第一个数组的第二个数组次幂,元素级别。

  • jax.numpy.pow(): 计算第一个数组的第二个数组次幂,元素级别。

示例

标量整数输入

>>> jnp.power(4, 3)
Array(64, dtype=int32, weak_type=True)

相同形状的输入

>>> x1 = jnp.array([2, 4, 5])
>>> x2 = jnp.array([3, 0.5, 2])
>>> jnp.power(x1, x2)
Array([ 8.,  2., 25.], dtype=float32)

广播兼容的输入

>>> x3 = jnp.array([-2, 3, 1])
>>> x4 = jnp.array([[4, 1, 6],
...                 [1.3, 3, 5]])
>>> jnp.power(x3, x4)
Array([[16.,  3.,  1.],
       [nan, 27.,  1.]], dtype=float32)