Problem
You are given the root
of a binary tree with n
nodes. Each node is assigned a unique value from 1
to n
. You are also given an array queries
of size m
.
You have to perform m
independent queries on the tree where in the ith
query you do the following:
- Remove the subtree rooted at the node with the value
queries[i]
from the tree. It is guaranteed thatqueries[i]
will not be equal to the value of the root.
Return an array answer
of size m
where answer[i]
is the height of the tree after performing the ith
query.
Note:
- The queries are independent, so the tree returns to its initial state after each query.
- The height of a tree is the number of edges in the longest simple path from the root to some node in the tree.
Examples
Example 1:
--- title: Input tree --- graph TD; A1((1)) B3((3)) C4((4)):::query D2((2)) E6((6)) F5((5)) G7((7)) Null3_1(("null")):::null Null5_1(("null")):::null A1 --> B3 A1 --> C4 B3 --> D2 B3 --> Null3_1 C4 --> E6 C4 --> F5 F5 --> Null5_1 F5 --> G7 classDef query fill:#f96,stroke:#333,stroke-width:4px; classDef null fill:#eee,stroke:#999,stroke-width:2px,stroke-dasharray: 5,5;
--- title: Tree after removing query node --- graph TD; A1((1)) B3((3)) D2((2)) Null1_1(("null")):::null Null3_1(("null")):::null A1 --> B3 & Null1_1 B3 --> D2 & Null3_1 classDef null fill:#eee,stroke:#999,stroke-width:2px,stroke-dasharray: 5,5;
Input: root = [1,3,4,2,null,6,5,null,null,null,null,null,7], queries = [4]
Output: [2]
Explanation: The diagram above shows the tree after removing the subtree rooted at node with value 4.
The height of the tree is 2 (The path 1 -> 3 -> 2).
Example 2:
graph TD; A5((5)) B8((8)):::query C9((9)) D2((2)):::query E1((1)) F3((3)):::query G7((7)) H4((4)):::query I6((6)) A5 --> B8 A5 --> C9 B8 --> D2 B8 --> E1 C9 --> F3 C9 --> G7 D2 --> H4 D2 --> I6 classDef query fill:#f96,stroke:#333,stroke-width:4px;
Input: root = [5,8,9,2,1,3,7,4,6], queries = [3,2,4,8]
Output: [3,2,3,2]
Explanation: We have the following queries:
- Removing the subtree rooted at node with value 3. The height of the tree becomes 3 (The path 5 -> 8 -> 2 -> 4).
- Removing the subtree rooted at node with value 2. The height of the tree becomes 2 (The path 5 -> 8 -> 1).
- Removing the subtree rooted at node with value 4. The height of the tree becomes 3 (The path 5 -> 8 -> 2 -> 6).
- Removing the subtree rooted at node with value 8. The height of the tree becomes 2 (The path 5 -> 9 -> 3).
Solution
Video explanation
Here is the video explaining below methods in detail. Please check it out:
Method 1 - DFS
We can perform DFS to compute the height of the tree effectively. Here is the approach:
- Calculate the initial height of the complete tree using Depth-First Search (DFS).
- For each query:
- Perform DFS again, but skip the subtree rooted at the specified node for that query.
- Compute the height of the modified tree.
- Restore the tree’s state to its original configuration since queries are independent.
Code
Java
public class Solution {
public int[] treeQueries(TreeNode root, int[] queries) {
int originalHeight = getHeight(root);
int[] result = new int[queries.length];
for (int i = 0; i < queries.length; i++) {
result[i] = dfs(root, queries[i]);
}
return result;
}
public int dfs(TreeNode node, int block) {
if (node == null || node.val == block) {
return -1;
}
return 1 + Math.max(dfs(node.left, block), dfs(node.right, block));
}
public int getHeight(TreeNode node) {
if (node == null) return -1;
return 1 + Math.max(getHeight(node.left), getHeight(node.right));
}
}
Python
class TreeNode:
def __init__(self, x):
self.val = x
self.left = None
self.right = None
class Solution:
def getHeight(self, node):
if not node:
return -1
return 1 + max(self.getHeight(node.left), self.getHeight(node.right))
def dfs(self, node, block):
if not node or node.val == block:
return -1
return 1 + max(self.dfs(node.left, block), self.dfs(node.right, block))
def treeQueries(self, root, queries):
originalHeight = self.getHeight(root)
result = []
for q in queries:
result.append(self.dfs(root, q))
return result
Complexity
- ⏰ Time complexity:
O(n * m)
, wheren
is the number of nodes, andm
is the number of queries.- Calculating the height of the tree using DFS takes
O(n)
. - For each query, DFS traversal ignoring the subtree rooted at the queried node also takes
O(n)
. - Thus, for
m
queries, it isO(n * m)
.
- Calculating the height of the tree using DFS takes
- 🧺 Space complexity:
O(n)
for storing the tree structure and the recursion stack during DFS.
Method 2 - Precompute depth and height of each node
To improve the efficiency of the solution, we need to avoid repeatedly calculating the height of the entire tree for each query. Instead, we can precompute the height of the tree with and without each possible subtree.
Each node has a depth
and a height
, and the longest path passing through it is the sum of these two values, i.e. depth + height
.
For eg. for node 4 below, longest height is 4. For node 7, it is 3.
When a node, say node 4 is removed, its children are also removed, stopping all paths that pass through it. However, if node 4 has peers (i.e. siblings or cousins), the paths through these peers will likely be longer, so we need to find the longest path through this node’s peers, which means locating the peer with the greatest height.
For node 4, that peer is node 7, resulting in height of 3.
To facilitate this, we store nodes by depth, and for those with the same depth, we sort them by height, retaining only the top two heights. In above tree for depth 2: We get [4, 5, 6, 7]
but we retain nodes nodes [4,7]
as they have max height.
When a node is ‘removed’ according to the queries, we find its peers and identify the one with the highest height to complete the task.
If there is only one node at that depth, it indicates that the removed node had no peers, so the longest remaining path is depth - 1. Otherwise, we determine the maximum height among the remaining nodes at that depth.
So, here is the approach:
- DFS Traversal: Compute the depth and height of each node using a DFS traversal.
- Storing Node Heights: Store nodes according to their depth and maintain a list of the highest and second highest nodes in terms of height for each depth level.
- Query Processing: For each query, determine the new height of the tree after removing the node by considering these pre-stored heights.
Code
Java
public class Solution {
// Maps to store the depth and height of each node
private Map<Integer, Integer> nodeDepths = new HashMap<>();
private Map<Integer, Integer> nodeHeights = new HashMap<>();
// Main function to handle the tree queries
public int[] treeQueries(TreeNode root, int[] queries) {
// Compute depth and height for each node
computeDepthAndHeight(root, 0);
// Map to store nodes according to their depth and heights
Map<Integer, List<int[]>> peers = new HashMap<>();
for (Map.Entry<Integer, Integer> entry : nodeDepths.entrySet()) {
int val = entry.getKey();
int depth = entry.getValue();
// Initialize list for this depth if not already present
peers.putIfAbsent(depth, new ArrayList<>());
List<int[]> peerList = peers.get(depth);
// Add the node's height and value to the list
peerList.add(new int[]{nodeHeights.get(val), val});
// Sort the list to keep nodes with the highest heights first
peerList.sort((a, b) -> b[0] - a[0]);
// Keep only the top 2 heights for this depth
if (peerList.size() > 2) {
peerList.remove(peerList.size() - 1);
}
}
// Array to store the results of each query
int[] result = new int[queries.length];
// Process each query
for (int i = 0; i < queries.length; i++) {
int q = queries[i];
int depth = nodeDepths.get(q);
List<int[]> peerList = peers.get(depth);
// Determine the longest path after removing the node
if (peerList.size() == 1) { // No peer, path length equals depth - 1
result[i] = depth - 1;
} else if (peerList.get(0)[1] == q) { // The removed node has the largest height
result[i] = peerList.get(1)[0] + depth;
} else { // Find the node with the largest height
result[i] = peerList.get(0)[0] + depth;
}
}
return result;
}
// DFS to compute depth and height
private int computeDepthAndHeight(TreeNode node, int depth) {
if (node == null) return -1;
// Store the depth of the current node
nodeDepths.put(node.val, depth);
// Recursively compute the heights of left and right children
int leftHeight = computeDepthAndHeight(node.left, depth + 1);
int rightHeight = computeDepthAndHeight(node.right, depth + 1);
// Compute the current node's height
int height = 1 + Math.max(leftHeight, rightHeight);
// Store the height of the current node
nodeHeights.put(node.val, height);
return height;
}
}
Python
class Solution:
def treeQueries(self, R: Optional[TreeNode], Q: List[int]) -> List[int]:
Depth = collections.defaultdict(int)
Height = collections.defaultdict(int)
def dfs(node, depth):
if not node:
return -1
Depth[node.val] = depth
cur = max(dfs(node.left, depth + 1), dfs(node.right, depth + 1)) + 1
Height[node.val] = cur
return cur
dfs(R, 0)
peers = collections.defaultdict(list) # Group nodes according to their depth, keep the top 2 heights.
for val, depth in Depth.items():
peers[depth].append((-Height[val], val)) # Use negative heights for sorting in descending order
peers[depth].sort()
if len(peers[depth]) > 2:
peers[depth].pop()
ans = []
for q in Q:
depth = Depth[q]
if len(peers[depth]) == 1: # No peer, path length equals depth - 1.
ans.append(depth - 1)
elif peers[depth][0][1] == q: # The removed node has the largest height, look for the node with 2nd largest height.
ans.append(-peers[depth][1][0] + depth)
else: # Look for the node with the largest height.
ans.append(-peers[depth][0][0] + depth)
return ans
Complexity
- ⏰ Time complexity:
O(n + m)
- DFS traversal to compute height and depth of each node takes
O(n)
- Then for each node, we iterate through all nodes and insert them into a data structure (like a list or priority queue) that maintains the top k heights. Even though inserting and maintaining the order in such a structure can take
O(log k)
time, sincek
is a constant (2 in our case), it simplifies toO(1)
, hence overall this step incursO(n)
complexity in practice. - Finally, answering queries take
O(m)
time
- DFS traversal to compute height and depth of each node takes
- 🧺 Space complexity:
O(n + m)
O(n)
for using recursion stack in dfs, heights, depths, and the sorted lists of nodes at each depth level is proportional to the number of nodesn
.O(m)
for storing queries