jax.tree.broadcast#
- jax.tree.broadcast(prefix_tree, full_tree, is_leaf=None)[源代码]#
将树前缀广播到给定树的完整结构中。
- 参数:
prefix_tree (Any) – 一个 pytree,它是 full_tree 的树前缀。
full_tree (Any) – 一个具有要将前缀叶子广播到的结构的 pytree。
is_leaf (Callable[[Any], bool] | None) – 一个可选的指定函数,将在每个展平步骤中调用。它应返回一个布尔值,如果为 True,则停止遍历,并将整个子树视为叶子;如果为 False,则表示展平应继续遍历当前对象。
- 返回:
一个与 full_tree 结构匹配的 pytree,其中 prefix_tree 的叶子已广播到每个相应子树的叶子中。
- 返回类型:
任意类型
示例
>>> import jax >>> prefix = (1, 2, 3) >>> full = (0, {'a': 0, 'b': 0}, (0, 0)) >>> jax.tree.broadcast(prefix, full) (1, {'a': 2, 'b': 2}, (3, 3))