problemhardalgorithmsleetcode-3321leetcode 3321leetcode3321

Find X-Sum of All K-Long Subarrays II

HardUpdated: Nov 10, 2025
Practice on:

Problem

You are given an array nums of n integers and two integers k and x.

The x-sum of an array is calculated by the following procedure:

  • Count the occurrences of all elements in the array.
  • Keep only the occurrences of the top x most frequent elements. If two elements have the same number of occurrences, the element with the bigger value is considered more frequent.
  • Calculate the sum of the resulting array.

Note that if an array has less than x distinct elements, its x-sum is the sum of the array.

Return an integer array answer of length n - k + 1 where answer[i] is the x-sum of the subarray nums[i..i + k - 1].

Examples

Example 1


Input: nums = [1,1,2,2,3,4,2,3], k = 6, x = 2

Output: [6,10,12]

Explanation:

  * For subarray `[1, 1, 2, 2, 3, 4]`, only elements 1 and 2 will be kept in the resulting array. Hence, `answer[0] = 1 + 1 + 2 + 2`.
  * For subarray `[1, 2, 2, 3, 4, 2]`, only elements 2 and 4 will be kept in the resulting array. Hence, `answer[1] = 2 + 2 + 2 + 4`. Note that 4 is kept in the array since it is bigger than 3 and 1 which occur the same number of times.
  * For subarray `[2, 2, 3, 4, 2, 3]`, only elements 2 and 3 are kept in the resulting array. Hence, `answer[2] = 2 + 2 + 2 + 3 + 3`.

Example 2


Input: nums = [3,8,7,8,7,5], k = 2, x = 2

Output: [11,15,15,15,12]

Explanation:

Since `k == x`, `answer[i]` is equal to the sum of the subarray `nums[i..i + k
- 1]`.

Constraints

  • nums.length == n
  • 1 <= n <= 10^5
  • 1 <= nums[i] <= 10^9
  • 1 <= x <= k <= nums.length

Solution

Method 1 – Sliding Window with Ordered Sets for Top-x Elements

Intuition

For each window, we want to maintain the top x elements ranked by (frequency desc, value desc). By keeping two ordered sets—one for the top x and one for the rest—we can efficiently update the running sum as the window slides, only updating frequencies and set membership for the outgoing and incoming elements.

Approach

  1. Use a frequency map to track the count of each value in the current window.
  2. Maintain two ordered sets (sorted by frequency descending, value descending):
    • top: the best x entries (or fewer if not enough unique values).
    • rest: all other entries.
  3. Keep a running sum of freq * value for entries in top.
  4. For each window slide:
    • Remove the outgoing value and update sets and sum.
    • Insert the incoming value and update sets and sum.
    • Record the current sum.
  5. Return the result array.

Complexity

  • Time complexity: O(n log U) – Each insert/erase is O(log U) where U is the number of distinct values in the window.
  • 🧺 Space complexity: O(U) – For the frequency map and set entries.

Code

C++
#include <vector>
#include <unordered_map>
#include <set>
using namespace std;

class Solution {
    struct Entry {
        int freq, val;
        bool operator<(const Entry& other) const {
            if (freq != other.freq) return freq > other.freq;
            return val > other.val;
        }
        bool operator==(const Entry& other) const {
            return freq == other.freq && val == other.val;
        }
    };
public:
    vector<long long> findXSum(vector<int>& nums, int k, int x) {
        int n = nums.size();
        vector<long long> ans(n - k + 1);
        unordered_map<int, int> freq;
        set<Entry> top, rest;
        long long currSum = 0;
        auto insertVal = [&](int v) {
            int f = freq[v];
            if (f > 0) {
                Entry old{f, v};
                if (top.erase(old)) currSum -= 1LL * f * v;
                else rest.erase(old);
            }
            f++;
            freq[v] = f;
            Entry now{f, v};
            top.insert(now);
            currSum += 1LL * f * v;
            if ((int)top.size() > x) {
                auto it = prev(top.end());
                currSum -= 1LL * it->freq * it->val;
                rest.insert(*it);
                top.erase(it);
            }
        };
        auto eraseVal = [&](int v) {
            int f = freq[v];
            if (f == 0) return;
            Entry cur{f, v};
            if (top.erase(cur)) currSum -= 1LL * f * v;
            else rest.erase(cur);
            f--;
            if (f == 0) freq.erase(v);
            else {
                freq[v] = f;
                rest.insert({f, v});
            }
            if ((int)top.size() < x && !rest.empty()) {
                auto it = rest.begin();
                top.insert(*it);
                currSum += 1LL * it->freq * it->val;
                rest.erase(it);
            }
        };
        for (int i = 0; i < k; i++) insertVal(nums[i]);
        ans[0] = currSum;
        for (int l = 1, r = k; r < n; l++, r++) {
            eraseVal(nums[l - 1]);
            insertVal(nums[r]);
            ans[l] = currSum;
        }
        return ans;
    }
};
Go
import (
    "container/heap"
    "sort"
)

