jax.random.bernoulli#
- jax.random.bernoulli(key, p=0.5, shape=None, mode='low')[源代码]#
使用给定的形状和均值采样伯努利随机值。
这些值根据概率质量函数分布
\[f(k; p) = p^k(1 - p)^{1 - k}\]其中 \(k \in \{0, 1\}\) 且 \(0 \le p \le 1\)。
- 参数:
key (ArrayLike) – 用作随机密钥的 PRNG 密钥。
p (RealArray) – 可选,用于随机变量均值的浮点数或浮点数数组。必须与
shape兼容。默认为 0.5。shape (Shape | None) – 可选,表示结果形状的非负整数元组。必须与
p.shape兼容。默认值 (None) 产生与p.shape相等的结果形状。mode (str) – 可选,“high” 或 “low”,表示采样时使用的位数。默认为 'low'。当 p 值很小时,请设置为“high”以获得正确的采样。在 float32 中采样时,mode='low' 的 bernoulli 采样对于 p < ~1E-7 会产生不正确的结果。mode="high" 会将采样成本大约加倍。
- 返回:
一个具有布尔 dtype 的随机数组,其形状由
shape(如果shape不是 None)或p.shape(否则)给出。- 返回类型: