jax.tree.broadcast#

jax.tree.broadcast(prefix_tree, full_tree, is_leaf=None)[源代码]#

将树前缀广播到给定树的完整结构中。

参数:
  • prefix_tree (Any) – 一个 pytree,它是 full_tree 的树前缀。

  • full_tree (Any) – 一个 pytree,具有将前缀叶子广播到的结构。

  • is_leaf (Callable[[Any], bool] | None) – 一个可选的指定函数,将在每个扁平化步骤调用。它应该返回一个布尔值,true 停止遍历并将整个子树视为叶子,false 表示扁平化应该遍历当前对象。

返回:

一个匹配 full_tree 结构的 pytree,其中 prefix_tree 的叶子已被广播到每个对应的子树的叶子中。

返回类型:

任意类型

示例

>>> import jax
>>> prefix = (1, 2, 3)
>>> full = (0, {'a': 0, 'b': 0}, (0, 0))
>>> jax.tree.broadcast(prefix, full)
(1, {'a': 2, 'b': 2}, (3, 3))