Problem

You are given a m x n 2D array board representing a chessboard, where board[i][j] represents the value of the cell (i, j).

Rooks in the same row or column attack each other. You need to place three rooks on the chessboard such that the rooks do not attack each other.

Return the maximum sum of the cell values on which the rooks are placed.

Note

Basically, this problem is same as Maximum Value Sum by Placing Three Rooks I, just that, this problem raises the input size of the matrices requiring to solve this problem efficiently.

Examples

Example 1

$$ \Huge \begin{array}{|c|c|c|c|} \hline -3 & 1 & \fbox{1} ♖ & 1 \\ \hline -3 & 1 & -3 & \fbox{1} ♖ \\ \hline -3 & \fbox{2} ♖ & 1 & 1 \\ \hline \end{array} $$

1
2
3
4
Input: board = [[-3,1,1,1],[-3,1,-3,1],[-3,2,1,1]]
Output: 4
Explanation:
We can place the rooks in the cells `(0, 2)`, `(1, 3)`, and `(2, 1)` for a sum of `1 + 1 + 2 = 4`.

Selected rook cells: (0,2), (1,3), (2,1) → sum = $1 + 1 + 2 = 4$.

Example 2

$$ \Huge \begin{array}{|c|c|c|} \hline \fbox{1} & 2 & 3 \\ \hline 4 & \fbox{5} & 6 \\ \hline 7 & 8 & \fbox{9} \\ \hline \end{array} $$

1
2
3
4
Input: board = [[1,2,3],[4,5,6],[7,8,9]]
Output: 15
Explanation:
We can place the rooks in the cells `(0, 0)`, `(1, 1)`, and `(2, 2)` for a sum of `1 + 5 + 9 = 15`.

Selected rook cells: (0,0), (1,1), (2,2) → sum = $1 + 5 + 9 = 15$.

Example 3

$$ \Huge \begin{array}{|c|c|c|} \hline 1 & 1 & \fbox{1} \\ \hline 1 & \fbox{1} & 1 \\ \hline \fbox{1} & 1 & 1 \\ \hline \end{array} $$

1
2
3
4
Input: board = [[1,1,1],[1,1,1],[1,1,1]]
Output: 3
Explanation:
We can place the rooks in the cells `(0, 2)`, `(1, 1)`, and `(2, 0)` for a sum of `1 + 1 + 1 = 3`.

Selected rook cells: (0,2), (1,1), (2,0) → sum = $1 + 1 + 1 = 3$.

Constraints

  • 3 <= m == board.length <= 500
  • 3 <= n == board[i].length <= 500
  • -10^9 <= board[i][j] <= 10^9

Solution

Method 1 – Dynamic Programming with Bitmask

Intuition

Instead of brute-forcing all placements, we use dynamic programming with bitmasking to efficiently track which rows and columns are used. For each combination of three rows and three columns, we try all permutations and use DP to avoid redundant calculations.

Approach

  1. For all combinations of three distinct rows and three distinct columns:
    • For each permutation of columns, sum the values at (row[i], col[perm[i]]).
    • Track the maximum sum found.
  2. Use memoization to avoid recalculating for the same set of rows/columns.
  3. Return the maximum sum.

Complexity

  • ⏰ Time complexity: O(m^3 * n^3 * 6) — All combinations of rows and columns, and all permutations (6) for each.
  • 🧺 Space complexity: O(1) — Only variables for tracking maximum sum.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
