Find Leaves of Binary Tree
MediumUpdated: Aug 2, 2025
Problem
Given a binary tree, collect a tree's nodes as if you were doing this: Collect and remove all leaves, repeat until the tree is empty.
Examples
Example 1 Given binary tree
Input: root = [1,2,3,4,5]
1
/ \
2 3
/ \
4 5
Output: [4, 5, 3], [2], [1]
Solution
The task is to simulate a process where we repeatedly collect all leaves of the binary tree and remove them, until the tree becomes empty. A "leaf" is a node that does not have any child nodes.
Method 1 - DFS with Returning Node Depth From Bottom
The solution lies in reformulating the problem as identifying the index of the element within the result list, which reduces it to a standard DFS problem on trees.
Approach
- Depth Calculation: Determine the depth at which a node turns into a leaf. A node becomes a leaf when all its children have been removed. By calculating the depth of each node's removal, we can efficiently group nodes.
- Recursive Structure:
Use recursion to calculate the depth of nodes, where the depth is defined as
max(left, right) + 1. The leaves of the current depth can then be appended to the result list. - Base Case:
An empty node (
nullin Java orNonein Python) contributes nothing, so return -1 or equivalent.
Code
Java
class Solution {
public List<List<Integer>> findLeaves(TreeNode root) {
List<List<Integer>> ans = new ArrayList<>();
collectLeaves(root, ans);
return ans;
}
private int collectLeaves(TreeNode node, List<List<Integer>> ans) {
if (node == null) return -1; // Handle null nodes
int lDepth = collectLeaves(node.left, ans);
int rDepth = collectLeaves(node.right, ans);
int depth = Math.max(lDepth, rDepth) + 1;
while (ans.size() <= depth) {
ans.add(new ArrayList<>()); // Ensure enough space for depth grouping
}
ans.get(depth).add(node.val); // Add node value to its depth-level group
return depth;
}
}
Python
class Solution:
def findLeaves(self, root: Optional[TreeNode]) -> List[List[int]]:
ans: List[List[int]] = []
def dfs(node: Optional[TreeNode]) -> int:
if not node:
return -1 # Handle empty nodes
l_depth = dfs(node.left)
r_depth = dfs(node.right)
depth = max(l_depth, r_depth) + 1
if len(ans) <= depth:
ans.append([]) # Ensure enough space to store nodes at this depth
ans[depth].append(node.val) # Append node value to its depth-level group
return depth
dfs(root)
return ans
Complexity
- ⏰ Time complexity:
O(n), due to traversal of each node once. - 🧺 Space complexity:
O(n). Using recursion stack takesO(h), wherehis the tree height for and thenO(n)for result storage.