Problem

Given the root of a binary tree, replace the value of each node in the tree with the sum of all its cousins’ values.

Two nodes of a binary tree are cousins if they have the same depth with different parents.

Return the root of the modified tree.

Note that the depth of a node is the number of edges in the path from the root node to it.

Examples

Example 1:

Input (# denotes null):

graph TD;
5 --> 4 & 9
4 --> 1 & 10
9 --> A("#") & 7
  

Output:

graph TD;
0 --> A(0) & B(0)
A --> 7 & C(7)
B --> D("#") & 11
  
Input: root = [5,4,9,1,10,null,7]
Output: [0,0,0,7,7,null,11]
Explanation: The diagram above shows the initial binary tree and the binary tree after changing the value of each node.
- Node with value 5 does not have any cousins so its sum is 0.
- Node with value 4 does not have any cousins so its sum is 0.
- Node with value 9 does not have any cousins so its sum is 0.
- Node with value 1 has a cousin with value 7 so its sum is 7.
- Node with value 10 has a cousin with value 7 so its sum is 7.
- Node with value 7 has cousins with values 1 and 10 so its sum is 11.

Example 2:

Input:

    3
   / \
  1   2

Output:

    0
   / \
  0   0
Input: root = [3,1,2]
Output: [0,0,0]
Explanation: The diagram above shows the initial binary tree and the binary tree after changing the value of each node.
- Node with value 3 does not have any cousins so its sum is 0.
- Node with value 1 does not have any cousins so its sum is 0.
- Node with value 2 does not have any cousins so its sum is 0.

Similar Problem

Cousins in Binary Tree

Solution

Method 1 - BFS - Two Pass

Here is the approach:

  1. Initialization:

    • Create a queue for level-order traversal (BFS).
    • Create a list to store the sum of values at each level (levelSums).
    • Create a map to track the sum of values of each node’s siblings (siblingSums).
  2. First BFS Traversal (Calculate Level Sums and Sibling Sums):

    • Initialize the queue with the root node.
    • For each level in the tree:
      • Calculate the total sum of node values for the current level (levelSum).
      • For each node, calculate the sum of its left and right child values (siblingSum).
      • Store the siblingSum for each child in the map (siblingSums).
      • Add the sum of values for the current level to the levelSums list.
  3. Second BFS Traversal (Update Node Values with Cousin Sums):

    • Reinitialize the queue with the root node.
    • For each level in the tree:
      • Retrieve the level sum for the current level from levelSums.
      • For each node, calculate the sum of its cousins’ values as cousinSum = levelSum - siblingSums[node].
      • Update the node’s value with cousinSum.
  4. Output the Modified Tree:

    • The tree is updated in place, and the modified root is returned.

Video explanation

Here is the video explaining this method in detail. Please check it out:

Code

Java
public class Solution {
    public TreeNode replaceValueInTree(TreeNode root) {
        // If the tree is empty, return null
        if (root == null) return null;

        // Initialize a queue for BFS traversal
        Queue<TreeNode> queue = new LinkedList<>();
        // List to store the sum of values at each level
        List<Integer> levelSums = new ArrayList<>();
        // Map to track the sum of sibling values for each node
        Map<TreeNode, Integer> siblingSums = new HashMap<>();
        
        // Start BFS with the root node
        queue.offer(root);
        siblingSums.put(root, root.val);
        
        // First BFS to calculate level sums and track sibling sums
        while (!queue.isEmpty()) {
            int levelSize = queue.size(); // Number of nodes at the current level
            int levelSum = 0; // Sum of node values at the current level
            
            // Traverse all nodes at the current level
            for (int i = 0; i < levelSize; i++) {
                TreeNode node = queue.poll();
                levelSum += node.val;
                int siblingSum = 0;
                
                // If the node has a left child, process it
                if (node.left != null) {
                    queue.offer(node.left);
                    siblingSum += node.left.val;
                }
                
                // If the node has a right child, process it
                if (node.right != null) {
                    queue.offer(node.right);
                    siblingSum += node.right.val;
                }

                // Update sibling sum for children
                if (node.left != null) {
                    siblingSums.put(node.left, siblingSum);
                }
                if (node.right != null) {
                    siblingSums.put(node.right, siblingSum);
                }
            }
            
            // Store the level sum
            levelSums.add(levelSum);
        }
        
        // BFS to modify the tree with cousin sums
        queue.clear();
        queue.offer(root);
        int level = 0;
        
        while (!queue.isEmpty()) {
            int levelSize = queue.size(); // Number of nodes at the current level
            int levelSum = levelSums.get(level); // Sum of node values at the current level
            
            // Traverse all nodes at the current level
            for (int i = 0; i < levelSize; i++) {
                TreeNode node = queue.poll();
                
                // Calculate the cousin sum
                int cousinSum = levelSum - siblingSums.get(node);
                
                // Replace the node's value with the cousin sum
                node.val = cousinSum;
                
                // Process the left child
                if (node.left != null) {
                queue.offer(node.left);
            }
                
                // Process the right child
                if (node.right != null) {
                queue.offer(node.right);
            }
            }
            
            level++;
        }
        
        return root;
    }
}
Python
class Solution:
    def replaceValueInTree(self, root: TreeNode) -> TreeNode:
        if not root:
            return None
        
        queue = deque([root])
        level_sums = []
        sibling_sums = {root: root.val}
        
        # First BFS to calculate level sums and track sibling sums
        while queue:
            level_size = len(queue)
            level_sum = 0
            
            for _ in range(level_size):
                node = queue.popleft()
                level_sum += node.val
                sibling_sum = 0
                
                if node.left:
                    queue.append(node.left)
                    sibling_sum += node.left.val
                
                if node.right:
                    queue.append(node.right)
                    sibling_sum += node.right.val
                
                if node.left:
                    sibling_sums[node.left] = sibling_sum
                
                if node.right:
                    sibling_sums[node.right] = sibling_sum
            
            level_sums.append(level_sum)
        
        # BFS to modify the tree with cousin sums
        queue.append(root)
        level = 0
        
        while queue:
            level_size = len(queue)
            level_sum = level_sums[level]
            
            for _ in range(level_size):
                node = queue.popleft()
                
                # Calculate the cousin sum
                cousin_sum = level_sum - sibling_sums[node]
                
                node.val = cousin_sum
                
                if node.left:
                    queue.append(node.left)
                
                if node.right:
                    queue.append(node.right)
            
            level += 1
        
        return root

Complexity

  • ⏰ Time complexity: O(n), where n is the number of nodes in the tree. Each node is visited twice.
  • 🧺 Space complexity: O(n),  for storing the queue, level sums, and sibling sums.

Method 2 - BFS + Kids Sum

  1. BFS traversal to get next level sum, and keep a copy of current level of nodes;
  2. For each node, use next level sum minus its kids’ value sum to get the required value.

Code

Java
    public TreeNode replaceValueInTree(TreeNode root) {
        root.val = 0;
        Queue<TreeNode> q = new LinkedList<>();
        q.offer(root);
        while (!q.isEmpty()) {
            List<TreeNode> parents = new ArrayList<>(q);
            int sum = 0;
            for (int sz = q.size(); sz > 0; --sz) {
                TreeNode n = q.poll();
                for (TreeNode kid : new TreeNode[]{n.left, n.right}) {
                    if (kid != null) {
                        q.offer(kid);
                        sum += kid.val;
                    }
                }
            }
            for (TreeNode n : parents) {
                int replacedVal = sum;
                if (n.left != null || n.right != null) {
                    for (TreeNode kid : new TreeNode[]{n.left, n.right}) {
                        if (kid != null) {
                            replacedVal -= kid.val;
                        }
                    }
                    for (TreeNode kid : new TreeNode[]{n.left, n.right}) {
                        if (kid != null) {
                            kid.val = replacedVal;
                        }
                    }
                }
            }
        }
        return root;
    }

Complexity

  • Time: O(n)
  • Space: O(n)