jax.numpy.argwhere#
jax.numpy.argwhere(a*size=Nonefill_value=None[source]#

查找非零数组元素的索引

JAX 对 numpy.argwhere() 的实现。

jnp.argwhere(x) 本质上等同于 jnp.column_stack(jnp.nonzero(x)),并对零维度(即标量)输入进行特殊处理。

由于 argwhere 的输出大小依赖于数据,该函数通常与 JIT 不兼容。JAX 版本添加了可选的 size 参数,该参数指定了输出主维度的大小——为了使 jnp.argwhere 能够用非静态操作数进行编译,必须静态指定此参数。有关 size 及其语义的完整讨论,请参阅 jax.numpy.nonzero()

参数:
返回:

一个形状为 [size, x.ndim] 的二维数组。如果未将 size 指定为参数,则其等于 x 中非零元素的数量。

返回类型:

数组

示例

二维数组

>>> x = jnp.array([[1, 0, 2],
...                [0, 3, 0]])
>>> jnp.argwhere(x)
Array([[0, 0],
       [0, 2],
       [1, 1]], dtype=int32)

使用 jax.numpy.column_stack()jax.numpy.nonzero() 进行的等效计算

>>> jnp.column_stack(jnp.nonzero(x))
Array([[0, 0],
       [0, 2],
       [1, 1]], dtype=int32)

零维度(即标量)输入的特殊情况

>>> jnp.argwhere(1)
Array([], shape=(1, 0), dtype=int32)
>>> jnp.argwhere(0)
Array([], shape=(0, 0), dtype=int32)