jax.random.categorical#

jax.random.categorical(key, logits, axis=-1, shape=None, replace=True, mode=None)[源代码]#

从分类分布中抽取随机值。

有放回抽样使用 Gumbel max 技巧。无放回抽样使用 Gumbel top-k 技巧。参见 [1]。

参数:
  • key (ArrayLike) – 用作随机密钥的 PRNG 密钥。

  • logits (RealArray) – 分类分布(或多个分布)的未归一化对数概率,以便 softmax(logits, axis) 给出相应的概率。

  • axis (int) – logits 属于同一分类分布的轴。

  • shape (Shape | None) – 可选,一个非负整数元组,表示结果的形状。必须与 np.delete(logits.shape, axis) 广播兼容。默认值 (None) 生成一个结果形状,等于 np.delete(logits.shape, axis)

  • replace (bool) – 如果为 True(默认),则进行有放回抽样。如果为 False,则进行无放回抽样。

  • mode (str | None) – 可选,值为“high”或“low”,用于指定 Gumbel 采样器使用的位数。默认值由 use_high_dynamic_range_gumbel 配置决定,该配置默认为“low”。当 mode="low" 时,在 float32 采样中,概率小于约 1E-7 的事件将有偏差;当 mode="high" 时,此限制推至约 1E-14。mode="high" 会使采样成本大约翻倍。

返回:

一个具有 int 数据类型和由 shape(如果 shape 不为 None)或 np.delete(logits.shape, axis)(否则)指定的形状的随机数组。

返回类型:

Array

参考文献