jax.random.choice#
- jax.random.choice(key, a, shape=(), replace=True, p=None, axis=0, mode=None)[source]#
从给定数组生成随机样本。
警告
如果
p的非零元素少于shape中指定的请求样本数,并且replace=False,则此函数的输出定义不明。请确保使用适当的输入。- 参数:
key (ArrayLike) – 用作随机密钥的 PRNG 密钥。
a (int | ArrayLike) – 数组或整数。如果为 ndarray,则从其元素生成随机样本。如果为整数,则随机样本的生成方式如同 a 为 arange(a)。
shape (Shape) – 整数元组,可选。输出形状。如果给定的形状是,例如,
(m, n),则抽取m * n个样本。默认值为 (),在这种情况下返回单个值。replace (bool) – 布尔值。样本是有放回还是无放回。默认值为 True。
p (RealArray | None) – 一维数组类,a 中每个条目的关联概率。如果未给定,则样本假定在 a 的所有条目上具有均匀分布。
axis (int) – 整数,可选。执行选择的轴。默认值 0 按行选择。
mode (str | None) – 可选,“high”或“low”,用于在 p 为 None 且 replace = False 时使用的 gumbel 采样器的位数。默认值由
use_high_dynamic_range_gumbel配置确定,该配置默认为“low”。使用 mode=”low”时,在 float32 中采样对于概率小于约 1E-7 的选择会有偏差;使用 mode=”high”时,此限制会降低到约 1E-14。mode=”high”大约使采样成本加倍。
- 返回:
一个形状为 shape 的数组,包含来自 a 的样本。
- 返回类型: