jax.numpy.square#
- jax.numpy.square(x, /)[来源]#
计算输入数组的逐元素平方。
JAX 对
numpy.square的实现。- 参数:
x (ArrayLike) – 输入数组或标量。
- 返回:
包含
x元素平方的数组。- 返回类型:
注意
jnp.square等同于计算jnp.power(x, 2)。另请参阅
jax.numpy.sqrt(): 计算输入数组的逐元素非负平方根。jax.numpy.power(): 计算逐元素的基数x1指数x2。jax.lax.integer_pow(): 计算逐元素幂 \(x^y\),其中 \(y\) 是一个固定整数。jax.numpy.float_power(): 通过提升到非精确 dtype 来计算第一个数组的第二个数组的逐元素幂。
示例
>>> x = jnp.array([3, -2, 5.3, 1]) >>> jnp.square(x) Array([ 9. , 4. , 28.090002, 1. ], dtype=float32) >>> jnp.power(x, 2) Array([ 9. , 4. , 28.090002, 1. ], dtype=float32)
对于整数输入
>>> x1 = jnp.array([2, 4, 5, 6]) >>> jnp.square(x1) Array([ 4, 16, 25, 36], dtype=int32)
对于复数值输入
>>> x2 = jnp.array([1-3j, -1j, 2]) >>> jnp.square(x2) Array([-8.-6.j, -1.+0.j, 4.+0.j], dtype=complex64)