jax.numpy.argmax#

jax.numpy.argmax(a, axis=None, out=None, keepdims=None)[源代码]#

返回数组最大值的索引。

JAX 实现 numpy.argmax()

参数:
  • a (ArrayLike) – 输入数组

  • axis (int | None) – 可选的整数,指定查找最大值所在的轴。如果未指定 axis,则 a 将被展平。

  • out (None) – JAX 未使用

  • keepdims (bool | None) – 如果为 True,则返回一个与 a 具有相同维度的数组。

返回:

一个包含指定轴上最大值索引的数组。

返回类型:

Array

另请参阅

注意

当最大值在特定轴上出现多次时,返回最小的索引。

示例

>>> x = jnp.array([1, 3, 5, 4, 2])
>>> jnp.argmax(x)
Array(2, dtype=int32)
>>> x = jnp.array([[1, 3, 2],
...                [5, 4, 1]])
>>> jnp.argmax(x, axis=1)
Array([1, 0], dtype=int32)
>>> jnp.argmax(x, axis=1, keepdims=True)
Array([[1],
       [0]], dtype=int32)