jax.random.multinomial#

jax.random.multinomial(key, n, p, *, shape=None, dtype=<class 'float'>, unroll=1)[source]#

从多项分布中采样。

概率质量函数为

\[f(x;n,p) = \frac{n!}{x_1! \ldots x_k!} p_1^{x_1} \ldots p_k^{x_k}\]
参数:
  • key (Array) – PRNG 密钥。

  • n (RealArray) – 试验次数。应具有可广播到 p.shape[:-1] 的形状。

  • p (RealArray) – 每个结果的概率,结果沿最后一个轴排列。

  • shape (Shape | None | None) – 可选,一个非负整数元组,指定结果批次形状,即结果形状的前缀,不包括最后一个轴。必须与 p.shape[:-1] 广播兼容。默认值 (None) 生成的结果形状等于 p.shape

  • dtype (DTypeLikeFloat) – 可选,返回值的浮点数据类型(如果 jax_enable_x64 为 true,则默认为 float64,否则为 float32)。

  • unroll (int | bool) – 可选,传递给 jax.lax.scan() 的展开参数,位于此函数的实现内部。

返回:

一个数组,其中包含每个结果的计数,具有指定的 dtype 和形状

如果 shape 为 None,则为 p.shape,否则为 shape + (p.shape[-1],)