jax.numpy.nanargmax#
- jax.numpy.nanargmax(a, axis=None, out=None, keepdims=None)[源]#
返回数组中最大值的索引,忽略非数字 (NaN) 值。
NumPy 中
numpy.nanargmax()
的 JAX 实现。- 参数:
- 返回:
一个数组,包含指定轴上最大值的索引。
- 返回类型:
注意
如果一个轴上的值全部为 NaN,则返回的索引将是 -1。这与
numpy.nanargmax()
的行为不同,后者会引发错误。另请参阅
jax.numpy.argmax()
: 返回最大值的索引。jax.numpy.nanargmin()
: 计算argmin
并忽略 NaN 值。
示例
>>> x = jnp.array([1, 3, 5, 4, jnp.nan])
使用标准的
argmax()
可能会导致意外结果>>> jnp.argmax(x) Array(4, dtype=int32)
使用
nanargmax
返回最大非 NaN 值的索引。>>> jnp.nanargmax(x) Array(2, dtype=int32)
>>> x = jnp.array([[1, 3, jnp.nan], ... [5, 4, jnp.nan]]) >>> jnp.nanargmax(x, axis=1) Array([1, 0], dtype=int32)
>>> jnp.nanargmax(x, axis=1, keepdims=True) Array([[1], [0]], dtype=int32)