jax.numpy.unique_inverse#
- jax.numpy.unique_inverse(x, /, *, size=None, fill_value=None)[源]#
从 x 返回唯一值,以及索引、逆索引和计数。
JAX 实现
numpy.unique_inverse();这等同于调用jax.numpy.unique()并将 return_inverse 和 equal_nan 设置为 True。由于
unique_inverse的输出大小取决于数据,因此该函数通常不兼容jit()和其他 JAX 变换。JAX 版本添加了一个可选的size参数,必须将其静态指定才能在此类上下文中jnp.unique。- 参数:
x (ArrayLike) – 将从中提取唯一值的 N 维数组。
size (int | None) – 如果指定,则仅返回前
size个排序后的唯一元素。如果唯一元素的数量少于size指定的数量,则返回值将用fill_value填充。fill_value (ArrayLike | None) – 当指定
size且元素数量少于指定数量时,用fill_value填充剩余的条目。默认为最小的唯一值。
- 返回:
values:形状为
(n_unique,)的数组,包含来自x的唯一值。
inverse_indices:形状为
x.shape的数组。包含x中每个值在values中的索引。对于 1D 输入,values[inverse_indices]等同于x。
- 返回类型:
一个元组
(values, indices, inverse_indices, counts),具有以下属性
另请参阅
jax.numpy.unique(): 计算唯一值的通用函数。jax.numpy.unique_values():仅计算values。jax.numpy.unique_counts(): 仅计算values和counts。jax.numpy.unique_all():计算values、indices、inverse_indices和counts。
示例
这里我们计算一维数组中的唯一值
>>> x = jnp.array([3, 4, 1, 3, 1]) >>> result = jnp.unique_inverse(x)
结果是一个
NamedTuple,带有两个命名属性。values属性包含数组中的唯一值>>> result.values Array([1, 3, 4], dtype=int32)
indices属性包含输入数组中唯一values的索引inverse_indices属性包含values中输入的索引>>> result.inverse_indices Array([1, 2, 0, 1, 0], dtype=int32) >>> jnp.all(x == result.values[result.inverse_indices]) Array(True, dtype=bool)
有关
size和fill_value参数的示例,请参阅jax.numpy.unique()。