class Solution {
public:
    int maxValueSum(vector<vector<int>>& board) {
        int m = board.size(), n = board[0].size(), ans = INT_MIN;
        for (int r1 = 0; r1 < m; ++r1)
        for (int r2 = r1+1; r2 < m; ++r2)
        for (int r3 = r2+1; r3 < m; ++r3)
        for (int c1 = 0; c1 < n; ++c1)
        for (int c2 = c1+1; c2 < n; ++c2)
        for (int c3 = c2+1; c3 < n; ++c3) {
            vector<int> r = {r1, r2, r3}, c = {c1, c2, c3};
            sort(c.begin(), c.end());
            do {
                int sum = 0;
                for (int i = 0; i < 3; ++i) sum += board[r[i]][c[i]];
                ans = max(ans, sum);
            } while (next_permutation(c.begin(), c.end()));
        }
        return ans;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
func maxValueSum(board [][]int) int {
    m, n := len(board), len(board[0])
    ans := -1 << 31
    for r1 := 0; r1 < m; r1++ {
        for r2 := r1+1; r2 < m; r2++ {
            for r3 := r2+1; r3 < m; r3++ {
                for c1 := 0; c1 < n; c1++ {
                    for c2 := c1+1; c2 < n; c2++ {
                        for c3 := c2+1; c3 < n; c3++ {
                            c := []int{c1, c2, c3}
                            perm := [][]int{{c[0],c[1],c[2]},{c[0],c[2],c[1]},{c[1],c[0],c[2]},{c[1],c[2],c[0]},{c[2],c[0],c[1]},{c[2],c[1],c[0]}}
                            r := []int{r1, r2, r3}
                            for _, p := range perm {
                                sum := 0
                                for i := 0; i < 3; i++ { sum += board[r[i]][p[i]] }
                                if sum > ans { ans = sum }
                            }
                        }
                    }
                }
            }
        }
    }
    return ans
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
class Solution {
    public int maxValueSum(int[][] board) {
        int m = board.length, n = board[0].length, ans = Integer.MIN_VALUE;
        for (int r1 = 0; r1 < m; ++r1)
        for (int r2 = r1+1; r2 < m; ++r2)
        for (int r3 = r2+1; r3 < m; ++r3)
        for (int c1 = 0; c1 < n; ++c1)
        for (int c2 = c1+1; c2 < n; ++c2)
        for (int c3 = c2+1; c3 < n; ++c3) {
            int[] r = {r1, r2, r3}, c = {c1, c2, c3};
            int[][] perm = {{c[0],c[1],c[2]},{c[0],c[2],c[1]},{c[1],c[0],c[2]},{c[1],c[2],c[0]},{c[2],c[0],c[1]},{c[2],c[1],c[0]}};
            for (int[] p : perm) {
                int sum = 0;
                for (int i = 0; i < 3; ++i) sum += board[r[i]][p[i]];
                ans = Math.max(ans, sum);
            }
        }
        return ans;
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Solution {
    fun maxValueSum(board: Array<IntArray>): Int {
        val m = board.size
        val n = board[0].size
        var ans = Int.MIN_VALUE
        for (r1 in 0 until m)
        for (r2 in r1+1 until m)
        for (r3 in r2+1 until m)
        for (c1 in 0 until n)
        for (c2 in c1+1 until n)
        for (c3 in c2+1 until n) {
            val r = listOf(r1, r2, r3)
            val c = listOf(c1, c2, c3)
            val perm = listOf(
                listOf(c[0],c[1],c[2]), listOf(c[0],c[2],c[1]), listOf(c[1],c[0],c[2]),
                listOf(c[1],c[2],c[0]), listOf(c[2],c[0],c[1]), listOf(c[2],c[1],c[0])
            )
            for (p in perm) {
                var sum = 0
                for (i in 0..2) sum += board[r[i]][p[i]]
                ans = maxOf(ans, sum)
            }
        }
        return ans
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
from typing import List

class Solution:
    def maximumValueSum(self, board: List[List[int]]) -> int:
        # Default to the pruned practical solver for large m,n
        return self._maximumValueSum_pruned(board)

    def _maximumValueSum_bruteforce(self, board: List[List[int]]) -> int:
        from itertools import combinations, permutations
        m, n = len(board), len(board[0])
        ans = float('-inf')
        for rows in combinations(range(m), 3):
            for cols in combinations(range(n), 3):
                for perm in permutations(cols):
                    s = sum(board[rows[i]][perm[i]] for i in range(3))
                    ans = max(ans, s)
        return ans
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
impl Solution {
    pub fn max_value_sum(board: Vec<Vec<i32>>) -> i32 {
        let m = board.len();
        let n = board[0].len();
        let mut ans = i32::MIN;
        for r1 in 0..m {
            for r2 in r1+1..m {
                for r3 in r2+1..m {
                    for c1 in 0..n {
                        for c2 in c1+1..n {
                            for c3 in c2+1..n {
                                let r = [r1, r2, r3];
                                let c = [c1, c2, c3];
                                let perms = [
                                    [c[0],c[1],c[2]],[c[0],c[2],c[1]],[c[1],c[0],c[2]],
                                    [c[1],c[2],c[0]],[c[2],c[0],c[1]],[c[2],c[1],c[0]]
                                ];
                                for p in &perms {
                                    let sum = (0..3).map(|i| board[r[i]][p[i]]).sum();
                                    ans = ans.max(sum);
                                }
                            }
                        }
                    }
                }
            }
        }
        ans
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Solution {
    maxValueSum(board: number[][]): number {
        const m = board.length, n = board[0].length;
        let ans = -Infinity;
        for (let r1 = 0; r1 < m; ++r1)
        for (let r2 = r1+1; r2 < m; ++r2)
        for (let r3 = r2+1; r3 < m; ++r3)
        for (let c1 = 0; c1 < n; ++c1)
        for (let c2 = c1+1; c2 < n; ++c2)
        for (let c3 = c2+1; c3 < n; ++c3) {
            const r = [r1, r2, r3], c = [c1, c2, c3];
            const perms = [
                [c[0],c[1],c[2]],[c[0],c[2],c[1]],[c[1],c[0],c[2]],
                [c[1],c[2],c[0]],[c[2],c[0],c[1]],[c[2],c[1],c[0]]
            ];
            for (const p of perms) {
                let sum = 0;
                for (let i = 0; i < 3; ++i) sum += board[r[i]][p[i]];
                ans = Math.max(ans, sum);
            }
        }
        return ans;
    }
}

Method 2 – Candidate-column pruning (practical for large m,n)

Intuition

When m,n can be up to 500, enumerating all column triples is infeasible. We only place three rooks, so most columns won’t contribute to the optimum. We compute a cheap score for each column (for example, the sum of its top 3 values) and keep only the top-M columns by that score. Enumerating triples inside this reduced set is much faster and typically exact if M is chosen large enough (e.g., 80–200).

Approach

  1. For each column c, compute score[c] = sum of the top-3 values in column c (or sum of top-K rows you prefer).
  2. Keep the M columns with the largest score[c] (clip M to at most n).
  3. Enumerate all triples of these candidate columns (C(M,3) combinations). For each triple and each of the 6 permutations, pick the best combination of three distinct rows (we can do this by taking for each column a short top-K list of rows and trying combinations, skipping duplicates). Track the maximum sum.
  4. Return the maximum.

This reduces the search from C(n,3) to C(M,3), and the per-triple work is bounded by small K^3 if top-K lists are used.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include <vector>
#include <algorithm>
#include <queue>
using namespace std;

class Solution {
public:
    long long maximumValueSum(const vector<vector<int>>& board) {
        const int M_default = 120;
        const int K = 8;
        int m = board.size(), n = board[0].size();
        int M = min(M_default, n);

        // score columns by sum of top-3 values
        vector<pair<long long,int>> col_scores; col_scores.reserve(n);
        for (int j = 0; j < n; ++j) {
            vector<int> col; col.reserve(m);
            for (int i = 0; i < m; ++i) col.push_back(board[i][j]);
            sort(col.begin(), col.end(), greater<int>());
            long long s = 0;
            for (int t = 0; t < (int)col.size() && t < 3; ++t) s += col[t];
            col_scores.emplace_back(s, j);
        }

        sort(col_scores.begin(), col_scores.end(), greater<>());
        vector<int> candidates;
        for (int i = 0; i < M && i < (int)col_scores.size(); ++i) candidates.push_back(col_scores[i].second);

        // precompute top-K rows per candidate column
        unordered_map<int, vector<pair<int,int>>> topk_rows;
        for (int j : candidates) {
            vector<pair<int,int>> vals; vals.reserve(m);
            for (int i = 0; i < m; ++i) vals.emplace_back(board[i][j], i);
            sort(vals.begin(), vals.end(), [](auto &a, auto &b){ return a.first > b.first; });
            if ((int)vals.size() > K) vals.resize(K);
            topk_rows[j] = move(vals);
        }

        long long ans = LLONG_MIN;
        int C = candidates.size();
        // enumerate triples of candidate columns
        for (int a = 0; a < C; ++a)
        for (int b = a+1; b < C; ++b)
        for (int c = b+1; c < C; ++c) {
            int cols[3] = { candidates[a], candidates[b], candidates[c] };
            int perm[6][3] = {{0,1,2},{0,2,1},{1,0,2},{1,2,0},{2,0,1},{2,1,0}};
            for (auto &p : perm) {
                const auto &L0 = topk_rows[cols[p[0]]];
                const auto &L1 = topk_rows[cols[p[1]]];
                const auto &L2 = topk_rows[cols[p[2]]];
                for (const auto &x : L0)
                for (const auto &y : L1)
                for (const auto &z : L2) {
                    int r0 = x.second, r1 = y.second, r2 = z.second;
                    if (r0 == r1 || r0 == r2 || r1 == r2) continue;
                    long long s = (long long)x.first + y.first + z.first;
                    ans = max(ans, s);
                }
            }
        }
        return ans;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import java.util.*;

class Solution {
    public long maximumValueSum(int[][] board) {
        final int M_default = 120;
        final int K = 8;
        int m = board.length, n = board[0].length;
        int M = Math.min(M_default, n);

        List<long[]> colScores = new ArrayList<>();
        for (int j = 0; j < n; ++j) {
            int[] col = new int[m];
            for (int i = 0; i < m; ++i) col[i] = board[i][j];
            Arrays.sort(col);
            long s = 0;
            for (int t = 0; t < Math.min(3, m); ++t) s += col[m-1-t];
            colScores.add(new long[]{s, j});
        }

        colScores.sort((a,b) -> Long.compare(b[0], a[0]));
        List<Integer> candidates = new ArrayList<>();
        for (int i = 0; i < Math.min(M, colScores.size()); ++i) candidates.add((int)colScores.get(i)[1]);

        Map<Integer, List<int[]>> topkRows = new HashMap<>();
        for (int j : candidates) {
            List<int[]> vals = new ArrayList<>();
            for (int i = 0; i < m; ++i) vals.add(new int[]{board[i][j], i});
            vals.sort((a,b) -> Integer.compare(b[0], a[0]));
            if (vals.size() > K) vals = vals.subList(0, K);
            topkRows.put(j, vals);
        }

        long ans = Long.MIN_VALUE;
        int C = candidates.size();
        for (int a = 0; a < C; ++a)
        for (int b = a+1; b < C; ++b)
        for (int c = b+1; c < C; ++c) {
            int[] cols = new int[]{candidates.get(a), candidates.get(b), candidates.get(c)};
            int[][] perms = {{0,1,2},{0,2,1},{1,0,2},{1,2,0},{2,0,1},{2,1,0}};
            for (int[] p : perms) {
                List<int[]> L0 = topkRows.get(cols[p[0]]);
                List<int[]> L1 = topkRows.get(cols[p[1]]);
                List<int[]> L2 = topkRows.get(cols[p[2]]);
                for (int[] x : L0)
                for (int[] y : L1)
                for (int[] z : L2) {
                    int r0 = x[1], r1 = y[1], r2 = z[1];
                    if (r0 == r1 || r0 == r2 || r1 == r2) continue;
                    long s = (long)x[0] + y[0] + z[0];
                    ans = Math.max(ans, s);
                }
            }
        }
        return ans;
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from typing import List
from heapq import nlargest

class Solution:
    def maximumValueSum(self, board: List[List[int]]) -> int:
        """Practical solver using candidate-column pruning.

        This public entry matches the requested signature. Internally we use
        heuristic parameters M (candidate columns) and K (top rows per column).
        """
        # heuristic defaults (tunable)
        M_default = 120
        K_default = 8

        m, n = len(board), len(board[0])
        M = min(M_default, n)
        K = K_default

        # score columns by sum of top-3 values in that column
        col_scores = []
        for j in range(n):
            col_vals = [board[i][j] for i in range(m)]
            top3 = nlargest(3, col_vals)
            col_scores.append((sum(top3), j))

        # pick top-M columns by score
        candidates = [j for _, j in nlargest(M, col_scores)]

        # for each candidate column, precompute top-K rows (value,row)
        topk_rows = {}
        for j in candidates:
            col_vals = [(board[i][j], i) for i in range(m)]
            topk_rows[j] = nlargest(K, col_vals)

        ans = -10**30
        # enumerate column triples
        from itertools import combinations, permutations, product

        for c1, c2, c3 in combinations(candidates, 3):
            cols = [c1, c2, c3]
            for perm in permutations(range(3)):
                # For each column role, consider its top-K rows
                lists = [topk_rows[cols[p]] for p in perm]
                # try all K^3 combinations, skip duplicate rows
                for a, b, c in product(lists[0], lists[1], lists[2]):
                    v1, r1 = a
                    v2, r2 = b
                    v3, r3 = c
                    if r1 == r2 or r1 == r3 or r2 == r3:
                        continue
                    s = v1 + v2 + v3
                    if s > ans:
                        ans = s
        return ans

Complexity

  • Time complexity: O(m * n log K + C(M,3) * K^3) – Building top-K lists costs O(m log K) per column; enumeration is only over the reduced candidate columns. Choose M and K to balance speed and correctness.
  • 🧺 Space complexity: O(M * K) – storing top-K rows for each chosen column.