jax.numpy.sqrt#
- jax.numpy.sqrt(x, /)[源代码]#
计算输入数组的逐元素非负平方根。
JAX 对
numpy.sqrt的实现。- 参数:
x (ArrayLike) – 输入数组或标量。
- 返回:
一个包含
x元素非负平方根的数组。- 返回类型:
注意
对于实值负输入,
jnp.sqrt会产生一个nan输出。对于复值负输入,
jnp.sqrt会产生一个complex输出。
另请参阅
jax.numpy.square():计算输入值的逐元素平方。jax.numpy.power():计算逐元素的基数x1指数x2。
示例
>>> x = jnp.array([-8-6j, 1j, 4]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.sqrt(x) Array([1. -3.j , 0.707+0.707j, 2. +0.j ], dtype=complex64) >>> jnp.sqrt(-1) Array(nan, dtype=float32, weak_type=True)