jax.flatten_util.ravel_pytree

jax.flatten_util.ravel_pytree#

jax.flatten_util.ravel_pytree(pytree)[source]#

将 pytree 中的数组展平(扁平化)为一个一维数组。

参数:

pytree – 要展平的数组和标量的 pytree。

返回值:

一个包含两个元素的元组,第一个元素是一维数组,表示扁平化并连接的叶子值,数据类型由提升叶子值的数据类型确定;第二个元素是一个可调用对象,用于将相同长度的一维向量解展平回与输入 pytree 结构相同的 pytree。如果输入 pytree 为空(即没有叶子),则按照约定,输出的第一个组件中返回一个数据类型为 float32 的一维空数组。

有关数据类型提升的详细信息,请参阅 https://jax.net.cn/en/latest/type_promotion.html