jax.tree_util.keystr#
- jax.tree_util.keystr(keys, *, simple=False, separator='')[源代码]#
用于美观打印键元组的辅助函数。
- 参数:
- 返回:
连接所有键的字符串表示形式的字符串。
- 返回类型:
示例
>>> import jax >>> params = {'foo': {'bar': {'baz': 1, 'bat': [2, 3]}}} >>> for path, _ in jax.tree_util.tree_leaves_with_path(params): ... print(jax.tree_util.keystr(path)) ['foo']['bar']['bat'][0] ['foo']['bar']['bat'][1] ['foo']['bar']['baz'] >>> for path, _ in jax.tree_util.tree_leaves_with_path(params): ... print(jax.tree_util.keystr(path, simple=True, separator='/')) foo/bar/bat/0 foo/bar/bat/1 foo/bar/baz