type entry struct{ freq, val int }

type entryHeap []entry

func (h entryHeap) Len() int      { return len(h) }
func (h entryHeap) Less(i, j int) bool {
    if h[i].freq != h[j].freq { return h[i].freq > h[j].freq }
    return h[i].val > h[j].val
}
func (h entryHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *entryHeap) Push(x interface{}) { *h = append(*h, x.(entry)) }
func (h *entryHeap) Pop() interface{} {
    old := *h
    n := len(old)
    x := old[n-1]
    *h = old[:n-1]
    return x
}

func findXSum(nums []int, k, x int) []int64 {
    n := len(nums)
    ans := make([]int64, n-k+1)
    freq := map[int]int{}
    h := &entryHeap{}
    for i := 0; i < k; i++ { freq[nums[i]]++ }
    for i := 0; i < n-k+1; i++ {
        h = &entryHeap{}
        for v, f := range freq { heap.Push(h, entry{f, v}) }
        sum, cnt := int64(0), 0
        for h.Len() > 0 && cnt < x {
            e := heap.Pop(h).(entry)
            sum += int64(e.freq) * int64(e.val)
            cnt++
        }
        ans[i] = sum
        if i+k < n {
            freq[nums[i]]--
            if freq[nums[i]] == 0 { delete(freq, nums[i]) }
            freq[nums[i+k]]++
        }
    }
    return ans
}
Java
class Solution {
    private static final class Entry {
        final int freq;
        final int val;
        Entry(int freq, int val) { this.freq = freq; this.val = val; }
        @Override public boolean equals(Object o) {
            if (this == o) return true;
            if (!(o instanceof Entry)) return false;
            Entry e = (Entry) o;
            return freq == e.freq && val == e.val;
        }
        @Override public int hashCode() { return Objects.hash(freq, val); }
    }

    private final Comparator<Entry> byFreqDescValDesc = (a, b) -> {
        if (a.freq != b.freq) return Integer.compare(b.freq, a.freq);
        return Integer.compare(b.val, a.val);
    };

    private Map<Integer, Integer> freq;
    private TreeSet<Entry> top, rest;
    private long currSum;

    private void insertVal(int v, int x) {
        int f = freq.getOrDefault(v, 0);
        if (f > 0) {
            Entry old = new Entry(f, v);
            if (top.remove(old)) currSum -= 1L * f * v;
            else rest.remove(old);
        }
        f++;
        freq.put(v, f);
        Entry now = new Entry(f, v);
        top.add(now);
        currSum += 1L * f * v;
        if (top.size() > x) {
            Entry smallestTop = top.last();
            currSum -= 1L * smallestTop.freq * smallestTop.val;
            rest.add(smallestTop);
            top.remove(smallestTop);
        }
    }

    private void eraseVal(int v, int x) {
        Integer cur = freq.get(v);
        if (cur == null || cur == 0) return;
        int f = cur;
        Entry curEntry = new Entry(f, v);
        if (top.remove(curEntry)) currSum -= 1L * f * v;
        else rest.remove(curEntry);
        f--;
        if (f == 0) freq.remove(v);
        else {
            freq.put(v, f);
            rest.add(new Entry(f, v));
        }
        if (top.size() < x && !rest.isEmpty()) {
            Entry bestRest = rest.first();
            rest.remove(bestRest);
            top.add(bestRest);
            currSum += 1L * bestRest.freq * bestRest.val;
        }
    }

