jax.random.multinomial#
- jax.random.multinomial(key, n, p, *, shape=None, dtype=<class 'float'>, unroll=1)[source]#
从多项分布中采样。
概率质量函数为
\[f(x;n,p) = \frac{n!}{x_1! \ldots x_k!} p_1^{x_1} \ldots p_k^{x_k}\]- 参数:
key (Array) – PRNG 密钥。
n (RealArray) – 试验次数。应具有可广播到
p.shape[:-1]
的形状。p (RealArray) – 每个结果的概率,结果沿最后一个轴排列。
shape (Shape | None | None) – 可选,一个非负整数元组,指定结果批次形状,即结果形状的前缀,不包括最后一个轴。必须与
p.shape[:-1]
广播兼容。默认值 (None) 生成的结果形状等于p.shape
。dtype (DTypeLikeFloat) – 可选,返回值的浮点数据类型(如果 jax_enable_x64 为 true,则默认为 float64,否则为 float32)。
unroll (int | bool) – 可选,传递给
jax.lax.scan()
的展开参数,位于此函数的实现内部。
- 返回:
- 一个数组,其中包含每个结果的计数,具有指定的 dtype 和形状
如果 shape 为 None,则为
p.shape
,否则为shape + (p.shape[-1],)
。