jax.numpy.argwhere#
-
jax.numpy.argwhere(a*size=Nonefill_value=None[source]#
查找非零数组元素的索引
JAX 对 numpy.argwhere()
的实现。
jnp.argwhere(x)
本质上等同于 jnp.column_stack(jnp.nonzero(x))
,并对零维度(即标量)输入进行特殊处理。
由于 argwhere
的输出大小依赖于数据,该函数通常与 JIT 不兼容。JAX 版本添加了可选的 size
参数,该参数指定了输出主维度的大小——为了使 jnp.argwhere
能够用非静态操作数进行编译,必须静态指定此参数。有关 size
及其语义的完整讨论,请参阅 jax.numpy.nonzero()
。
- 参数:
a (ArrayLike) – 要查找非零元素的数组
size (int | None) – 一个可选的整数,用于静态指定预期的非零元素数量。必须指定此参数,以便在诸如 jax.jit()
的 JAX 变换中使用 argwhere
。有关更多信息,请参阅 jax.numpy.nonzero()
。
fill_value (ArrayLike | None) – 一个可选数组,指定当 size
指定时的填充值。有关更多信息,请参阅 jax.numpy.nonzero()
。
- 返回:
一个形状为 [size, x.ndim]
的二维数组。如果未将 size
指定为参数,则其等于 x
中非零元素的数量。
- 返回类型:
-
示例
二维数组
>>> x = jnp.array([[1, 0, 2],
... [0, 3, 0]])
>>> jnp.argwhere(x)
Array([[0, 0],
[0, 2],
[1, 1]], dtype=int32)
使用 jax.numpy.column_stack()
和 jax.numpy.nonzero()
进行的等效计算
>>> jnp.column_stack(jnp.nonzero(x))
Array([[0, 0],
[0, 2],
[1, 1]], dtype=int32)
零维度(即标量)输入的特殊情况
>>> jnp.argwhere(1)
Array([], shape=(1, 0), dtype=int32)
>>> jnp.argwhere(0)
Array([], shape=(0, 0), dtype=int32)
查找非零数组元素的索引
JAX 对 numpy.argwhere()
的实现。
jnp.argwhere(x)
本质上等同于 jnp.column_stack(jnp.nonzero(x))
,并对零维度(即标量)输入进行特殊处理。
由于 argwhere
的输出大小依赖于数据,该函数通常与 JIT 不兼容。JAX 版本添加了可选的 size
参数,该参数指定了输出主维度的大小——为了使 jnp.argwhere
能够用非静态操作数进行编译,必须静态指定此参数。有关 size
及其语义的完整讨论,请参阅 jax.numpy.nonzero()
。
- 参数:
a (ArrayLike) – 要查找非零元素的数组
size (int | None) – 一个可选的整数,用于静态指定预期的非零元素数量。必须指定此参数,以便在诸如
jax.jit()
的 JAX 变换中使用argwhere
。有关更多信息,请参阅jax.numpy.nonzero()
。fill_value (ArrayLike | None) – 一个可选数组,指定当
size
指定时的填充值。有关更多信息,请参阅jax.numpy.nonzero()
。
- 返回:
一个形状为
[size, x.ndim]
的二维数组。如果未将size
指定为参数,则其等于x
中非零元素的数量。- 返回类型:
示例
二维数组
>>> x = jnp.array([[1, 0, 2],
... [0, 3, 0]])
>>> jnp.argwhere(x)
Array([[0, 0],
[0, 2],
[1, 1]], dtype=int32)
使用 jax.numpy.column_stack()
和 jax.numpy.nonzero()
进行的等效计算
>>> jnp.column_stack(jnp.nonzero(x))
Array([[0, 0],
[0, 2],
[1, 1]], dtype=int32)
零维度(即标量)输入的特殊情况
>>> jnp.argwhere(1)
Array([], shape=(1, 0), dtype=int32)
>>> jnp.argwhere(0)
Array([], shape=(0, 0), dtype=int32)