    public long[] findXSum(int[] nums, int k, int x) {
        int n = nums.length;
        long[] ans = new long[n - k + 1];
        freq = new HashMap<>(Math.max(16, n * 2));
        top = new TreeSet<>(byFreqDescValDesc);
        rest = new TreeSet<>(byFreqDescValDesc);
        currSum = 0L;
        for (int i = 0; i < k; i++) insertVal(nums[i], x);
        ans[0] = currSum;
        for (int l = 1, r = k; r < n; l++, r++) {
            eraseVal(nums[l - 1], x);
            insertVal(nums[r], x);
            ans[l] = currSum;
        }
        return ans;
    }
}
Kotlin
class Solution {
    data class Entry(val freq: Int, val val_: Int)
    fun findXSum(nums: IntArray, k: Int, x: Int): LongArray {
        val n = nums.size
        val ans = LongArray(n - k + 1)
        val freq = mutableMapOf<Int, Int>()
        val top = java.util.TreeSet(compareByDescending<Entry> { it.freq }.thenByDescending { it.val_ })
        val rest = java.util.TreeSet(compareByDescending<Entry> { it.freq }.thenByDescending { it.val_ })
        var currSum = 0L
        fun insertVal(v: Int) {
            val f = freq.getOrDefault(v, 0)
            if (f > 0) {
                val old = Entry(f, v)
                if (top.remove(old)) currSum -= f.toLong() * v
                else rest.remove(old)
            }
            val nf = f + 1
            freq[v] = nf
            val now = Entry(nf, v)
            top.add(now)
            currSum += nf.toLong() * v
            if (top.size > x) {
                val smallestTop = top.last()
                currSum -= smallestTop.freq.toLong() * smallestTop.val_
                rest.add(smallestTop)
                top.remove(smallestTop)
            }
        }
        fun eraseVal(v: Int) {
            val f = freq[v] ?: return
            val curEntry = Entry(f, v)
            if (top.remove(curEntry)) currSum -= f.toLong() * v
            else rest.remove(curEntry)
            val nf = f - 1
            if (nf == 0) freq.remove(v)
            else {
                freq[v] = nf
                rest.add(Entry(nf, v))
            }
            if (top.size < x && rest.isNotEmpty()) {
                val bestRest = rest.first()
                rest.remove(bestRest)
                top.add(bestRest)
                currSum += bestRest.freq.toLong() * bestRest.val_
            }
        }
        for (i in 0 until k) insertVal(nums[i])
        ans[0] = currSum
        for (l in 1..(n - k)) {
            eraseVal(nums[l - 1])
            insertVal(nums[l + k - 1])
            ans[l] = currSum
        }
        return ans
    }
}
Python
from collections import Counter
import bisect

class Solution:
    def findXSum(self, nums: list[int], k: int, x: int) -> list[int]:
        n = len(nums)
        freq = Counter()
        top, rest = [], []  # Each is a list of (freq, val), sorted by (-freq, -val)
        def add_entry(lst, entry):
            bisect.insort(lst, (-entry[0], -entry[1]))
        def remove_entry(lst, entry):
            idx = bisect.bisect_left(lst, (-entry[0], -entry[1]))
            if idx < len(lst) and lst[idx] == (-entry[0], -entry[1]):
                lst.pop(idx)
        currSum = 0
        for i in range(k):
            v = nums[i]
            f = freq[v]
            if f > 0:
                remove_entry(top, (f, v))
                remove_entry(rest, (f, v))
            f += 1
            freq[v] = f
            add_entry(top, (f, v))
            currSum += f * v
            if len(top) > x:
                smallestTop = top.pop()
                currSum -= -smallestTop[0] * -smallestTop[1]
                add_entry(rest, (-smallestTop[0], -smallestTop[1]))
        ans = [currSum]
        for l in range(1, n - k + 1):
            v = nums[l - 1]
            f = freq[v]
            remove_entry(top, (f, v))
            remove_entry(rest, (f, v))
            f -= 1
            if f == 0:
                del freq[v]
            else:
                freq[v] = f
                add_entry(rest, (f, v))
            if len(top) < x and rest:
                bestRest = rest.pop(0)
                add_entry(top, (-bestRest[0], -bestRest[1]))
                currSum += -bestRest[0] * -bestRest[1]
            v = nums[l + k - 1]
            f = freq.get(v, 0)
            if f > 0:
                remove_entry(top, (f, v))
                remove_entry(rest, (f, v))
            f += 1
            freq[v] = f
            add_entry(top, (f, v))
            currSum += f * v
            if len(top) > x:
                smallestTop = top.pop()
                currSum -= -smallestTop[0] * -smallestTop[1]
                add_entry(rest, (-smallestTop[0], -smallestTop[1]))
            ans.append(currSum)
        return ans
Rust
use std::collections::{BTreeSet, HashMap};

