jax.tree_util.keystr#

jax.tree_util.keystr(keys, *, simple=False, separator='')[源代码]#

用于美观地打印键元组的辅助函数。

参数:
  • keys (KeyPath) – 一个由 KeyEntry 或任何可转换为字符串的类组成的元组。

  • simple (bool) – 如果为 True,则对键使用简化的字符串表示形式。简化的键表示形式将比默认形式更紧凑,但在某些情况下可能会产生歧义(例如,“0”可能指列表中的第一个项,或整数 0 或字符串“0”的字典键)。

  • separator (str) – 用于连接键的字符串表示形式的分隔符。

返回:

一个将所有键的字符串表示形式连接起来的字符串。

返回类型:

str

示例

>>> import jax
>>> params = {'foo': {'bar': {'baz': 1, 'bat': [2, 3]}}}
>>> for path, _ in jax.tree_util.tree_leaves_with_path(params):
...   print(jax.tree_util.keystr(path))
['foo']['bar']['bat'][0]
['foo']['bar']['bat'][1]
['foo']['bar']['baz']
>>> for path, _ in jax.tree_util.tree_leaves_with_path(params):
...   print(jax.tree_util.keystr(path, simple=True, separator='/'))
foo/bar/bat/0
foo/bar/bat/1
foo/bar/baz