jax.flatten_util.ravel_pytree#
- jax.flatten_util.ravel_pytree(pytree)[源代码]#
将数组的 pytree 展平(ravel)到一个一维数组。
- 参数:
pytree – 要展平的数组和标量的 pytree。
- 返回:
返回一个元组,其中第一个元素是一个一维数组,表示展平并连接的叶子值,其 dtype 由叶子值的 dtype 提升确定。第二个元素是一个可调用对象,用于将相同长度的一维向量反展平回与输入
pytree具有相同结构的 pytree。如果输入 pytree 为空(即没有叶子),则约定返回一个 dtype 为 float32 的一维空数组作为输出的第一个分量。
有关 dtype 提升的详细信息,请参阅 https://jax.net.cn/en/latest/type_promotion.html。