jax.lax.approx_max_k#
- jax.lax.approx_max_k(operand, k, reduction_dimension=-1, recall_target=0.95, reduction_input_size_override=-1, aggregate_to_topk=True)[source]#
以近似方式返回
operand
的最大k
值及其索引。有关算法详细信息,请参阅 https://arxiv.org/abs/2206.14286。
- 参数:
operand (Array) – 用于搜索 max-k 的数组。 必须为浮点数类型。
k (int) – 指定 max-k 的数量。
reduction_dimension (int) – 搜索的整数维度。 默认值:-1。
recall_target (float) – 近似的召回目标。
reduction_input_size_override (int) – 当设置为正值时,它会覆盖由
operand[reduction_dim]
确定的大小,以评估召回率。 当给定的operand
只是 SPMD 或分布式管道中整体计算的子集时,此选项非常有用,其中真实输入大小不能通过操作数形状来推迟。aggregate_to_topk (bool) – 如果为 true,则将近似结果聚合到排序后的 top-k 中。 如果为 false,则返回未排序的近似结果。 在这种情况下,近似结果的数量由实现定义,并且大于或等于指定的
k
。
- 返回:
两个数组的元组。 这些数组是输入
operand
沿reduction_dimension
的最大k
值和相应的索引。 这些数组的维度与输入operand
相同,但reduction_dimension
除外:当aggregate_to_topk
为 true 时,归约维度为k
;否则,它大于等于k
,其中大小由实现定义。- 返回类型:
我们鼓励用户使用 jit 包装
approx_max_k
。 有关最大内积搜索 (MIPS) 的示例,请参见以下内容>>> import functools >>> import jax >>> import numpy as np >>> @functools.partial(jax.jit, static_argnames=["k", "recall_target"]) ... def mips(qy, db, k=10, recall_target=0.95): ... dists = jax.lax.dot(qy, db.transpose()) ... # returns (f32[qy_size, k], i32[qy_size, k]) ... return jax.lax.approx_max_k(dists, k=k, recall_target=recall_target) >>> >>> qy = jax.numpy.array(np.random.rand(50, 64)) >>> db = jax.numpy.array(np.random.rand(1024, 64)) >>> dot_products, neighbors = mips(qy, db, k=10)