jax.tree_util.tree_leaves_with_path# jax.tree_util.tree_leaves_with_path(tree, is_leaf=None)[source]# Gets the leaves of a pytree like tree_leaves and returns each leaf’s key path. Parameters: tree (Any) – a pytree. If it contains a custom type, it must be registered with register_pytree_with_keys. is_leaf (Callable[[Any], bool] | None) Returns: A list of key-leaf pairs, each of which contains a leaf and its key path. Return type: list[tuple[tuple[KeyEntry, …], Any]]