Optimal binary search tree construction for minimum search cost
Problem
The Optimal Binary Search Tree (OBST) problem involves constructing a binary search tree (BST) in such a way that the weighted search cost of accessing nodes is minimised. Given a set of node values and their corresponding access frequencies (or probabilities), the goal is to compute the arrangement of BST that ensures the minimum possible weighted search cost. The weight is calculated by multiplying the frequency by the depth of the node in the tree.
Examples
Example 1
Input:
Node values: [10, 20]
Access frequencies: [50, 100]
Output:
Optimal BST cost: 300
Explanation:
Possible trees:
Tree 1:
20
/
10
Cost = 100*1 + 50*2 = 300
Tree 2:
10
\
20
Cost = 50*1 + 100*2 = 450
Optimal cost is achieved with Tree 1.
Example 2
Input:
Node values: [10, 20, 30]
Access frequencies: [100, 150, 200]
Output:
Optimal BST cost: 950
Explanation:
Using dynamic programming, the optimal BST cost is calculated as 950 by proper arrangement of the nodes.
Method 1 - Using DP
To minimise the search cost, nodes with higher frequencies should ideally be closer to the root. The problem is solved using dynamic programming because it involves overlapping subproblems (cost computations for subarrays) and a directed acyclic structure. The DP table keeps track of the minimum cost for each subtree arrangement.
Approach
- Sort nodes by their values (if not already sorted).
- Use the dynamic programming approach:
- Define
dp[i][j]as the minimum search cost for keysi...j. - Base case: For a single key
i,dp[i][i] = frequency[i]. - Recursive case: For multiple keys
i...j, compute the cost for each possible rootk, wherei ≤ k ≤ j, and ensure the subtree costs are minimised. - Include the prefix sum optimisation to compute weights for the range
i...j.
- Define
- Compute the optimal BST cost by leveraging the DP table.
Code
Java
class Solution {
public int optimalBST(int[] keys, int[] frequencies) {
int n = keys.length;
int[][] dp = new int[n][n];
int[][] weight = new int[n][n];
// Fill weight matrix
for (int i = 0; i < n; i++) {
weight[i][i] = frequencies[i];
for (int j = i + 1; j < n; j++) {
weight[i][j] = weight[i][j - 1] + frequencies[j];
}
}
// Fill DP table
for (int length = 1; length <= n; length++) { // length of subtree range
for (int i = 0; i <= n - length; i++) {
int j = i + length - 1;
dp[i][j] = Integer.MAX_VALUE;
// Find optimal root
for (int k = i; k <= j; k++) {
int cost = (k > i ? dp[i][k - 1] : 0) +
(k < j ? dp[k + 1][j] : 0) +
weight[i][j];
dp[i][j] = Math.min(dp[i][j], cost);
}
}
}
return dp[0][n - 1];
}
// Example usage
public static void main(String[] args) {
Solution solution = new Solution();
int[] keys = {10, 20, 30};
int[] frequencies = {100, 150, 200};
int optimalCost = solution.optimalBST(keys, frequencies);
System.out.println("Optimal BST cost: " + optimalCost);
}
}
Python
class Solution:
def optimalBST(self, keys, frequencies):
n = len(keys)
# Create cost and weight tables
dp = [[0] * n for _ in range(n)]
weight = [[0] * n for _ in range(n)]
# Fill weights matrix
for i in range(n):
weight[i][i] = frequencies[i]
for j in range(i + 1, n):
weight[i][j] = weight[i][j - 1] + frequencies[j]
# Fill DP table
for length in range(1, n + 1): # length of subtree range
for i in range(n - length + 1):
j = i + length - 1
dp[i][j] = float('inf')
# Find optimal root
for k in range(i, j + 1):
cost = (dp[i][k - 1] if k > i else 0) + \
(dp[k + 1][j] if k < j else 0) + \
weight[i][j]
dp[i][j] = min(dp[i][j], cost)
return dp[0][n - 1]
# Example usage
solution = Solution()
keys = [10, 20, 30]
frequencies = [100, 150, 200]
optimal_cost = solution.optimalBST(keys, frequencies)
print("Optimal BST cost:", optimal_cost)
Complexity
- ⏰ Time complexity:
O(n^3)wherenis the number of nodes (due to nested loops for ranges and roots). - 🧺 Space complexity:
O(n^2)due to the DP table.