jax.numpy.setdiff1d#
- jax.numpy.setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None)[源码]#
计算两个一维数组的集合差。
JAX 对
numpy.setdiff1d()的实现。由于
setdiff1d的输出大小依赖于数据,因此该函数通常不兼容jit()和其他 JAX 变换。JAX 版本添加了可选的size参数,必须将其静态指定,才能在这些上下文中jnp.setdiff1d。 才能用于这些上下文。- 参数:
ar1 (ArrayLike) – 要进行差分的第一个元素数组。
ar2 (ArrayLike) – 要进行差分的第二个元素数组。
assume_unique (bool) – 如果为 True,则假定输入数组包含唯一值。这允许更高效的实现,但如果
assume_unique为 True 且输入数组包含重复值,则行为未定义。默认值:False。size (int | None) – 如果指定,则仅返回前
size个已排序的元素。如果元素少于size所指示的数量,则返回值将用fill_value进行填充。fill_value (ArrayLike | None) – 当指定
size且元素数量少于指定数量时,用fill_value填充剩余的条目。默认为最小值。
- 返回:
即
ar1中不包含在ar2中的元素。- 返回类型:
包含输入数组元素集合差的数组
另请参阅
jax.numpy.intersect1d(): 两个一维数组的集合交集。jax.numpy.setxor1d():两个一维数组的集合异或。jax.numpy.union1d():两个一维数组的集合并集。
示例
计算两个数组的集合差
>>> ar1 = jnp.array([1, 2, 3, 4]) >>> ar2 = jnp.array([3, 4, 5, 6]) >>> jnp.setdiff1d(ar1, ar2) Array([1, 2], dtype=int32)
由于输出形状是动态的,因此在
jit()和其他变换下会失败>>> jax.jit(jnp.setdiff1d)(ar1, ar2) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4]. The error occurred while tracing the function setdiff1d at /Users/vanderplas/github/jax-ml/jax/jax/_src/numpy/setops.py:64 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.
为了确保静态已知的输出形状,您可以传递一个静态
size参数>>> jit_setdiff1d = jax.jit(jnp.setdiff1d, static_argnames=['size']) >>> jit_setdiff1d(ar1, ar2, size=2) Array([1, 2], dtype=int32)
如果
size太小,则差值将被截断>>> jit_setdiff1d(ar1, ar2, size=1) Array([1], dtype=int32)
如果
size太大,则输出将用fill_value填充>>> jit_setdiff1d(ar1, ar2, size=4, fill_value=0) Array([1, 2, 0, 0], dtype=int32)