Problem

You are given the root of a binary tree and a positive integer k.

The level sum in the tree is the sum of the values of the nodes that are on the same level.

Return the kth largest level sum in the tree (not necessarily distinct). If there are fewer than k levels in the tree, return -1.

Note that two nodes are on the same level if they have the same distance from the root.

Examples

Example 1:

graph TD
	A[5] --> B[8]
	A[5] --> C[9]
	B[8] --> D[2]
	B[8] --> E[1]
	C[9] --> F[3]
	C[9] --> G[7]
	D[2] --> H[4]
	D[2] --> I[6]
  
Input: root = [5,8,9,2,1,3,7,4,6], k = 2
Output: 13
Explanation: The level sums are the following:
- Level 1: 5.
- Level 2: 8 + 9 = 17.
- Level 3: 2 + 1 + 3 + 7 = 13.
- Level 4: 4 + 6 = 10.
The 2nd largest level sum is 13.

Example 2:

    1
   /
  2
 /
3
Input: root = [1,2,null,3], k = 1
Output: 3
Explanation: The largest level sum is 3.

Solution

Video explanation

Here is the video explaining below methods in detail. Please check it out:

Method 1 - Using BFS

To solve the problem, we can perform a level order traversal (BFS) of the binary tree to compute the sum of nodes at each level. We then keep track of these sums and, finally, find the kth largest sum.

  1. Breadth-First Search (BFS): Use BFS to access nodes level by level.
  2. Sum Calculation: For each level, calculate the sum of the node values.
  3. Store Level Sums: Store the level sums in a list.
  4. Find k-th Largest Sum: Sort the list of level sums in descending order and return the kth largest sum if it exists; otherwise, return -1.

Code

Java
public class Solution {
    public long kthLargestLevelSum(TreeNode root, int k) {
        Queue<TreeNode> queue = new LinkedList<>();
        List<Long> levelSums = new ArrayList<>();
        
        queue.offer(root);
        
        while (!queue.isEmpty()) {
            int levelSize = queue.size();
            int levelSum = 0;
            
            for (int i = 0; i < levelSize; i++) {
                TreeNode node = queue.poll();
                levelSum += node.val;
                
                if (node.left != null) {
	                queue.offer(node.left);
	            }
                if (node.right != null) {
	                queue.offer(node.right);
	            }
            }
            
            levelSums.add(levelSum);
        }
        
        Collections.sort(levelSums, Collections.reverseOrder());
        
        return k <= levelSums.size() ? levelSums.get(k - 1) : -1;
    }
}
Python
class Solution:
    def kthLargestLevelSum(self, root: TreeNode, k: int) -> int:
        if not root:
            return -1
        
        queue = deque([root])
        level_sums = []
        
        while queue:
            level_size = len(queue)
            level_sum = 0
            
            for _ in range(level_size):
                curr = queue.popleft()
                level_sum += curr.val
                
                if curr.left:
                    queue.append(curr.left)
                if curr.right:
                    queue.append(curr.right)
            
            level_sums.append(level_sum)
        
        level_sums.sort(reverse=True)
        
        return level_sums[k - 1] if k <= len(level_sums) else -1

Complexity

  • ⏰ Time complexity: O(n + h log h), where n is number of nodes in tree, and h is the number of levels in the tree.
    • O(n) to calculate sum of all the levels
    • O(h log h) due to sorting the list of level sums
    • In worst case h will be equal to n, which is number of nodes in the tree.
  • 🧺 Space complexity: O(h) for storing the sums of the levels.

Method 2 - Using max heap

Here is the approach:

  1. Perform a level-order traversal (BFS) of the binary tree.
  2. Calculate the sum of the nodes at each level.
  3. Insert each level sum into a max heap.
  4. Extract the maximum element from the heap k-1 times to get the kth largest element.

Code

