jax.numpy.round#

jax.numpy.round(a, decimals=0, out=None)[源代码]#

将输入四舍五入到指定的小数位数。

numpy.round() 的 JAX 实现。

参数:
  • a (ArrayLike) – 输入数组或标量。

  • decimals (int) – int,默认值=0。需要对输入进行四舍五入的小数位数。必须静态指定。当 decimals < 0 时未实现。

  • out (None) – JAX 未使用。

返回:

一个包含四舍五入到指定 decimals 的值,并与 a 具有相同形状和数据类型的数组。

返回类型:

数组

注意

对于恰好介于两个舍入后小数位值中间的数,jnp.round 会四舍五入到最近的偶数。

另请参阅

示例

>>> x = jnp.array([1.532, 3.267, 6.149])
>>> jnp.round(x)
Array([2., 3., 6.], dtype=float32)
>>> jnp.round(x, decimals=2)
Array([1.53, 3.27, 6.15], dtype=float32)

对于恰好介于两个舍入值中间的数

>>> x1 = jnp.array([10.5, 21.5, 12.5, 31.5])
>>> jnp.round(x1)
Array([10., 22., 12., 32.], dtype=float32)