Problem

You are given a 0-indexed integer array nums.

The distinct count of a subarray of nums is defined as:

  • Let nums[i..j] be a subarray of nums consisting of all the indices from i to j such that 0 <= i <= j < nums.length. Then the number of distinct values in nums[i..j] is called the distinct count of nums[i..j].

Return _the sum of thesquares of distinct counts of all subarrays of _nums.

Since the answer may be very large, return it modulo 10^9 + 7.

A subarray is a contiguous non-empty sequence of elements within an array.

Examples

Example 1

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
Input: nums = [1,2,1]
Output: 15
Explanation: Six possible subarrays are:
[1]: 1 distinct value
[2]: 1 distinct value
[1]: 1 distinct value
[1,2]: 2 distinct values
[2,1]: 2 distinct values
[1,2,1]: 2 distinct values
The sum of the squares of the distinct counts in all subarrays is equal to 12 + 12 + 12 + 22 + 22 + 22 = 15.

Example 2

1
2
3
4
5
6
7
Input: nums = [2,2]
Output: 3
Explanation: Three possible subarrays are:
[2]: 1 distinct value
[2]: 1 distinct value
[2,2]: 1 distinct value
The sum of the squares of the distinct counts in all subarrays is equal to 12 + 12 + 12 = 3.

Constraints

  • 1 <= nums.length <= 10^5
  • 1 <= nums[i] <= 10^5

Solution

Method 1 - Sliding Window with Contribution Counting

Intuition

For each subarray, the number of distinct elements can be tracked using a sliding window and a hash map. To efficiently compute the sum of squares of distinct counts for all subarrays, we can use a contribution method: for each position, count how many subarrays have a given element as a new unique value, and sum up the squares accordingly.

Approach

  1. For each index, keep track of the last occurrence of each value.
  2. For each right endpoint, for each left endpoint where a new unique value appears, count the number of subarrays and their distinct counts.
  3. Use a hash map to track the last seen index and a prefix sum to efficiently compute the result.
  4. For each subarray, add the square of the number of distinct elements to the answer.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
#include <vector>
#include <unordered_map>
using namespace std;
class Solution {
public:
    int sumCounts(vector<int>& nums) {
        const int MOD = 1e9+7;
        int n = nums.size();
        long long ans = 0;
        for (int l = 0; l < n; ++l) {
            unordered_map<int, int> freq;
            int distinct = 0;
            for (int r = l; r < n; ++r) {
                if (++freq[nums[r]] == 1) ++distinct;
                ans = (ans + 1LL * distinct * distinct) % MOD;
            }
        }
        return ans;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
func sumCounts(nums []int) int {
    mod := int(1e9+7)
    n := len(nums)
    ans := 0
    for l := 0; l < n; l++ {
        freq := map[int]int{}
        distinct := 0
        for r := l; r < n; r++ {
            freq[nums[r]]++
            if freq[nums[r]] == 1 {
                distinct++
            }
            ans = (ans + distinct*distinct) % mod
        }
    }
    return ans
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
import java.util.*;
class Solution {
    public int sumCounts(int[] nums) {
        int MOD = 1_000_000_007, n = nums.length;
        long ans = 0;
        for (int l = 0; l < n; ++l) {
            Map<Integer, Integer> freq = new HashMap<>();
            int distinct = 0;
            for (int r = l; r < n; ++r) {
                freq.put(nums[r], freq.getOrDefault(nums[r], 0) + 1);
                if (freq.get(nums[r]) == 1) ++distinct;
                ans = (ans + 1L * distinct * distinct) % MOD;
            }
        }
        return (int)ans;
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
fun sumCounts(nums: IntArray): Int {
    val MOD = 1_000_000_007
    val n = nums.size
    var ans = 0L
    for (l in 0 until n) {
        val freq = mutableMapOf<Int, Int>()
        var distinct = 0
        for (r in l until n) {
            freq[nums[r]] = freq.getOrDefault(nums[r], 0) + 1
            if (freq[nums[r]] == 1) distinct++
            ans = (ans + 1L * distinct * distinct) % MOD
        }
    }
    return ans.toInt()
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
def sumCounts(nums):
    MOD = 10**9 + 7
    n = len(nums)
    ans = 0
    for l in range(n):
        freq = {}
        distinct = 0
        for r in range(l, n):
            freq[nums[r]] = freq.get(nums[r], 0) + 1
            if freq[nums[r]] == 1:
                distinct += 1
            ans = (ans + distinct * distinct) % MOD
    return ans
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
use std::collections::HashMap;
pub fn sum_counts(nums: Vec<i32>) -> i32 {
    let modulo = 1_000_000_007;
    let n = nums.len();
    let mut ans = 0i64;
    for l in 0..n {
        let mut freq = HashMap::new();
        let mut distinct = 0;
        for r in l..n {
            *freq.entry(nums[r]).or_insert(0) += 1;
            if *freq.get(&nums[r]).unwrap() == 1 { distinct += 1; }
            ans = (ans + (distinct * distinct) as i64) % modulo;
        }
    }
    ans as i32
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
function sumCounts(nums: number[]): number {
    const MOD = 1e9 + 7;
    const n = nums.length;
    let ans = 0;
    for (let l = 0; l < n; l++) {
        const freq: Record<number, number> = {};
        let distinct = 0;
        for (let r = l; r < n; r++) {
            freq[nums[r]] = (freq[nums[r]] || 0) + 1;
            if (freq[nums[r]] === 1) distinct++;
            ans = (ans + distinct * distinct) % MOD;
        }
    }
    return ans;
}

Complexity

  • ⏰ Time complexity: O(n^2)
  • 🧺 Space complexity: O(n)