Problem

Given an array arr of positive integers, consider all binary trees such that:

  • Each node has either 0 or 2 children;
  • The values of arr correspond to the values of each leaf in an in-order traversal of the tree.
  • The value of each non-leaf node is equal to the product of the largest leaf value in its left and right subtree, respectively.

Among all possible binary trees considered, return the smallest possible sum of the values of each non-leaf node. It is guaranteed this sum fits into a 32-bit integer.

A node is a leaf if and only if it has zero children.

Examples

Example 1:

graph TD;
A(24) --- B(12) & C(4)
B --- D(6) & E(2)

L(24) --- M(6) & N(8)
N --- O(2) & P(4)
  
Input: arr = [6,2,4]
Output: 32
Explanation: There are two possible trees shown.
The first has a non-leaf node sum 36, and the second has non-leaf node sum 32.

Example 2:

graph TD;
	A(44) --- B(4) & C(11)
  
       44
      /   \
     4     11
Input: arr = [4,11]
Output: 44

Solution

Method 1 - Recursion

  1. The goal is to optimally convert the array into a binary tree.
  2. The optimal way to divide the array into left and right subtrees is unknown.
  3. Notice that dividing the array creates two subarrays, resulting in two subproblems.
  4. Therefore, we should try all possible ways to divide the array and recursively apply the same logic to the subarrays. Use a 2D array for memoization.

Code

Java
public class Solution {

    public int mctFromLeafValues(int[] arr) {
        int n = arr.length;
        return dfs(arr, 0, n - 1);
    }

    private int dfs(int[] arr, int s, int e) {
        if (s == e) return 0;

        int ans = Integer.MAX_VALUE;
        for (int i = s; i < e; i++) {
            int left = dfs(arr, s, i);
            int right = dfs(arr, i + 1, e);
            int maxLeft = max(arr, s, i);
            int maxRight = max(arr, i + 1, e);
            int rootVal = maxLeft * maxRight;
            ans = Math.min(ans, left + right + rootVal);
        }

        return ans;
    }

    private int max(int[] arr, int l, int r) {
        int max = 0;
        for (int i = l; i <= r; i++) {
            max = Math.max(max, arr[i]);
        }
        return max;
    }
}
Python
class Solution:
    def __init__(self):
        self.memo = []

    def mctFromLeafValues(self, arr: List[int]) -> int:
        n = len(arr)
        self.memo = [[-1] * n for _ in range(n)]
        return self.dfs(arr, 0, n - 1)

    def dfs(self, arr: List[int], s: int, e: int) -> int:
        if s == e:
            return 0
        if self.memo[s][e] != -1:
            return self.memo[s][e]

        ans = float('inf')
        for i in range(s, e):
            left = self.dfs(arr, s, i)
            right = self.dfs(arr, i + 1, e)
            max_left = self.max_in_range(arr, s, i)
            max_right = self.max_in_range(arr, i + 1, e)
            root_val = max_left * max_right
            ans = min(ans, left + right + root_val)

        self.memo[s][e] = ans
        return ans

    def max_in_range(self, arr: List[int], l: int, r: int) -> int:
        return max(arr[l:r+1])

Complexity

  • ⏰ Time complexity: O(2^n), where n is the length of the array. Without memoization, the function will solve each subproblem many times, leading to an exponential time complexity due to the recursive nature of the solution.
  • 🧺 Space complexity: O(n) for the recursion stack, where n is the length of the array.

Method 2 - Top Down DP with memoization

We will use dynamic programming to solve the problem. The dynamic programming formula is:

dp[i, j] = dp[i, k] + dp[k + 1, j] + max(A[i, k]) * max(A[k + 1, j])

This means that the smallest possible sum of the values of non-leaf nodes for the subarray from i to j can be calculated by dividing the array at every possible k, where k is between i and j, and combining the results from the left and right subarrays.

Code

Java
public class Solution {
    public int mctFromLeafValues(int[] arr) {
        int n = arr.length;
        int[][] dp = new int[n][n];
        return dfs(arr, 0, n - 1, dp);
    }

