jax.tree.broadcast#
- jax.tree.broadcast(prefix_tree, full_tree, is_leaf=None)[源代码]#
将树前缀广播到给定树的完整结构中。
- 参数:
prefix_tree (Any) – 一个 pytree,它是 full_tree 的树前缀。
full_tree (Any) – 一个 pytree,具有将前缀叶子广播到的结构。
is_leaf (Callable[[Any], bool] | None) – 一个可选的指定函数,将在每个扁平化步骤调用。它应该返回一个布尔值,true 停止遍历并将整个子树视为叶子,false 表示扁平化应该遍历当前对象。
- 返回:
一个匹配 full_tree 结构的 pytree,其中 prefix_tree 的叶子已被广播到每个对应的子树的叶子中。
- 返回类型:
任意类型
示例
>>> import jax >>> prefix = (1, 2, 3) >>> full = (0, {'a': 0, 'b': 0}, (0, 0)) >>> jax.tree.broadcast(prefix, full) (1, {'a': 2, 'b': 2}, (3, 3))