Remove duplicate in Binary Search Tree BST
MediumUpdated: Aug 19, 2025
Problem
Given a Binary Search Tree (BST) that may contain duplicate values, remove all duplicate nodes so that every value appears only once in the tree. The resulting tree should remain a valid BST.
Examples
Example 1
Input: [1,2,2,3,4,4,4,5] (BST with inorder traversal: 1 2 2 3 4 4 4 5)
Output: [1,2,3,4,5] (BST with inorder traversal: 1 2 3 4 5)
Explanation: All duplicate nodes are removed, only unique values remain.
Constraints
- The BST may contain duplicate values.
- The operation must be done in-place.
Solution
Method 1 - Inorder Traversal & Remove
Intuition
Duplicates in a BST appear as consecutive nodes in an inorder traversal. By tracking the previous node, we can detect and remove duplicates efficiently.
Approach
Traverse the BST in-order. If the current node's value equals the previous node's value, remove the current node and continue. Otherwise, update the previous node and proceed.
Complexity
- Time Complexity: O(N), where N is the number of nodes in the BST (each node is visited once).
- Space Complexity: O(H), where H is the height of the tree (due to recursion stack).
Code
C++
#include <iostream>
using namespace std;
struct TreeNode {
int val;
TreeNode *left, *right;
TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
};
class Solution {
public:
TreeNode* removeDuplicates(TreeNode* root) {
TreeNode* prev = nullptr;
return removeDup(root, prev);
}
private:
TreeNode* removeDup(TreeNode* node, TreeNode*& prev) {
if (!node) return nullptr;
node->left = removeDup(node->left, prev);
if (prev && prev->val == node->val) {
TreeNode* right = removeDup(node->right, prev);
delete node;
return right;
}
prev = node;
node->right = removeDup(node->right, prev);
return node;
}
};
Java
class TreeNode {
int val;
TreeNode left, right;
TreeNode(int x) { val = x; }
}
class Solution {
private TreeNode prev = null;
public TreeNode removeDuplicates(TreeNode root) {
return removeDup(root);
}
private TreeNode removeDup(TreeNode node) {
if (node == null) return null;
node.left = removeDup(node.left);
if (prev != null && prev.val == node.val) {
return removeDup(node.right);
}
prev = node;
node.right = removeDup(node.right);
return node;
}
}
Python
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
class Solution:
def removeDuplicates(self, root):
self.prev = None
return self._removeDup(root)
def _removeDup(self, node):
if not node:
return None
node.left = self._removeDup(node.left)
if self.prev and self.prev.val == node.val:
return self._removeDup(node.right)
self.prev = node
node.right = self._removeDup(node.right)
return node