    private int dfs(int[] arr, int s, int e, int[][] dp) {
        if (s == e) return 0;
        if (dp[s][e] > 0) return dp[s][e];

        int ans = Integer.MAX_VALUE;
        for (int i = s; i < e; i++) {
            int left = dfs(arr, s, i, dp);
            int right = dfs(arr, i + 1, e, dp);
            int maxLeft = max(arr, s, i);
            int maxRight = max(arr, i + 1, e);
            int rootVal = maxLeft * maxRight;
            ans = Math.min(ans, left + right + rootVal);
        }
        dp[s][e] = ans;
        return ans;
    }

    private int max(int[] arr, int l, int r) {
        int max = 0;
        for (int i = l; i <= r; i++) {
            max = Math.max(max, arr[i]);
        }
        return max;
    }
}
Python
class Solution:
    def mctFromLeafValues(self, arr: List[int]) -> int:
        n = len(arr)
        dp = [[0] * n for _ in range(n)]
        return self.dfs(arr, 0, n - 1, dp)

    def dfs(self, arr: List[int], s: int, e: int, dp: List[List[int]]) -> int:
        if s == e:
            return 0
        if dp[s][e] != 0:
            return dp[s][e]

        ans = float('inf')
        for i in range(s, e):
            left = self.dfs(arr, s, i, dp)
            right = self.dfs(arr, i + 1, e, dp)
            max_left = self.max_in_range(arr, s, i)
            max_right = self.max_in_range(arr, i + 1, e)
            root_val = max_left * max_right
            ans = min(ans, left + right + root_val)

        dp[s][e] = ans
        return ans

    def max_in_range(self, arr: List[int], l: int, r: int) -> int:
        return max(arr[l:r+1])

Complexity

  • ⏰ Time complexity: O(n^3), where n is the length of the array. This is because we need to fill the dp array which has O(n^2) entries, and for each entry, we can perform up to O(n) operations.
  • 🧺 Space complexity: O(n^2) for the dp array used for memoization.

Method 3 - Greedy

Here is the approach:

  1. The initial solution tries every way to divide the array, but is there truly no better way to find the optimal division?
  2. We do have some insights. In the final binary tree, non-leaf node values come from the maximum leaf values in the left and right subtrees.
  3. If the maximum value in the array is at the deepest level, it will be used multiple times for several parent nodes, which is not efficient.
  4. Thus, we should place the smallest product pair from the array at the deepest level. The smallest product pair is arr[i] * arr[i + 1] where the product is the smallest.
  5. After processing the pair, the smaller value will no longer be used and can be discarded.
  6. The O(n^2) solution involves iterating through the array to find the index i where arr[i] * arr[i + 1] is the smallest, add it to the result, and discard the smaller of the pair. Repeat until the array size is 1.

Code

Java
public class Solution {

    public int mctFromLeafValues(int[] arr) {
        int res = 0;
        while (arr.length > 1) {
            int minProduct = Integer.MAX_VALUE;
            int minIndex = -1;

            for (int i = 0; i < arr.length - 1; i++) {
                int product = arr[i] * arr[i + 1];
                if (product < minProduct) {
                    minProduct = product;
                    minIndex = i;
                }
            }

            res += minProduct;

            // Create a new array excluding the smaller element of the pair (arr[minIndex], arr[minIndex + 1])
            int[] newArr = new int[arr.length - 1];
            for (int i = 0, j = 0; i < arr.length; i++) {
                if ((i != minIndex) && (i != minIndex + 1)) {
                    newArr[j++] = arr[i];
                } else if (i == minIndex) {
                    newArr[j++] = Math.max(arr[minIndex], arr[minIndex + 1]);
                }
            }

            arr = newArr;
        }
        return res;
    }
}
Python
class Solution:
    def mctFromLeafValues(self, arr: List[int]) -> int:
        res = 0
        while len(arr) > 1:
            min_product = float('inf')
            min_index = -1

            for i in range(len(arr) - 1):
                product = arr[i] * arr[i + 1]
                if product < min_product:
                    min_product = product
                    min_index = i

            res += min_product

            # Create a new array excluding the smaller element of the pair (arr[min_index], arr[min_index + 1])
            new_arr = []
            for i in range(len(arr)):
                if i != min_index and i != min_index + 1:
                    new_arr.append(arr[i])
                elif i == min_index:
                    new_arr.append(max(arr[min_index], arr[min_index + 1]))

            arr = new_arr
        
        return res

