Problem

You are given the root of a binary tree with n nodes. Each node is assigned a unique value from 1 to n. You are also given an array queries of size m.

You have to perform m independent queries on the tree where in the ith query you do the following:

  • Remove the subtree rooted at the node with the value queries[i] from the tree. It is guaranteed that queries[i] will not be equal to the value of the root.

Return an array answer of size m where answer[i] is the height of the tree after performing the ith query.

Note:

  • The queries are independent, so the tree returns to its initial state after each query.
  • The height of a tree is the number of edges in the longest simple path from the root to some node in the tree.

Examples

Example 1:

---
title: Input tree
---
 graph TD;
     A1((1))
     B3((3))
     C4((4)):::query
     D2((2))
     E6((6))
     F5((5))
     G7((7))
     Null3_1(("null")):::null
     Null5_1(("null")):::null
     
     A1 --> B3
     A1 --> C4
     B3 --> D2
     B3 --> Null3_1
     C4 --> E6
     C4 --> F5
     F5 --> Null5_1
     F5 --> G7
     
     classDef query fill:#f96,stroke:#333,stroke-width:4px;
     classDef null fill:#eee,stroke:#999,stroke-width:2px,stroke-dasharray: 5,5;
  
---
title: Tree after removing query node
---
 graph TD;
     A1((1))
     B3((3))
     D2((2))
     Null1_1(("null")):::null
     Null3_1(("null")):::null
     
     A1 --> B3 & Null1_1
     B3 --> D2 & Null3_1

classDef null fill:#eee,stroke:#999,stroke-width:2px,stroke-dasharray: 5,5;
  
Input: root = [1,3,4,2,null,6,5,null,null,null,null,null,7], queries = [4]
Output: [2]
Explanation: The diagram above shows the tree after removing the subtree rooted at node with value 4.
The height of the tree is 2 (The path 1 -> 3 -> 2).

Example 2:

 graph TD;
     A5((5))
     B8((8)):::query
     C9((9))
     D2((2)):::query
     E1((1))
     F3((3)):::query
     G7((7))
     H4((4)):::query
     I6((6))
     
     A5 --> B8
     A5 --> C9
     B8 --> D2
     B8 --> E1
     C9 --> F3
     C9 --> G7
     D2 --> H4
     D2 --> I6
     
     classDef query fill:#f96,stroke:#333,stroke-width:4px;
  
Input: root = [5,8,9,2,1,3,7,4,6], queries = [3,2,4,8]
Output: [3,2,3,2]
Explanation: We have the following queries:
- Removing the subtree rooted at node with value 3. The height of the tree becomes 3 (The path 5 -> 8 -> 2 -> 4).
- Removing the subtree rooted at node with value 2. The height of the tree becomes 2 (The path 5 -> 8 -> 1).
- Removing the subtree rooted at node with value 4. The height of the tree becomes 3 (The path 5 -> 8 -> 2 -> 6).
- Removing the subtree rooted at node with value 8. The height of the tree becomes 2 (The path 5 -> 9 -> 3).

Solution

Video explanation

Here is the video explaining below methods in detail. Please check it out:

Method 1 - DFS

We can perform DFS to compute the height of the tree effectively. Here is the approach:

  1. Calculate the initial height of the complete tree using Depth-First Search (DFS).
  2. For each query:
    • Perform DFS again, but skip the subtree rooted at the specified node for that query.
    • Compute the height of the modified tree.
  3. Restore the tree’s state to its original configuration since queries are independent.

Code

Java
public class Solution {
    public int[] treeQueries(TreeNode root, int[] queries) {
        int originalHeight = getHeight(root);
        int[] result = new int[queries.length];
        
        for (int i = 0; i < queries.length; i++) {
            result[i] = dfs(root, queries[i]);
        }
        
        return result;
    }
    public int dfs(TreeNode node, int block) {
        if (node == null || node.val == block) {
	        return -1;
	    }
        return 1 + Math.max(dfs(node.left, block), dfs(node.right, block));
    }    
    public int getHeight(TreeNode node) {
        if (node == null) return -1;
        return 1 + Math.max(getHeight(node.left), getHeight(node.right));
    }
}
Python
class TreeNode:
    def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None

class Solution:
    def getHeight(self, node):
        if not node:
            return -1
        return 1 + max(self.getHeight(node.left), self.getHeight(node.right))
    
    def dfs(self, node, block):
        if not node or node.val == block:
            return -1
        return 1 + max(self.dfs(node.left, block), self.dfs(node.right, block))
    
    def treeQueries(self, root, queries):
        originalHeight = self.getHeight(root)
        result = []
        
        for q in queries:
            result.append(self.dfs(root, q))
        
        return result

Complexity

  • ⏰ Time complexity: O(n * m), where n is the number of nodes, and m is the number of queries.
    • Calculating the height of the tree using DFS takes O(n).
    • For each query, DFS traversal ignoring the subtree rooted at the queried node also takes O(n).
    • Thus, for m queries, it is O(n * m).
  • 🧺 Space complexity: O(n) for storing the tree structure and the recursion stack during DFS.

Method 2 - Precompute depth and height of each node

To improve the efficiency of the solution, we need to avoid repeatedly calculating the height of the entire tree for each query. Instead, we can precompute the height of the tree with and without each possible subtree.

Each node has a depth and a height, and the longest path passing through it is the sum of these two values, i.e. depth + height.

For eg. for node 4 below, longest height is 4. For node 7, it is 3.

