jax.numpy.power#
- jax.numpy.power(x1, x2, /)[源代码]#
计算元素级底数
x1的x2次幂。JAX 对
numpy.power的实现。- 参数:
x1 (ArrayLike) – 标量或数组。指定底数。
x2 (ArrayLike) – 标量或数组。指定指数。
x1和x2应具有相同的形状或兼容广播。
- 返回:
一个包含底数
x1的x2次幂的数组,具有与输入相同的 dtype。- 返回类型:
注意
当
x2是一个具体的整数标量时,jnp.power会被降低为jax.lax.integer_pow()。当
x2是一个被跟踪的标量或一个数组时,jnp.power会被降低为jax.lax.pow()。jnp.power对于整数类型计算具体负整数次幂时会引发TypeError。对于非具体的幂,操作无效,返回的值由实现定义。jnp.power对于负数计算非整数次幂时返回nan。
另请参阅
jax.lax.pow(): 计算元素级幂 \(x^y\)。jax.lax.integer_pow(): 计算元素级幂 \(x^y\),其中 \(y\) 是一个固定的整数。jax.numpy.float_power(): 通过提升到非精确 dtype 来计算第一个数组的第二个数组次幂,元素级。jax.numpy.pow(): 计算第一个数组的第二个数组次幂,元素级。
示例
标量整数的输入
>>> jnp.power(4, 3) Array(64, dtype=int32, weak_type=True)
具有相同形状的输入
>>> x1 = jnp.array([2, 4, 5]) >>> x2 = jnp.array([3, 0.5, 2]) >>> jnp.power(x1, x2) Array([ 8., 2., 25.], dtype=float32)
支持广播的输入
>>> x3 = jnp.array([-2, 3, 1]) >>> x4 = jnp.array([[4, 1, 6], ... [1.3, 3, 5]]) >>> jnp.power(x3, x4) Array([[16., 3., 1.], [nan, 27., 1.]], dtype=float32)