jax.random.choice#
- jax.random.choice(key, a, shape=(), replace=True, p=None, axis=0, mode=None)[source]#
从给定数组生成随机样本。
警告
如果
p
拥有的非零元素数量少于请求的样本数量(在shape
中指定),并且replace=False
,则此函数的输出是定义不明确的。请确保使用适当的输入。- 参数:
key (ArrayLike) – 用作随机密钥的 PRNG 密钥。
a (int | ArrayLike) – 数组或整数。如果是一个 ndarray,则从其元素生成随机样本。如果是一个整数,则生成随机样本,就像 a 是 arange(a) 一样。
shape (Shape) – 整数元组,可选。输出形状。如果给定的形状是,例如
(m, n)
,则抽取m * n
个样本。 默认值为 (),在这种情况下,返回单个值。replace (bool) – 布尔值。样本是否进行替换。默认值为 True。
p (RealArray | None) – 类似 1-D 数组,与 a 中每个条目关联的概率。如果未给出,则样本假定 a 中所有条目的均匀分布。
axis (int) – 整数,可选。执行选择的轴。默认值 0,按行选择。
mode (str | None) – 可选,“high” 或 “low”,用于在使用 Gumbel 采样器时使用多少位,p is None 并且 replace = False。默认值由
use_high_dynamic_range_gumbel
配置确定,默认为 “low”。使用 mode="low",在 float32 中,对于概率小于约 1E-7 的选择,采样会有偏差;使用 mode="high",此限制被推低至约 1E-14。mode="high" 大约使采样成本翻倍。
- 返回:
一个形状为 shape 的数组,包含来自 a 的样本。
- 返回类型: