jax.random.categorical#

jax.random.categorical(key, logits, axis=-1, shape=None, replace=True)[source]#

从分类分布中抽取随机值。

有放回采样使用 Gumbel max 技巧。无放回采样使用 Gumbel top-k 技巧。参见 [1] 作为参考。

参数:
  • key (ArrayLike) – 用作随机键的 PRNG 键。

  • logits (RealArray) – 要从中采样的分类分布的未归一化对数概率,因此 softmax(logits, axis) 给出相应的概率。

  • axis (int) – logits 属于同一分类分布的轴。

  • shape (Shape | None | None) – 可选,表示结果形状的非负整数元组。必须与 np.delete(logits.shape, axis) 广播兼容。默认值 (None) 生成的结果形状等于 np.delete(logits.shape, axis)

  • replace (bool) – 如果为 True(默认),则执行有放回采样。如果为 False,则执行无放回采样。

返回:

如果 shape 不是 None,则返回具有 int dtype 和 shape 给定形状的随机数组,否则返回 np.delete(logits.shape, axis)

返回类型:

Array

参考文献