Kth Smallest Element in a Sorted Matrix
Problem
Given an n x n matrix where each of the rows and columns is sorted in ascending order, return the kth smallest element in the matrix.
Note that it is the kth smallest element in the sorted order, not the kth distinct element.
You must find a solution with a memory complexity better than O(n^2).
Examples
Example 1:
Input: matrix = [
[1,5,9],
[10,11,13],
[12,13,15]
], k = 8
Output: 13
Explanation: The elements in the matrix are [1,5,9,10,11,12,13,13,15], and the 8th smallest number is 13
Example 2:
Input: matrix = [[-5]], k = 1
Output: -5
Solution
The example 1 given is not good enough. Lets look at more examples.
Let’s take some examples to understand better the problem.
# example 1 : any number from row i+1 is higher than anyone from row i. But we can't establish any order between columns
1 2 3
4 5 6
7 8 9
# example 2: any number from col i+1 is higher than anyone from col i. But we can't establish any order between rows
1 4 7
2 5 8
3 6 9
# example 3: We can't establish any order between columns or rows
1 2 5
3 4 8
6 7 9
Now that we understand the problem better, lets try to solve it.
Method 1 – Naive Approach Using Auxiliary Array
Intuition
The simplest way to find the kth smallest element is to flatten the matrix into a single array, sort it, and pick the kth element. This approach does not leverage the sorted properties of the matrix rows and columns, but it is straightforward and easy to implement.
Approach
- Traverse the matrix and copy all elements into a new array.
- Sort the array in ascending order.
- Return the kth element (1-based index).
- Edge cases: If k is out of bounds, handle appropriately.
Notes
In this approach:
- we haven’t used the first property of the problem : rows are sorted
- we haven’t used the second property of the problem : columns are sorted
Code
C++
class Solution {
public:
int kthSmallest(vector<vector<int>>& mat, int k) {
int n = mat.size();
vector<int> arr;
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
arr.push_back(mat[i][j]);
}
}
sort(arr.begin(), arr.end());
return arr[k - 1];
}
};
Go
func kthSmallest(mat [][]int, k int) int {
n := len(mat)
arr := []int{}
for i := 0; i < n; i++ {
for j := 0; j < n; j++ {
arr = append(arr, mat[i][j])
}
}
sort.Ints(arr)
return arr[k-1]
}
Java
class Solution {
public int kthSmallest(int[][] mat, int k) {
int n = mat.length;
int[] arr = new int[n * n];
int idx = 0;
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
arr[idx++] = mat[i][j];
}
}
Arrays.sort(arr);
return arr[k - 1];
}
}
Kotlin
class Solution {
fun kthSmallest(mat: List<List<Int>>, k: Int): Int {
val arr = mat.flatten().sorted()
return arr[k - 1]
}
}
Python
def kth_smallest(mat: list[list[int]], k: int) -> int:
arr = [v for row in mat for v in row]
arr.sort()
return arr[k-1]
Rust
impl Solution {
pub fn kth_smallest(mat: Vec<Vec<i32>>, k: i32) -> i32 {
let arr: Vec<i32> = mat.into_iter().flatten().collect();
let mut arr = arr;
arr.sort();
arr[(k-1) as usize]
}
}
TypeScript
class Solution {
kthSmallest(mat: number[][], k: number): number {
const arr = mat.flat().sort((a, b) => a - b);
return arr[k - 1];
}
}
Complexity
- ⏰ Time complexity:
O(n^2 log n^2), because we copy all n^2 elements and sort them. - 🧺 Space complexity:
O(n^2), storing all elements in an auxiliary array.
Method 2 - Using Merging of Two Rows
Linear Merge
We know that merging two sorted arrays can be done in O(N) time.
A straightforward approach is to merge the first row with the second, then merge this result with the third row, and continue this process for all rows.
flowchart TD A1[Row 1] --> M12[Merged 1 & 2] A2[Row 2] --> M12 M12 --> M123[Merged 1 & 2 & 3] A3[Row 3] --> M123 M123 --> M1234[Merged 1 & 2 & 3 & 4] A4[Row 4] --> M1234
In the worst case, merging two arrays of sizes N and M requires N+M comparisons. The total time complexity for merging all rows is O(2N + 3N + 4N + ... + (N-1)N) = O(N³).
This approach is not efficient for large matrices.
Merge using Divide & Conquer
A divide and conquer strategy can be used to optimize merging: first, merge row 1 with row 2 and row 3 with row 4 separately, then merge these two results together, and repeat this process in levels until all rows are combined into one array.
flowchart TD A1[Row 1] --> M12[Merged 1 & 2] A2[Row 2] --> M12 A3[Row 3] --> M34[Merged 3 & 4] A4[Row 4] --> M34 M12 --> M1234[Merged 1 & 2 & 3 & 4] M34 --> M1234
Each level of merging requires O(N²) time, and with O(logN) levels, the overall time complexity is O(N²logN), which does not improve upon the previous approach.
Code
C++
class Solution {
public:
int kthSmallest(vector<vector<int>>& mat, int k) {
int n = mat.size();
vector<int> merged = mat[0];
for (int i = 1; i < n; ++i) {
vector<int> temp;
int p1 = 0, p2 = 0;
while (p1 < merged.size() && p2 < mat[i].size()) {
if (merged[p1] < mat[i][p2]) temp.push_back(merged[p1++]);
else temp.push_back(mat[i][p2++]);
}
while (p1 < merged.size()) temp.push_back(merged[p1++]);
while (p2 < mat[i].size()) temp.push_back(mat[i][p2++]);
merged = temp;
}
return merged[k-1];
}
};
Go
func kthSmallest(mat [][]int, k int) int {
n := len(mat)
merged := make([]int, len(mat[0]))
copy(merged, mat[0])
for i := 1; i < n; i++ {
temp := []int{}
p1, p2 := 0, 0
for p1 < len(merged) && p2 < len(mat[i]) {
if merged[p1] < mat[i][p2] {
temp = append(temp, merged[p1])
p1++
} else {
temp = append(temp, mat[i][p2])
p2++
}
}
for p1 < len(merged) {
temp = append(temp, merged[p1])
p1++
}
for p2 < len(mat[i]) {
temp = append(temp, mat[i][p2])
p2++
}
merged = temp
}
return merged[k-1]
}
Java
class Solution {
public int kthSmallest(int[][] mat, int k) {
int n = mat.length;
int[] merged = Arrays.copyOf(mat[0], n);
for (int i = 1; i < n; i++) {
int[] temp = new int[n * (i+1)];
int p1 = 0, p2 = 0, idx = 0;
while (p1 < merged.length && p2 < n) {
if (merged[p1] < mat[i][p2]) temp[idx++] = merged[p1++];
else temp[idx++] = mat[i][p2++];
}
while (p1 < merged.length) temp[idx++] = merged[p1++];
while (p2 < n) temp[idx++] = mat[i][p2++];
merged = Arrays.copyOf(temp, idx);
}
return merged[k-1];
}
}
Kotlin
class Solution {
fun kthSmallest(mat: List<List<Int>>, k: Int): Int {
var merged = mat[0]
for (i in 1 until mat.size) {
val temp = mutableListOf<Int>()
var p1 = 0; var p2 = 0
while (p1 < merged.size && p2 < mat[i].size) {
if (merged[p1] < mat[i][p2]) temp.add(merged[p1++])
else temp.add(mat[i][p2++])
}
while (p1 < merged.size) temp.add(merged[p1++])
while (p2 < mat[i].size) temp.add(mat[i][p2++])
merged = temp
}
return merged[k-1]
}
}
Python
def kth_smallest(mat: list[list[int]], k: int) -> int:
merged = mat[0][:]
for i in range(1, len(mat)):
temp = []
p1, p2 = 0, 0
while p1 < len(merged) and p2 < len(mat[i]):
if merged[p1] < mat[i][p2]:
temp.append(merged[p1])
p1 += 1
else:
temp.append(mat[i][p2])
p2 += 1
temp.extend(merged[p1:])
temp.extend(mat[i][p2:])
merged = temp
return merged[k-1]
Rust
impl Solution {
pub fn kth_smallest(mat: Vec<Vec<i32>>, k: i32) -> i32 {
let mut merged = mat[0].clone();
for i in 1..mat.len() {
let mut temp = Vec::with_capacity(merged.len() + mat[i].len());
let (mut p1, mut p2) = (0, 0);
while p1 < merged.len() && p2 < mat[i].len() {
if merged[p1] < mat[i][p2] {
temp.push(merged[p1]);
p1 += 1;
} else {
temp.push(mat[i][p2]);
p2 += 1;
}
}
temp.extend_from_slice(&merged[p1..]);
temp.extend_from_slice(&mat[i][p2..]);
merged = temp;
}
merged[(k-1) as usize]
}
}
TypeScript
class Solution {
kthSmallest(mat: number[][], k: number): number {
let merged = mat[0].slice();
for (let i = 1; i < mat.length; i++) {
let temp: number[] = [];
let p1 = 0, p2 = 0;
while (p1 < merged.length && p2 < mat[i].length) {
if (merged[p1] < mat[i][p2]) temp.push(merged[p1++]);
else temp.push(mat[i][p2++]);
}
while (p1 < merged.length) temp.push(merged[p1++]);
while (p2 < mat[i].length) temp.push(mat[i][p2++]);
merged = temp;
}
return merged[k-1];
}
}
Complexity
- ⏰ Time complexity:
O(N^3), because merging each row with the previous result takes up to O(N^2) and there are N rows. - 🧺 Space complexity:
O(N^2), as we store the merged array at each step.
Method 3 - Using Priority Queue with matrix node
Let's pause and visualize the problem, as drawing it often reveals new insights and is helpful when stuck.
Consider the arrows that show the relationships in the matrix. Here, orange arrows indicate order within each row, and green arrows show order within each column.

