jax.numpy.round#
- jax.numpy.round(a, decimals=0, out=None)[源代码]#
将输入四舍五入到指定的小数位数。
numpy.round()
的 JAX 实现。- 参数:
a (ArrayLike) – 输入数组或标量。
decimals (int) – int,默认值=0。需要对输入进行四舍五入的小数位数。必须静态指定。当
decimals < 0
时未实现。out (None) – JAX 未使用。
- 返回:
一个包含四舍五入到指定
decimals
的值,并与a
具有相同形状和数据类型的数组。- 返回类型:
注意
对于恰好介于两个舍入后小数位值中间的数,
jnp.round
会四舍五入到最近的偶数。另请参阅
jax.numpy.floor()
: 将输入向下舍入到最近的整数。jax.numpy.ceil()
: 将输入向上舍入到最近的整数。jax.numpy.fix()
and :func:numpy.trunc`: 将输入向零方向舍入到最近的整数。
示例
>>> x = jnp.array([1.532, 3.267, 6.149]) >>> jnp.round(x) Array([2., 3., 6.], dtype=float32) >>> jnp.round(x, decimals=2) Array([1.53, 3.27, 6.15], dtype=float32)
对于恰好介于两个舍入值中间的数
>>> x1 = jnp.array([10.5, 21.5, 12.5, 31.5]) >>> jnp.round(x1) Array([10., 22., 12., 32.], dtype=float32)