jax.random.categorical#
- jax.random.categorical(key, logits, axis=-1, shape=None, replace=True)[source]#
从分类分布中抽取随机值。
有放回采样使用 Gumbel max 技巧。无放回采样使用 Gumbel top-k 技巧。参见 [1] 作为参考。
- 参数:
key (ArrayLike) – 用作随机键的 PRNG 键。
logits (RealArray) – 要从中采样的分类分布的未归一化对数概率,因此 softmax(logits, axis) 给出相应的概率。
axis (int) – logits 属于同一分类分布的轴。
shape (Shape | None | None) – 可选,表示结果形状的非负整数元组。必须与
np.delete(logits.shape, axis)
广播兼容。默认值 (None) 生成的结果形状等于np.delete(logits.shape, axis)
。replace (bool) – 如果为 True(默认),则执行有放回采样。如果为 False,则执行无放回采样。
- 返回:
如果
shape
不是 None,则返回具有 int dtype 和shape
给定形状的随机数组,否则返回np.delete(logits.shape, axis)
。- 返回类型:
参考文献