jax.tree_util.keystr#

jax.tree_util.keystr(keys, *, simple=False, separator='')[源代码]#

用于美观打印键元组的辅助函数。

参数:
  • keys (KeyPath) – 一个由 KeyEntry 或任何可以转换为字符串的类组成的元组。

  • simple (bool) – 如果为 True,则使用键的简化字符串表示。键的简化表示比默认表示更紧凑,但在某些情况下可能不明确(例如,“0”可能指列表中的第一个项,或者字典键中的整数 0 或字符串“0”)。

  • separator (str) – 用于连接键的字符串表示形式的分隔符。

返回:

连接所有键的字符串表示形式的字符串。

返回类型:

str

示例

>>> import jax
>>> params = {'foo': {'bar': {'baz': 1, 'bat': [2, 3]}}}
>>> for path, _ in jax.tree_util.tree_leaves_with_path(params):
...   print(jax.tree_util.keystr(path))
['foo']['bar']['bat'][0]
['foo']['bar']['bat'][1]
['foo']['bar']['baz']
>>> for path, _ in jax.tree_util.tree_leaves_with_path(params):
...   print(jax.tree_util.keystr(path, simple=True, separator='/'))
foo/bar/bat/0
foo/bar/bat/1
foo/bar/baz