Problem
Given two integers m
and n
where n >= m
, find the total number of structurally correct Binary Search Trees (BSTs) possible that include all numbers between m
and n
(inclusive).
Examples
Example 1:
Input: m = 3, n = 5
Output: 5
Explanation See trees below
3 5 3 5 4
\ / \ / / \
4 4 5 3 3 5
\ / / \
5 3 4 4
Example 2:
Input: m = 1, n = 3
Output: 5
Explanation: There are five structurally unique BSTs possible with numbers 1 to 3.
Example 3:
Input: m = 2, n = 3
Output: 2
Explanation: There are two structurally unique BSTs possible with numbers 2 and 3.
Solution
We can reduce this problem to - Unique Binary Search Trees BST 1 - Count them all, if we set N = n - m + 1. But lets continue.
Method 1 - Mathematics
The number of BSTs with n
nodes where n
ranges from m
to n
can be calculated using Catalan numbers. The formula for the n
-th Catalan number is given by:
$$
C_n = \frac{(2n)!}{(n+1)! \cdot n!}
$$
Code
Java
class Solution {
public int numTrees(int m, int n) {
int numKeys = n - m + 1;
return catalanNumber(numKeys);
}
private int catalanNumber(int n) {
BigInteger numerator = factorial(2 * n);
BigInteger denominator = factorial(n + 1).multiply(factorial(n));
return numerator.divide(denominator).intValue();
}
private BigInteger factorial(int num) {
BigInteger result = BigInteger.ONE;
for (int i = 2; i <= num; i++) {
result = result.multiply(BigInteger.valueOf(i));
}
return result;
}
}
Python
class Solution:
def numTrees(self, m: int, n: int) -> int:
num_keys = n - m + 1
return self.catalan_number(num_keys)
def catalan_number(self, n: int) -> int:
numerator = math.factorial(2 * n)
denominator = math.factorial(n + 1) * math.factorial(n)
return numerator // denominator
Complexity
- ⏰ Time complexity:
O(n)
for factorial calculations. - 🧺 Space complexity:
O(1)
ignoring the space required for large integers.
Method 2 - Recursive
The number of structurally unique BSTs can also be computed using a recursive function. For each number within a given range [m, n]
, we consider the number as the root and recursively compute the number of possible left and right subtrees.
Code
Java
class Solution {
public int numTrees(int m, int n) {
int numKeys = n - m + 1;
return countTrees(numKeys);
}
private int countTrees(int numKeys) {
if (numKeys <= 1) {
return 1;
} else {
int sum = 0;
for (int root = 1; root <= numKeys; root++) {
int left = countTrees(root - 1);
int right = countTrees(numKeys - root);
sum += left * right;
}
return sum;
}
}
public static void main(String[] args) {
Solution sol = new Solution();
System.out.println(sol.numTrees(1, 3)); // Output: 5
System.out.println(sol.numTrees(2, 3)); // Output: 2
}
}
Python
class Solution:
def numTrees(self, m: int, n: int) -> int:
num_keys = n - m + 1
return self.count_trees(num_keys)
def count_trees(self, n: int) -> int:
if n <= 1:
return 1
total = 0
for root in range(1, n + 1):
left = self.count_trees(root - 1)
right = self.count_trees(n - root)
total += left * right
return total
Complexity
⏰ Time complexity:
O(4^n / sqrt(n))
, due to the exponential number of recursive calls.🧺 Space complexity:
O(n)
for the recursion stack.
Method 3 - Dynamic Programming
Using DP, we can iteratively solve for the number of unique BSTs with n
nodes. We use an array to store the number of possible BSTs for each subtree size and build from the base cases up to the desired size.
Code
Java
class Solution {
public int numTrees(int m, int n) {
int numKeys = n - m + 1;
int[] count = new int[numKeys + 1];
// Base cases
count[0] = 1; // Empty tree
count[1] = 1; // One node tree
// Fill the DP table
for (int nodes = 2; nodes <= numKeys; nodes++) {
for (int root = 1; root <= nodes; root++) {
int left = root - 1;
int right = nodes - root;
count[nodes] += count[left] * count[right];
}
}
return count[numKeys];
}
}
Python
class Solution:
def numTrees(self, m: int, n: int) -> int:
num_keys = n - m + 1
count = [0] * (num_keys + 1)
# Base cases
count[0] = 1 # Empty tree
count[1] = 1 # One node tree
# Fill the DP table
for nodes in range(2, num_keys + 1):
for root in range(1, nodes + 1):
left = root - 1
right = nodes - root
count[nodes] += count[left] * count[right]
return count[num_keys]
Complexity
- ⏰ Time complexity:
O(n^2)
, due to the nested loop. - 🧺 Space complexity:
O(n)
for the storage array.