Java
public class Solution {
    public long kthLargestLevelSum(TreeNode root, int k) {
        // If the tree is empty, return -1
        if (root == null) return -1;
        
        // Initialize a queue for level-order traversal and a min heap to track the k largest levels
        Queue<TreeNode> queue = new LinkedList<>();
        PriorityQueue<Long> minHeap = new PriorityQueue<>();
        
        // Start with the root node in the queue
        queue.offer(root);
        
        while (!queue.isEmpty()) {
            // Get the number of nodes at the current level
            int levelSize = queue.size();
            // Initialize the sum for the current level
            long levelSum = 0;
            
            // Process each node at the current level
            for (int i = 0; i < levelSize; i++) {
                // Poll the next node from the queue
                TreeNode curr = queue.poll();
                // Add its value to the level sum
                levelSum += curr.val;
                
                // If the current node has a left child, add it to the queue
                if (curr.left != null) {
                queue.offer(curr.left);
            }
                // If the current node has a right child, add it to the queue
                if (curr.right != null) {
                queue.offer(curr.right);
            }
            }
            
            // Add the level sum to the min heap
            minHeap.offer(levelSum);
            
            // If the heap size exceeds k, remove the smallest element to keep only the k largest sums
            if (minHeap.size() > k) {
                minHeap.poll();
            }
        }
        
        // If the heap contains fewer than k elements, return -1; otherwise, return the k-th largest sum
        return minHeap.size() < k ? -1 : minHeap.peek();
    }
}
Python
class Solution:
    def kthLargestLevelSum(self, root: TreeNode, k: int) -> int:
        if not root:
            return -1
        
        queue = deque([root])
        max_heap = []
        
        while queue:
            level_size = len(queue)
            level_sum = 0
            
            for _ in range(level_size):
                curr = queue.popleft()
                level_sum += curr.val
                
                if curr.left:
                    queue.append(curr.left)
                if curr.right:
                    queue.append(curr.right)
            
            # Using a max heap to store negative values
            heapq.heappush(max_heap, -level_sum)
        
        # Extract max k-1 times
        for _ in range(k - 1):
            if not max_heap:
                return -1
            heapq.heappop(max_heap)
        
        return -heapq.heappop(max_heap) if max_heap else -1

Complexity

  • ⏰ Time complexity: O(n + (h + k) log h),
    • Doing BFS to calculate level sum takes O(n) time
    • Adding these to heap takes O(h log h) time
    • Remove k max element k times from heap takes O(k log h) time
    • Each insertion and extraction operation from the heap takes O(log h) time.
    • In worst case h = n.
  • 🧺 Space complexity: O(k)

Method 3 - Using Minheap

The strategy will be to maintain a min heap of size k. As we encounter level sums, we add them to the heap. If the heap exceeds size k, we remove the smallest sum from heap.

The root of the heap will then be the kth largest element.

Code

Java
public class Solution {
    public long kthLargestLevelSum(TreeNode root, int k) {
        if (root == null) return -1;
        
        Queue<TreeNode> queue = new LinkedList<>();
        PriorityQueue<Long> minHeap = new PriorityQueue<>();
        
        queue.offer(root);
        
        while (!queue.isEmpty()) {
            int levelSize = queue.size();
            long levelSum = 0;
            
            for (int i = 0; i < levelSize; i++) {
                TreeNode curr = queue.poll();
                levelSum += curr.val;
                
                if (curr.left != null) {
	                queue.offer(curr.left);
	            }
                if (curr.right != null) {
	                queue.offer(curr.right);
	            }
            }
            
            minHeap.offer(levelSum);
            
            if (minHeap.size() > k) {
                minHeap.poll();
            }
        }
        
        return minHeap.size() < k ? -1 : minHeap.peek();
    }
}
Python
class Solution:
    def kthLargestLevelSum(self, root: TreeNode, k: int) -> int:
        if not root:
            return -1
        
        queue = deque([root])
        min_heap = []
        
        while queue:
            level_size = len(queue)
            level_sum = 0
            
            for _ in range(level_size):
                curr = queue.popleft()
                level_sum += curr.val
                
                if curr.left:
                    queue.append(curr.left)
                if curr.right:
                    queue.append(curr.right)
            
            heapq.heappush(min_heap, level_sum)
            
            if len(min_heap) > k:
                heapq.heappop(min_heap)
        
        return min_heap[0] if len(min_heap) == k else -1

Complexity

  • ⏰ Time complexity: O(n + h log k)
  • Each insertion and removal operation in the heap takes O(log k) time, and we do it for h levels
  • 🧺 Space complexity: O(k) for storing in min heap.