Problem

Given two integers m and n where n >= m, find the total number of structurally correct Binary Search Trees (BSTs) possible that include all numbers between m and n (inclusive).

Examples

Example 1:

Input: m = 3, n = 5
Output: 5
Explanation See trees below
3                5     3            5          4 
 \              /       \          /         /   \ 
  4            4         5        3        3       5 
   \          /         /          \
    5        3         4            4

Example 2:

 Input: m = 1, n = 3
 Output: 5
 Explanation: There are five structurally unique BSTs possible with numbers 1 to 3.

Example 3:

 Input: m = 2, n = 3
 Output: 2
 Explanation: There are two structurally unique BSTs possible with numbers 2 and 3.

Solution

We can reduce this problem to - Unique Binary Search Trees BST 1 - Count them all, if we set N = n - m + 1. But lets continue.

Method 1 - Mathematics

The number of BSTs with n nodes where n ranges from m to n can be calculated using Catalan numbers. The formula for the n-th Catalan number is given by: $$ C_n = \frac{(2n)!}{(n+1)! \cdot n!} $$

Code

Java
class Solution {
    public int numTrees(int m, int n) {
        int numKeys = n - m + 1;
        return catalanNumber(numKeys);
    }

    private int catalanNumber(int n) {
        BigInteger numerator = factorial(2 * n);
        BigInteger denominator = factorial(n + 1).multiply(factorial(n));
        return numerator.divide(denominator).intValue();
    }

    private BigInteger factorial(int num) {
        BigInteger result = BigInteger.ONE;
        for (int i = 2; i <= num; i++) {
            result = result.multiply(BigInteger.valueOf(i));
        }
        return result;
    }
}
Python
class Solution:
    def numTrees(self, m: int, n: int) -> int:
        num_keys = n - m + 1
        return self.catalan_number(num_keys)

    def catalan_number(self, n: int) -> int:
        numerator = math.factorial(2 * n)
        denominator = math.factorial(n + 1) * math.factorial(n)
        return numerator // denominator

Complexity

  • ⏰ Time complexity: O(n) for factorial calculations.
  • 🧺 Space complexity: O(1) ignoring the space required for large integers.

Method 2 - Recursive

The number of structurally unique BSTs can also be computed using a recursive function. For each number within a given range [m, n], we consider the number as the root and recursively compute the number of possible left and right subtrees.

Code

Java
class Solution {
    public int numTrees(int m, int n) {
        int numKeys = n - m + 1;
        return countTrees(numKeys);
    }

    private int countTrees(int numKeys) {
        if (numKeys <= 1) {
            return 1;
        } else {
            int sum = 0;
            for (int root = 1; root <= numKeys; root++) {
                int left = countTrees(root - 1);
                int right = countTrees(numKeys - root);
                sum += left * right;
            }
            return sum;
        }
    }

    public static void main(String[] args) {
        Solution sol = new Solution();
        System.out.println(sol.numTrees(1, 3)); // Output: 5
        System.out.println(sol.numTrees(2, 3)); // Output: 2
    }
}
Python
class Solution:
    def numTrees(self, m: int, n: int) -> int:
        num_keys = n - m + 1
        return self.count_trees(num_keys)

    def count_trees(self, n: int) -> int:
        if n <= 1:
            return 1
        total = 0
        for root in range(1, n + 1):
            left = self.count_trees(root - 1)
            right = self.count_trees(n - root)
            total += left * right
        return total

Complexity

  • ⏰ Time complexity: O(4^n / sqrt(n)), due to the exponential number of recursive calls.

  • 🧺 Space complexity: O(n) for the recursion stack.

Method 3 - Dynamic Programming

Using DP, we can iteratively solve for the number of unique BSTs with n nodes. We use an array to store the number of possible BSTs for each subtree size and build from the base cases up to the desired size.

Code

Java
class Solution {
    public int numTrees(int m, int n) {
        int numKeys = n - m + 1;
        int[] count = new int[numKeys + 1];
        
        // Base cases
        count[0] = 1; // Empty tree
        count[1] = 1; // One node tree
        
        // Fill the DP table
        for (int nodes = 2; nodes <= numKeys; nodes++) {
            for (int root = 1; root <= nodes; root++) {
                int left = root - 1;
                int right = nodes - root;
                count[nodes] += count[left] * count[right];
            }
        }
        
        return count[numKeys];
    }
}
Python
class Solution:
    def numTrees(self, m: int, n: int) -> int:
        num_keys = n - m + 1
        count = [0] * (num_keys + 1)

        # Base cases
        count[0] = 1  # Empty tree
        count[1] = 1  # One node tree

        # Fill the DP table
        for nodes in range(2, num_keys + 1):
            for root in range(1, nodes + 1):
                left = root - 1
                right = nodes - root
                count[nodes] += count[left] * count[right]

        return count[num_keys]

Complexity

  • ⏰ Time complexity: O(n^2), due to the nested loop.
  • 🧺 Space complexity: O(n) for the storage array.