Problem

Given two sorted arrays, we can form a set of sums by adding one element from the first array and one from the second array. Find the N-th smallest element in the sorted order of this set of sums.

Examples

Example 1:

 Input: arr1 = [1, 7, 11], arr2 = [2, 4, 6], N = 3
 Output: 7
 Explanation: The set of sums from two arrays is [3, 5, 7, 8, 9, 11, 11, 13, 17], and the 3rd smallest sum is 7.

Example 2:

Input: arr1 = [1, 1, 2], arr2 = [1, 2, 3], N = 2
Output: 2
Explanation: The set of sums from two arrays is [2, 2, 3, 3, 4, 4, 4, 5, 5], and the 2nd smallest sum is 2.

Solution

Method 1 - Priority Queue

To solve this problem efficiently, we can use a min-heap (priority queue). The basic idea is to keep track of the smallest sums currently possible and gradually expand from there. Here are the steps:

  1. Initialize a min-heap (priority queue) to keep track of the smallest sums.
  2. Use a set to keep track of visited pairs to avoid duplication.
  3. Start with the smallest pair (arr1[0] + arr2[0]) and expand by pushing the next potential smallest pairs into the heap.
  4. Extract the smallest element from the heap N times to get the N-th smallest sum.

Code

Java
public class Solution {
    public static int nthSmallestSum(int[] arr1, int[] arr2, int N) {
        PriorityQueue<int[]> minHeap =
            new PriorityQueue<>((a, b) -> a[0] - b[0]);
        HashSet<String> visited = new HashSet<>();
        int[] directions = {
            1, 0, 0, 1}; // used to generate adjacent pairs (i+1,j) and (i,j+1)

        // Initial pair (sum, index1, index2)
        minHeap.offer(new int[] {arr1[0] + arr2[0], 0, 0});
        visited.add(0 + "," + 0);

        while (!minHeap.isEmpty() && N > 0) {
            int[] current = minHeap.poll();
            int sum = current[0];
            int i = current[1];
            int j = current[2];
            N--;

            if (N == 0) {
                return sum;
            }

            // Generate next pairs (i+1, j) and (i, j+1)
            for (int k = 0; k < 2; k++) {
                int ni = i + directions[k * 2];
                int nj = j + directions[k * 2 + 1];
                if (ni < arr1.length && nj < arr2.length
                    && !visited.contains(ni + "," + nj)) {
                    minHeap.offer(new int[] {arr1[ni] + arr2[nj], ni, nj});
                    visited.add(ni + "," + nj);
                }
            }
        }

        return -1; // In case the input is invalid and N is too large
    }

    public static void main(String[] args) {
        int[] arr1 = {1, 7, 11};
        int[] arr2 = {2, 4, 6};
        int N = 3;
        System.out.println("The " + N + "-th smallest sum is: "
            + nthSmallestSum(arr1, arr2, N)); // Output: 7

        int[] arr1_2 = {1, 1, 2};
        int[] arr2_2 = {1, 2, 3};
        int N_2 = 2;
        System.out.println("The " + N_2 + "-th smallest sum is: "
            + nthSmallestSum(arr1_2, arr2_2, N_2)); // Output: 2
    }
}
Python
import heapq


def nth_element_of_sums(arr1, arr2, N):
    n1, n2 = len(arr1), len(arr2)

    # Initialize the min-heap
    min_heap = []

    # Push initial sums into the heap
    for i in range(n2):
        heapq.heappush(min_heap, (arr1[0] + arr2[i], 0, i))

    # Pop from the heap N times to get the Nth smallest sum
    for _ in range(N):
        current_sum, i, j = heapq.heappop(min_heap)
        if i + 1 < n1:
            heapq.heappush(min_heap, (arr1[i + 1] + arr2[j], i + 1, j))

    return current_sum


# Example usage:
arr1 = [1, 3, 5]
arr2 = [2, 4, 6]
N = 4
print(
    f"The {N}th smallest sum is: {nth_element_of_sums(arr1, arr2, N)}"
)  # Output: 8

Complexity

  • ⏰ Time complexity: O(n log n), because each insertion and extraction operation in the heap will take (O(\log N)) time, and we perform these operations up to N times.
  • 🧺 Space complexity: O(N + min(N, m * m)), where m and n are the lengths of the two arrays.
    • O(N) from heap size