Problem

You are given a 1**-indexed** array nums. Your task is to select a complete subset from nums where every pair of selected indices multiplied is a perfect square,. i. e. if you select ai and aj, i * j must be a perfect square.

Return the sum of the complete subset with the maximum sum.

Examples

Example 1

1
2
3
4
5
6
7
8

Input: nums = [8,7,3,5,7,2,4,9]

Output: 16

Explanation:

We select elements at indices 2 and 8 and `2 * 8` is a perfect square.

Example 2

1
2
3
4
5
6
7
8
9

Input: nums = [8,10,3,8,1,13,7,9,4]

Output: 20

Explanation:

We select elements at indices 1, 4, and 9. `1 * 4`, `1 * 9`, `4 * 9` are
perfect squares.

Constraints

  • 1 <= n == nums.length <= 10^4
  • 1 <= nums[i] <= 10^9

Solution

Method 1 – Group by Square-Free Part

Intuition

For indices i and j, i * j is a perfect square if and only if their square-free parts are equal. So, group indices by their square-free part and sum the maximum group.

Approach

  1. For each index i (1-based), compute its square-free part (product of primes with odd exponent in its factorization).
  2. Group all indices with the same square-free part.
  3. For each group, sum the corresponding nums values.
  4. The answer is the maximum sum among all groups.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
class Solution {
public:
    int maxSum(vector<int>& nums) {
        int n = nums.size();
        unordered_map<int, int> group;
        for (int i = 1; i <= n; ++i) {
            int x = i, sqf = 1;
            for (int d = 2; d * d <= x; ++d) {
                int cnt = 0;
                while (x % d == 0) x /= d, cnt++;
                if (cnt % 2) sqf *= d;
            }
            if (x > 1) sqf *= x;
            group[sqf] += nums[i-1];
        }
        int ans = 0;
        for (auto& [_, s] : group) ans = max(ans, s);
        return ans;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
func maxSum(nums []int) int {
    n := len(nums)
    group := map[int]int{}
    for i := 1; i <= n; i++ {
        x, sqf := i, 1
        d := 2
        for d*d <= x {
            cnt := 0
            for x%d == 0 {
                x /= d
                cnt++
            }
            if cnt%2 == 1 {
                sqf *= d
            }
            d++
        }
        if x > 1 {
            sqf *= x
        }
        group[sqf] += nums[i-1]
    }
    ans := 0
    for _, s := range group {
        if s > ans {
            ans = s
        }
    }
    return ans
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Solution {
    public int maximumSum(int[] nums) {
        int n = nums.length;
        Map<Integer, Integer> group = new HashMap<>();
        for (int i = 1; i <= n; i++) {
            int x = i, sqf = 1;
            for (int d = 2; d * d <= x; d++) {
                int cnt = 0;
                while (x % d == 0) {
                    x /= d;
                    cnt++;
                }
                if (cnt % 2 == 1) sqf *= d;
            }
            if (x > 1) sqf *= x;
            group.put(sqf, group.getOrDefault(sqf, 0) + nums[i-1]);
        }
        int ans = 0;
        for (int s : group.values()) ans = Math.max(ans, s);
        return ans;
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class Solution {
    fun maximumSum(nums: IntArray): Int {
        val n = nums.size
        val group = mutableMapOf<Int, Int>()
        for (i in 1..n) {
            var x = i
            var sqf = 1
            var d = 2
            while (d * d <= x) {
                var cnt = 0
                while (x % d == 0) {
                    x /= d
                    cnt++
                }
                if (cnt % 2 == 1) sqf *= d
                d++
            }
            if (x > 1) sqf *= x
            group[sqf] = group.getOrDefault(sqf, 0) + nums[i-1]
        }
        return group.values.maxOrNull() ?: 0
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
class Solution:
    def maximumSum(self, nums: list[int]) -> int:
        from collections import defaultdict
        def square_free(x: int) -> int:
            d, sqf = 2, 1
            while d * d <= x:
                cnt = 0
                while x % d == 0:
                    x //= d
                    cnt += 1
                if cnt % 2:
                    sqf *= d
                d += 1
            if x > 1:
                sqf *= x
            return sqf
        group = defaultdict(int)
        for i in range(1, len(nums)+1):
            group[square_free(i)] += nums[i-1]
        return max(group.values())
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
impl Solution {
    pub fn maximum_sum(nums: Vec<i32>) -> i32 {
        use std::collections::HashMap;
        fn square_free(mut x: usize) -> usize {
            let mut sqf = 1;
            let mut d = 2;
            while d * d <= x {
                let mut cnt = 0;
                while x % d == 0 {
                    x /= d;
                    cnt += 1;
                }
                if cnt % 2 == 1 {
                    sqf *= d;
                }
                d += 1;
            }
            if x > 1 {
                sqf *= x;
            }
            sqf
        }
        let mut group = HashMap::new();
        for (i, &v) in nums.iter().enumerate() {
            let sqf = square_free(i+1);
            *group.entry(sqf).or_insert(0) += v;
        }
        *group.values().max().unwrap()
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
class Solution {
    maximumSum(nums: number[]): number {
        const n = nums.length;
        const group = new Map<number, number>();
        for (let i = 1; i <= n; i++) {
            let x = i, sqf = 1, d = 2;
            while (d * d <= x) {
                let cnt = 0;
                while (x % d === 0) {
                    x /= d;
                    cnt++;
                }
                if (cnt % 2 === 1) sqf *= d;
                d++;
            }
            if (x > 1) sqf *= x;
            group.set(sqf, (group.get(sqf) ?? 0) + nums[i-1]);
        }
        return Math.max(...group.values());
    }
}

Complexity

  • ⏰ Time complexity: O(n√n), for each index up to n, we factorize up to √n.
  • 🧺 Space complexity: O(n), for the group map.