Problem
Given an integer n
, return the number of structurally unique **BST’**s (binary search trees) which has exactly n
nodes of unique values from 1
to n
.
Examples
Example 1:
Input: n = 3
Output: 5
Explanation:
Example 2:
Input: n = 1
Output = 1
Solution
Let’s solve the problem of finding the number of structurally unique Binary Search Trees (BSTs) for a given number n
using three different approaches: mathematical (Catalan number), recursion, and dynamic programming (DP).
Method 1 - Mathematics
To find the number of possible Binary Search Trees (BSTs) with n
different keys, we can use the Catalan Numbers.
The formula for the n
-th Catalan number is:
$$ C_n = \frac{1}{n+1}\binom{2n}{n} = \frac{(2n)!}{(n+1)!n!} $$
For example, for n = 3
:
$$ C(3) = \frac{6!}{(4! \cdot 3!)} = \frac{720}{(24 \cdot 6)} = \frac{720}{144} = 5 $$
Code
Java
class Solution {
public int numTrees(int n) {
BigInteger numerator = factorial(2 * n);
BigInteger denominator = factorial(n + 1).multiply(factorial(n));
return numerator.divide(denominator).intValue();
}
private BigInteger factorial(int num) {
BigInteger result = BigInteger.ONE;
for (int i = 2; i <= num; i++) {
result = result.multiply(BigInteger.valueOf(i));
}
return result;
}
}
Python
class Solution:
def numTrees(self, n: int) -> int:
numerator = math.factorial(2 * n)
denominator = math.factorial(n + 1) * math.factorial(n)
return numerator // denominator
Complexity
- ⏰ Time complexity:
O(n)
, due to factorial calculation. - 🧺 Space complexity:
O(1)
, ignoring the space required for large integers.
Method 2 - Recursion
The number of structurally unique BSTs can also be computed using a recursive function by considering each value as the root and recursively computing the left and right subtrees. Lets try to build the cases.
N = 0 ⇨ count = 1 ⇨ as we can have a null tree N = 1 ⇨ count = 1 ⇨ as we can have only 1 tree N = 2 ⇨ count = 2 ⇨ See the trees below:
Tree 1 Tree 2
1 2
\ /
2 1
N = 3 ⇨ count = 5 ⇨ see trees below:
N = 4 ⇨ 14 ⇨ We have to consider the cases when 1 is parent, then for 2, 3 and 4. See trees below:
When 1 is the root, 2 and 3 are the remaining nodes. When 2 is the root, there is one node on each side, i.e., 1 on the left and 3 on the right. When 3 is the root, both remaining nodes (1 and 2) are on the left. Thus, the total is N(2) + N(1) + N(2) = 5
.
For N = 4
, the count is 14.
Write a recursive function that, given the number of distinct values, computes the number of structurally unique binary search trees that store those values. For example, countTrees(4) should return 14, since there are 14 structurally unique binary search trees that store 1, 2, 3, and 4. The base case is easy, and the recursion is short but dense. Your code should not construct any actual trees; it’s just a counting problem.
Code
Java
class Solution {
public int numTrees(int n) {
if (n <= 1) {
return 1;
} else {
int sum = 0;
for (int root = 1; root <= n; root++) {
int left = numTrees(root - 1);
int right = numTrees(n - root);
sum += left * right;
}
return sum;
}
}
}
Python
class Solution:
def numTrees(self, n: int) -> int:
if n <= 1:
return 1
ans = 0
for root in range(1, n + 1):
left = self.numTrees(root - 1)
right = self.numTrees(n - root)
ans += left * right
return ans
Complexity
- ⏰ Time complexity:
O(4^n / sqrt(n))
due to the exponential number of recursive calls. - 🧺 Space complexity:
O(n)
for the recursion stack.
Method 3 - DP
We can re-use some of the recursion. Using dynamic programming, we can solve the problem by breaking it into smaller subproblems and storing the results. Let count[i]
be the number of unique BSTs with i
nodes. The count is determined by the sum of BSTs formed by different root nodes:
i = 0, count[0] = 1 # empty tree
i = 1, count[1] = 1 # one tree
i = 2, count[2] = count[0] * count[1] # 0 is root
+ count[1] * count[0] # 1 is root
i = 3, count[3] = count[0] * count[2] # 1 is root
+ count[1] * count[1] # 2 is root
+ count[2] * count[0] # 3 is root
i = 4, count[4] = count[0] * count[3] # 1 is root
+ count[1] * count[2] # 2 is root
+ count[2] * count[1] # 3 is root
+ count[3] * count[0] # 4 is root
...
...
...
i = n, count[n] = sum(count[0..k] * count[k+1..n]) for 0 <= k < n
Code
Java
class Solution {
public int numTrees(int n) {
int[] count = new int[n + 1];
count[0] = 1; // Base case: empty tree
count[1] = 1; // Base case: one node tree
for (int nodes = 2; nodes <= n; nodes++) {
for (int root = 1; root <= nodes; root++) {
int left = root - 1;
int right = nodes - root;
count[nodes] += count[left] * count[right];
}
}
return count[n];
}
}
Python
class Solution:
def numTrees(self, n: int) -> int:
count = [0] * (n + 1)
count[0] = 1 # Base case: empty tree
count[1] = 1 # Base case: one node tree
for nodes in range(2, n + 1):
for root in range(1, nodes + 1):
left = root - 1
right = nodes - root
count[nodes] += count[left] * count[right]
return count[n]
Complexity
- ⏰ Time complexity:
O(n^2)
due to the nested loop. - 🧺 Space complexity:
O(n)
for the storage array.