jax.random.multivariate_normal#
- jax.random.multivariate_normal(key, mean, cov, shape=None, dtype=None, method='cholesky')[source]#
根据给定的均值和协方差,对多元正态随机值进行采样。
这些值根据概率密度函数返回
\[f(x;\mu, \Sigma) = (2\pi)^{-k/2} \det(\Sigma)^{-1}e^{-\frac{1}{2}(x - \mu)^T \Sigma^{-1} (x - \mu)}\]其中 \(k\) 是维度,\(\mu\) 是均值 (由
mean
给出),\(\Sigma\) 是协方差矩阵 (由cov
给出)。- 参数:
key (ArrayLike) – 用作随机密钥的 PRNG 密钥。
mean (RealArray) – 形状为
(..., n)
的均值向量。cov (RealArray) – 形状为
(..., n, n)
的正定协方差矩阵。 批处理形状...
必须与mean
的批处理形状广播兼容。shape (Shape | None) – 可选,一个非负整数的元组,用于指定结果的批处理形状;也就是结果形状的前缀,不包括最后一个轴。 必须与
mean.shape[:-1]
和cov.shape[:-2]
广播兼容。默认值 (None) 通过一起广播mean
和cov
的批处理形状来生成结果批处理形状。dtype (DTypeLikeFloat | None) – 可选,返回值的浮点 dtype (如果 jax_enable_x64 为 true,则默认为 float64,否则为 float32)。
method (str) – 可选,计算
cov
的因子方法。 必须是 ‘svd’、‘eigh’ 和 ‘cholesky’ 之一。 默认为 ‘cholesky’。 对于奇异协方差矩阵,请使用 ‘svd’ 或 ‘eigh’。
- 返回:
具有指定 dtype 的随机数组,其形状由
shape + mean.shape[-1:]
给出 (如果shape
不是 None),否则由broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]
给出。- 返回类型: