Kth Smallest Element in a BST
Problem
Given the root of a binary search tree, and an integer k, return the kth smallest value (1-indexed) of all the values of the nodes in the tree.
Examples
Example 1:
3
/ \
/ \
1 4
\ /
\ /
2
Input: root = [3,1,4,null,2], k = 1
Output: 1
Example 2:
5
/ \
/ \
3 6
/ \
/ \
2 4
/
/
1
Input: root = [5,3,6,2,4,null,null,1], k = 3
Output: 3
Solution
Video explanation
Here is the video explaining this method in detail. Please check it out:
<div class="youtube-embed"><iframe src="https://www.youtube.com/embed/Ep7Y220RXAM" frameborder="0" allowfullscreen></iframe></div>
Method 1 - Naive Inorder Traversal to List
Intuition
The most straightforward way to solve this is to materialize the sorted sequence. If we have a sorted list of all the node values, we can simply pick the element at the k-th position.
Approach
- Create an empty list.
- Perform a full inorder traversal of the BST.
- For each node visited, add its value to the list.
- After the traversal is complete, the list will contain all the node values in sorted order.
- Return the element at index
k-1from the list (sincekis 1-indexed).
Code
Java
class Solution {
public int kthSmallest(TreeNode root, int k) {
List<Integer> inorderList = new ArrayList<>();
inorder(root, inorderList);
return inorderList.get(k - 1);
}
private void inorder(TreeNode node, List<Integer> list) {
if (node == null) {
return;
}
inorder(node.left, list);
list.add(node.val);
inorder(node.right, list);
}
}
Python
class Solution:
def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
inorder_list = []
def inorder(node):
if not node:
return
inorder(node.left)
inorder_list.append(node.val)
inorder(node.right)
inorder(root)
return inorder_list[k-1]
Complexity
- ⏰ Time complexity:
O(n), wherenis the number of nodes, as we traverse the tree once to build the list. - 🧺 Space complexity:
O(n)as we store allnelements in the list.
Method 2a - Recursive Inorder Traversal with counter
We don't actually need to store the entire list of nodes. We just need to find the k-th node visited during the inorder traversal and stop. We can do this by counting the nodes as we visit them, instead of adding them to list.
Code
Java
class Solution {
public int findKthSmallest(TreeNode root, int k) {
// counter tracks visited nodes
// `AtomicInteger` is used as`Integer` is passed by value in Java
// OR use some wrapper class
AtomicInteger counter = new AtomicInteger(0);
TreeNode kthNode = helper(root, k, counter);
return kthNode == null? -1: kthNode.val;
}
public TreeNode helper(TreeNode root,int k,AtomicInteger counter) {
if (root == null) {
return null;
}
TreeNode left = helper(root.left, k, counter);
// if k'th smallest node is found
if (left != null) {
return left;
}
// if the root is k'th smallest node
if (counter.incrementAndGet() == k) {
return root;
}
return helper(root.right, k, counter);
}
}
Python
class Solution:
def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
self.count = 0
self.ans = -1
def inorder(node):
if not node or self.result != -1:
return
inorder(node.left)
if self.ans != -1:
return
self.count += 1
if self.count == k:
self.ans = node.val
return
inorder(node.right)
inorder(root)
return self.ans
Complexity
- ⏰ Time complexity:
O(h + k). It takesO(h)time to travel down to the smallest element. Then, we visitkelements to find our answer. - 🧺 Space complexity:
O(h)due to depth of recursion. For balanced tree it isO(log n)and for unbalanced/skewed tree it isO(n).
Method 2b - Recursive inorder traversal with countdown
Intuition
We don't actually need to store the entire list of nodes. We just need to find the k-th node visited during the inorder traversal and stop. We can do this by treating k as a countdown.
Approach
- Create a recursive
inorderhelper function. - In the function, first traverse the left subtree.
- After returning from the left subtree, process the current node. This is the "visit" step. Decrement
k. - Check if
khas just become0. If it has, we have found our target. Store this node's value as the result and ensure the recursion stops. - If
kis still greater than0, traverse the right subtree.
Code
Java
class Solution {
private int result;
private int k;
public int kthSmallest(TreeNode root, int k) {
this.k = k;
inorder(root);
return result;
}
private void inorder(TreeNode node) {
if (node == null) {
return;
}
inorder(node.left);
// Once result is found, we don't need to continue
if (this.k == 0) return;
k--;
if (k == 0) {
result = node.val;
return;
}
inorder(node.right);
}
}
Python
class Solution:
def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
self.k = k
self.res = -1
self.inorder(root)
return self.res
def inorder(self, node):
if not node or self.k == 0:
return
self.inorder(node.left)
if self.k == 0:
return
self.k -= 1
if self.k == 0:
self.res = node.val
return
self.inorder(node.right)
Method 3 - Iterative Inorder Traversal 🏆
Intuition
To avoid the awkwardness of stopping a recursion and to gain more direct control, we can simulate the inorder traversal iteratively using a stack. This is the most common and generally preferred optimal solution.
Approach
- Initialize an empty stack and a
currentpointer to the root. - Begin a loop that continues as long as
currentis not null or the stack is not empty. - Inside the loop, have a nested
whileloop that goes as far left as possible from thecurrentnode, pushing each node onto the stack along the way. - When the inner loop finishes, pop a node from the stack. This is the next node in the inorder sequence.
- Process the popped node: decrement
k. - If
kis now0, you've found the k-th smallest element. Return its value. - If not, set the
currentpointer to the right child of the popped node and repeat the process.
Code
Java
class Solution {
public int kthSmallest(TreeNode root, int k) {
Stack<TreeNode> stack = new Stack<>();
TreeNode curr = root;
while (curr != null || !stack.isEmpty()) {
while (curr != null) {
stack.push(curr);
curr = curr.left;
}
curr = stack.pop();
k--;
if (k == 0) {
return curr.val;
}
curr = curr.right;
}
return -1; // Should not be reached if k is valid
}
}
Python
class Solution:
def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
stack = []
curr = root
while curr or stack:
while curr:
stack.append(curr)
curr = curr.left
curr = stack.pop()
k -= 1
if k == 0:
return curr.val
curr = curr.right
Complexity
- ⏰ Time complexity:
O(h + k). It takesO(h)time to travel down to the smallest element. Then, we visitkelements to find our answer. - 🧺 Space complexity:
O(h)for using explicit stack. For balanced tree it isO(log n)and for unbalanced/skewed tree it isO(n).