jax.random.categorical#
- jax.random.categorical(key, logits, axis=-1, shape=None, replace=True, mode=None)[source]#
从分类分布中采样随机值。
有放回的采样使用 Gumbel max 技巧。 无放回的采样使用 Gumbel top-k 技巧。 有关参考,请参见 [1]。
- 参数:
key (ArrayLike) – 用作随机密钥的 PRNG 密钥。
logits (RealArray) – 要从中采样的分类分布的未标准化对数概率,因此 softmax(logits, axis) 给出相应的概率。
axis (int) – logits 属于同一分类分布的轴。
shape (Shape | None) – 可选,一个表示结果形状的非负整数元组。 必须与
np.delete(logits.shape, axis)
广播兼容。 默认值 (None) 生成一个等于np.delete(logits.shape, axis)
的结果形状。replace (bool) – 如果为 True(默认值),则执行有放回的采样。 如果为 False,则执行无放回的采样。
mode (str | None) – 可选,“high” 或 “low” 表示要在 gumbel 采样器中使用多少位。 默认值由
use_high_dynamic_range_gumbel
配置确定,该配置默认为 “low”。 使用 mode="low",在 float32 中,对于概率小于约 1E-7 的事件,采样会有偏差;使用 mode="high",此限制会降至约 1E-14。 mode="high" 大约会使采样成本增加一倍。
- 返回:
一个具有 int dtype 的随机数组,其形状由
shape
给定(如果shape
不为 None),否则由np.delete(logits.shape, axis)
给定。- 返回类型:
参考文献