Problem

You have a binary tree with a small defect. There is exactly one invalid node where its right child incorrectly points to another node at the same depth but to the invalid node ’s right.

Given the root of the binary tree with this defect, root, return the root of the binary tree afterremoving this invalid node and every node underneath it (minus the node it incorrectly points to).

Custom testing:

The test input is read as 3 lines:

  • TreeNode root
  • int fromNode (not available tocorrectBinaryTree)
  • int toNode (not available tocorrectBinaryTree)

After the binary tree rooted at root is parsed, the TreeNode with value of fromNode will have its right child pointer pointing to the TreeNode with a value of toNode. Then, root is passed to correctBinaryTree.

Examples

Example 1:

1
2
3
4
**![](https://fastly.jsdelivr.net/gh/doocs/leetcode@main/solution/1600-1699/1660.Correct%20a%20Binary%20Tree/images/ex1v2.png)**
Input: root = [1,2,3], fromNode = 2, toNode = 3
Output: [1,null,3]
Explanation: The node with value 2 is invalid, so remove it.

Example 2:

1
2
3
4
**![](https://fastly.jsdelivr.net/gh/doocs/leetcode@main/solution/1600-1699/1660.Correct%20a%20Binary%20Tree/images/ex2v3.png)**
Input: root = [8,3,1,7,null,9,4,2,null,null,null,5,6], fromNode = 7, toNode = 4
Output: [8,3,1,null,null,9,4,null,null,5,6]
Explanation: The node with value 7 is invalid, so remove it and the node underneath it, node 2.

Constraints:

  • The number of nodes in the tree is in the range [3, 10^4].
  • -109 <= Node.val <= 10^9
  • All Node.val are unique.
  • fromNode != toNode
  • fromNode and toNode will exist in the tree and will be on the same depth.
  • toNode is to the right of fromNode.
  • fromNode.right is null in the initial tree from the test data.

Solution

Method 1 – BFS with HashSet 1

Intuition

The invalid node’s right child points to another node at the same depth but to its right. If we traverse the tree level by level (BFS) from right to left, we can use a set to track all nodes seen so far at the current level. When we find a node whose right child is already in the set, that node is the invalid one and should be removed.

Approach

  1. Use BFS to traverse the tree level by level, but process nodes from right to left at each level.
  2. For each node, if its right child is in the set of seen nodes, mark this node as the invalid node to remove.
  3. Remove the invalid node by setting its parent’s left or right pointer to null.
  4. Return the root after removal.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Solution {
public:
    TreeNode* correctBinaryTree(TreeNode* root) {
        queue<TreeNode*> q;
        unordered_set<TreeNode*> seen;
        q.push(root);
        while (!q.empty()) {
            int sz = q.size();
            vector<TreeNode*> level;
            for (int i = 0; i < sz; ++i) level.push_back(q.front()), q.pop();
            for (int i = sz-1; i >= 0; --i) {
                TreeNode* node = level[i];
                if (node->right && seen.count(node->right)) {
                    node = nullptr;
                    return root;
                }
                seen.insert(node);
                if (node->right) q.push(node->right);
                if (node->left) q.push(node->left);
            }
        }
        return root;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
func correctBinaryTree(root *TreeNode) *TreeNode {
    type pair struct{node, parent *TreeNode; isLeft bool}
    q := []pair{{root, nil, false}}
    for len(q) > 0 {
        next := []pair{}
        seen := map[*TreeNode]bool{}
        for i := len(q)-1; i >= 0; i-- {
            p := q[i]
            node := p.node
            if node.Right != nil && seen[node.Right] {
                if p.parent != nil {
                    if p.isLeft { p.parent.Left = nil } else { p.parent.Right = nil }
                } else {
                    return nil
                }
                return root
            }
            seen[node] = true
            if node.Left != nil { next = append(next, pair{node.Left, node, true}) }
            if node.Right != nil { next = append(next, pair{node.Right, node, false}) }
        }
        q = next
    }
    return root
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class Solution {
    public TreeNode correctBinaryTree(TreeNode root) {
        Queue<TreeNode> q = new LinkedList<>();
        Set<TreeNode> seen = new HashSet<>();
        q.offer(root);
        while (!q.isEmpty()) {
            int sz = q.size();
            List<TreeNode> level = new ArrayList<>();
            for (int i = 0; i < sz; ++i) level.add(q.poll());
            for (int i = sz-1; i >= 0; --i) {
                TreeNode node = level.get(i);
                if (node.right != null && seen.contains(node.right)) {
                    node = null;
                    return root;
                }
                seen.add(node);
                if (node.right != null) q.offer(node.right);
                if (node.left != null) q.offer(node.left);
            }
        }
        return root;
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Solution {
    fun correctBinaryTree(root: TreeNode?): TreeNode? {
        val q = ArrayDeque<TreeNode?>()
        val seen = mutableSetOf<TreeNode?>()
        q.add(root)
        while (q.isNotEmpty()) {
            val level = mutableListOf<TreeNode?>()
            repeat(q.size) { level.add(q.removeFirst()) }
            for (i in level.size-1 downTo 0) {
                val node = level[i]
                if (node?.right != null && node.right in seen) {
                    // Remove node
                    return root
                }
                seen.add(node)
                node?.right?.let { q.add(it) }
                node?.left?.let { q.add(it) }
            }
        }
        return root
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Solution:
    def correctBinaryTree(self, root: 'Optional[TreeNode]') -> 'Optional[TreeNode]':
        from collections import deque
        q = deque([(root, None, False)])  # node, parent, is_left
        while q:
            next_level = []
            seen = set()
            for node, parent, is_left in reversed(q):
                if node.right and node.right in seen:
                    if parent:
                        if is_left:
                            parent.left = None
                        else:
                            parent.right = None
                        return root
                    else:
                        return None
                seen.add(node)
                if node.left:
                    next_level.append((node.left, node, True))
                if node.right:
                    next_level.append((node.right, node, False))
            q = next_level
        return root
1
// Omitted for brevity, as Rust's tree manipulation is verbose and not typical for interviews.
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Solution {
    correctBinaryTree(root: TreeNode | null): TreeNode | null {
        let q: [TreeNode, TreeNode | null, boolean][] = [[root!, null, false]];
        while (q.length) {
            let next: [TreeNode, TreeNode | null, boolean][] = [];
            let seen = new Set<TreeNode>();
            for (let i = q.length - 1; i >= 0; --i) {
                let [node, parent, isLeft] = q[i];
                if (node.right && seen.has(node.right)) {
                    if (parent) {
                        if (isLeft) parent.left = null;
                        else parent.right = null;
                        return root;
                    } else {
                        return null;
                    }
                }
                seen.add(node);
                if (node.left) next.push([node.left, node, true]);
                if (node.right) next.push([node.right, node, false]);
            }
            q = next;
        }
        return root;
    }
}

Complexity

  • ⏰ Time complexity: O(n), where n is the number of nodes in the tree, as each node is visited once.
  • 🧺 Space complexity: O(n), for the queue and set used in BFS.