impl Solution {
    pub fn find_x_sum(nums: Vec<i32>, k: i32, x: i32) -> Vec<i64> {
        type LL = i64;
        type PII = (i32, i32);
        let (k, x) = (k as usize, x as usize);
        let n = nums.len();
        let mut res = vec![0i64; n - k + 1];
        let mut freq: HashMap<i32, i32> = HashMap::new();
        let mut active: BTreeSet<PII> = BTreeSet::new();
        let mut bank: BTreeSet<PII> = BTreeSet::new();
        let mut sum: LL = 0;

        fn upd_all(sum: &mut LL, active: &mut BTreeSet<PII>, bank: &mut BTreeSet<PII>, x: usize) {
            if bank.is_empty() {
                return;
            }
            let mx = *bank.iter().rev().next().unwrap();
            let mv_opt = active.iter().next().cloned();
            if active.len() < x || mv_opt.map_or(true, |mv| mx > mv) {
                *sum += mx.0 as LL * mx.1 as LL;
                active.insert(mx);
                bank.remove(&mx);
                if active.len() > x {
                    let mv = mv_opt.unwrap();
                    *sum -= mv.0 as LL * mv.1 as LL;
                    bank.insert(mv);
                    active.remove(&mv);
                }
            }
        }

        fn upd_val(
            val: i32,
            delta: i32,
            sum: &mut LL,
            freq: &mut HashMap<i32, i32>,
            active: &mut BTreeSet<PII>,
            bank: &mut BTreeSet<PII>,
            x: usize,
        ) {
            let old = *freq.get(&val).unwrap_or(&0);
            let new = old + delta;
            freq.insert(val, new);
            let og = (old, val);
            let ng = (new, val);

            if active.remove(&og) {
                *sum -= og.0 as LL * og.1 as LL;
                active.insert(ng);
                *sum += ng.0 as LL * ng.1 as LL;
            } else {
                bank.remove(&og);
                bank.insert(ng);
            }
            upd_all(sum, active, bank, x);
        }

        for &v in nums.iter().take(k) {
            upd_val(v, 1, &mut sum, &mut freq, &mut active, &mut bank, x);
        }
        res[0] = sum;

        for i in k..n {
            upd_val(nums[i], 1, &mut sum, &mut freq, &mut active, &mut bank, x);
            upd_val(nums[i - k], -1, &mut sum, &mut freq, &mut active, &mut bank, x);
            res[i - k + 1] = sum;
        }

        res
    }
}
TypeScript
class AVLNode<T> {
    constructor(
        public value: T,
        public height = 1,
        public left: AVLNode<T> | null = null,
        public right: AVLNode<T> | null = null
    ) { }
}

class SortedSet<T> {
    public size = 0;
    private root: AVLNode<T> | null = null;
    constructor(private compare: (a: T, b: T) => number) { }
    // ... AVL tree implementation as above ...
}

class SortedSetWithSum extends SortedSet<number[]> {
    sum = 0
    add([num, freq]: number[]) {
        super.add([num, freq]);
        this.sum += num * freq;
    }
    delete([num, freq]: number[]) {
        super.delete([num, freq]);
        this.sum -= num * freq;
    }
}

function findXSum(nums: number[], k: number, x: number): number[] {
    const cmp = ([n1, f1]: number[], [n2, f2]: number[]) => {
        if (f1 !== f2) {
            return f1 - f2;
        }
        return n1 - n2;
    };
    const topx = new SortedSetWithSum(cmp);
    const rest = new SortedSetWithSum(cmp);
    const fr = new Map<number, number>();

    function insert(n: number) {
        const currentFreq = fr.get(n) ?? 0;
        fr.set(n, currentFreq + 1);
        if(currentFreq) {
            const item = [n, currentFreq];
            if (topx.find(item)) {
                topx.delete(item);
                item[1]++;
                topx.add(item);
            } else {
                rest.delete(item);
                item[1]++;
                rest.add(item);
            }
        } else {
            if (topx.size < x) {
                topx.add([n, 1]);
            } else {
                rest.add([n, 1]);
            }
        }
        maybeSwap();
    }

    function maybeSwap() {
        if (topx.size === x && rest.size > 0) {
            const topxMin = topx.first()!;
            const restMax = rest.last()!;
            if (cmp(topxMin, restMax) < 0) {
                topx.delete(topxMin);
                rest.delete(restMax);
                topx.add(restMax);
                rest.add(topxMin);
            }
        }
    }

    function remove(n: number) {
        const currentFreq = fr.get(n) ?? 0;
        fr.set(n, currentFreq - 1);
        const item = [n, currentFreq];
        if (topx.find(item)) {
            topx.delete(item);
            item[1]--;
            if(item[1]) {
                topx.add(item);
            } else {
                if(rest.size) {
                    const restMax = rest.last()!;
                    rest.delete(restMax);
                    topx.add(restMax);
                }
            }
        } else {
            rest.delete(item);
            item[1]--;
            if(item[1]) {
                rest.add(item);
            }
        }
        maybeSwap();
    }

    const result: number[] = [];
    for(let i = 0; i < nums.length; i++) {
        if(i >= k) {
            remove(nums[i - k]);
        }
        insert(nums[i]);
        if(i >= k - 1) {
            result.push(topx.sum);
        }
    }
    return result;
}

Comments