jax.numpy.minimum#
- jax.numpy.minimum = <jnp.ufunc 'minimum'>#
返回输入数组的逐元素最小值。
JAX 对
numpy.minimum的实现。- 参数:
x – 输入数组或标量。
y – 输入数组或标量。
x和y应具有相同的形状或可广播兼容。args (ArrayLike)
out (None)
where (None)
- 返回:
包含
x和y元素级最小值的数组。- 返回类型:
任意类型
注意
- 对于每个元素对,
jnp.minimum返回 如果两个元素都是有限数字,则返回较小的一个。
nan如果其中一个元素是nan。
另请参阅
jax.numpy.maximum():返回输入数组的元素级最大值。jax.numpy.fmin():返回输入数组的元素级最小值,忽略 NaN。jax.numpy.amin():返回给定轴上数组元素的最小值。jax.numpy.nanmin():返回给定轴上数组元素的最小值,忽略 NaN。
示例
具有
x.shape == y.shape的输入>>> x = jnp.array([2, 3, 5, 1]) >>> y = jnp.array([-3, 6, -4, 7]) >>> jnp.minimum(x, y) Array([-3, 3, -4, 1], dtype=int32)
具有广播兼容性的输入
>>> x1 = jnp.array([[1, 5, 2], ... [-3, 4, 7]]) >>> y1 = jnp.array([-2, 3, 6]) >>> jnp.minimum(x1, y1) Array([[-2, 3, 2], [-3, 3, 6]], dtype=int32)
具有
nan的输入>>> nan = jnp.nan >>> x2 = jnp.array([[2.5, nan, -2], ... [nan, 5, 6], ... [-4, 3, 7]]) >>> y2 = jnp.array([1, nan, 5]) >>> jnp.minimum(x2, y2) Array([[ 1., nan, -2.], [nan, nan, 5.], [-4., nan, 5.]], dtype=float32)