Problem

The Optimal Binary Search Tree (OBST) problem involves constructing a binary search tree (BST) in such a way that the weighted search cost of accessing nodes is minimised. Given a set of node values and their corresponding access frequencies (or probabilities), the goal is to compute the arrangement of BST that ensures the minimum possible weighted search cost. The weight is calculated by multiplying the frequency by the depth of the node in the tree.

Examples

Example 1

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
Input:
Node values: [10, 20]
Access frequencies: [50, 100]

Output:
Optimal BST cost: 300

Explanation:
Possible trees:
Tree 1:
  20
 /
10
Cost = 100*1 + 50*2 = 300

Tree 2:
 10
  \
   20
Cost = 50*1 + 100*2 = 450

Optimal cost is achieved with Tree 1.

Example 2

1
2
3
4
5
6
7
8
9
Input:
Node values: [10, 20, 30]
Access frequencies: [100, 150, 200]

Output:
Optimal BST cost: 950

Explanation:
Using dynamic programming, the optimal BST cost is calculated as 950 by proper arrangement of the nodes.

Method 1 - Using DP

To minimise the search cost, nodes with higher frequencies should ideally be closer to the root. The problem is solved using dynamic programming because it involves overlapping subproblems (cost computations for subarrays) and a directed acyclic structure. The DP table keeps track of the minimum cost for each subtree arrangement.

Approach

  1. Sort nodes by their values (if not already sorted).
  2. Use the dynamic programming approach:
    • Define dp[i][j] as the minimum search cost for keys i...j.
    • Base case: For a single key idp[i][i] = frequency[i].
    • Recursive case: For multiple keys i...j, compute the cost for each possible root k, where i ≤ k ≤ j, and ensure the subtree costs are minimised.
    • Include the prefix sum optimisation to compute weights for the range i...j.
  3. Compute the optimal BST cost by leveraging the DP table.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class Solution {
    public int optimalBST(int[] keys, int[] frequencies) {
        int n = keys.length;
        int[][] dp = new int[n][n];
        int[][] weight = new int[n][n];
        
        // Fill weight matrix
        for (int i = 0; i < n; i++) {
            weight[i][i] = frequencies[i];
            for (int j = i + 1; j < n; j++) {
                weight[i][j] = weight[i][j - 1] + frequencies[j];
            }
        }
        
        // Fill DP table
        for (int length = 1; length <= n; length++) {  // length of subtree range
            for (int i = 0; i <= n - length; i++) {
                int j = i + length - 1;
                dp[i][j] = Integer.MAX_VALUE;
                
                // Find optimal root
                for (int k = i; k <= j; k++) {
                    int cost = (k > i ? dp[i][k - 1] : 0) +
                            (k < j ? dp[k + 1][j] : 0) +
                            weight[i][j];
                    dp[i][j] = Math.min(dp[i][j], cost);
                }
            }
        }
        
        return dp[0][n - 1];
    }

    // Example usage
    public static void main(String[] args) {
        Solution solution = new Solution();
        int[] keys = {10, 20, 30};
        int[] frequencies = {100, 150, 200};
        int optimalCost = solution.optimalBST(keys, frequencies);
        System.out.println("Optimal BST cost: " + optimalCost);
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class Solution:
    def optimalBST(self, keys, frequencies):
        n = len(keys)
        # Create cost and weight tables
        dp = [[0] * n for _ in range(n)]
        weight = [[0] * n for _ in range(n)]
        
        # Fill weights matrix
        for i in range(n):
            weight[i][i] = frequencies[i]
            for j in range(i + 1, n):
                weight[i][j] = weight[i][j - 1] + frequencies[j]
        
        # Fill DP table
        for length in range(1, n + 1):  # length of subtree range
            for i in range(n - length + 1):
                j = i + length - 1
                dp[i][j] = float('inf')
                
                # Find optimal root
                for k in range(i, j + 1):
                    cost = (dp[i][k - 1] if k > i else 0) + \
                        (dp[k + 1][j] if k < j else 0) + \
                        weight[i][j]
                    dp[i][j] = min(dp[i][j], cost)

        return dp[0][n - 1]


# Example usage
solution = Solution()
keys = [10, 20, 30]
frequencies = [100, 150, 200]
optimal_cost = solution.optimalBST(keys, frequencies)
print("Optimal BST cost:", optimal_cost)

Complexity

  • ⏰ Time complexity: O(n^3) where n is the number of nodes (due to nested loops for ranges and roots).
  • 🧺 Space complexity: O(n^2) due to the DP table.