jax.random.multinomial#

jax.random.multinomial(key, n, p, *, shape=None, dtype=None, unroll=1)[源代码]#

从多项分布中采样。

概率质量函数为

\[f(x;n,p) = \frac{n!}{x_1! \ldots x_k!} p_1^{x_1} \ldots p_k^{x_k}\]
参数:
  • key (Array) – PRNG 密钥。

  • n (RealArray) – 试验次数。其形状应可广播到 p.shape[:-1]

  • p (RealArray) – 每个结果的概率,结果沿最后一个轴排列。

  • shape (Shape | None) – 可选,一个非负整数元组,指定结果批次的形状,即结果形状中除最后一个轴之外的前缀。必须与 p.shape[:-1] 兼容。默认值 (None) 生成的结果形状等于 p.shape

  • dtype (DTypeLikeFloat | None) – 可选,返回值的浮点数 dtype(如果 jax_enable_x64 为 true,则默认为 float64,否则为 float32)。

  • unroll (int | bool) – 可选,传递给函数内部 jax.lax.scan() 的 unroll 参数。

返回:

一个包含每个结果计数的数组,具有指定的 dtype 和形状

p.shape (如果 shape 为 None),否则为 shape + (p.shape[-1],)