Recover Binary Search Tree Problem

Problem

You are given the root of a binary search tree (BST), where the values of exactly two nodes of the tree were swapped by mistake. Recover the tree without changing its structure.

Note: A solution using O(n) space is pretty straight forward. Could you devise a constant space solution?

Example

Example 1:

Input:
     *1
    /
  *3
    \
     2

Output:

     *3
    /
  *1
    \
     2
Input: root = [1,3,null,null,2]
Output: [3,1,null,null,2]
Explanation: 3 cannot be a left child of 1 because 3 > 1. Swapping 1 and 3 makes the BST valid.

Example 2:

Input;
    *3
   /  \
  1    4 
      /
     *2 
    
Output:
    *2
   /  \
  1    4 
      /
     3 
Input: root = [3,1,4,null,null,2]
Output: [2,1,4,null,null,3]
Explanation: 2 cannot be in the right subtree of 3 because 2 < 3. Swapping 2 and 3 makes the BST valid.

Solution

Inorder traveral will return values in an increasing order. So if an element is less than its previous element, the previous element is a swapped node.

Consider the below wrong BST tree with 8 and 20 swapped:

         10
        /  \
       5    (8)
      / \
     2   (20)

We know, that in inorder traversal all nodes are sorted, increasing order. But as 2 nodes are incorrectly swapped, say X and Y, and assume Y > X. Then, we will find Y first in inorder traversal, but it will be less than next element. Similarly, X will be more than 1 element in inorder traversal. Lets look at example above.

The inorder traversal is 2, 5, 20, 10, 8.

  • When we see it closely, we can see that 20 > 10. This implies 20 is the first node, which is wrong.
  • Then we see again, that 10 > 8 OR 10<8. This implies 8 is the second node, which is wrong.

So, what we can do is compare previous and current node in inorder traversal. In above example, prev node will be (20), and current node is 10, … so, prev node is the first node. Similarly, when prev is 10, and curr is 8, but now prev > curr, so, second is curr. So,

  • prev > curr => first = prev
  • prev > curr => second = curr

Code

Java
public class Solution {
	TreeNode first;
	TreeNode second;
	TreeNode prev;

	public void recoverTree(TreeNode root) {
		if (root == null) {
			return;
		}

		inorder(root);

		if (second != null && first != null) {
			int val = second.val;
			second.val = first.val;
			first.val = val;
		}

	}

	public void inorder(TreeNode curr) {
		if (curr == null) {
			return;
		}

		inorder(curr.left);

		if (prev == null) {
			prev = curr;
		} else {
			if (curr.val < prev.val) {
				if (first == null) {
					first = prev;
				}

				second = curr;
			}

			prev = curr;
		}

		inorder(curr.right);
	}
}