Problem

Given a sequence of matrices, find the most efficient way to multiply these matrices together. The problem is not actually to perform the multiplications, but merely to decide in which order to perform the multiplications. Print the optimal parenthesization (bracketing) of the matrix chain multiplication.

Examples

Example 1

1
2
3
Input: arr = [40, 20, 30, 10, 30]
Output: ((A(BC))D)
Explanation: The minimum number of multiplications is 26000. The optimal order is ((A(BC))D).

Example 2

1
2
3
Input: arr = [10, 20, 30, 40, 30]
Output: (A((BC)D))
Explanation: The minimum number of multiplications is 30000. The optimal order is (A((BC)D)).

Example 3

1
2
3
Input: arr = [10, 20, 30]
Output: (AB)
Explanation: Only one way to multiply two matrices: (AB).

Similar Problems

Matrix Chain Multiplication

Solution

Method 1 - DP

Reasoning / Intuition

Matrix multiplication is associative, so the order of multiplication can be changed by parenthesizing the product differently. The goal is to minimize the total number of scalar multiplications. We use dynamic programming to compute the minimum cost and store the split points to reconstruct the optimal parenthesization.

Approach

  • Use dynamic programming to compute the minimum cost for multiplying matrices from i to j.
  • Store the split point (k) for each subproblem in a separate table.
  • Use recursion to print the parenthesization using the split points.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include <iostream>
#include <vector>
#include <climits>
using namespace std;

class Solution {
public:
  string printParenthesis(int i, int j, vector<vector<int>>& bracket, char& name) {
    if (i == j) {
      string res(1, name++);
      return res;
    }
    string res = "(";
    res += printParenthesis(i, bracket[i][j], bracket, name);
    res += printParenthesis(bracket[i][j] + 1, j, bracket, name);
    res += ")";
    return res;
  }

  string matrixChainOrder(vector<int>& arr) {
    int n = arr.size();
    vector<vector<int>> dp(n, vector<int>(n, 0));
    vector<vector<int>> bracket(n, vector<int>(n, 0));
    for (int l = 2; l < n; ++l) {
      for (int i = 1; i < n - l + 1; ++i) {
        int j = i + l - 1;
        dp[i][j] = INT_MAX;
        for (int k = i; k < j; ++k) {
          int q = dp[i][k] + dp[k+1][j] + arr[i-1]*arr[k]*arr[j];
          if (q < dp[i][j]) {
            dp[i][j] = q;
            bracket[i][j] = k;
          }
        }
      }
    }
    char name = 'A';
    return printParenthesis(1, n-1, bracket, name);
  }
};

// Example usage:
// vector<int> arr = {40, 20, 30, 10, 30};
// Solution sol;
// cout << sol.matrixChainOrder(arr) << endl;
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import java.util.*;

class Solution {
  private String printParenthesis(int i, int j, int[][] bracket, char[] name) {
    if (i == j) {
      return String.valueOf(name[0]++);
    }
    StringBuilder sb = new StringBuilder();
    sb.append("(");
    sb.append(printParenthesis(i, bracket[i][j], bracket, name));
    sb.append(printParenthesis(bracket[i][j] + 1, j, bracket, name));
    sb.append(")");
    return sb.toString();
  }

  public String matrixChainOrder(int[] arr) {
    int n = arr.length;
    int[][] dp = new int[n][n];
    int[][] bracket = new int[n][n];
    for (int l = 2; l < n; l++) {
      for (int i = 1; i < n - l + 1; i++) {
        int j = i + l - 1;
        dp[i][j] = Integer.MAX_VALUE;
        for (int k = i; k < j; k++) {
          int q = dp[i][k] + dp[k+1][j] + arr[i-1]*arr[k]*arr[j];
          if (q < dp[i][j]) {
            dp[i][j] = q;
            bracket[i][j] = k;
          }
        }
      }
    }
    char[] name = {'A'};
    return printParenthesis(1, n-1, bracket, name);
  }
}

// Example usage:
// int[] arr = {40, 20, 30, 10, 30};
// Solution sol = new Solution();
// System.out.println(sol.matrixChainOrder(arr));
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class Solution:
  def matrixChainOrder(self, arr):
    n = len(arr)
    dp = [[0]*n for _ in range(n)]
    bracket = [[0]*n for _ in range(n)]
    for l in range(2, n):
      for i in range(1, n-l+1):
        j = i + l - 1
        dp[i][j] = float('inf')
        for k in range(i, j):
          q = dp[i][k] + dp[k+1][j] + arr[i-1]*arr[k]*arr[j]
          if q < dp[i][j]:
            dp[i][j] = q
            bracket[i][j] = k

    def printParenthesis(i, j, name):
      if i == j:
        res = name[0]
        name[0] = chr(ord(name[0]) + 1)
        return res
      res = "("
      res += printParenthesis(i, bracket[i][j], name)
      res += printParenthesis(bracket[i][j]+1, j, name)
      res += ")"
      return res

    name = ['A']
    return printParenthesis(1, n-1, name)

# Example usage:
# arr = [40, 20, 30, 10, 30]
# sol = Solution()
# print(sol.matrixChainOrder(arr))

Complexity

  • ⏰ Time complexity: O(n^3), where n is the number of matrices.
  • 🧺 Space complexity: O(n^2) for the DP and bracket tables.