jax.random.categorical#

jax.random.categorical(key, logits, axis=-1, shape=None, replace=True, mode=None)[source]#

从分类分布中采样随机值。

有放回的采样使用 Gumbel max 技巧。 无放回的采样使用 Gumbel top-k 技巧。 有关参考,请参见 [1]。

参数:
  • key (ArrayLike) – 用作随机密钥的 PRNG 密钥。

  • logits (RealArray) – 要从中采样的分类分布的未标准化对数概率,因此 softmax(logits, axis) 给出相应的概率。

  • axis (int) – logits 属于同一分类分布的轴。

  • shape (Shape | None) – 可选,一个表示结果形状的非负整数元组。 必须与 np.delete(logits.shape, axis) 广播兼容。 默认值 (None) 生成一个等于 np.delete(logits.shape, axis) 的结果形状。

  • replace (bool) – 如果为 True(默认值),则执行有放回的采样。 如果为 False,则执行无放回的采样。

  • mode (str | None) – 可选,“high” 或 “low” 表示要在 gumbel 采样器中使用多少位。 默认值由 use_high_dynamic_range_gumbel 配置确定,该配置默认为 “low”。 使用 mode="low",在 float32 中,对于概率小于约 1E-7 的事件,采样会有偏差;使用 mode="high",此限制会降至约 1E-14。 mode="high" 大约会使采样成本增加一倍。

返回:

一个具有 int dtype 的随机数组,其形状由 shape 给定(如果 shape 不为 None),否则由 np.delete(logits.shape, axis) 给定。

返回类型:

Array

参考文献