jax.numpy.isin#
- jax.numpy.isin(element, test_elements, assume_unique=False, invert=False, *, method='auto')[源代码]#
确定
element
中的元素是否出现在test_elements
中。NumPy
numpy.isin()
的 JAX 实现。- 参数:
element (ArrayLike) – 用于检查成员资格的输入元素数组。
test_elements (ArrayLike) – N 维测试值数组,用于检查每个元素是否存在。
invert (布尔值) – 如果为 True,则返回
~isin(element, test_elements)
。默认为 False。assume_unique (布尔值) – 如果为 True,则假定输入数组是唯一的,这可以带来更高效的计算。如果输入数组不唯一且 assume_unique 设置为 True,则结果将是未定义的。
method – 字符串,指定用于计算结果的方法。支持的选项有 'compare_all'、'binary_search'、'sort' 和 'auto'(默认)。
- 返回:
一个形状为
element.shape
的布尔数组,指定每个元素是否出现在test_elements
中。- 返回类型:
示例
>>> elements = jnp.array([1, 2, 3, 4]) >>> test_elements = jnp.array([[1, 5, 6, 3, 7, 1]]) >>> jnp.isin(elements, test_elements) Array([ True, False, True, False], dtype=bool)