Problem

Given an integer n, return the least number of perfect square numbers that sum to n.

perfect square is an integer that is the square of an integer; in other words, it is the product of some integer with itself. For example, 149, and 16 are perfect squares while 3 and 11 are not.

Examples

Example 1:

Input: n = 12
Output: 3
Explanation: 12 = 4 + 4 + 4.

Example 2:

Input: n = 13
Output: 2
Explanation: 13 = 4 + 9.

Solution

Why Greedy Will Not Work?

This is not a greedy problem. For eg. n = 12. If we start with largest perfect square less than 12, i.e. 3, we have 4 as output.

12 = 9+1+1+1

But actual answer is 3.

Method 1 - Recursion - TLE ❌

Let’s take n=12. Now, to find the perfect squares summing up to a number, we have to iterate from 1 to maximum perfect square less than equal to current number. The maximum perfect square less then n will be √n. For each i, we check if n - j^2 can also be expressed as a sum of perfect squares. This forms following recurrence relation:

f(n) = i * i + f(n - i * i) for all possible 1≤j≤√n for n = 12.

Our goal is to minimize the total number of perfect squares used across all valid combinations.

Code

Java
public int numSquares(int n) {
	if (n < 4) {
		return n;
	}
	
	int ans = n;
	
	for (int i = 1; i * i <= n; i++) {
		int square = i * i;
		ans = Math.min(ans, 1 + numSquares(n - square));
	}
	
	return ans;
}

Method 2 - Top Down DP + Memoization

Now, lets look at the recursion tree (not complete) for previous method.

We can clearly see that we can reach solution in many paths but the least number of perfect squares that sums to n=12 is ps(12) = 2^2+2^2+2^2 which has 3 perfect squares.

Note that the problem has repeating subproblems. For example, f(2), f(3) are appearing at-least twice (note, I have not drawn complete tree, just to give the gist). So, we can use memoization to speed up our algorithm.

We can cache each solution to reuse in other subproblems.

class Solution {
    public int numSquares(int n) {
        return helper(n, new int[n + 1];);
    }
    
    public int helper(int n, int[] memo) {
        if (n < 4) {
	        return n;
        }
            
        
        if (memo[n] != 0) {
            return memo[n];
        }

        int ans = n;
        
        for (int i = 1; i * i <= n; i++) {
            int square = i * i;
            ans = Math.min(ans, 1 + helper(n - square, memo));
        }
        
        return memo[n] = ans;
    }
}

Complexity

  • ⏰ Time complexity: O(n*√n)
  • 🧺 Space complexity: O(n)

Method 3 - Bottom up DP Solution O(n^2)

We can device this recurrence (as we want minimum number of steps).

We can come up with the recurrence:

Let, f(i) is minimum number of perfect squares that sum to i
f(i) = min{1+f(i-j*j)}, for all j, 1j≤√i

The key is to find the relation which is dp[i] = min(dp[i], dp[i-square]+1). For example, dp[5]=dp[4]+1=1+1=2.

In below code target is the number for eg. 12, which we want to calculate the answer for.

public int numSquares(int n) {
	int max = (int) Math.sqrt(n);

	int[] dp = new int[n + 1];
	Arrays.fill(dp, Integer.MAX_VALUE);
	
	dp[0] = 0; // we have 0 squares to reach 0 as target value

	for (int target = 1; target<= n; target++) {
		for (int j = 1; j*j<= target; j++) {
			dp[i] = Math.min(dp[i], dp[i - j * j] + 1);
		}
	}

	return dp[n];
}

public int numSquares(int n) {
	int[] dp = new int[n + 1];
	
	dp[0] = 0;
	
	for (int i = 1; i <= n; i++) {
		dp[i] = i;
		
		for (int j = 1; j * j <= i; j++) {
			int square = j * j;
			dp[i] = Math.min(dp[i], 1 + dp[i - square]);
		}
	}
	
	return dp[n];
}

Complexity

  • ⏰ Time complexity: O(n*√n)
  • 🧺 Space complexity: O(n)

Method 4 - BFS

Why BFS? Because BFS results in wave form traversal or level by level. So, unlike DFS, we don’t have to go deep, but process the closest f(0) node to f(n). Please refer to the below tree (not complete), I draw this tree where the root is n and the difference between two nodes is a perfect square number. We want to get the least number of perfect square numbers that sum to n. This is equivalent to find the shortest path from node n to node 0.

Now, we will be using BFS level by level, and we will see the closest f(0) node from input f(n). Refer Binary Tree Level Order Traversal - Level by Level.

Code

Java
class Solution {
    public int numSquares(int n) {
        Queue<Integer> q = new LinkedList<>();
        q.offer(n); // add initial number
        var visited = new HashSet<Integer>();
        visited.add(n);
        var level = 0;

        while (!q.isEmpty()) {
            level++;
            int sz = q.size();
            while (sz > 0) {
                var curr = q.poll();
                for (var i = 1; i * i <= curr; i++) {
                    var remainder = curr - i * i;

                    if (remainder == 0) {
                        return level;
                    }
                        
                    if (visited.add(remainder)) {
                        q.add(remainder);
                    }
                }
                sz--;
            }
        }
        return n;
    }
}

Complexity

  • ⏰ Time complexity: O(n*√n)
  • 🧺 Space complexity: O(n)

Method 5 - Lagrange’s Four-square Theorem 🏆

Lagrange’s Four-Square Theorem

Lagrange’s four-square theorem, also known as Bachet’s conjecture, states that every natural number can be represented as the sum of four integer squares.

$$ p = a^2 + b^2 + c^2 + d^2 $$ where the four numbers a, b, c, d are all integers.

Lets look at example with - 3, 31 and 310 - can be represented as sum of 4 integer squares: $$ 3 = 1^2 + 1^2 + 1^2 + 0^2 $$

$$ 31 = 5^2 + 2^2 + 1^2 + 1^2 $$

$$ 310 = 17^2 + 4^2 + 2^2 + 1^2 $$

This theorem was proved by Joseph Louis Lagrange in 1770.

Three-Square Theorem

Adrien-Marie Legendre completed the theorem in 1797–8 with his three-square theorem, by proving that a positive integer can be expressed as the sum of three squares if and only if it is not of the form 4^k(8m+7) for integers k and m.

Approach

Our default answer will be 4 squares, but we need to check if we have a solution with 3 or 2 or 1 squares. There can be a 3-square solution if and only if we can’t write n in the form 4^k(8m+7) for integers k and m. If a number itself is a perfect square number then numbers of square is 1. Otherwise we can try break the number into 2 squares i and j such that n=i*i+j*j, for any i, 1≤i≤√n. So, for any natural positive number there are only 4 possible results: 1, 2, 3, 4.

Below is a O(√n) time solution using the above math based solution.

public int numSquares(int n) {
	if (n == 0) {
		return 0;
	}
	// case 1 - perfect square
	if (isSquare(n)) {
		return 1;
	}

	// case 2 - we have 2
	for(int i=1; i*i <= n; i++){
		if(isSquare(n - (i * i))) {
			return 2;
		}
	}

	// case 3 - when n is in form 4^k(8m+7)
	while(n % 4 ==0) {
		n = n/4;
	}
	
	if(n % 8 == 7){
		return 4;    // handling 4^k
	}	
	
	// other cases
	return 3;
}

private boolean isSquare(int n) {
	int sqrtN = (int)(Math.sqrt(n));
	return (sqrtN * sqrtN == n);
}