jax.numpy.isin#

jax.numpy.isin(element, test_elements, assume_unique=False, invert=False, *, method='auto')[源代码]#

确定 element 中的元素是否出现在 test_elements 中。

JAX 实现的 numpy.isin()

参数:
  • element (ArrayLike) – 输入数组,用于检查成员资格的元素。

  • test_elements (ArrayLike) – N 维数组,用于检查每个元素是否存在。

  • invert (bool) – 如果为 True,则返回 ~isin(element, test_elements)。默认为 False。

  • assume_unique (bool) – 如果为 true,则假定输入数组是唯一的,这可以提高计算效率。如果输入数组不是唯一的,并且 assume_unique 设置为 True,则结果是未定义的。

  • method – 字符串,指定用于计算结果的方法。支持的选项有 ‘compare_all’、‘binary_search’、‘sort’ 和 ‘auto’(默认)。

返回:

一个形状为 element.shape 的布尔数组,用于指定每个元素是否出现在 test_elements 中。

返回类型:

数组

示例

>>> elements = jnp.array([1, 2, 3, 4])
>>> test_elements = jnp.array([[1, 5, 6, 3, 7, 1]])
>>> jnp.isin(elements, test_elements)
Array([ True, False,  True, False], dtype=bool)