Problem

You are given an m x n matrix mat that has its rows sorted in non-decreasing order and an integer k.

You are allowed to choose exactly one element from each row to form an array.

Return thekth smallest array sum among all possible arrays.

Examples

Example 1

1
2
3
4
Input: mat = [[1,3,11],[2,4,6]], k = 5
Output: 7
Explanation: Choosing one element from each row, the first k smallest sum are:
[1,2], [1,4], [3,2], [3,4], [1,6]. Where the 5th sum is 7.

Example 2

1
2
Input: mat = [[1,3,11],[2,4,6]], k = 9
Output: 17

Example 3

1
2
3
4
Input: mat = [[1,10,10],[1,4,5],[2,3,6]], k = 7
Output: 9
Explanation: Choosing one element from each row, the first k smallest sum are:
[1,1,2], [1,1,3], [1,4,2], [1,4,3], [1,1,6], [1,5,2], [1,5,3]. Where the 7th sum is 9.  

Constraints

  • m == mat.length
  • n == mat.length[i]
  • 1 <= m, n <= 40
  • 1 <= mat[i][j] <= 5000
  • 1 <= k <= min(200, nm)
  • mat[i] is a non-decreasing array.

Solution

Method 1 – Min Heap + Row-wise Merging

Intuition

We want to find the k-th smallest sum by picking one element from each row. Since each row is sorted, we can use a min-heap to always expand the smallest possible sum, similar to merging k sorted lists.

Approach

  1. Start with the first row as the initial list of sums.
  2. For each subsequent row, merge the current list of smallest sums with the new row using a min-heap, always keeping only the k smallest sums.
  3. Repeat until all rows are processed.
  4. The k-th smallest sum will be the last in the heap after processing all rows.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
class Solution {
public:
    int kthSmallest(vector<vector<int>>& mat, int k) {
        vector<int> ans = mat[0];
        for (int i = 1; i < mat.size(); ++i) {
            priority_queue<int, vector<int>, greater<int>> pq;
            for (int x : ans) {
                for (int y : mat[i]) {
                    pq.push(x + y);
                }
            }
            ans.clear();
            for (int j = 0; j < k && !pq.empty(); ++j) {
                ans.push_back(pq.top());
                pq.pop();
            }
        }
        return ans[k-1];
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
func kthSmallest(mat [][]int, k int) int {
    ans := mat[0]
    for i := 1; i < len(mat); i++ {
        pq := make([]int, 0)
        for _, x := range ans {
            for _, y := range mat[i] {
                pq = append(pq, x+y)
            }
        }
        sort.Ints(pq)
        if len(pq) > k {
            pq = pq[:k]
        }
        ans = pq
    }
    return ans[k-1]
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
class Solution {
    public int kthSmallest(int[][] mat, int k) {
        List<Integer> ans = new ArrayList<>();
        for (int x : mat[0]) ans.add(x);
        for (int i = 1; i < mat.length; i++) {
            PriorityQueue<Integer> pq = new PriorityQueue<>();
            for (int x : ans) {
                for (int y : mat[i]) {
                    pq.offer(x + y);
                }
            }
            ans = new ArrayList<>();
            for (int j = 0; j < k && !pq.isEmpty(); j++) {
                ans.add(pq.poll());
            }
        }
        return ans.get(k-1);
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
class Solution {
    fun kthSmallest(mat: Array<IntArray>, k: Int): Int {
        var ans = mat[0].toList()
        for (i in 1 until mat.size) {
            val pq = PriorityQueue<Int>()
            for (x in ans) {
                for (y in mat[i]) {
                    pq.add(x + y)
                }
            }
            ans = mutableListOf<Int>()
            repeat(k.coerceAtMost(pq.size)) {
                ans.add(pq.poll())
            }
        }
        return ans[k-1]
    }
}
1
2
3
4
5
6
7
8
9
def kthSmallest(mat: list[list[int]], k: int) -> int:
    ans: list[int] = mat[0]
    for i in range(1, len(mat)):
        heap: list[int] = []
        for x in ans:
            for y in mat[i]:
                heappush(heap, x + y)
        ans = [heappop(heap) for _ in range(min(k, len(heap)))]
    return ans[k-1]
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
impl Solution {
    pub fn kth_smallest(mat: Vec<Vec<i32>>, k: i32) -> i32 {
        let mut ans = mat[0].clone();
        for i in 1..mat.len() {
            let mut heap = std::collections::BinaryHeap::new();
            for &x in &ans {
                for &y in &mat[i] {
                    heap.push(std::cmp::Reverse(x + y));
                }
            }
            ans = heap.into_sorted_vec().into_iter().map(|r| r.0).take(k as usize).collect();
        }
        ans[(k-1) as usize]
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
class Solution {
    kthSmallest(mat: number[][], k: number): number {
        let ans = mat[0].slice();
        for (let i = 1; i < mat.length; i++) {
            let heap: number[] = [];
            for (const x of ans) {
                for (const y of mat[i]) {
                    heap.push(x + y);
                }
            }
            heap.sort((a, b) => a - b);
            ans = heap.slice(0, k);
        }
        return ans[k-1];
    }
}

Complexity

  • ⏰ Time complexity: O(m * n^2 * k * log k), because for each of m rows, we merge up to k sums with n elements, and keep the k smallest using a heap.
  • 🧺 Space complexity: O(k), as we only keep up to k sums at each step.