This structure resembles a directed graph (on the right). If we rotate the matrix 45 degrees, we can see how each cell connects to its children.
In this graph, every cell is smaller than its children. The smallest element is always matrix[0][0]. The next smallest could be either matrix[0][1] or matrix[1][0].
If matrix[0][1] is chosen as the second smallest, then matrix[1][0] and the children of matrix[0][1] (which are matrix[0][2] and matrix[1][1]) become candidates for the third smallest. Note that matrix[1][1] is also a child of matrix[1][0], so it can't be considered before its ancestor.
Thus, the true candidates for the third smallest are matrix[1][0] and matrix[0][2].
To efficiently find the next smallest, we need a data structure that can always give us the minimum among current candidates, remove it, and add its children. A priority queue is ideal for this purpose.
While we could avoid adding a child if its ancestor is still present, simply adding both children is sufficient for correctness. However, we must ensure that we do not revisit the same cell, as duplicates can occur when adding nodes this way.
Code
static class Node { // Inner class for combining row and column into single object
int row, col;
Node(int r, int c) {
row = r;
col = c;
}
}
public int kthSmallest(int[][] matrix, int k) {
Set<String> visited = new HashSet<>();
Queue<Node> minHeap = new PriorityQueue<>((a, b) -> matrix[a.row][a.col] - matrix[b.row][b.col]);
// add the first node
Node source = new Node(0,0);
minHeap.add(new Node(0, 0));
visited.add(getKey(source));
int m = matrix.length;
int n = matrix[0].length;//m and n are same, but just to make generic
while (k > 1) {
Node node = minHeap.poll();
Node rowNode = new Node(node.row+1, node.col);
String rowKey = getKey(rowNode);
if (node.row + 1< m && !visited.contains(rowKey)) {
minHeap.add(rowNode);
visited.add(rowKey);
}
Node colNode = new Node(node.row, node.col+1);
String colKey = getKey(colNode);
if (node.col + 1< n && !visited.contains(colKey)) {
minHeap.add(colNode);
visited.add(colKey);
}
k--;
}
return matrix[minHeap.peek().row][minHeap.peek().col];
}
private String getKey(Node node) {
return ""+node.row+"->"+node.col;
}
Complexity
- ⏰ Time complexity:
O(k log k)- Each insertion and removal from the min-heap takes O(log k) time.
- Up to k elements are processed.
- 🧺 Space complexity:
O(k)- The min-heap and visited set can grow up to size k in the worst case.
Method 4 - Using Priority Queue and Starting with First Row 🏆
Rather than using the tree-like graph from method 3, we can simply start by adding the first row's elements to a min-heap.
Steps:
- Initialize a min-heap with all elements from the first row.
- Repeat k-1 times: each time you remove the smallest element, track its row and column, and insert the next element from the same column into the heap.
Further thoughts:
- You could also use the first column to build the min-heap and, for each extraction, insert the next element from the same row.
- This approach is equivalent to the solution for [Find K Pairs with Smallest Sums](find-k-pairs-with-smallest-sums).
Code
Java
public int kthSmallest(int[][] matrix, int k) {
class Node { // Inner class for combining row and column into single object
int row, col;
Node(int r, int c) {
row = r;
col = c;
}
}
Queue<Node> minHeap = new PriorityQueue<>((a, b) -> matrix[a.row][a.col] - matrix[b.row][b.col]);
int n = matrix[0].length;
// add the first row
for (int i = 0; i<n && i<k; i++) {
minHeap.add(new Node(0, i));
}
while (k > 1) {
Node node = minHeap.poll();
if (node.row + 1<matrix.length) {
minHeap.add(new Node(node.row + 1, node.col));
}
k--;
}
return matrix[minHeap.peek().row][minHeap.peek().col];
}
Complexity
- ⏰ Time complexity:
O(k log k) - 🧺 Space complexity:
O(k)
Method 5 - Using Priority Queue without using classes
Intuition
By leveraging a min-heap (priority queue), we can always extract the smallest available element from the matrix. Starting from the top-left, for each extracted cell, we add its right and down neighbors to the heap if they haven't been visited. This efficiently finds the kth smallest element without traversing the entire matrix.
Approach
- Initialize a min-heap with the top-left element.
- Use a set to track visited cells and avoid duplicates.
- For each extraction, add the right and down neighbors to the heap if not visited.
- Repeat until the kth smallest element is extracted.
Code
C++
int kthSmallest(vector<vector<int>>& matrix, int k) {
int n = matrix.size();
auto cmp = [&](const pair<int, int>& a, const pair<int, int>& b) {
return matrix[a.first][a.second] > matrix[b.first][b.second];
};
priority_queue<pair<int, int>, vector<pair<int, int>>, decltype(cmp)> minHeap(cmp);
set<int> visited;
minHeap.emplace(0, 0);
visited.insert(0);
int result = 0;
while (k--) {
auto [row, col] = minHeap.top(); minHeap.pop();
result = matrix[row][col];
if (row + 1 < n && visited.insert((row + 1) * n + col).second)
minHeap.emplace(row + 1, col);
if (col + 1 < n && visited.insert(row * n + (col + 1)).second)
minHeap.emplace(row, col + 1);
}
return result;
}
Go
func kthSmallest(matrix [][]int, k int) int {
n := len(matrix)
type cell struct{ row, col int }
minHeap := &Heap{}
heap.Init(minHeap)
visited := map[int]bool{}
heap.Push(minHeap, cell{0, 0})
visited[0] = true
var result int
for k > 0 {
curr := heap.Pop(minHeap).(cell)
result = matrix[curr.row][curr.col]
if curr.row+1 < n {
key := (curr.row+1)*n + curr.col
if !visited[key] {
heap.Push(minHeap, cell{curr.row + 1, curr.col})
visited[key] = true
}
}
if curr.col+1 < n {
key := curr.row*n + (curr.col+1)
if !visited[key] {
heap.Push(minHeap, cell{curr.row, curr.col + 1})
visited[key] = true
}
}
k--
}
return result
}
// You need to implement the Heap type with heap.Interface
Java
public int kthSmallest(int[][] matrix, int k) {
int n = matrix.length;
PriorityQueue<int[]> minHeap = new PriorityQueue<>(Comparator.comparingInt(a -> matrix[a[0]][a[1]]));
Set<Integer> visited = new HashSet<>();
minHeap.offer(new int[]{0, 0});
visited.add(0);
int result = 0;
while (k-- > 0) {
int[] curr = minHeap.poll();
int row = curr[0], col = curr[1];
result = matrix[row][col];
if (row + 1 < n) {
int key = (row + 1) * n + col;
if (!visited.contains(key)) {
minHeap.offer(new int[]{row + 1, col});
visited.add(key);
}
}
if (col + 1 < n) {
int key = row * n + (col + 1);
if (!visited.contains(key)) {
minHeap.offer(new int[]{row, col + 1});
visited.add(key);
}
}
}
return result;
}
Kotlin
fun kthSmallest(matrix: Array<IntArray>, k: Int): Int {
val n = matrix.size
val minHeap = PriorityQueue(compareBy<Pair<Int, Int>> { matrix[it.first][it.second] })
val visited = mutableSetOf<Int>()
minHeap.add(0 to 0)
visited.add(0)
var result = 0
repeat(k) {
val (row, col) = minHeap.poll()
result = matrix[row][col]
if (row + 1 < n) {
val key = (row + 1) * n + col
if (key !in visited) {
minHeap.add(row + 1 to col)
visited.add(key)
}
}
if (col + 1 < n) {
val key = row * n + (col + 1)
if (key !in visited) {
minHeap.add(row to col + 1)
visited.add(key)
}
}
}
return result
}
Python
def kthSmallest(matrix, k):
import heapq
n = len(matrix)
minHeap = [(matrix[0][0], 0, 0)]
visited = set([0])
result = 0
for _ in range(k):
val, row, col = heapq.heappop(minHeap)
result = val
if row + 1 < n:
key = (row + 1) * n + col
if key not in visited:
heapq.heappush(minHeap, (matrix[row + 1][col], row + 1, col))
visited.add(key)
if col + 1 < n:
key = row * n + (col + 1)
if key not in visited:
heapq.heappush(minHeap, (matrix[row][col + 1], row, col + 1))
visited.add(key)
return result
Rust
use std::collections::{BinaryHeap, HashSet};
use std::cmp::Reverse;
fn kth_smallest(matrix: Vec<Vec<i32>>, k: i32) -> i32 {
let n = matrix.len();
let mut heap = BinaryHeap::new();
let mut visited = HashSet::new();
heap.push(Reverse((matrix[0][0], 0, 0)));
visited.insert(0);
let mut result = 0;
for _ in 0..k {
let Reverse((val, row, col)) = heap.pop().unwrap();
result = val;
if row + 1 < n {
let key = (row + 1) * n + col;
if visited.insert(key) {
heap.push(Reverse((matrix[row + 1][col], row + 1, col)));
}
}
if col + 1 < n {
let key = row * n + (col + 1);
if visited.insert(key) {
heap.push(Reverse((matrix[row][col + 1], row, col + 1)));
}
}
}
result
}
TypeScript
function kthSmallest(matrix: number[][], k: number): number {
const n = matrix.length;
const minHeap: [number, number, number][] = [[matrix[0][0], 0, 0]];
const visited = new Set<number>([0]);
let result = 0;
for (let i = 0; i < k; i++) {
minHeap.sort((a, b) => a[0] - b[0]);
const [val, row, col] = minHeap.shift()!;
result = val;
if (row + 1 < n) {
const key = (row + 1) * n + col;
if (!visited.has(key)) {
minHeap.push([matrix[row + 1][col], row + 1, col]);
visited.add(key);
}
}
if (col + 1 < n) {
const key = row * n + (col + 1);
if (!visited.has(key)) {
minHeap.push([matrix[row][col + 1], row, col + 1]);
visited.add(key);
}
}
}
return result;
}
Complexity
- ⏰ Time complexity:
O(k log k), because each heap operation (insert and pop) takes O(log k) time and we process up to k elements. - 🧺 Space complexity:
O(k), since the heap and visited set can grow up to k elements in the worst case.
Method 6 - Binary Search
Instead of scanning all elements, we use binary search between the smallest and largest values in the matrix. For each guess, we count how many elements are less than or equal to it by traversing from the bottom-left to the top-right, which takes O(n) time per guess. This approach results in a total time complexity of O(n log m), where m is the range between the minimum and maximum values in the matrix.
To understand the countLessEqual function, note that for each mid value, it effectively partitions each row, marking how many elements are less than or equal to mid:
* * * * | *
* * * | * *
* * * | * *
* * | * * *
* | * * * *
* * * * *
Code
Java
public int kthSmallest(int[][] matrix, int k) {
int n = matrix.length;
int lo = matrix[0][0], hi = matrix[n - 1][n - 1];
while (lo <= hi) {
int mid = lo + (hi - lo) / 2;
int count = getLessEqual(matrix, mid);
if (count < k) lo = mid + 1;
else hi = mid - 1;
}
return lo;
}
private int getLessEqual(int[][] matrix, int val) {
int res = 0;
int n = matrix.length, i = n - 1, j = 0;
while (i >= 0 && j < n) {
if (matrix[i][j] > val) i--;
else {
res += i + 1;
j++;
}
}
return res;
}
Complexity
- ⏰ Time complexity:
O(n log m), because each binary search step takes O(n) time and we perform O(log m) steps, where m is the value range. - 🧺 Space complexity:
O(1), since only a constant amount of extra space is used for counters and variables.