jax.numpy.intersect1d#
- jax.numpy.intersect1d(ar1, ar2, assume_unique=False, return_indices=False, *, size=None, fill_value=None)[源代码]#
计算两个一维数组的集合交集。
JAX 对
numpy.intersect1d()的实现。由于
intersect1d的输出大小依赖于数据,因此该函数通常不兼容jit()以及其他 JAX 变换。JAX 版本增加了可选的size参数,该参数必须静态指定,才能在这些上下文中使用jnp.intersect1d。- 参数:
ar1 (ArrayLike) – 要计算交集的第一个数组。
ar2 (ArrayLike) – 要计算交集的第二个数组。
assume_unique (bool) – 如果为 True,则假定输入数组包含唯一值。这允许更高效的实现,但如果
assume_unique为 True 且输入数组包含重复值,则行为未定义。默认值:False。return_indices (bool) – 如果为 True,则返回索引数组,这些索引指定了相交值在输入数组中首次出现的位置。
size (int | None) – 如果指定,则只返回前
size个排序后的元素。如果元素少于size指定的数量,则返回值将用fill_value填充,返回的索引将用越界索引填充。fill_value (ArrayLike | None) – 当指定
size且元素数量少于指定数量时,用fill_value填充剩余的条目。默认为交集中最小的值。
- 返回:
一个数组
intersection,或者如果return_indices=True,则是一个元组数组(intersection, ar1_indices, ar2_indices)。返回的值是intersection: 一个一维数组,包含出现在ar1和ar2中的每个值。ar1_indices: (如果 return_indices=True 则返回) 一个形状为intersection.shape的数组,包含intersection中值的展平ar1索引。对于一维输入,intersection等同于ar1[ar1_indices]。ar2_indices: (如果 return_indices=True 则返回) 一个形状为intersection.shape的数组,包含intersection中值的展平ar2索引。对于一维输入,intersection等同于ar2[ar2_indices]。
- 返回类型:
另请参阅
jax.numpy.union1d():两个一维数组的集合并集。jax.numpy.setxor1d():两个一维数组的集合异或。jax.numpy.setdiff1d(): 两个一维数组的集合差集。
示例
>>> ar1 = jnp.array([1, 2, 3, 4]) >>> ar2 = jnp.array([3, 4, 5, 6]) >>> jnp.intersect1d(ar1, ar2) Array([3, 4], dtype=int32)
计算带索引的交集
>>> intersection, ar1_indices, ar2_indices = jnp.intersect1d(ar1, ar2, return_indices=True) >>> intersection Array([3, 4], dtype=int32)
ar1_indices提供交集值在ar1中的索引>>> ar1_indices Array([2, 3], dtype=int32) >>> jnp.all(intersection == ar1[ar1_indices]) Array(True, dtype=bool)
ar2_indices提供交集值在ar2中的索引>>> ar2_indices Array([0, 1], dtype=int32) >>> jnp.all(intersection == ar2[ar2_indices]) Array(True, dtype=bool)