jax.numpy.argmax#
- jax.numpy.argmax(a, axis=None, out=None, keepdims=None)[源代码]#
返回数组最大值的索引。
JAX 实现
numpy.argmax()。- 参数:
- 返回:
一个包含指定轴上最大值索引的数组。
- 返回类型:
另请参阅
jax.numpy.argmin(): 返回最小值索引。jax.numpy.nanargmax(): 计算argmax,同时忽略 NaN 值。
注意
当最大值在特定轴上出现多次时,返回最小的索引。
示例
>>> x = jnp.array([1, 3, 5, 4, 2]) >>> jnp.argmax(x) Array(2, dtype=int32)
>>> x = jnp.array([[1, 3, 2], ... [5, 4, 1]]) >>> jnp.argmax(x, axis=1) Array([1, 0], dtype=int32)
>>> jnp.argmax(x, axis=1, keepdims=True) Array([[1], [0]], dtype=int32)