jax.numpy.argwhere#
- jax.numpy.argwhere(a, *, size=None, fill_value=None)[源代码]#
查找非零数组元素的索引
JAX 对
numpy.argwhere()的实现。jnp.argwhere(x)基本等同于jnp.column_stack(jnp.nonzero(x)),并对零维(即标量)输入进行特殊处理。由于
argwhere的输出大小取决于数据,因此该函数通常不兼容 JIT。JAX 版本增加了可选的size参数,该参数指定输出的第一个维度的大小 - 必须静态指定才能用非静态操作数编译jnp.argwhere。有关size及其语义的完整讨论,请参阅jax.numpy.nonzero()。- 参数:
a (ArrayLike) – 要查找非零元素的数组
size (int | None) – 可选整数,用于静态指定预期的非零元素数量。为了在
jax.jit()等 JAX 变换中使用argwhere,必须指定此参数。有关更多信息,请参阅jax.numpy.nonzero()。fill_value (ArrayLike | None) – 当指定
size时,用于指定填充值的可选数组。有关更多信息,请参阅jax.numpy.nonzero()。
- 返回:
形状为
[size, x.ndim]的二维数组。如果size未作为参数指定,则它等于x中的非零元素数量。- 返回类型:
示例
二维数组
>>> x = jnp.array([[1, 0, 2], ... [0, 3, 0]]) >>> jnp.argwhere(x) Array([[0, 0], [0, 2], [1, 1]], dtype=int32)
使用
jax.numpy.column_stack()和jax.numpy.nonzero()实现的等效计算>>> jnp.column_stack(jnp.nonzero(x)) Array([[0, 0], [0, 2], [1, 1]], dtype=int32)
零维(即标量)输入的特殊情况
>>> jnp.argwhere(1) Array([], shape=(1, 0), dtype=int32) >>> jnp.argwhere(0) Array([], shape=(0, 0), dtype=int32)