Problem

Given a binary tree, collect a tree’s nodes as if you were doing this: Collect and remove all leaves, repeat until the tree is empty.

Examples

Example 1 Given binary tree

1
2
3
4
5
6
7
8
Input: root = [1,2,3,4,5]
          1
         / \
        2   3
       / \     
      4   5  

Output: [4, 5, 3], [2], [1]

Solution

The task is to simulate a process where we repeatedly collect all leaves of the binary tree and remove them, until the tree becomes empty. A “leaf” is a node that does not have any child nodes.

Method 1 - DFS with Returning Node Depth From Bottom

The solution lies in reformulating the problem as identifying the index of the element within the result list, which reduces it to a standard DFS problem on trees.

Approach

  1. Depth Calculation: Determine the depth at which a node turns into a leaf. A node becomes a leaf when all its children have been removed. By calculating the depth of each node’s removal, we can efficiently group nodes.
  2. Recursive Structure: Use recursion to calculate the depth of nodes, where the depth is defined as max(left, right) + 1. The leaves of the current depth can then be appended to the result list.
  3. Base Case: An empty node (null in Java or None in Python) contributes nothing, so return -1 or equivalent.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
class Solution {
    public List<List<Integer>> findLeaves(TreeNode root) {
        List<List<Integer>> ans = new ArrayList<>();
        collectLeaves(root, ans);
        return ans;
    }

    private int collectLeaves(TreeNode node, List<List<Integer>> ans) {
        if (node == null) return -1; // Handle null nodes
        int lDepth = collectLeaves(node.left, ans);
        int rDepth = collectLeaves(node.right, ans);
        int depth = Math.max(lDepth, rDepth) + 1;
        while (ans.size() <= depth) {
            ans.add(new ArrayList<>()); // Ensure enough space for depth grouping
        }
        ans.get(depth).add(node.val); // Add node value to its depth-level group
        return depth;
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
class Solution:
    def findLeaves(self, root: Optional[TreeNode]) -> List[List[int]]:
        ans: List[List[int]] = []

        def dfs(node: Optional[TreeNode]) -> int:
            if not node:
                return -1  # Handle empty nodes
            l_depth = dfs(node.left)
            r_depth = dfs(node.right)
            depth = max(l_depth, r_depth) + 1
            if len(ans) <= depth:
                ans.append([])  # Ensure enough space to store nodes at this depth
            ans[depth].append(node.val)  # Append node value to its depth-level group
            return depth

        dfs(root)
        return ans

Complexity

  • ⏰ Time complexity: O(n) , due to traversal of each node once.
  • 🧺 Space complexity: O(n). Using recursion stack takes O(h), where h is the tree height for and then O(n) for result storage.