Convert binary tree to its Sum tree
MediumUpdated: Aug 2, 2025
Problem
Given a binary tree, write an algorithm to convert it into its Sum tree.
Definition of Sum Tree
A Sum tree of a binary tree is defined as a tree in which each node stores the sum of the values of all its left and right subtrees in the original tree. Additionally, the leaf nodes of the binary tree are replaced by zero in the Sum tree.
Examples
Example 1
Input:
10
/ \
2 6
/ \ / \
1 3 4 5
Output:
21
/ \
4 9
/ \ / \
0 0 0 0
Explanation:
- Node 10 becomes 2+6+1+3+4+5 = 21.
- Node 2 becomes 1+3 = 4.
- Node 6 becomes 4+5 = 9.
- Leaf nodes 1, 3, 4, 5 are replaced with 0.
Example 2
Input:
5
/ \
3 8
/ \
2 1
Output:
14
/ \
3 0
/ \
0 0
Explanation:
- Node 10 becomes 2+6+1+3+4+5 = 21.
- Node 2 becomes 1+3 = 4.
- Node 6 becomes 4+5 = 9.
- Leaf nodes 1, 3, 4, 5 are replaced with 0.
Solution
Method 1 - Using Postorder Traversal
Convert binary tree to its Sum tree.
- Traverse the tree recursively in a postorder traversal (left subtree → right subtree → node itself).
- For each node:
- Compute the sum of values in the left and right subtree.
- Update the current node with the computed sum.
- Return the original value of the node plus the computed sum to its parent to contribute to the parent's value calculation.
- Replace leaf nodes with
0since their sum is calculated as0.
Code
Java
class Solution {
public void convertToSumTree(TreeNode root) {
helper(root);
}
private int helper(TreeNode node) {
// Base case: If the node is null, return 0
if (node == null) {
return 0;
}
// Recursive calls to left and right subtrees
int leftSum = helper(node.left);
int rightSum = helper(node.right);
// Save the original value of the node
int originalValue = node.value;
// Update node's value to the sum of left and right subtrees
node.value = leftSum + rightSum;
// Return the total sum of subtree including original value
return originalValue + node.value;
}
// Example usage
public static void main(String[] args) {
TreeNode root = new TreeNode(10);
root.left = new TreeNode(2);
root.right = new TreeNode(6);
root.left.left = new TreeNode(1);
root.left.right = new TreeNode(3);
root.right.left = new TreeNode(4);
root.right.right = new TreeNode(5);
Solution solution = new Solution();
solution.convertToSumTree(root);
}
}
Python
class Solution:
def convertToSumTree(self, root):
def helper(node):
# Base case: If the node is null, return 0
if node is None:
return 0
# Recursive calls to left and right subtrees
left_sum = helper(node.left)
right_sum = helper(node.right)
# Save the original value of the node
original_value = node.value
# Update node's value to the sum of left and right subtrees
node.value = left_sum + right_sum
# Return the total sum of subtree including original value
return original_value + node.value
helper(root)
# Example usage
root = TreeNode(10, TreeNode(2, TreeNode(1), TreeNode(3)), TreeNode(6, TreeNode(4), TreeNode(5)))
solution = Solution()
solution.convertToSumTree(root)
Complexity
- ⏰ Time complexity:
O(n), wherenis the number of nodes in the binary tree. The algorithm involves visiting each node exactly once in postorder traversal. - 🧺 Space complexity:
O(h), wherehis the height of the tree. The space complexity is proportional to the recursion stack depth.