Complexity

  • ⏰ Time complexity: O(n^2), where n is the length of the array. This is because in the worst case scenario, we need to go through the array n times, and in each pass, we search for the smallest product pair in O(n) time.
  • 🧺 Space complexity: O(n), since we use a new array to hold the reduced elements each time we remove a pair.

Method 4 - Monotonic Decreasing Stack

When building a node in the tree, we compare two numbers, a and b. The smaller one is removed and won’t be used again, while the larger one remains.

The problem can be translated as: Given an array A, choose two neighboring elements a and b. Remove the smaller one (min(a, b)) and the cost is a * b. What is the minimum cost to reduce the array to one element?

To remove a number a, it requires a cost of a * b, where b >= a. Thus, a must be removed by a larger number. We aim to minimize this cost, so b must be minimized.

b has two possibilities: the first larger number on the left and the first larger number on the right.

The cost to remove a is a * min(left, right).

Algorithm

So, here is the algorithm:

  1. The second solution removes the smaller value of pair arr[i] and arr[i + 1] with the smallest product in each iteration.
  2. Each iteration actually removes a local minimum value:
    • For elements arr[i - 1]arr[i], and arr[i + 1] where arr[i] is the local minimum.
    • The product added to the final result is arr[i] * min(arr[i - 1], arr[i + 1]).
  3. The problem translates into removing all local minimum values while finding the first larger element on the left and right.

Similar to Trapping Rain Water OR Largest Rectangle in Histogram.

Use a stack to maintain a decreasing order. When a larger value arr[i] is encountered, pop from the stack, calculate the product mid * min(arr[i], stack.peek()), and store it.

If we look at the number array closely, in order to obtain the smallest sum of all non-leaf values, we should merge the smallest values first, i.e. smaller values should be lower leaves to minimize multiplication as the tree grows.

Ex: arr = [4, 3, 2, 1, 5]

There are various ways to construct a tree following the problem’s requirements. To achieve the smallest sum, we should merge 2 and 1 first since they are the smallest values. We use a stack in decreasing order for this. After merging 2 and 1, the resulting parent node value is 2. This parent node should ideally have the smallest possible parent as well. We observe that the smallest candidate next to 2 is 3, making 3 the left child and 1 * 2 = 2 the right child.

      ...
      / \
     3   2
        / \
       2   1

If we observe closely, 3 2 1 forms a decreasing sequence. So, whenever we encounter a “dip” in the array, we calculate the multiplication. For instance, at arr[i] with the condition arr[i-1] <= arr[i] <= arr[i+1], the minimum multiplication is arr[i] * min(arr[i-1], arr[i+1]). In the example above, this translates to arr[i] = 1, arr[i-1] = 2, arr[i+1] = 5.

Code

Java
public class Solution {
    public int mctFromLeafValues(int[] arr) {
        int ans = 0;
        Stack<Integer> stack = new Stack<>();
        stack.push(Integer.MAX_VALUE);

        for (int num : arr) {
            while (stack.peek() <= num) {
                int mid = stack.pop();
                ans += mid * Math.min(stack.peek(), num);
            }
            stack.push(num);
        }

        while (stack.size() > 2) {
            ans += stack.pop() * stack.peek();
        }

        return ans;
    }
}
Python
class Solution:
    def mctFromLeafValues(self, arr: List[int]) -> int:
        ans = 0
        stack = [float('inf')]

        for num in arr:
            while stack[-1] <= num:
                mid = stack.pop()
                ans += mid * min(stack[-1], num)
            stack.append(num)

        while len(stack) > 2:
            ans += stack.pop() * stack[-1]

        return ans

Complexity

  • ⏰ Time complexity:  O(n), where n is the length of the array. Each element is pushed and popped from the stack at most once.
  • 🧺 Space complexity: O(n), for the stack used to store elements.