jax.random.wrap_key_data#
- jax.random.wrap_key_data(key_bits_array, *, impl=None)[源码]#
将密钥数据位的数组封装为 PRNG 密钥数组。
- 参数:
key_bits_array (Array) – 一个
uint32数组,其最后一个维度对应于impl指定的 PRNG 实现的密钥形状。impl (PRNGSpecDesc | None) – 可选参数,指定一个 PRNG 实现,如
random.key中所述。
- 返回:
- 一个 PRNG 密钥数组,其 dtype 是
jax.dtypes.prng_key的子类型 对应于
impl,并且其形状等于key_bits_array.shape的前导形状(直到密钥位数维度)。
- 一个 PRNG 密钥数组,其 dtype 是