jax.random.wrap_key_data#
- jax.random.wrap_key_data(key_bits_array, *, impl=None)[source]#
将密钥数据位数组包装成一个PRNG密钥数组。
- 参数:
key_bits_array (Array) — 一个
uint32
数组,其尾部形状对应于由impl
指定的 PRNG 实现的密钥形状。impl (PRNGSpecDesc | None) — 可选,指定一个 PRNG 实现,如同
random.key
中所示。
- 返回:
- 一个 PRNG 密钥数组,其 dtype 是
jax.dtypes.prng_key
的子类型 对应于
impl
,且其形状等于key_bits_array.shape
的前导形状,直到密钥位维度。
- 一个 PRNG 密钥数组,其 dtype 是