jax.tree.broadcast#

jax.tree.broadcast(prefix_tree, full_tree, is_leaf=None)[源代码]#

将树前缀广播到给定树的完整结构中。

参数:
  • prefix_tree (Any) – 一个 pytree,它是 full_tree 的树前缀。

  • full_tree (Any) – 一个具有要将前缀叶子广播到的结构的 pytree。

  • is_leaf (Callable[[Any], bool] | None) – 一个可选的指定函数,将在每个展平步骤中调用。它应返回一个布尔值,如果为 True,则停止遍历,并将整个子树视为叶子;如果为 False,则表示展平应继续遍历当前对象。

返回:

一个与 full_tree 结构匹配的 pytree,其中 prefix_tree 的叶子已广播到每个相应子树的叶子中。

返回类型:

任意类型

示例

>>> import jax
>>> prefix = (1, 2, 3)
>>> full = (0, {'a': 0, 'b': 0}, (0, 0))
>>> jax.tree.broadcast(prefix, full)
(1, {'a': 2, 'b': 2}, (3, 3))