jax.random.multinomial#
- jax.random.multinomial(key, n, p, *, shape=None, dtype=<class 'float'>, unroll=1)[源代码]#
从多项分布中采样。
概率质量函数为
\[f(x;n,p) = \frac{n!}{x_1! \ldots x_k!} p_1^{x_1} \ldots p_k^{x_k}\]- 参数:
key (Array) – PRNG 密钥。
n (RealArray) – 试验次数。 形状应可广播到
p.shape[:-1]
。p (RealArray) – 每个结果的概率,结果沿最后一个轴。
shape (Shape | None) – 可选,一个非负整数元组,指定结果批处理形状,即结果形状的前缀(不包括最后一个轴)。 必须与
p.shape[:-1]
广播兼容。 默认值 (None) 生成一个等于p.shape
的结果形状。dtype (DTypeLikeFloat) – 可选,返回值的浮点数 dtype(如果 jax_enable_x64 为 true,则默认为 float64,否则为 float32)。
unroll (int | bool) – 可选,传递给
jax.lax.scan()
的展开参数,该函数位于此函数的实现内部。
- 返回:
- 具有指定 dtype 的每个结果的计数数组,形状为
p.shape
,如果shape
为 None,否则为shape + (p.shape[-1],)
。