Problem

Given a m * n matrix of ones and zeros, return how many square submatrices have all ones.

Examples

Example 1:

Input: matrix =
[
  [0,1,1,1],
  [1,1,1,1],
  [0,1,1,1]
]
Output: 15
Explanation: 
There are **10** squares of side 1.
There are **4** squares of side 2.
There is  **1** square of side 3.
Total number of squares = 10 + 4 + 1 = **15**.

Example 2:

Input: matrix = 
[
  [1,0,1],
  [1,1,0],
  [1,1,0]
]
Output: 7
Explanation: 
There are **6** squares of side 1.  
There is **1** square of side 2. 
Total number of squares = 6 + 1 = **7**.

Solution

Method 1 - Using Recursion

In the recursive approach, we can define a function that computes the size of the largest square submatrix with all ones ending at a specific cell (i, j) and recursively explore all possibilities.

Here is the approach:

  1. Base Case: If we’re out of the matrix bounds or if the current cell value is 0, return 0.
  2. Recursive Case:
    • If the current cell value is 1, the largest square submatrix ending at (i, j) is 1 plus the minimum value of the largest squares ending at:
      • The cell to the left (left)
      • The cell above (up)
      • The cell diagonally up-left (diag)
    • Recursively compute these values for each cell.

Code

Java
public class Solution {

    public int countSquares(int[][] matrix) {
        int m = matrix.length, n = matrix[0].length;
        int ans = 0;
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                if (matrix[i][j] == 1) {
                    ans += helper(matrix, i, j);
                }
            }
        }
        return ans;
    }

    private int helper(int[][] matrix, int i, int j) {
        if (i < 0 || j < 0) {
            return 0;
        }
        if (matrix[i][j] == 0) {
            return 0;
        }
        int left = helper(matrix, i, j - 1);
        int up = helper(matrix, i - 1, j);
        int diag = helper(matrix, i - 1, j - 1);
        return Math.min(Math.min(left, up), diag) + 1;
    }

}
Python
class Solution:
    
    def countSquares(self, matrix: List[List[int]]) -> int:
        def helper(i: int, j: int) -> int:
            if i < 0 or j < 0:
                return 0
            if matrix[i][j] == 0:
                return 0
            left = helper(i, j - 1)
            up = helper(i - 1, j)
            diag = helper(i - 1, j - 1)
            return min(left, up, diag) + 1

        m, n = len(matrix), len(matrix[0])
        ans = 0
        for i in range(m):
            for j in range(n):
                if matrix[i][j] == 1:
                    ans += helper(i, j)
        return ans

Complexity

This recursive method will have a lot of overlapping subproblems and a high time complexity due to recomputation of the same values multiple times.

  • ⏰ Time complexity: O(3^(m*n)), where m and n are dimensions of the matrix, because each recursive step can branch into three more recursive steps.
  • 🧺 Space complexity: O(m*n) for the recursion stack in the worst case.

Method 2 - Top Down DP with memoization

Here is the approach:

  1. Memoization Array: We’ll use a 2D array dp where dp[i][j] will store the size of the largest square submatrix with all ones ending at cell (i, j).
  2. Base Case: If we’re out of the matrix bounds or if the current cell value is 0, return 0.
  3. Recursive Case: If we have already computed dp[i][j], return its value to avoid redundant calculations. Otherwise, recursively compute the value as in the naive recursion, but store the result in dp[i][j] before returning it.

Code

Java
public class Solution {

    public int countSquares(int[][] matrix) {
        int m = matrix.length, n = matrix[0].length;
        int[][] dp = new int[m][n];
        for (int i = 0; i < m; i++) {
            Arrays.fill(dp[i], -1);
        }
        int ans = 0;
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                if (matrix[i][j] == 1) {
                    ans += helper(matrix, dp, i, j);
                }
            }
        }
        return ans;
    }

    private int helper(int[][] matrix, int[][] dp, int i, int j) {
        if (i < 0 || j < 0) {
            return 0;
        }
        if (matrix[i][j] == 0) {
            return 0;
        }
        if (dp[i][j] != -1) {
            return dp[i][j];
        }
        int left = helper(matrix, dp, i, j - 1);
        int up = helper(matrix, dp, i - 1, j);
        int diag = helper(matrix, dp, i - 1, j - 1);
        dp[i][j] = Math.min(Math.min(left, up), diag) + 1;
        return dp[i][j];
    }
}
Python
class Solution:
    
    def countSquares(self, matrix: List[List[int]]) -> int:
        def helper(i: int, j: int) -> int:
            if i < 0 or j < 0:
                return 0
            if matrix[i][j] == 0:
                return 0
            if dp[i][j] != -1:
                return dp[i][j]
            left = helper(i, j - 1)
            up = helper(i - 1, j)
            diag = helper(i - 1, j - 1)
            dp[i][j] = min(left, up, diag) + 1
            return dp[i][j]

        m, n = len(matrix), len(matrix[0])
        dp = [[-1] * n for _ in range(m)]
        ans = 0
        for i in range(m):
            for j in range(n):
                if matrix[i][j] == 1:
                    ans += helper(i, j)
        return ans

Complexity

  • ⏰ Time complexity: O(m*n), as each cell is computed once and stored.
  • 🧺 Space complexity: O(m*n), for storing the memoization array.

Method 3 - Bottom up DP

Here is the approach:

  1. Define the DP Array: Similar to the memoization approach, create a 2D DP array where dp[i][j] represents the side length of the largest square submatrix ending at (i, j).
  2. Initialization:
    • If matrix[i][j] is 1 and either i or j is 0, set dp[i][j] to 1 (since it’s on the boundary and can only form a 1x1 square).
  3. Iterative Case:
    • For other elements: if matrix[i][j] is 1, then calculate dp[i][j] as 1 + min(dp[i-1][j], dp[i][j-1], dp[i-1][j-1]).
    • Accumulate the sum of all dp[i][j] values to get the total count of all possible square submatrices.

Code

Java
public class Solution {
    public int countSquares(int[][] matrix) {
        int m = matrix.length, n = matrix[0].length;
        int[][] dp = new int[m][n];
        int ans = 0;
        
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                if (matrix[i][j] == 1) {
                    if (i == 0 || j == 0) {
                        dp[i][j] = 1;
                    } else {
                        dp[i][j] = Math.min(Math.min(dp[i-1][j], dp[i][j-1]), dp[i-1][j-1]) + 1;
                    }
                    ans += dp[i][j];
                }
            }
        }
        return ans;
    }
}
Python
class Solution:
    
    def countSquares(self, matrix: List[List[int]]) -> int:
        m, n = len(matrix), len(matrix[0])
        dp = [[0] * n for _ in range(m)]
        ans = 0
        for i in range(m):
            for j in range(n):
                if matrix[i][j] == 1:
                    if i == 0 or j == 0:
                        dp[i][j] = 1
                    else:
                        dp[i][j] = min(dp[i-1][j], dp[i][j-1], dp[i-1][j-1]) + 1
                    ans += dp[i][j]
        return ans

Complexity

  • ⏰ Time complexity: O(m*n), where m and n are dimensions of the matrix.
  • 🧺 Space complexity: O(m*n), for storing the DP array. However, this can be optimized to O(n) using a single row DP array if needed.