jax.numpy.nanargmax#

jax.numpy.nanargmax(a, axis=None, out=None, keepdims=None)[源]#

返回数组中最大值的索引,忽略非数字 (NaN) 值。

NumPy 中 numpy.nanargmax() 的 JAX 实现。

参数:
  • a (ArrayLike) – 输入数组

  • axis (int | None) – 可选整数,指定查找最大值的轴。如果未指定 axis,则 a 将被展平。

  • out (None) – JAX 未使用

  • keepdims (bool | None) – 如果为 True,则返回一个与 a 维度数量相同的数组。

返回:

一个数组,包含指定轴上最大值的索引。

返回类型:

Array

注意

如果一个轴上的值全部为 NaN,则返回的索引将是 -1。这与 numpy.nanargmax() 的行为不同,后者会引发错误。

另请参阅

示例

>>> x = jnp.array([1, 3, 5, 4, jnp.nan])

使用标准的 argmax() 可能会导致意外结果

>>> jnp.argmax(x)
Array(4, dtype=int32)

使用 nanargmax 返回最大非 NaN 值的索引。

>>> jnp.nanargmax(x)
Array(2, dtype=int32)
>>> x = jnp.array([[1, 3, jnp.nan],
...                [5, 4, jnp.nan]])
>>> jnp.nanargmax(x, axis=1)
Array([1, 0], dtype=int32)
>>> jnp.nanargmax(x, axis=1, keepdims=True)
Array([[1],
       [0]], dtype=int32)