jax.random.categorical#
- jax.random.categorical(key, logits, axis=-1, shape=None, replace=True, mode=None)[源代码]#
从分类分布中抽取随机值。
有放回抽样使用 Gumbel max 技巧。无放回抽样使用 Gumbel top-k 技巧。参见 [1]。
- 参数:
key (ArrayLike) – 用作随机密钥的 PRNG 密钥。
logits (RealArray) – 分类分布(或多个分布)的未归一化对数概率,以便 softmax(logits, axis) 给出相应的概率。
axis (int) – logits 属于同一分类分布的轴。
shape (Shape | None) – 可选,一个非负整数元组,表示结果的形状。必须与
np.delete(logits.shape, axis)广播兼容。默认值 (None) 生成一个结果形状,等于np.delete(logits.shape, axis)。replace (bool) – 如果为 True(默认),则进行有放回抽样。如果为 False,则进行无放回抽样。
mode (str | None) – 可选,值为“high”或“low”,用于指定 Gumbel 采样器使用的位数。默认值由
use_high_dynamic_range_gumbel配置决定,该配置默认为“low”。当 mode="low" 时,在 float32 采样中,概率小于约 1E-7 的事件将有偏差;当 mode="high" 时,此限制推至约 1E-14。mode="high" 会使采样成本大约翻倍。
- 返回:
一个具有 int 数据类型和由
shape(如果shape不为 None)或np.delete(logits.shape, axis)(否则)指定的形状的随机数组。- 返回类型:
参考文献