Problem

Given a Binary Search Tree (BST) and a key Node (K), calculate the total sum of all node values in the BST except for nodes that are adjacent to the key node (K). Adjacent nodes are defined as the parent and immediate children of the key node.

Examples

Example 1

graph TD;
	A(8):::yellow --- B(7) & C(10):::blue
	B --- D(2)
	B ~~~ N1:::hidden
	C --- E(9):::yellow & F(13):::yellow
	

classDef hidden display:none
classDef blue fill:#ADD8E6,stroke:#000,stroke-width:1px;
classDef yellow fill:#FFD700,stroke:#000,stroke-width:1px,color:#000;
  
1
2
3
4
Input: root = [8, 7, 10, 2, null, 9, 13], k = 10
Output: 19
Explanation:
Nodes 8, 9, and 13 are adjacent to K = 10, so we exclude them. The remaining nodes, 7, 2, and 10, sum up to 19.

Solution

Method 1 - Find the key node

When calculating the total sum in the BST excluding adjacent nodes for the key node:

  • Identify the adjacent nodes which include:
    • The parent of the key node (if it exists).
    • The left and right children of the key node (if they exist).
  • Traverse the BST, summing up all the nodes except those identified as adjacent to the given key node.

The intuition is based on tree traversal (DFS or BFS) and follows the rules for adjacency exclusion once the key node is located. This ensures only non-adjacent nodes are summed up.

Approach

  1. Traverse the BST while calculating the total sum.
  2. Identify the key node (K) during traversal.
  3. Gather adjacent nodes:
    • Parent of the key node.
    • Left and right children of the key node.
  4. Traverse again or filter nodes during the first traversal, excluding adjacent nodes.
  5. Return the sum of non-adjacent nodes.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
 class Solution {

    public int sumExcludingAdjacent(TreeNode root, int key) {
        // Helper function to locate the key node and store its parent.
        TreeNode[] findKeyNode(TreeNode node, TreeNode parent, int key) {
            if (node == null) return new TreeNode[]{null, null};
            if (node.val == key) return new TreeNode[]{node, parent};
            // Traverse left or right
            if (key < node.val) {
                return findKeyNode(node.left, node, key);
            } else {
                return findKeyNode(node.right, node, key);
            }
        }

        int computeTotalSum(TreeNode node, HashSet<TreeNode> exclude) {
            if (node == null) return 0;
            if (exclude.contains(node)) return 0;
            // Traverse left, right, and include current node
            return computeTotalSum(node.left, exclude) +
                    computeTotalSum(node.right, exclude) +
                    node.val;
        }

        public int sumExcludingAdjacent(TreeNode root, int key) {
            // Find the key node and its parent
            TreeNode[] result = findKeyNode(root, null, key);
            TreeNode keyNode = result[0];
            TreeNode parentNode = result[1];
            if (keyNode == null) return 0;  // Key not found

            // Determine adjacent nodes
            HashSet<TreeNode> excludeNodes = new HashSet<>();
            excludeNodes.add(keyNode);
            excludeNodes.add(parentNode);
            excludeNodes.add(keyNode.left);
            excludeNodes.add(keyNode.right);

            // Compute total sum excluding adjacent nodes
            return computeTotalSum(root, excludeNodes);
        }

        public static void main(String[] args) {
            Solution sol = new Solution();

            TreeNode root = new TreeNode(8);
            root.left = new TreeNode(7);
            root.right = new TreeNode(10);
            root.left.left = new TreeNode(2);
            root.right.left = new TreeNode(9);
            root.right.right = new TreeNode(13);

            int key = 10;
            System.out.println(sol.sumExcludingAdjacent(root, key)); // Output: 19
        }
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class Solution:

    def sumExcludingAdjacent(self, root, key):
        # Helper function to locate the key node and store its parent.
        def findKeyNode(node, parent, key):
            if not node:
                return None, None
            if node.val == key:
                return node, parent
            # Traverse left or right
            if key < node.val:
                return findKeyNode(node.left, node, key)
            else:
                return findKeyNode(node.right, node, key)

        def computeTotalSum(node, exclude):
            if not node:
                return 0
            if node in exclude:
                return 0
            # Traverse left, right, and include current node
            return (
                computeTotalSum(node.left, exclude) +
                computeTotalSum(node.right, exclude) +
                node.val
            )

        # Find the key node and its parent
        keyNode, parentNode = findKeyNode(root, None, key)
        if not keyNode:
            return 0  # Key not found

        # Determine adjacent nodes
        excludeNodes = {keyNode, parentNode, keyNode.left, keyNode.right}

        # Compute total sum excluding adjacent nodes
        return computeTotalSum(root, excludeNodes)


# Example usage
if __name__ == "__main__":
    # Construct the tree
    sol = Solution()
    root = sol.TreeNode(8)
    root.left = sol.TreeNode(7)
    root.right = sol.TreeNode(10)
    root.left.left = sol.TreeNode(2)
    root.right.left = sol.TreeNode(9)
    root.right.right = sol.TreeNode(13)

    key = 10
    print(sol.sumExcludingAdjacent(root, key))  # Output: 19

Complexity

  • ⏰ Time complexity: O(n). Traversing all nodes to find the sum of non-adjacent nodes.
  • 🧺 Space complexity: O(h). Where h is the height of the tree, required for recursive calls during traversal.