Problem

Consider a matrix M with dimensions width * height, such that every cell has value 0 or 1, and any square sub-matrix of M of size sideLength * sideLength has at most maxOnes ones.

Return the maximum possible number of ones that the matrix M can have.

Examples

Example 1:

1
2
3
4
5
6
7
8
Input: width = 3, height = 3, sideLength = 2, maxOnes = 1
Output: 4
Explanation:
In a 3*3 matrix, no 2*2 sub-matrix can have more than 1 one.
The best solution that has 4 ones is:
[1,0,1]
[0,0,0]
[1,0,1]

Example 2:

1
2
3
4
5
6
Input: width = 3, height = 3, sideLength = 2, maxOnes = 2
Output: 6
Explanation:
[1,0,1]
[1,0,1]
[1,0,1]

Constraints:

  • 1 <= width, height <= 100
  • 1 <= sideLength <= width, height
  • 0 <= maxOnes <= sideLength * sideLength

Solution

Method 1 – Greedy (Pattern Tiling)

Intuition

The constraint is local: every sideLength x sideLength submatrix can have at most maxOnes ones. The optimal way is to maximize the number of ones in the repeating sideLength x sideLength pattern, then tile this pattern over the whole matrix. The best pattern is to put ones in the maxOnes cells that appear most frequently in the tiling.

Approach

  1. For each cell (i, j) in the sideLength x sideLength pattern, count how many times it appears in the full matrix.
  2. For each cell, the count is ((width - i + sideLength - 1) // sideLength) * ((height - j + sideLength - 1) // sideLength).
  3. Collect all counts, sort them in descending order, and sum the top maxOnes counts.
  4. Return the total sum as the answer.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
class Solution {
public:
    int maximumNumberOfOnes(int width, int height, int sideLength, int maxOnes) {
        vector<int> counts;
        for (int i = 0; i < sideLength; ++i) {
            for (int j = 0; j < sideLength; ++j) {
                int cnt = ((width - i + sideLength - 1) / sideLength) * ((height - j + sideLength - 1) / sideLength);
                counts.push_back(cnt);
            }
        }
        sort(counts.rbegin(), counts.rend());
        int ans = 0;
        for (int i = 0; i < maxOnes; ++i) ans += counts[i];
        return ans;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
import "sort"
func maximumNumberOfOnes(width, height, sideLength, maxOnes int) int {
    counts := []int{}
    for i := 0; i < sideLength; i++ {
        for j := 0; j < sideLength; j++ {
            cnt := ((width - i + sideLength - 1) / sideLength) * ((height - j + sideLength - 1) / sideLength)
            counts = append(counts, cnt)
        }
    }
    sort.Sort(sort.Reverse(sort.IntSlice(counts)))
    ans := 0
    for i := 0; i < maxOnes; i++ {
        ans += counts[i]
    }
    return ans
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
import java.util.*;
class Solution {
    public int maximumNumberOfOnes(int width, int height, int sideLength, int maxOnes) {
        List<Integer> counts = new ArrayList<>();
        for (int i = 0; i < sideLength; i++) {
            for (int j = 0; j < sideLength; j++) {
                int cnt = ((width - i + sideLength - 1) / sideLength) * ((height - j + sideLength - 1) / sideLength);
                counts.add(cnt);
            }
        }
        counts.sort(Collections.reverseOrder());
        int ans = 0;
        for (int i = 0; i < maxOnes; i++) ans += counts.get(i);
        return ans;
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
class Solution {
    fun maximumNumberOfOnes(width: Int, height: Int, sideLength: Int, maxOnes: Int): Int {
        val counts = mutableListOf<Int>()
        for (i in 0 until sideLength) {
            for (j in 0 until sideLength) {
                val cnt = ((width - i + sideLength - 1) / sideLength) * ((height - j + sideLength - 1) / sideLength)
                counts.add(cnt)
            }
        }
        counts.sortDescending()
        return counts.take(maxOnes).sum()
    }
}
1
2
3
4
5
6
7
8
9
class Solution:
    def maximumNumberOfOnes(self, width: int, height: int, sideLength: int, maxOnes: int) -> int:
        counts = []
        for i in range(sideLength):
            for j in range(sideLength):
                cnt = ((width - i + sideLength - 1) // sideLength) * ((height - j + sideLength - 1) // sideLength)
                counts.append(cnt)
        counts.sort(reverse=True)
        return sum(counts[:maxOnes])
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
impl Solution {
    pub fn maximum_number_of_ones(width: i32, height: i32, side_length: i32, max_ones: i32) -> i32 {
        let mut counts = vec![];
        for i in 0..side_length {
            for j in 0..side_length {
                let cnt = ((width - i + side_length - 1) / side_length) * ((height - j + side_length - 1) / side_length);
                counts.push(cnt);
            }
        }
        counts.sort_by(|a, b| b.cmp(a));
        counts.into_iter().take(max_ones as usize).sum()
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
class Solution {
    maximumNumberOfOnes(width: number, height: number, sideLength: number, maxOnes: number): number {
        const counts: number[] = [];
        for (let i = 0; i < sideLength; i++) {
            for (let j = 0; j < sideLength; j++) {
                const cnt = Math.floor((width - i + sideLength - 1) / sideLength) * Math.floor((height - j + sideLength - 1) / sideLength);
                counts.push(cnt);
            }
        }
        counts.sort((a, b) => b - a);
        return counts.slice(0, maxOnes).reduce((a, b) => a + b, 0);
    }
}

Complexity

  • ⏰ Time complexity: O(sideLength^2 log sideLength), for counting and sorting all pattern cells.
  • 🧺 Space complexity: O(sideLength^2), for storing the counts.