jax.tree_util.tree_flatten_with_path

jax.tree_util.tree_flatten_with_path#

jax.tree_util.tree_flatten_with_path(tree, is_leaf=None)[source]#

Flattens a pytree like tree_flatten, but also returns each leaf’s key path.

Parameters:
  • tree (Any) – a pytree to flatten. If it contains a custom type, it must be registered with register_pytree_with_keys.

  • is_leaf (Callable[[Any], bool] | None)

Return type:

tuple[list[tuple[KeyPath, Any]], PyTreeDef]

Returns:

A pair which the first element is a list of key-leaf pairs, each of which contains a leaf and its key path. The second element is a treedef representing the structure of the flattened tree.