When a node, say node 4 is removed, its children are also removed, stopping all paths that pass through it. However, if node 4 has peers (i.e. siblings or cousins), the paths through these peers will likely be longer, so we need to find the longest path through this node’s peers, which means locating the peer with the greatest height.

For node 4, that peer is node 7, resulting in height of 3.

To facilitate this, we store nodes by depth, and for those with the same depth, we sort them by height, retaining only the top two heights. In above tree for depth 2: We get [4, 5, 6, 7] but we retain nodes nodes [4,7] as they have max height.

When a node is ‘removed’ according to the queries, we find its peers and identify the one with the highest height to complete the task.

If there is only one node at that depth, it indicates that the removed node had no peers, so the longest remaining path is depth - 1. Otherwise, we determine the maximum height among the remaining nodes at that depth.

So, here is the approach:

  1. DFS Traversal: Compute the depth and height of each node using a DFS traversal.
  2. Storing Node Heights: Store nodes according to their depth and maintain a list of the highest and second highest nodes in terms of height for each depth level.
  3. Query Processing: For each query, determine the new height of the tree after removing the node by considering these pre-stored heights.

Code

Java
public class Solution {
    // Maps to store the depth and height of each node
    private Map<Integer, Integer> nodeDepths = new HashMap<>();
    private Map<Integer, Integer> nodeHeights = new HashMap<>();

    // Main function to handle the tree queries
    public int[] treeQueries(TreeNode root, int[] queries) {
        // Compute depth and height for each node
        computeDepthAndHeight(root, 0);

        // Map to store nodes according to their depth and heights
        Map<Integer, List<int[]>> peers = new HashMap<>();
        for (Map.Entry<Integer, Integer> entry : nodeDepths.entrySet()) {
            int val = entry.getKey();
            int depth = entry.getValue();

            // Initialize list for this depth if not already present
            peers.putIfAbsent(depth, new ArrayList<>());
            List<int[]> peerList = peers.get(depth);

            // Add the node's height and value to the list
            peerList.add(new int[]{nodeHeights.get(val), val});

            // Sort the list to keep nodes with the highest heights first
            peerList.sort((a, b) -> b[0] - a[0]);
            // Keep only the top 2 heights for this depth
            if (peerList.size() > 2) {
                peerList.remove(peerList.size() - 1);
            }
        }

        // Array to store the results of each query
        int[] result = new int[queries.length];
        // Process each query
        for (int i = 0; i < queries.length; i++) {
            int q = queries[i];
            int depth = nodeDepths.get(q);

            List<int[]> peerList = peers.get(depth);
            // Determine the longest path after removing the node
            if (peerList.size() == 1) {  // No peer, path length equals depth - 1
                result[i] = depth - 1;
            } else if (peerList.get(0)[1] == q) {  // The removed node has the largest height
                result[i] = peerList.get(1)[0] + depth;
            } else {  // Find the node with the largest height
                result[i] = peerList.get(0)[0] + depth;
            }
        }

        return result;
    }

    // DFS to compute depth and height
    private int computeDepthAndHeight(TreeNode node, int depth) {
        if (node == null) return -1;
        
        // Store the depth of the current node
        nodeDepths.put(node.val, depth);

        // Recursively compute the heights of left and right children
        int leftHeight = computeDepthAndHeight(node.left, depth + 1);
        int rightHeight = computeDepthAndHeight(node.right, depth + 1);
        
        // Compute the current node's height
        int height = 1 + Math.max(leftHeight, rightHeight);

        // Store the height of the current node
        nodeHeights.put(node.val, height);
        return height;
    }
}
Python
class Solution:
    def treeQueries(self, R: Optional[TreeNode], Q: List[int]) -> List[int]:
        Depth = collections.defaultdict(int)
        Height = collections.defaultdict(int)

        def dfs(node, depth):
            if not node:
                return -1
            Depth[node.val] = depth
            cur = max(dfs(node.left, depth + 1), dfs(node.right, depth + 1)) + 1
            Height[node.val] = cur
            return cur

        dfs(R, 0)

        peers = collections.defaultdict(list)  # Group nodes according to their depth, keep the top 2 heights.
        for val, depth in Depth.items():
            peers[depth].append((-Height[val], val))  # Use negative heights for sorting in descending order
            peers[depth].sort()
            if len(peers[depth]) > 2:
                peers[depth].pop()

        ans = []
        for q in Q:
            depth = Depth[q]
            if len(peers[depth]) == 1:  # No peer, path length equals depth - 1.
                ans.append(depth - 1)
            elif peers[depth][0][1] == q:  # The removed node has the largest height, look for the node with 2nd largest height.
                ans.append(-peers[depth][1][0] + depth)
            else:  # Look for the node with the largest height.
                ans.append(-peers[depth][0][0] + depth)
        return ans

Complexity

  • ⏰ Time complexity: O(n + m)
    • DFS traversal to compute height and depth of each node takes O(n)
    • Then for each node, we iterate through all nodes and insert them into a data structure (like a list or priority queue) that maintains the top k heights. Even though inserting and maintaining the order in such a structure can take O(log k) time, since k is a constant (2 in our case), it simplifies to O(1), hence overall this step incurs O(n) complexity in practice.
    • Finally, answering queries take O(m) time
  • 🧺 Space complexity: O(n + m)
    • O(n) for using recursion stack in dfs, heights, depths, and the sorted lists of nodes at each depth level is proportional to the number of nodes n.
    • O(m) for storing queries