jax.numpy.extract#
- jax.numpy.extract(condition, arr, *, size=None, fill_value=0)[源码]#
返回满足条件的数组元素。
JAX 对
numpy.extract()的实现。- 参数:
- 返回:
提取的条目的 1D 数组。如果指定了
size,则结果的形状为(size,),并用fill_value进行右填充。如果未指定size,则输出形状将取决于condition中 True 条目的数量。- 返回类型:
注意事项
此函数不需要
condition和arr之间严格的形状匹配。如果condition.size > arr.size,则condition将被截断;如果arr.size > condition.size,则arr将被截断。另请参阅
jax.numpy.compress():extract的多维版本。示例
从 1D 数组中提取值
>>> x = jnp.array([1, 2, 3, 4, 5, 6]) >>> mask = (x % 2 == 0) >>> jnp.extract(mask, x) Array([2, 4, 6], dtype=int32)
在最简单的情况下,这等同于布尔索引
>>> x[mask] Array([2, 4, 6], dtype=int32)
为了与 JAX 转换一起使用,您可以传递
size参数来为输出指定静态形状,以及一个可选的fill_value,其默认值为零>>> jnp.extract(mask, x, size=len(x), fill_value=0) Array([2, 4, 6, 0, 0, 0], dtype=int32)
请注意,与布尔索引不同,
extract不需要数组和条件的大小严格匹配,并且将有效地将两者截断为最小大小>>> short_mask = jnp.array([False, True]) >>> jnp.extract(short_mask, x) Array([2], dtype=int32) >>> long_mask = jnp.array([True, False, True, False, False, False, False, False]) >>> jnp.extract(long_mask, x) Array([1, 3], dtype=int32)