Problem#
Given the root
of a binary tree, return the most frequent subtree sum . If there is a tie, return all the values with the highest frequency in any order.
The subtree sum of a node is defined as the sum of all the node values formed by the subtree rooted at that node (including the node itself).
Examples#
Example 1:
1
2
3
4
5
5
/ \
2 - 3
Input: root = [5 ,2 ,- 3 ]
Output: [2 ,- 3 ,4 ]
Example 2:
1
2
3
4
5
6
5
/ \
2 - 5
Input: root = [5 ,2 ,- 5 ]
Output: [2 ]
Explanation: The subtree sums for subtrees with root as 5 , 2 and - 5 as [2 , 2 , - 5 ] with 2 being the most frequent sum.
Solution#
Method 1 - Using Postorder Traversal#
To solve this problem, we need to:
Calculate the subtree sum for each node in the binary tree.
Track the frequency of each subtree sum using a dictionary or a hash map.
Identify the subtree sum(s) that appear most frequently.
Approach#
Use a helper function to recursively calculate the subtree sum for each node.
During the recursion, update the frequency map for each subtree sum.
After computing the subtree sums, find the maximum frequency and return the corresponding subtree sums.
Code#
Java
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
public class Solution {
public int [] findFrequentTreeSum (TreeNode root) {
if (root == null ) return new int [ 0] ;
Map< Integer, Integer> freq = new HashMap<> ();
postOrderSum(root, freq);
int maxFreq = 0;
for (int count : freq.values ()) {
if (count > maxFreq) {
maxFreq = count;
}
}
List< Integer> ans = new ArrayList<> ();
for (Map.Entry < Integer, Integer> entry : freq.entrySet ()) {
if (entry.getValue () == maxFreq) {
ans.add (entry.getKey ());
}
}
int [] res = new int [ ans.size ()] ;
for (int i = 0; i < ans.size (); i++ ) {
res[ i] = ans.get (i);
}
return res;
}
private int postOrderSum (TreeNode node, Map< Integer, Integer> freq) {
if (node == null ) return 0;
int leftSum = postOrderSum(node.left , freq);
int rightSum = postOrderSum(node.right , freq);
int totalSum = node.val + leftSum + rightSum;
freq.put (totalSum, freq.getOrDefault (totalSum, 0) + 1);
return totalSum;
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class Solution :
def findFrequentTreeSum (self, root: Optional[TreeNode]) -> List[int]:
def subtree_sum (node: TreeNode) -> int:
if not node:
return 0
left_sum = subtree_sum(node. left)
right_sum = subtree_sum(node. right)
total_sum = node. val + left_sum + right_sum
freq[total_sum] += 1
return total_sum
freq: Dict[int, int] = defaultdict(int)
subtree_sum(root)
max_freq = max(freq. values(), default= 0 )
ans = [s for s, c in freq. items() if c == max_freq]
return ans
Complexity#
⏰ Time complexity: O(n)
, where n
is the number of nodes in the tree, since we visit each node exactly once.
🧺 Space complexity: O(n)
, due to the space required for the recursive call stack and the frequency map.