Problem#
Given the root
of a binary search tree, and an integer k
, return the kth
smallest value (1-indexed ) of all the values of the nodes in the tree .
Examples#
Example 1:
1
2
3
4
5
6
7
3
/ \
/ \
1 4
\ /
\ /
2
1
2
Input: root = [ 3 , 1 , 4 , null , 2 ], k = 1
Output: 1
Example 2:
1
2
3
4
5
6
7
8
9
10
5
/ \
/ \
3 6
/ \
/ \
2 4
/
/
1
1
2
Input: root = [ 5 , 3 , 6 , 2 , 4 , null , null , 1 ], k = 3
Output: 3
Solution#
Video explanation#
Here is the video explaining this method in detail. Please check it out:
VIDEO
Method 1 - Naive Inorder Traversal to List#
Intuition#
The most straightforward way to solve this is to materialize the sorted sequence. If we have a sorted list of all the node values, we can simply pick the element at the k-th position.
Approach#
Create an empty list.
Perform a full inorder traversal of the BST.
For each node visited, add its value to the list.
After the traversal is complete, the list will contain all the node values in sorted order.
Return the element at index k-1
from the list (since k
is 1-indexed).
Code#
Java
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class Solution {
public int kthSmallest (TreeNode root, int k) {
List< Integer> inorderList = new ArrayList<> ();
inorder(root, inorderList);
return inorderList.get (k - 1);
}
private void inorder (TreeNode node, List< Integer> list) {
if (node == null ) {
return ;
}
inorder(node.left , list);
list.add (node.val );
inorder(node.right , list);
}
}
1
2
3
4
5
6
7
8
9
10
11
12
class Solution :
def kthSmallest (self, root: Optional[TreeNode], k: int) -> int:
inorder_list = []
def inorder (node):
if not node:
return
inorder(node. left)
inorder_list. append(node. val)
inorder(node. right)
inorder(root)
return inorder_list[k- 1 ]
Complexity#
⏰ Time complexity: O(n)
, where n
is the number of nodes, as we traverse the tree once to build the list.
🧺 Space complexity: O(n)
as we store all n
elements in the list.
Method 2a - Recursive Inorder Traversal with counter#
We don’t actually need to store the entire list of nodes. We just need to find the k-th node visited during the inorder traversal and stop. We can do this by counting the nodes as we visit them, instead of adding them to list.
Code#
Java
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class Solution {
public int findKthSmallest (TreeNode root, int k) {
// counter tracks visited nodes
// `AtomicInteger` is used as`Integer` is passed by value in Java
// OR use some wrapper class
AtomicInteger counter = new AtomicInteger(0);
TreeNode kthNode = helper(root, k, counter);
return kthNode == null ? - 1: kthNode.val ;
}
public TreeNode helper (TreeNode root,int k,AtomicInteger counter) {
if (root == null ) {
return null ;
}
TreeNode left = helper(root.left , k, counter);
// if k'th smallest node is found
if (left != null ) {
return left;
}
// if the root is k'th smallest node
if (counter.incrementAndGet () == k) {
return root;
}
return helper(root.right , k, counter);
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class Solution :
def kthSmallest (self, root: Optional[TreeNode], k: int) -> int:
self. count = 0
self. ans = - 1
def inorder (node):
if not node or self. result != - 1 :
return
inorder(node. left)
if self. ans != - 1 :
return
self. count += 1
if self. count == k:
self. ans = node. val
return
inorder(node. right)
inorder(root)
return self. ans
Complexity#
⏰ Time complexity: O(h + k)
. It takes O(h)
time to travel down to the smallest element. Then, we visit k
elements to find our answer.
🧺 Space complexity: O(h)
due to depth of recursion. For balanced tree it is O(log n)
and for unbalanced/skewed tree it is O(n)
.
Method 2b - Recursive inorder traversal with countdown#
Intuition#
We don’t actually need to store the entire list of nodes. We just need to find the k-th node visited during the inorder traversal and stop. We can do this by treating k
as a countdown.
Approach#
Create a recursive inorder
helper function.
In the function, first traverse the left subtree.
After returning from the left subtree, process the current node. This is the “visit” step. Decrement k
.
Check if k
has just become 0
. If it has, we have found our target. Store this node’s value as the result and ensure the recursion stops.
If k
is still greater than 0
, traverse the right subtree.
Code#
Java
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class Solution {
private int result;
private int k;
public int kthSmallest (TreeNode root, int k) {
this .k = k;
inorder(root);
return result;
}
private void inorder (TreeNode node) {
if (node == null ) {
return ;
}
inorder(node.left );
// Once result is found, we don't need to continue
if (this .k == 0) return ;
k-- ;
if (k == 0) {
result = node.val ;
return ;
}
inorder(node.right );
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Solution :
def kthSmallest (self, root: Optional[TreeNode], k: int) -> int:
self. k = k
self. res = - 1
self. inorder(root)
return self. res
def inorder (self, node):
if not node or self. k == 0 :
return
self. inorder(node. left)
if self. k == 0 :
return
self. k -= 1
if self. k == 0 :
self. res = node. val
return
self. inorder(node. right)
Method 3 - Iterative Inorder Traversal 🏆#
Intuition#
To avoid the awkwardness of stopping a recursion and to gain more direct control, we can simulate the inorder traversal iteratively using a stack. This is the most common and generally preferred optimal solution.
Approach#
Initialize an empty stack and a current
pointer to the root.
Begin a loop that continues as long as current
is not null or the stack is not empty.
Inside the loop, have a nested while
loop that goes as far left as possible from the current
node, pushing each node onto the stack along the way.
When the inner loop finishes, pop a node from the stack. This is the next node in the inorder sequence.
Process the popped node: decrement k
.
If k
is now 0
, you’ve found the k-th smallest element. Return its value.
If not, set the current
pointer to the right child of the popped node and repeat the process.
Code#
Java
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class Solution {
public int kthSmallest (TreeNode root, int k) {
Stack< TreeNode> stack = new Stack<> ();
TreeNode curr = root;
while (curr != null || ! stack.isEmpty ()) {
while (curr != null ) {
stack.push (curr);
curr = curr.left ;
}
curr = stack.pop ();
k-- ;
if (k == 0) {
return curr.val ;
}
curr = curr.right ;
}
return - 1; // Should not be reached if k is valid
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class Solution :
def kthSmallest (self, root: Optional[TreeNode], k: int) -> int:
stack = []
curr = root
while curr or stack:
while curr:
stack. append(curr)
curr = curr. left
curr = stack. pop()
k -= 1
if k == 0 :
return curr. val
curr = curr. right
Complexity#
⏰ Time complexity: O(h + k)
. It takes O(h)
time to travel down to the smallest element. Then, we visit k
elements to find our answer.
🧺 Space complexity: O(h)
for using explicit stack. For balanced tree it is O(log n)
and for unbalanced/skewed tree it is O(n)
.