Problem

Given k sorted arrays, write an algorithm to merge them into one sorted array.

Examples

Input: arrays = [
	[1, 4, 7],
	[2, 5, 8],
	[3, 6, 9]
]
Output: [1, 2, 3, 4, 5, 6, 7, 8, 9]

Solution

Method 1 - Naive - Merge and sort the array

  • Create an array result[] with size n * k.
  • Copy all elements from the k arrays into the result array. This operation takes O(nk) time.
  • Sort the result[] array using merge sort. This takes O(nk log(nk)) time.

Method 2 - Merge 2 Arrays at a time

One efficient approach is to merge the arrays in pairs. After the first merge, we reduce the number of arrays to k/2. We repeat the merging process, reducing the number to k/4, and continue this process until only one array remains. The time complexity of this method is O(nk log k). Here’s why: Each merge operation in the first iteration takes O(2n) time for combining two arrays of size n. Since there are k/2 merges, the first iteration takes O(nk) time. Each subsequent iteration also takes O(nk) time. With O(log k) iterations in total, the overall time complexity is O(nk log k).

Code

Java
public class Solution {
    public int[] mergeKSortedArrays(int[][] arrays) {
        int k = arrays.length;
        if (k == 0) return new int[0];
        return mergeKArrays(arrays, 0, k - 1);
    }

    private int[] mergeKArrays(int[][] arrays, int start, int end) {
        if (start == end) {
            return arrays[start];
        }

        int mid = start + (end - start) / 2;
        int[] left = mergeKArrays(arrays, start, mid);
        int[] right = mergeKArrays(arrays, mid + 1, end);

        return mergeTwoArrays(left, right);
    }

    private int[] mergeTwoArrays(int[] arr1, int[] arr2) {
        int len1 = arr1.length, len2 = arr2.length;
        int[] result = new int[len1 + len2];
        int i = 0, j = 0, k = 0;

        while (i < len1 && j < len2) {
            if (arr1[i] <= arr2[j]) {
                result[k++] = arr1[i++];
            } else {
                result[k++] = arr2[j++];
            }
        }

        while (i < len1) {
            result[k++] = arr1[i++];
        }

        while (j < len2) {
            result[k++] = arr2[j++];
        }

        return result;
    }

    public static void main(String[] args) {
        Solution sol = new Solution();
        int[][] arrays = {
            {1, 4, 7},
            {2, 5, 8},
            {3, 6, 9}
        };
        
        int[] result = sol.mergeKSortedArrays(arrays);
        System.out.println("Merged array: " + Arrays.toString(result));  // Expected output: [1, 2, 3, 4, 5, 6, 7, 8, 9]
    }
}
Python
class Solution:
    def mergeKSortedArrays(self, arrays: List[List[int]]) -> List[int]:
        if not arrays:
            return []
        return self.mergeKArrays(arrays, 0, len(arrays) - 1)

    def mergeKArrays(self, arrays: List[List[int]], start: int, end: int) -> List[int]:
        if start == end:
            return arrays[start]

        mid = start + (end - start) // 2
        left = self.mergeKArrays(arrays, start, mid)
        right = self.mergeKArrays(arrays, mid + 1, end)

        return self.mergeTwoArrays(left, right)

    def mergeTwoArrays(self, arr1: List[int], arr2: List[int]) -> List[int]:
        len1, len2 = len(arr1), len(arr2)
        result = []
        i = j = 0

        while i < len1 and j < len2:
            if arr1[i] <= arr2[j]:
                result.append(arr1[i])
                i += 1
            else:
                result.append(arr2[j])
                j += 1

        while i < len1:
            result.append(arr1[i])
            i += 1

        while j < len2:
        result.append(arr2[j])
        j += 1

        return result

# Example usage
sol = Solution()
arrays = [
    [1, 4, 7],
    [2, 5, 8],
    [3, 6, 9]
]
result = sol.mergeKSortedArrays(arrays)
print("Merged array:", result)  # Expected output: [1, 2, 3, 4, 5, 6, 7, 8, 9]

Complexity

  • ⏰ Time complexity: O(n log k), where N is the total number of elements across all k arrays, and k is the number of arrays.
  • 🧺 Space complexity: O(n), for storing the merged result in the result array.

Method 3 - Use minHeap

  • Use a min-heap to keep track of the smallest elements among the k arrays.
  • Insert the first element of each array into the heap along with the array index and element index.
  • Extract the smallest element from the heap, add it to the result array, and then insert the next element from the same array into the heap.
  • Repeat until the heap is empty.

Code

Java
public class Solution {
    public int[] mergeKSortedArrays(int[][] arrays) {
        PriorityQueue<Element> minHeap = new PriorityQueue<>((a, b) -> a.value - b.value);
        int totalSize = 0;
        
        // Initialize the heap with the first element of each array
        for (int i = 0; i < arrays.length; i++) {
            if (arrays[i].length > 0) {
                minHeap.add(new Element(arrays[i][0], i, 0));
                totalSize += arrays[i].length;
            }
        }
        
        int[] result = new int[totalSize];
        int index = 0;

        // Process the heap
        while (!minHeap.isEmpty()) {
            Element current = minHeap.poll();
            result[index++] = current.value;
            
            // If there's a next element in the same array, add it to the heap
            if (current.elementIndex + 1 < arrays[current.arrayIndex].length) {
                int nextElement = arrays[current.arrayIndex][current.elementIndex + 1];
                minHeap.add(new Element(nextElement, current.arrayIndex, current.elementIndex + 1));
            }
        }
        
        return result;
    }

    private static class Element {
        int value;
        int arrayIndex;
        int elementIndex;

        Element(int value, int arrayIndex, int elementIndex) {
            this.value = value;
            this.arrayIndex = arrayIndex;
            this.elementIndex = elementIndex;
        }
    }
}
Python
class Solution:
    def mergeKSortedArrays(self, arrays: List[List[int]]) -> List[int]:
        min_heap = []
        total_size = 0

        # Initialize the heap with the first element of each array
        for i, array in enumerate(arrays):
            if array:
                heapq.heappush(min_heap, (array[0], i, 0))
                total_size += len(array)

        result = []
        result.extend([0] * total_size)  # Pre-allocate space for the result
        index = 0

        # Process the heap
        while min_heap:
            value, array_index, element_index = heapq.heappop(min_heap)
            result[index] = value
            index += 1

            # If there's a next element in the same array, add it to the heap
            if element_index + 1 < len(arrays[array_index]):
                next_element = arrays[array_index][element_index + 1]
                heapq.heappush(min_heap, (next_element, array_index, element_index + 1))

        return result

Complexity

  • ⏰ Time complexity: O(N log k). This is because each insertion and extraction operation in the heap takes O(log k).
  • 🧺 Space complexity: O(k),  for storing the heap with up to k elements at a time.