Input: matrix = [
[1,5,9],
[10,11,13],
[12,13,15]
], k = 8
Output: 13
Explanation: The elements in the matrix are [1,5,9,10,11,12,13,13,15], and the 8th smallest number is 13
The example 1 given is not good enough. Lets look at more examples.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
Let’s take some examples to understand better the problem.# example 1 : any number from row i+1 is higher than anyone from row i. But we can't establish any order between columns123456789# example 2: any number from col i+1 is higher than anyone from col i. But we can't establish any order between rows147258369# example 3: We can't establish any order between columns or rows125348679
Now that we understand the problem better, lets try to solve it.
The simplest way to find the kth smallest element is to flatten the matrix into a single array, sort it, and pick the kth element. This approach does not leverage the sorted properties of the matrix rows and columns, but it is straightforward and easy to implement.
We know that merging two sorted arrays can be done in O(N) time.
A straightforward approach is to merge the first row with the second, then merge this result with the third row, and continue this process for all rows.
In the worst case, merging two arrays of sizes N and M requires N+M comparisons. The total time complexity for merging all rows is O(2N + 3N + 4N + … + (N-1)N) = O(N³).
This approach is not efficient for large matrices.
A divide and conquer strategy can be used to optimize merging: first, merge row 1 with row 2 and row 3 with row 4 separately, then merge these two results together, and repeat this process in levels until all rows are combined into one array.
Each level of merging requires O(N²) time, and with O(logN) levels, the overall time complexity is O(N²logN), which does not improve upon the previous approach.
Let’s pause and visualize the problem, as drawing it often reveals new insights and is helpful when stuck.
Consider the arrows that show the relationships in the matrix. Here, orange arrows indicate order within each row, and green arrows show order within each column.
This structure resembles a directed graph (on the right). If we rotate the matrix 45 degrees, we can see how each cell connects to its children.
In this graph, every cell is smaller than its children. The smallest element is always matrix[0][0]. The next smallest could be either matrix[0][1] or matrix[1][0].
If matrix[0][1] is chosen as the second smallest, then matrix[1][0] and the children of matrix[0][1] (which are matrix[0][2] and matrix[1][1]) become candidates for the third smallest. Note that matrix[1][1] is also a child of matrix[1][0], so it can’t be considered before its ancestor.
Thus, the true candidates for the third smallest are matrix[1][0] and matrix[0][2].
To efficiently find the next smallest, we need a data structure that can always give us the minimum among current candidates, remove it, and add its children. A priority queue is ideal for this purpose.
While we could avoid adding a child if its ancestor is still present, simply adding both children is sufficient for correctness. However, we must ensure that we do not revisit the same cell, as duplicates can occur when adding nodes this way.
staticclassNode { // Inner class for combining row and column into single objectint row, col;
Node(int r, int c) {
row = r;
col = c;
}
}
publicintkthSmallest(int[][] matrix, int k) {
Set<String> visited =new HashSet<>();
Queue<Node> minHeap =new PriorityQueue<>((a, b) -> matrix[a.row][a.col]- matrix[b.row][b.col]);
// add the first node Node source =new Node(0,0);
minHeap.add(new Node(0, 0));
visited.add(getKey(source));
int m = matrix.length;
int n = matrix[0].length;//m and n are same, but just to make genericwhile (k > 1) {
Node node = minHeap.poll();
Node rowNode =new Node(node.row+1, node.col);
String rowKey = getKey(rowNode);
if (node.row+ 1< m &&!visited.contains(rowKey)) {
minHeap.add(rowNode);
visited.add(rowKey);
}
Node colNode =new Node(node.row, node.col+1);
String colKey = getKey(colNode);
if (node.col+ 1< n &&!visited.contains(colKey)) {
minHeap.add(colNode);
visited.add(colKey);
}
k--;
}
return matrix[minHeap.peek().row][minHeap.peek().col];
}
private String getKey(Node node) {
return""+node.row+"->"+node.col;
}
By leveraging a min-heap (priority queue), we can always extract the smallest available element from the matrix. Starting from the top-left, for each extracted cell, we add its right and down neighbors to the heap if they haven’t been visited. This efficiently finds the kth smallest element without traversing the entire matrix.
intkthSmallest(vector<vector<int>>& matrix, int k) {
int n = matrix.size();
auto cmp = [&](const pair<int, int>& a, const pair<int, int>& b) {
return matrix[a.first][a.second] > matrix[b.first][b.second];
};
priority_queue<pair<int, int>, vector<pair<int, int>>, decltype(cmp)> minHeap(cmp);
set<int> visited;
minHeap.emplace(0, 0);
visited.insert(0);
int result =0;
while (k--) {
auto [row, col] = minHeap.top(); minHeap.pop();
result = matrix[row][col];
if (row +1< n && visited.insert((row +1) * n + col).second)
minHeap.emplace(row +1, col);
if (col +1< n && visited.insert(row * n + (col +1)).second)
minHeap.emplace(row, col +1);
}
return result;
}
funkthSmallest(matrix: Array<IntArray>, k: Int): Int {
val n = matrix.size
val minHeap = PriorityQueue(compareBy<Pair<Int, Int>> { matrix[it.first][it.second] })
val visited = mutableSetOf<Int>()
minHeap.add(0 to 0)
visited.add(0)
var result = 0 repeat(k) {
val(row, col) = minHeap.poll()
result = matrix[row][col]
if (row + 1 < n) {
val key = (row + 1) * n + col
if (key !in visited) {
minHeap.add(row + 1 to col)
visited.add(key)
}
}
if (col + 1 < n) {
val key = row * n + (col + 1)
if (key !in visited) {
minHeap.add(row to col + 1)
visited.add(key)
}
}
}
return result
}
defkthSmallest(matrix, k):
import heapq
n = len(matrix)
minHeap = [(matrix[0][0], 0, 0)]
visited = set([0])
result =0for _ in range(k):
val, row, col = heapq.heappop(minHeap)
result = val
if row +1< n:
key = (row +1) * n + col
if key notin visited:
heapq.heappush(minHeap, (matrix[row +1][col], row +1, col))
visited.add(key)
if col +1< n:
key = row * n + (col +1)
if key notin visited:
heapq.heappush(minHeap, (matrix[row][col +1], row, col +1))
visited.add(key)
return result
use std::collections::{BinaryHeap, HashSet};
use std::cmp::Reverse;
fnkth_smallest(matrix: Vec<Vec<i32>>, k: i32) -> i32 {
let n = matrix.len();
letmut heap = BinaryHeap::new();
letmut visited = HashSet::new();
heap.push(Reverse((matrix[0][0], 0, 0)));
visited.insert(0);
letmut result =0;
for _ in0..k {
let Reverse((val, row, col)) = heap.pop().unwrap();
result = val;
if row +1< n {
let key = (row +1) * n + col;
if visited.insert(key) {
heap.push(Reverse((matrix[row +1][col], row +1, col)));
}
}
if col +1< n {
let key = row * n + (col +1);
if visited.insert(key) {
heap.push(Reverse((matrix[row][col +1], row, col +1)));
}
}
}
result
}
Instead of scanning all elements, we use binary search between the smallest and largest values in the matrix. For each guess, we count how many elements are less than or equal to it by traversing from the bottom-left to the top-right, which takes O(n) time per guess. This approach results in a total time complexity of O(n log m), where m is the range between the minimum and maximum values in the matrix.
To understand the countLessEqual function, note that for each mid value, it effectively partitions each row, marking how many elements are less than or equal to mid:
publicintkthSmallest(int[][] matrix, int k) {
int n = matrix.length;
int lo = matrix[0][0], hi = matrix[n - 1][n - 1];
while (lo <= hi) {
int mid = lo + (hi - lo) / 2;
int count = getLessEqual(matrix, mid);
if (count < k) lo = mid + 1;
else hi = mid - 1;
}
return lo;
}
privateintgetLessEqual(int[][] matrix, int val) {
int res = 0;
int n = matrix.length, i = n - 1, j = 0;
while (i >= 0 && j < n) {
if (matrix[i][j]> val) i--;
else {
res += i + 1;
j++;
}
}
return res;
}