Problem
You are given the root
of a binary tree and a positive integer k
.
The level sum in the tree is the sum of the values of the nodes that are on the same level.
Return the kth
largest level sum in the tree (not necessarily distinct). If there are fewer than k
levels in the tree, return -1
.
Note that two nodes are on the same level if they have the same distance from the root.
Examples
Example 1:
graph TD A[5] --> B[8] A[5] --> C[9] B[8] --> D[2] B[8] --> E[1] C[9] --> F[3] C[9] --> G[7] D[2] --> H[4] D[2] --> I[6]
Input: root = [5,8,9,2,1,3,7,4,6], k = 2
Output: 13
Explanation: The level sums are the following:
- Level 1: 5.
- Level 2: 8 + 9 = 17.
- Level 3: 2 + 1 + 3 + 7 = 13.
- Level 4: 4 + 6 = 10.
The 2nd largest level sum is 13.
Example 2:
1
/
2
/
3
Input: root = [1,2,null,3], k = 1
Output: 3
Explanation: The largest level sum is 3.
Solution
Video explanation
Here is the video explaining below methods in detail. Please check it out:
Method 1 - Using BFS
To solve the problem, we can perform a level order traversal (BFS) of the binary tree to compute the sum of nodes at each level. We then keep track of these sums and, finally, find the k
th largest sum.
- Breadth-First Search (BFS): Use BFS to access nodes level by level.
- Sum Calculation: For each level, calculate the sum of the node values.
- Store Level Sums: Store the level sums in a list.
- Find k-th Largest Sum: Sort the list of level sums in descending order and return the
k
th largest sum if it exists; otherwise, return-1
.
Code
Java
public class Solution {
public long kthLargestLevelSum(TreeNode root, int k) {
Queue<TreeNode> queue = new LinkedList<>();
List<Long> levelSums = new ArrayList<>();
queue.offer(root);
while (!queue.isEmpty()) {
int levelSize = queue.size();
int levelSum = 0;
for (int i = 0; i < levelSize; i++) {
TreeNode node = queue.poll();
levelSum += node.val;
if (node.left != null) {
queue.offer(node.left);
}
if (node.right != null) {
queue.offer(node.right);
}
}
levelSums.add(levelSum);
}
Collections.sort(levelSums, Collections.reverseOrder());
return k <= levelSums.size() ? levelSums.get(k - 1) : -1;
}
}
Python
class Solution:
def kthLargestLevelSum(self, root: TreeNode, k: int) -> int:
if not root:
return -1
queue = deque([root])
level_sums = []
while queue:
level_size = len(queue)
level_sum = 0
for _ in range(level_size):
curr = queue.popleft()
level_sum += curr.val
if curr.left:
queue.append(curr.left)
if curr.right:
queue.append(curr.right)
level_sums.append(level_sum)
level_sums.sort(reverse=True)
return level_sums[k - 1] if k <= len(level_sums) else -1
Complexity
- ⏰ Time complexity:
O(n + h log h)
, wheren
is number of nodes in tree, andh
is the number of levels in the tree.O(n)
to calculate sum of all the levelsO(h log h)
due to sorting the list of level sums- In worst case
h
will be equal ton
, which is number of nodes in the tree.
- 🧺 Space complexity:
O(h)
for storing the sums of the levels.
Method 2 - Using max heap
Here is the approach:
- Perform a level-order traversal (BFS) of the binary tree.
- Calculate the sum of the nodes at each level.
- Insert each level sum into a max heap.
- Extract the maximum element from the heap
k-1
times to get thek
th largest element.
Code
Java
public class Solution {
public long kthLargestLevelSum(TreeNode root, int k) {
// If the tree is empty, return -1
if (root == null) return -1;
// Initialize a queue for level-order traversal and a min heap to track the k largest levels
Queue<TreeNode> queue = new LinkedList<>();
PriorityQueue<Long> minHeap = new PriorityQueue<>();
// Start with the root node in the queue
queue.offer(root);
while (!queue.isEmpty()) {
// Get the number of nodes at the current level
int levelSize = queue.size();
// Initialize the sum for the current level
long levelSum = 0;
// Process each node at the current level
for (int i = 0; i < levelSize; i++) {
// Poll the next node from the queue
TreeNode curr = queue.poll();
// Add its value to the level sum
levelSum += curr.val;
// If the current node has a left child, add it to the queue
if (curr.left != null) {
queue.offer(curr.left);
}
// If the current node has a right child, add it to the queue
if (curr.right != null) {
queue.offer(curr.right);
}
}
// Add the level sum to the min heap
minHeap.offer(levelSum);
// If the heap size exceeds k, remove the smallest element to keep only the k largest sums
if (minHeap.size() > k) {
minHeap.poll();
}
}
// If the heap contains fewer than k elements, return -1; otherwise, return the k-th largest sum
return minHeap.size() < k ? -1 : minHeap.peek();
}
}
Python
class Solution:
def kthLargestLevelSum(self, root: TreeNode, k: int) -> int:
if not root:
return -1
queue = deque([root])
max_heap = []
while queue:
level_size = len(queue)
level_sum = 0
for _ in range(level_size):
curr = queue.popleft()
level_sum += curr.val
if curr.left:
queue.append(curr.left)
if curr.right:
queue.append(curr.right)
# Using a max heap to store negative values
heapq.heappush(max_heap, -level_sum)
# Extract max k-1 times
for _ in range(k - 1):
if not max_heap:
return -1
heapq.heappop(max_heap)
return -heapq.heappop(max_heap) if max_heap else -1
Complexity
- ⏰ Time complexity:
O(n + (h + k) log h)
,- Doing BFS to calculate level sum takes O(n) time
- Adding these to heap takes O(h log h) time
- Remove k max element k times from heap takes O(k log h) time
- Each insertion and extraction operation from the heap takes
O(log h)
time. - In worst case h = n.
- 🧺 Space complexity:
O(k)
Method 3 - Using Minheap
The strategy will be to maintain a min heap of size k. As we encounter level sums, we add them to the heap. If the heap exceeds size k, we remove the smallest sum from heap.
The root of the heap will then be the kth largest element.
Code
Java
public class Solution {
public long kthLargestLevelSum(TreeNode root, int k) {
if (root == null) return -1;
Queue<TreeNode> queue = new LinkedList<>();
PriorityQueue<Long> minHeap = new PriorityQueue<>();
queue.offer(root);
while (!queue.isEmpty()) {
int levelSize = queue.size();
long levelSum = 0;
for (int i = 0; i < levelSize; i++) {
TreeNode curr = queue.poll();
levelSum += curr.val;
if (curr.left != null) {
queue.offer(curr.left);
}
if (curr.right != null) {
queue.offer(curr.right);
}
}
minHeap.offer(levelSum);
if (minHeap.size() > k) {
minHeap.poll();
}
}
return minHeap.size() < k ? -1 : minHeap.peek();
}
}
Python
class Solution:
def kthLargestLevelSum(self, root: TreeNode, k: int) -> int:
if not root:
return -1
queue = deque([root])
min_heap = []
while queue:
level_size = len(queue)
level_sum = 0
for _ in range(level_size):
curr = queue.popleft()
level_sum += curr.val
if curr.left:
queue.append(curr.left)
if curr.right:
queue.append(curr.right)
heapq.heappush(min_heap, level_sum)
if len(min_heap) > k:
heapq.heappop(min_heap)
return min_heap[0] if len(min_heap) == k else -1
Complexity
- ⏰ Time complexity:
O(n + h log k)
- Each insertion and removal operation in the heap takes O(log k) time, and we do it for h levels
- 🧺 Space complexity:
O(k)
for storing in min heap.