jax.lax.round#
- jax.lax.round(x, rounding_method=RoundingMethod.AWAY_FROM_ZERO)[源代码]#
逐元素取整。
将值四舍五入到最接近的整数。此函数直接降低到 stablehlo.round 操作。
- 参数:
x (ArrayLike) – 要取整的数组或标量值。必须具有浮点类型。
rounding_method (RoundingMethod) – 用于对半分值(例如,
0.5
)进行取整的方法。有关可能的值,请参见jax.lax.RoundingMethod
。
- 返回值:
与
x
形状和数据类型相同的数组,包含x
的逐元素取整结果。- 返回类型:
另请参见
jax.lax.floor()
:向负无穷方向舍入到下一个整数jax.lax.ceil()
:向正无穷方向舍入到下一个整数
示例
>>> import jax.numpy as jnp >>> from jax import lax >>> x = jnp.array([-1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5]) >>> jax.lax.round(x) # defaults method is AWAY_FROM_ZERO Array([-2., -1., -1., 0., 1., 1., 2.], dtype=float32) >>> jax.lax.round(x, rounding_method=jax.lax.RoundingMethod.TO_NEAREST_EVEN) Array([-2., -1., -0., 0., 0., 1., 2.], dtype=float32)