Problem#
Given a Binary Search Tree (BST) and a key Node (K
), calculate the total sum of all node values in the BST except for nodes that are adjacent to the key node (K
). Adjacent nodes are defined as the parent and immediate children of the key node.
Examples#
Example 1#
graph TD;
A(8):::yellow --- B(7) & C(10):::blue
B --- D(2)
B ~~~ N1:::hidden
C --- E(9):::yellow & F(13):::yellow
classDef hidden display:none
classDef blue fill:#ADD8E6,stroke:#000,stroke-width:1px;
classDef yellow fill:#FFD700,stroke:#000,stroke-width:1px,color:#000;
1
2
3
4
Input: root = [ 8 , 7 , 10 , 2 , null , 9 , 13 ], k = 10
Output: 19
Explanation:
Nodes 8 , 9 , and 13 are adjacent to K = 10 , so we exclude them. The remaining nodes, 7 , 2 , and 10 , sum up to 19.
Solution#
Method 1 - Find the key node#
When calculating the total sum in the BST excluding adjacent nodes for the key node:
Identify the adjacent nodes which include:
The parent of the key node (if it exists).
The left and right children of the key node (if they exist).
Traverse the BST, summing up all the nodes except those identified as adjacent to the given key node.
The intuition is based on tree traversal (DFS or BFS) and follows the rules for adjacency exclusion once the key node is located. This ensures only non-adjacent nodes are summed up.
Approach#
Traverse the BST while calculating the total sum.
Identify the key node (K
) during traversal.
Gather adjacent nodes:
Parent of the key node.
Left and right children of the key node.
Traverse again or filter nodes during the first traversal, excluding adjacent nodes.
Return the sum of non-adjacent nodes.
Code#
Java
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class Solution {
public int sumExcludingAdjacent (TreeNode root, int key) {
// Helper function to locate the key node and store its parent.
TreeNode[] findKeyNode (TreeNode node, TreeNode parent, int key) {
if (node == null ) return new TreeNode[] {null , null };
if (node.val == key) return new TreeNode[] {node, parent};
// Traverse left or right
if (key < node.val ) {
return findKeyNode(node.left , node, key);
} else {
return findKeyNode(node.right , node, key);
}
}
int computeTotalSum (TreeNode node, HashSet< TreeNode> exclude) {
if (node == null ) return 0;
if (exclude.contains (node)) return 0;
// Traverse left, right, and include current node
return computeTotalSum(node.left , exclude) +
computeTotalSum(node.right , exclude) +
node.val ;
}
public int sumExcludingAdjacent (TreeNode root, int key) {
// Find the key node and its parent
TreeNode[] result = findKeyNode(root, null , key);
TreeNode keyNode = result[ 0] ;
TreeNode parentNode = result[ 1] ;
if (keyNode == null ) return 0; // Key not found
// Determine adjacent nodes
HashSet< TreeNode> excludeNodes = new HashSet<> ();
excludeNodes.add (keyNode);
excludeNodes.add (parentNode);
excludeNodes.add (keyNode.left );
excludeNodes.add (keyNode.right );
// Compute total sum excluding adjacent nodes
return computeTotalSum(root, excludeNodes);
}
public static void main (String[] args) {
Solution sol = new Solution();
TreeNode root = new TreeNode(8);
root.left = new TreeNode(7);
root.right = new TreeNode(10);
root.left .left = new TreeNode(2);
root.right .left = new TreeNode(9);
root.right .right = new TreeNode(13);
int key = 10;
System.out .println (sol.sumExcludingAdjacent (root, key)); // Output: 19
}
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class Solution :
def sumExcludingAdjacent (self, root, key):
# Helper function to locate the key node and store its parent.
def findKeyNode (node, parent, key):
if not node:
return None , None
if node. val == key:
return node, parent
# Traverse left or right
if key < node. val:
return findKeyNode(node. left, node, key)
else :
return findKeyNode(node. right, node, key)
def computeTotalSum (node, exclude):
if not node:
return 0
if node in exclude:
return 0
# Traverse left, right, and include current node
return (
computeTotalSum(node. left, exclude) +
computeTotalSum(node. right, exclude) +
node. val
)
# Find the key node and its parent
keyNode, parentNode = findKeyNode(root, None , key)
if not keyNode:
return 0 # Key not found
# Determine adjacent nodes
excludeNodes = {keyNode, parentNode, keyNode. left, keyNode. right}
# Compute total sum excluding adjacent nodes
return computeTotalSum(root, excludeNodes)
# Example usage
if __name__ == "__main__" :
# Construct the tree
sol = Solution()
root = sol. TreeNode(8 )
root. left = sol. TreeNode(7 )
root. right = sol. TreeNode(10 )
root. left. left = sol. TreeNode(2 )
root. right. left = sol. TreeNode(9 )
root. right. right = sol. TreeNode(13 )
key = 10
print(sol. sumExcludingAdjacent(root, key)) # Output: 19
Complexity#
⏰ Time complexity: O(n)
. Traversing all nodes to find the sum of non-adjacent nodes.
🧺 Space complexity: O(h)
. Where h
is the height of the tree, required for recursive calls during traversal.