jax.tree_util.keystr#

jax.tree_util.keystr(keys, *, simple=False, separator='')[源代码]#

用于美观打印键元组的辅助函数。

参数:
  • keys (KeyPath) – KeyEntry 元组或任何可以转换为字符串的类。

  • simple (bool) – 如果为 True,则使用简化的字符串表示形式表示键。键的简单表示形式将比默认形式更紧凑,但在某些情况下是模棱两可的(例如,“0”可能指列表中的第一个项目,也可能指整数 0 或字符串 “0” 的字典键)。

  • separator (str) – 用于连接键的字符串表示形式的分隔符。

返回:

连接所有键的字符串表示形式的字符串。

返回类型:

str

示例

>>> import jax
>>> params = {'foo': {'bar': {'baz': 1, 'bat': [2, 3]}}}
>>> for path, _ in jax.tree_util.tree_leaves_with_path(params):
...   print(jax.tree_util.keystr(path))
['foo']['bar']['bat'][0]
['foo']['bar']['bat'][1]
['foo']['bar']['baz']
>>> for path, _ in jax.tree_util.tree_leaves_with_path(params):
...   print(jax.tree_util.keystr(path, simple=True, separator='/'))
foo/bar/bat/0
foo/bar/bat/1
foo/bar/baz