Problem
Given k
sorted arrays, write an algorithm to merge them into one sorted array.
Examples
Input: arrays = [
[1, 4, 7],
[2, 5, 8],
[3, 6, 9]
]
Output: [1, 2, 3, 4, 5, 6, 7, 8, 9]
Solution
Method 1 - Naive - Merge and sort the array
- Create an array
result[]
with sizen * k
. - Copy all elements from the
k
arrays into theresult
array. This operation takesO(nk)
time. - Sort the
result[]
array using merge sort. This takesO(nk log(nk))
time.
Method 2 - Merge 2 Arrays at a time
One efficient approach is to merge the arrays in pairs. After the first merge, we reduce the number of arrays to k/2
. We repeat the merging process, reducing the number to k/4
, and continue this process until only one array remains. The time complexity of this method is O(nk log k)
. Here’s why: Each merge operation in the first iteration takes O(2n)
time for combining two arrays of size n
. Since there are k/2
merges, the first iteration takes O(nk)
time. Each subsequent iteration also takes O(nk)
time. With O(log k)
iterations in total, the overall time complexity is O(nk log k)
.
Code
Java
public class Solution {
public int[] mergeKSortedArrays(int[][] arrays) {
int k = arrays.length;
if (k == 0) return new int[0];
return mergeKArrays(arrays, 0, k - 1);
}
private int[] mergeKArrays(int[][] arrays, int start, int end) {
if (start == end) {
return arrays[start];
}
int mid = start + (end - start) / 2;
int[] left = mergeKArrays(arrays, start, mid);
int[] right = mergeKArrays(arrays, mid + 1, end);
return mergeTwoArrays(left, right);
}
private int[] mergeTwoArrays(int[] arr1, int[] arr2) {
int len1 = arr1.length, len2 = arr2.length;
int[] result = new int[len1 + len2];
int i = 0, j = 0, k = 0;
while (i < len1 && j < len2) {
if (arr1[i] <= arr2[j]) {
result[k++] = arr1[i++];
} else {
result[k++] = arr2[j++];
}
}
while (i < len1) {
result[k++] = arr1[i++];
}
while (j < len2) {
result[k++] = arr2[j++];
}
return result;
}
public static void main(String[] args) {
Solution sol = new Solution();
int[][] arrays = {
{1, 4, 7},
{2, 5, 8},
{3, 6, 9}
};
int[] result = sol.mergeKSortedArrays(arrays);
System.out.println("Merged array: " + Arrays.toString(result)); // Expected output: [1, 2, 3, 4, 5, 6, 7, 8, 9]
}
}
Python
class Solution:
def mergeKSortedArrays(self, arrays: List[List[int]]) -> List[int]:
if not arrays:
return []
return self.mergeKArrays(arrays, 0, len(arrays) - 1)
def mergeKArrays(self, arrays: List[List[int]], start: int, end: int) -> List[int]:
if start == end:
return arrays[start]
mid = start + (end - start) // 2
left = self.mergeKArrays(arrays, start, mid)
right = self.mergeKArrays(arrays, mid + 1, end)
return self.mergeTwoArrays(left, right)
def mergeTwoArrays(self, arr1: List[int], arr2: List[int]) -> List[int]:
len1, len2 = len(arr1), len(arr2)
result = []
i = j = 0
while i < len1 and j < len2:
if arr1[i] <= arr2[j]:
result.append(arr1[i])
i += 1
else:
result.append(arr2[j])
j += 1
while i < len1:
result.append(arr1[i])
i += 1
while j < len2:
result.append(arr2[j])
j += 1
return result
# Example usage
sol = Solution()
arrays = [
[1, 4, 7],
[2, 5, 8],
[3, 6, 9]
]
result = sol.mergeKSortedArrays(arrays)
print("Merged array:", result) # Expected output: [1, 2, 3, 4, 5, 6, 7, 8, 9]
Complexity
- ⏰ Time complexity:
O(n log k)
, whereN
is the total number of elements across allk
arrays, andk
is the number of arrays. - 🧺 Space complexity:
O(n)
, for storing the merged result in the result array.
Method 3 - Use minHeap
- Use a min-heap to keep track of the smallest elements among the
k
arrays. - Insert the first element of each array into the heap along with the array index and element index.
- Extract the smallest element from the heap, add it to the result array, and then insert the next element from the same array into the heap.
- Repeat until the heap is empty.
Code
Java
public class Solution {
public int[] mergeKSortedArrays(int[][] arrays) {
PriorityQueue<Element> minHeap = new PriorityQueue<>((a, b) -> a.value - b.value);
int totalSize = 0;
// Initialize the heap with the first element of each array
for (int i = 0; i < arrays.length; i++) {
if (arrays[i].length > 0) {
minHeap.add(new Element(arrays[i][0], i, 0));
totalSize += arrays[i].length;
}
}
int[] result = new int[totalSize];
int index = 0;
// Process the heap
while (!minHeap.isEmpty()) {
Element current = minHeap.poll();
result[index++] = current.value;
// If there's a next element in the same array, add it to the heap
if (current.elementIndex + 1 < arrays[current.arrayIndex].length) {
int nextElement = arrays[current.arrayIndex][current.elementIndex + 1];
minHeap.add(new Element(nextElement, current.arrayIndex, current.elementIndex + 1));
}
}
return result;
}
private static class Element {
int value;
int arrayIndex;
int elementIndex;
Element(int value, int arrayIndex, int elementIndex) {
this.value = value;
this.arrayIndex = arrayIndex;
this.elementIndex = elementIndex;
}
}
}
Python
class Solution:
def mergeKSortedArrays(self, arrays: List[List[int]]) -> List[int]:
min_heap = []
total_size = 0
# Initialize the heap with the first element of each array
for i, array in enumerate(arrays):
if array:
heapq.heappush(min_heap, (array[0], i, 0))
total_size += len(array)
result = []
result.extend([0] * total_size) # Pre-allocate space for the result
index = 0
# Process the heap
while min_heap:
value, array_index, element_index = heapq.heappop(min_heap)
result[index] = value
index += 1
# If there's a next element in the same array, add it to the heap
if element_index + 1 < len(arrays[array_index]):
next_element = arrays[array_index][element_index + 1]
heapq.heappush(min_heap, (next_element, array_index, element_index + 1))
return result
Complexity
- ⏰ Time complexity:
O(N log k)
. This is because each insertion and extraction operation in the heap takesO(log k)
. - 🧺 Space complexity:
O(k)
, for storing the heap with up tok
elements at a time.