Problem

You are given two integers, m and k, and a stream of integers. You are tasked to implement a data structure that calculates the MKAverage for the stream.

The MKAverage can be calculated using these steps:

  1. If the number of the elements in the stream is less than m you should consider the MKAverage to be -1. Otherwise, copy the last m elements of the stream to a separate container.
  2. Remove the smallest k elements and the largest k elements from the container.
  3. Calculate the average value for the rest of the elements rounded down to the nearest integer.

Implement the MKAverage class:

  • MKAverage(int m, int k) Initializes the MKAverage object with an empty stream and the two integers m and k.
  • void addElement(int num) Inserts a new element num into the stream.
  • int calculateMKAverage() Calculates and returns the MKAverage for the current stream rounded down to the nearest integer.

Examples

Example 1

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
**Input**
["MKAverage", "addElement", "addElement", "calculateMKAverage", "addElement", "calculateMKAverage", "addElement", "addElement", "addElement", "calculateMKAverage"]
[[3, 1], [3], [1], [], [10], [], [5], [5], [5], []]
**Output**
[null, null, null, -1, null, 3, null, null, null, 5]

**Explanation**
MKAverage obj = new MKAverage(3, 1); 
obj.addElement(3);        // current elements are [3]
obj.addElement(1);        // current elements are [3,1]
obj.calculateMKAverage(); // return -1, because m = 3 and only 2 elements exist.
obj.addElement(10);       // current elements are [3,1,10]
obj.calculateMKAverage(); // The last 3 elements are [3,1,10].
                          // After removing smallest and largest 1 element the container will be [3].
                          // The average of [3] equals 3/1 = 3, return 3
obj.addElement(5);        // current elements are [3,1,10,5]
obj.addElement(5);        // current elements are [3,1,10,5,5]
obj.addElement(5);        // current elements are [3,1,10,5,5,5]
obj.calculateMKAverage(); // The last 3 elements are [5,5,5].
                          // After removing smallest and largest 1 element the container will be [5].
                          // The average of [5] equals 5/1 = 5, return 5

Constraints

  • 3 <= m <= 10^5
  • 1 <= k*2 < m
  • 1 <= num <= 10^5
  • At most 105 calls will be made to addElement and calculateMKAverage.

Solution

Method 1 – Multiset/SortedList Partitioning

Intuition

To efficiently maintain the last m elements and quickly remove the k smallest and k largest, we use three multisets (or sorted lists): one for the k smallest, one for the k largest, and one for the middle elements. As new elements are added and old ones removed, we rebalance these multisets to always keep the correct partition.

Approach

  1. Use a queue to store the last m elements.
  2. Use three multisets (or sorted lists):
    • left: k smallest elements
    • right: k largest elements
    • mid: the rest (m - 2k elements)
  3. When adding a new element:
    • Add it to the appropriate multiset.
    • If the queue exceeds m, remove the oldest element and rebalance.
    • Always maintain the sizes: left and right have k elements each, mid has m-2k.
  4. To calculate MKAverage:
    • If less than m elements, return -1.
    • Otherwise, return the integer average of the mid set.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import java.util.*;
class MKAverage {
    Queue<Integer> q = new LinkedList<>();
    TreeMap<Integer, Integer> left = new TreeMap<>(), mid = new TreeMap<>(), right = new TreeMap<>();
    int m, k, sizeL = 0, sizeM = 0, sizeR = 0;
    long sumM = 0;
    public MKAverage(int m, int k) {
        this.m = m; this.k = k;
    }
    public void addElement(int num) {
        q.offer(num);
        if (sizeL < k) { add(left, num); sizeL++; }
        else if (sizeM < m - 2 * k) { add(mid, num); sumM += num; sizeM++; }
        else { add(right, num); sizeR++; }
        rebalance();
        if (q.size() > m) {
            int old = q.poll();
            if (left.containsKey(old)) { remove(left, old); sizeL--; }
            else if (right.containsKey(old)) { remove(right, old); sizeR--; }
            else { remove(mid, old); sumM -= old; sizeM--; }
            rebalance();
        }
    }
    public int calculateMKAverage() {
        if (q.size() < m) return -1;
        return (int)(sumM / (m - 2 * k));
    }
    void add(TreeMap<Integer, Integer> map, int x) { map.put(x, map.getOrDefault(x, 0) + 1); }
    void remove(TreeMap<Integer, Integer> map, int x) {
        map.put(x, map.get(x) - 1);
        if (map.get(x) == 0) map.remove(x);
    }
    void rebalance() {
        while (sizeL < k && sizeM > 0) {
            int x = mid.firstKey();
            remove(mid, x); sumM -= x; sizeM--;
            add(left, x); sizeL++;
        }
        while (sizeL > k) {
            int x = left.lastKey();
            remove(left, x); sizeL--;
            add(mid, x); sumM += x; sizeM++;
        }
        while (sizeR < k && sizeM > 0) {
            int x = mid.lastKey();
            remove(mid, x); sumM -= x; sizeM--;
            add(right, x); sizeR++;
        }
        while (sizeR > k) {
            int x = right.firstKey();
            remove(right, x); sizeR--;
            add(mid, x); sumM += x; sizeM++;
        }
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from collections import deque
from sortedcontainers import SortedList

class MKAverage:
    def __init__(self, m: int, k: int):
        self.m, self.k = m, k
        self.q = deque()
        self.left = SortedList()
        self.mid = SortedList()
        self.right = SortedList()
        self.sumM = 0

    def addElement(self, num: int) -> None:
        self.q.append(num)
        if len(self.left) < self.k:
            self.left.add(num)
        elif len(self.mid) < self.m - 2 * self.k:
            self.mid.add(num)
            self.sumM += num
        else:
            self.right.add(num)
        self._rebalance()
        if len(self.q) > self.m:
            old = self.q.popleft()
            if old in self.left:
                self.left.remove(old)
            elif old in self.right:
                self.right.remove(old)
            else:
                self.mid.remove(old)
                self.sumM -= old
            self._rebalance()

    def calculateMKAverage(self) -> int:
        if len(self.q) < self.m:
            return -1
        return self.sumM // (self.m - 2 * self.k)

    def _rebalance(self):
        while len(self.left) < self.k and self.mid:
            x = self.mid.pop(0)
            self.left.add(x)
            self.sumM -= x
        while len(self.left) > self.k:
            x = self.left.pop(-1)
            self.mid.add(x)
            self.sumM += x
        while len(self.right) < self.k and self.mid:
            x = self.mid.pop(-1)
            self.right.add(x)
            self.sumM -= x
        while len(self.right) > self.k:
            x = self.right.pop(0)
            self.mid.add(x)
            self.sumM += x

Complexity

  • ⏰ Time complexity: O(log m) per operation, since each insertion/removal in a SortedList/TreeMap is logarithmic in m.
  • 🧺 Space complexity: O(m), for storing the last m elements and the three multisets.