Problem

Given a positive integer n, return the punishment number of n.

The punishment number of n is defined as the sum of the squares of all integers i such that:

  • 1 <= i <= n
  • The decimal representation of i * i can be partitioned into contiguous substrings such that the sum of the integer values of these substrings equals i.

Examples

Example 1:

Input: n = 10
Output: 182
Explanation: There are exactly 3 integers i that satisfy the conditions in the statement:
- 1 since 1 * 1 = 1
- 9 since 9 * 9 = 81 and 81 can be partitioned into 8 + 1.
- 10 since 10 * 10 = 100 and 100 can be partitioned into 10 + 0.
Hence, the punishment number of 10 is 1 + 81 + 100 = 182

Example 2:

Input: n = 37
Output: 1478
Explanation: There are exactly 4 integers i that satisfy the conditions in the statement:
- 1 since 1 * 1 = 1. 
- 9 since 9 * 9 = 81 and 81 can be partitioned into 8 + 1. 
- 10 since 10 * 10 = 100 and 100 can be partitioned into 10 + 0. 
- 36 since 36 * 36 = 1296 and 1296 can be partitioned into 1 + 29 + 6.
Hence, the punishment number of 37 is 1 + 81 + 100 + 1296 = 1478

Constraints:

  • 1 <= n <= 1000

Solution

Method 1 - Backtracking

To solve the problem, we need to:

  1. Calculate the square of every integer i where 1 <= i <= n.
  2. Check if the square of i can be partitioned into contiguous substrings such that the sum of these substrings equals i.
  3. If the above condition is satisfied, include i*i in the punishment number, which is the final sum.

Here is the approach to check the partition condition:

  • Convert the square of a number (i * i) to a string to facilitate substring manipulation.
  • Use recursion or backtracking to check all possible ways of partitioning the string representation of the square. If any partition’s sum equals i, we include the value in the answer.

Here is how the state space tree for backtracking for i = 36 and i^2 = 1296.

graph TD
    A["1296"] --> B1["1"]
    A --> B2["12"]
    A --> B3["129"]
    A --> B4["1296"]:::endNode

    B1 --> C1["2"]
    B1 --> C2["29"]
    B1 --> C3["296"]:::endNode

    C1 --> D1["9"]
    C1 --> D2["96"]:::endNode

	D1 --> E1["6"]:::endNode

    C2 --> D3["6"]:::success

    B2 --> C4["9"]
    B2 --> C5["96"]:::endNode

    C4 --> D4["6"]:::endNode

    B3 --> C6["6"]:::endNode

    %% Class Definitions for Styling %%
    classDef success fill:#d4f5d4,stroke:#34a853,stroke-width:2px,color:#34a853;
    classDef endNode fill:#ffe6e6,stroke:#ff4d4d,stroke-width:2px,color:#ff4d4d;
  

Video explanation

Here is the video explaining this method in detail. Please check it out:

Code

Java
class Solution {
    public int punishmentNumber(int n) {
        int ans = 0;
        for (int i = 1; i <= n; i++) {
            int sq = i * i;
            if (canPartition(Integer.toString(sq), i)) {
                ans += sq;
            }
        }
        return ans;
    }
        
    private boolean canPartition(String sqStr, int target) {
        return helper(sqStr, 0, 0, target);
    }

    private boolean helper(String sqStr, int idx, int currSum, int target) {
        if (currSum > target) {
            return false;
        }

        if (idx == sqStr.length()) {
            return currSum == target;
        }
        
        for (int j = idx + 1; j <= sqStr.length(); j++) {
            String substr = sqStr.substring(idx, j);
            if (helper(sqStr, j, currSum + Integer.parseInt(substr), target)) {
                return true;
            }
        }
        return false;
    }
}
Python
class Solution:
    def punishment_number(self, n: int) -> int:
        def can_partition(sq_str: str, target: int) -> bool:
            # Helper function to check if we can partition
            def check(idx: int, curr_sum: int) -> bool:
                if idx == len(sq_str):
                    return curr_sum == target
                for j in range(idx + 1, len(sq_str) + 1):
                    substr = sq_str[idx:j]
                    if check(j, curr_sum + int(substr)):
                        return True
                return False

            return check(0, 0)

        ans = 0
        for i in range(1, n + 1):
            sq = i * i
            if can_partition(str(sq), i):
                ans += sq
        return ans

Complexity

  • ⏰ Time complexity: O(n * 2^d), where is d is the number of digits in number on average. For a single number i:
    • Calculating i * i is O(d), where d is the number of digits in the square of i.
    • Checking valid partitions with backtracking is O(2^d) because there are 2^d possible ways to split d digits. Considering all numbers 1 to n, the total complexity becomes approximately O(n * 2^d), where d is the number of digits in the largest square.
  • 🧺 Space complexity: O(d)