Problem

Given the root of a binary search tree, and an integer k, return the kth smallest value (1-indexed) of all the values of the nodes in the tree.

Examples

Example 1:

1
2
3
4
5
6
7
		  3
		/    \
	   /      \
	  1        4
	    \     /  
	     \   /    
          2
1
2
Input: root = [3,1,4,null,2], k = 1
Output: 1

Example 2:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
                  5
                /    \
               /      \
              3        6
             /  \     
            /    \    
           2     4       
          /
         /
        1
1
2
Input: root = [5,3,6,2,4,null,null,1], k = 3
Output: 3

Solution

Video explanation

Here is the video explaining this method in detail. Please check it out:

Method 1 - Naive Inorder Traversal to List

Intuition

The most straightforward way to solve this is to materialize the sorted sequence. If we have a sorted list of all the node values, we can simply pick the element at the k-th position.

Approach

  1. Create an empty list.
  2. Perform a full inorder traversal of the BST.
  3. For each node visited, add its value to the list.
  4. After the traversal is complete, the list will contain all the node values in sorted order.
  5. Return the element at index k-1 from the list (since k is 1-indexed).

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
class Solution {
    public int kthSmallest(TreeNode root, int k) {
        List<Integer> inorderList = new ArrayList<>();
        inorder(root, inorderList);
        return inorderList.get(k - 1);
    }

    private void inorder(TreeNode node, List<Integer> list) {
        if (node == null) {
            return;
        }
        inorder(node.left, list);
        list.add(node.val);
        inorder(node.right, list);
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
class Solution:
    def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
        inorder_list = []
        def inorder(node):
            if not node:
                return
            inorder(node.left)
            inorder_list.append(node.val)
            inorder(node.right)
        
        inorder(root)
        return inorder_list[k-1]

Complexity

  • ⏰ Time complexity: O(n), where n is the number of nodes, as we traverse the tree once to build the list.
  • 🧺 Space complexity: O(n) as we store all n elements in the list.

Method 2a - Recursive Inorder Traversal with counter

We don’t actually need to store the entire list of nodes. We just need to find the k-th node visited during the inorder traversal and stop. We can do this by counting the nodes as we visit them, instead of adding them to list.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class Solution {
	public int findKthSmallest(TreeNode root, int k) {
		// counter tracks visited nodes
		// `AtomicInteger` is used as`Integer` is passed by value in Java
		// OR use some wrapper class
		AtomicInteger counter = new AtomicInteger(0);
		TreeNode kthNode = helper(root, k, counter);
		return  kthNode == null? -1: kthNode.val;
	}
	public TreeNode helper(TreeNode root,int k,AtomicInteger counter) {
		if (root == null) {
			return null;
		}
	
		TreeNode left = helper(root.left, k, counter);
	
		// if k'th smallest node is found
		if (left != null) {
			return left;
		}
	
		// if the root is k'th smallest node
		if (counter.incrementAndGet() == k) {
			return root;
		}
	
		return helper(root.right, k, counter);
	}
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class Solution:
    def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
        self.count = 0
        self.ans = -1
        
        def inorder(node):
            if not node or self.result != -1:
                return

            inorder(node.left)
            
            if self.ans != -1:
                return

            self.count += 1
            if self.count == k:
                self.ans = node.val
                return
            
            inorder(node.right)

        inorder(root)
        return self.ans

Complexity

  • ⏰ Time complexity: O(h + k). It takes O(h) time to travel down to the smallest element. Then, we visit k elements to find our answer.
  • 🧺 Space complexity: O(h) due to depth of recursion. For balanced tree it is O(log n) and for unbalanced/skewed tree it is O(n).

Method 2b - Recursive inorder traversal with countdown

Intuition

We don’t actually need to store the entire list of nodes. We just need to find the k-th node visited during the inorder traversal and stop. We can do this by treating k as a countdown.

Approach

  1. Create a recursive inorder helper function.
  2. In the function, first traverse the left subtree.
  3. After returning from the left subtree, process the current node. This is the “visit” step. Decrement k.
  4. Check if k has just become 0. If it has, we have found our target. Store this node’s value as the result and ensure the recursion stops.
  5. If k is still greater than 0, traverse the right subtree.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class Solution {
    private int result;
    private int k;

    public int kthSmallest(TreeNode root, int k) {
        this.k = k;
        inorder(root);
        return result;
    }

    private void inorder(TreeNode node) {
        if (node == null) {
            return;
        }
        
        inorder(node.left);
        
        // Once result is found, we don't need to continue
        if (this.k == 0) return;

        k--;
        if (k == 0) {
            result = node.val;
            return;
        }
        
        inorder(node.right);
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Solution:
    def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
        self.k = k
        self.res = -1
        self.inorder(root)
        return self.res

    def inorder(self, node):
        if not node or self.k == 0:
            return

        self.inorder(node.left)

        if self.k == 0:
            return
        
        self.k -= 1
        if self.k == 0:
            self.res = node.val
            return

        self.inorder(node.right)

Method 3 - Iterative Inorder Traversal 🏆

Intuition

To avoid the awkwardness of stopping a recursion and to gain more direct control, we can simulate the inorder traversal iteratively using a stack. This is the most common and generally preferred optimal solution.

Approach

  1. Initialize an empty stack and a current pointer to the root.
  2. Begin a loop that continues as long as current is not null or the stack is not empty.
  3. Inside the loop, have a nested while loop that goes as far left as possible from the current node, pushing each node onto the stack along the way.
  4. When the inner loop finishes, pop a node from the stack. This is the next node in the inorder sequence.
  5. Process the popped node: decrement k.
  6. If k is now 0, you’ve found the k-th smallest element. Return its value.
  7. If not, set the current pointer to the right child of the popped node and repeat the process.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
class Solution {
    public int kthSmallest(TreeNode root, int k) {
        Stack<TreeNode> stack = new Stack<>();
        TreeNode curr = root;

        while (curr != null || !stack.isEmpty()) {
            while (curr != null) {
                stack.push(curr);
                curr = curr.left;
            }
            curr = stack.pop();
            k--;
            if (k == 0) {
                return curr.val;
            }
            curr = curr.right;
        }
        return -1; // Should not be reached if k is valid
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
class Solution:
    def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
        stack = []
        curr = root

        while curr or stack:
            while curr:
                stack.append(curr)
                curr = curr.left
            
            curr = stack.pop()
            k -= 1
            if k == 0:
                return curr.val
            
            curr = curr.right

Complexity

  • ⏰ Time complexity: O(h + k). It takes O(h) time to travel down to the smallest element. Then, we visit k elements to find our answer.
  • 🧺 Space complexity: O(h) for using explicit stack. For balanced tree it is O(log n) and for unbalanced/skewed tree it is O(n).