jax.tree_util.tree_leaves_with_path#
- jax.tree_util.tree_leaves_with_path(tree, is_leaf=None)[source]#
Gets the leaves of a pytree like
tree_leaves
and returns each leaf’s key path.- Parameters:
tree (Any) – a pytree. If it contains a custom type, it must be registered with
register_pytree_with_keys
.is_leaf (Callable[[Any], bool] | None)
- Return type:
list[tuple[KeyPath, Any]]
- Returns:
A list of key-leaf pairs, each of which contains a leaf and its key path.