Finding MK Average
HardUpdated: Aug 2, 2025
Practice on:
Problem
You are given two integers, m and k, and a stream of integers. You are tasked to implement a data structure that calculates the MKAverage for the stream.
The MKAverage can be calculated using these steps:
- If the number of the elements in the stream is less than
myou should consider the MKAverage to be-1. Otherwise, copy the lastmelements of the stream to a separate container. - Remove the smallest
kelements and the largestkelements from the container. - Calculate the average value for the rest of the elements rounded down to the nearest integer.
Implement the MKAverage class:
MKAverage(int m, int k)Initializes the MKAverage object with an empty stream and the two integersmandk.void addElement(int num)Inserts a new elementnuminto the stream.int calculateMKAverage()Calculates and returns the MKAverage for the current stream rounded down to the nearest integer.
Examples
Example 1
**Input**
["MKAverage", "addElement", "addElement", "calculateMKAverage", "addElement", "calculateMKAverage", "addElement", "addElement", "addElement", "calculateMKAverage"]
[[3, 1], [3], [1], [], [10], [], [5], [5], [5], []]
**Output**
[null, null, null, -1, null, 3, null, null, null, 5]
**Explanation**
MKAverage obj = new MKAverage(3, 1);
obj.addElement(3); // current elements are [3]
obj.addElement(1); // current elements are [3,1]
obj.calculateMKAverage(); // return -1, because m = 3 and only 2 elements exist.
obj.addElement(10); // current elements are [3,1,10]
obj.calculateMKAverage(); // The last 3 elements are [3,1,10].
// After removing smallest and largest 1 element the container will be [3].
// The average of [3] equals 3/1 = 3, return 3
obj.addElement(5); // current elements are [3,1,10,5]
obj.addElement(5); // current elements are [3,1,10,5,5]
obj.addElement(5); // current elements are [3,1,10,5,5,5]
obj.calculateMKAverage(); // The last 3 elements are [5,5,5].
// After removing smallest and largest 1 element the container will be [5].
// The average of [5] equals 5/1 = 5, return 5
Constraints
3 <= m <= 10^51 <= k*2 < m1 <= num <= 10^5- At most
105calls will be made toaddElementandcalculateMKAverage.
Solution
Method 1 – Multiset/SortedList Partitioning
Intuition
To efficiently maintain the last m elements and quickly remove the k smallest and k largest, we use three multisets (or sorted lists): one for the k smallest, one for the k largest, and one for the middle elements. As new elements are added and old ones removed, we rebalance these multisets to always keep the correct partition.
Approach
- Use a queue to store the last m elements.
- Use three multisets (or sorted lists):
- left: k smallest elements
- right: k largest elements
- mid: the rest (m - 2k elements)
- When adding a new element:
- Add it to the appropriate multiset.
- If the queue exceeds m, remove the oldest element and rebalance.
- Always maintain the sizes: left and right have k elements each, mid has m-2k.
- To calculate MKAverage:
- If less than m elements, return -1.
- Otherwise, return the integer average of the mid set.
Code
Java
import java.util.*;
class MKAverage {
Queue<Integer> q = new LinkedList<>();
TreeMap<Integer, Integer> left = new TreeMap<>(), mid = new TreeMap<>(), right = new TreeMap<>();
int m, k, sizeL = 0, sizeM = 0, sizeR = 0;
long sumM = 0;
public MKAverage(int m, int k) {
this.m = m; this.k = k;
}
public void addElement(int num) {
q.offer(num);
if (sizeL < k) { add(left, num); sizeL++; }
else if (sizeM < m - 2 * k) { add(mid, num); sumM += num; sizeM++; }
else { add(right, num); sizeR++; }
rebalance();
if (q.size() > m) {
int old = q.poll();
if (left.containsKey(old)) { remove(left, old); sizeL--; }
else if (right.containsKey(old)) { remove(right, old); sizeR--; }
else { remove(mid, old); sumM -= old; sizeM--; }
rebalance();
}
}
public int calculateMKAverage() {
if (q.size() < m) return -1;
return (int)(sumM / (m - 2 * k));
}
void add(TreeMap<Integer, Integer> map, int x) { map.put(x, map.getOrDefault(x, 0) + 1); }
void remove(TreeMap<Integer, Integer> map, int x) {
map.put(x, map.get(x) - 1);
if (map.get(x) == 0) map.remove(x);
}
void rebalance() {
while (sizeL < k && sizeM > 0) {
int x = mid.firstKey();
remove(mid, x); sumM -= x; sizeM--;
add(left, x); sizeL++;
}
while (sizeL > k) {
int x = left.lastKey();
remove(left, x); sizeL--;
add(mid, x); sumM += x; sizeM++;
}
while (sizeR < k && sizeM > 0) {
int x = mid.lastKey();
remove(mid, x); sumM -= x; sizeM--;
add(right, x); sizeR++;
}
while (sizeR > k) {
int x = right.firstKey();
remove(right, x); sizeR--;
add(mid, x); sumM += x; sizeM++;
}
}
}
Python
from collections import deque
from sortedcontainers import SortedList
class MKAverage:
def __init__(self, m: int, k: int):
self.m, self.k = m, k
self.q = deque()
self.left = SortedList()
self.mid = SortedList()
self.right = SortedList()
self.sumM = 0
def addElement(self, num: int) -> None:
self.q.append(num)
if len(self.left) < self.k:
self.left.add(num)
elif len(self.mid) < self.m - 2 * self.k:
self.mid.add(num)
self.sumM += num
else:
self.right.add(num)
self._rebalance()
if len(self.q) > self.m:
old = self.q.popleft()
if old in self.left:
self.left.remove(old)
elif old in self.right:
self.right.remove(old)
else:
self.mid.remove(old)
self.sumM -= old
self._rebalance()
def calculateMKAverage(self) -> int:
if len(self.q) < self.m:
return -1
return self.sumM // (self.m - 2 * self.k)
def _rebalance(self):
while len(self.left) < self.k and self.mid:
x = self.mid.pop(0)
self.left.add(x)
self.sumM -= x
while len(self.left) > self.k:
x = self.left.pop(-1)
self.mid.add(x)
self.sumM += x
while len(self.right) < self.k and self.mid:
x = self.mid.pop(-1)
self.right.add(x)
self.sumM -= x
while len(self.right) > self.k:
x = self.right.pop(0)
self.mid.add(x)
self.sumM += x
Complexity
- ⏰ Time complexity:
O(log m)per operation, since each insertion/removal in a SortedList/TreeMap is logarithmic in m. - 🧺 Space complexity:
O(m), for storing the last m elements and the three multisets.