jax.lax.round#

jax.lax.round(x, rounding_method=RoundingMethod.AWAY_FROM_ZERO)[源代码]#

逐元素取整。

将值四舍五入到最接近的整数。此函数直接降低到 stablehlo.round 操作。

参数:
  • x (ArrayLike) – 要取整的数组或标量值。必须具有浮点类型。

  • rounding_method (RoundingMethod) – 用于对半分值(例如,0.5)进行取整的方法。有关可能的值,请参见 jax.lax.RoundingMethod

返回值:

x 形状和数据类型相同的数组,包含 x 的逐元素取整结果。

返回类型:

Array

另请参见

示例

>>> import jax.numpy as jnp
>>> from jax import lax
>>> x = jnp.array([-1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5])
>>> jax.lax.round(x)  # defaults method is AWAY_FROM_ZERO
Array([-2., -1., -1.,  0.,  1.,  1.,  2.], dtype=float32)
>>> jax.lax.round(x, rounding_method=jax.lax.RoundingMethod.TO_NEAREST_EVEN)
Array([-2., -1., -0.,  0.,  0.,  1.,  2.], dtype=float32)