jax.random.bernoulli#
- jax.random.bernoulli(key, p=np.float32(0.5), shape=None, mode='low')[source]#
采样具有给定形状和均值的伯努利随机值。
这些值根据概率质量函数分布
\[f(k; p) = p^k(1 - p)^{1 - k}\]其中 \(k \in \{0, 1\}\) 且 \(0 \le p \le 1\).
- 参数:
key (ArrayLike) – 用作随机密钥的 PRNG 密钥。
p (RealArray) – 可选,一个浮点数或浮点数数组,表示随机变量的均值。必须与
shape
广播兼容。默认值为 0.5。shape (Shape | None) – 可选,表示结果形状的非负整数元组。必须与
p.shape
广播兼容。默认值 (None) 生成一个等于p.shape
的结果形状。mode (str) – 可选,“high” 或 “low”,用于确定采样时使用的位数。default='low'。设置为 “high”,以便在较小的 p 值处进行正确的采样。当在 float32 中采样时,对于 p < ~1E-7,bernoulli 样本使用 mode='low' 会产生不正确的结果。mode="high" 大约使采样成本翻倍。
- 返回:
一个具有布尔 dtype 的随机数组,如果
shape
不是 None,则形状由shape
给定,否则由p.shape
给定。- 返回类型: