Problem

Given a chessboard with n rows and columns OR Given an integer n, compute the total number of axis-aligned squares in an n x n grid (i.e., a chessboard of side length n).

Problem

Given an integer n, compute the total number of axis-aligned squares in an n x n grid (i.e., a chessboard of side length n). As a follow-up, generalize the solution to an m x n grid and return the number of axis-aligned squares (not rectangles).

Examples

Example 1

1
2
3
Input: n = 8
Output: 204
Explanation: Sum of squares 1^2 + 2^2 + ... + 8^2 = 204

If you thought the answer is 64, think again! :)

At first glance you might answer 64, because there are 8 × 8 unit squares. But that only counts 1×1 squares. Larger squares form by grouping adjacent unit squares: a k×k square can start at any of the (8 − k + 1) horizontal positions and any of the (8 − k + 1) vertical positions, so there are (8 − k + 1)^2 distinct k×k squares. For example:

  • 1×1: 8 × 8 = 64
  • 2×2: 7 × 7 = 49
  • 3×3: 6 × 6 = 36
  • 4×4: 5 × 5 = 25
  • 5×5: 4 × 4 = 16
  • 6×6: 3 × 3 = 9
  • 7×7: 2 × 2 = 4
  • 8×8: 1 × 1 = 1

So the total equals 1^2 + 2^2 + … + 8^2 = 64 + 49 + … + 1 = 204 squares.

Follow up

Generalize the solution to an m x n grid and return the number of axis-aligned squares (not rectangles).

Followup Example

1
2
3
Input: m = 2, n = 3
Output: 8
Explanation: 1x1 squares = 2*3 = 6; 2x2 squares = 1*2 = 2; total = 8

Solution

Method 1 - Closed form for n x n (Math formula)

Intuition

For an n x n grid a k x k square can be placed in (n - k + 1)^2 different positions; summing that quantity for k = 1..n yields the total number of squares. The result has a compact closed-form: n * (n + 1) * (2*n + 1) / 6.

Example (n = 8):

  • k = 1(8 - 1 + 1)^2 = 8^2 = 64
  • k = 27^2 = 49
  • k = 36^2 = 36
  • k = 45^2 = 25
  • k = 54^2 = 16
  • k = 63^2 = 9
  • k = 72^2 = 4
  • k = 81^2 = 1

Summing these gives 64 + 49 + 36 + 25 + 16 + 9 + 4 + 1 = 204. As k increases the number of valid placements falls roughly as (n - k + 1)^2, which is why the total accumulates the sequence of square numbers.

Approach

  1. If the input is a single integer n, compute the closed form n * (n + 1) * (2*n + 1) / 6.
  2. Return the result as a 64-bit integer to avoid overflow for moderate n.

Complexity

  • Time complexity: O(1) – single arithmetic formula computed in constant time.
  • 🧺 Space complexity: O(1) – only a constant number of variables used.

Code

1
2
3
4
5
6
7
class Solution {
public:
  // closed-form for n x n
  int countSquares(int n) {
    return n * (n + 1) * (2 * n + 1) / 6;
  }
};
1
2
3
4
5
6
7
8
9
package main

func countSquares(n int) int {
  return n * (n + 1) * (2*n + 1) / 6
}

func main() {
  _ = countSquares(8)
}
1
2
3
4
5
class Solution {
  public int countSquares(int n) {
    return n * (n + 1) * (2 * n + 1) / 6;
  }
}
1
2
3
4
5
class Solution {
  fun countSquares(n: Int): Int {
    return n * (n + 1) * (2 * n + 1) / 6
  }
}
1
2
3
class Solution:
  def countSquares(self, n: int) -> int:
    return n * (n + 1) * (2 * n + 1) // 6
1
2
3
4
5
6
7
pub struct Solution;

impl Solution {
  pub fn countSquares(n: i32) -> i32 {
    n * (n + 1) * (2 * n + 1) / 6
  }
}
1
2
3
function countSquares(n: number): number {
  return (n * (n + 1) * (2 * n + 1)) / 6;
}

Follow up Solution

Method 2 - Direct summation (general m x n)

Intuition

For a general m x n grid, a k x k square can be placed in (m - k + 1) * (n - k + 1) positions. Sum over k = 1..min(m,n) to count all squares.

Approach

  1. Compute limit = min(m, n).
  2. Loop k from 1 to limit and accumulate (m - k + 1) * (n - k + 1).
  3. Return the accumulated sum.

Complexity

  • Time complexity: O(min(m, n)) – we iterate up to the smaller side to sum contributions of each square size.
  • 🧺 Space complexity: O(1) – only a few scalar variables are used.

Code

1
2
3
4
5
6
7
8
9
class Solution {
public:
  int countSquares(int m, int n) {
    int minv = std::min(m, n);
    int ans = 0;
    for (int s = 1; s <= minv; ++s) ans += (m - s + 1) * (n - s + 1);
    return ans;
  }
};
1
2
3
4
5
6
7
8
9
func countSquares(m int, n int) int {
  minv := m
  if n < minv { minv = n }
  ans := 0
  for s := 1; s <= minv; s++ {
    ans += (m - s + 1) * (n - s + 1)
  }
  return ans
}
1
2
3
4
5
6
7
8
class Solution {
  public int countSquares(int m, int n) {
    int min = Math.min(m, n);
    int ans = 0;
    for (int s = 1; s <= min; s++) ans += (m - s + 1) * (n - s + 1);
    return ans;
  }
}
1
2
3
4
5
6
7
8
class Solution {
  fun countSquares(m: Int, n: Int): Int {
    val minv = kotlin.math.min(m, n)
    var ans = 0
    for (s in 1..minv) ans += (m - s + 1) * (n - s + 1)
    return ans
  }
}
1
2
3
4
5
6
7
class SolutionMN:
  def countSquares(self, m: int, n: int) -> int:
    ans = 0
    mn = min(m, n)
    for s in range(1, mn + 1):
      ans += (m - s + 1) * (n - s + 1)
    return ans
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
pub struct Solution;

impl Solution {
  pub fn count_squares(m: i32, n: i32) -> i32 {
    let minv = std::cmp::min(m, n);
    let mut ans = 0i32;
    for s in 1..=minv { ans += (m - s + 1) * (n - s + 1); }
    ans
  }
}
1
2
3
4
5
6
function countSquares(m: number, n: number): number {
  let ans = 0;
  const minv = Math.min(m, n);
  for (let s = 1; s <= minv; s++) ans += (m - s + 1) * (n - s + 1);
  return ans;
}