Problem

Given a 2D array where each row is sorted in ascending order, find the median of the entire matrix.

Examples

Example 1

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
Input:
A = [
  [2, 4, 5, 6],
  [1, 2, 2, 4],
  [3, 4, 4, 5],
  [1, 2, 3, 3]
]
Output:
3
Explanation:
Merged array: [1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 6]. Median is (3+3)/2 = 3.

Solution

Method 1 – Brute Force Merge

Intuition

The simplest way is to flatten the matrix into a single array, sort it, and pick the median. This is easy to implement but not optimal for large matrices.

Approach

  1. Traverse all rows and columns, collect all elements into a new array.
  2. Sort the array.
  3. Return the middle element (or average of two middle elements if total count is even).

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
class Solution {
public:
    int matrixMedian(vector<vector<int>>& mat) {
        vector<int> arr;
        for (auto& row : mat) arr.insert(arr.end(), row.begin(), row.end());
        sort(arr.begin(), arr.end());
        int n = arr.size();
        if (n % 2 == 1) return arr[n/2];
        return (arr[n/2 - 1] + arr[n/2]) / 2;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
func matrixMedian(mat [][]int) int {
    arr := []int{}
    for _, row := range mat {
        arr = append(arr, row...)
    }
    sort.Ints(arr)
    n := len(arr)
    if n%2 == 1 {
        return arr[n/2]
    }
    return (arr[n/2-1] + arr[n/2]) / 2
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
class Solution {
    public int matrixMedian(int[][] mat) {
        List<Integer> arr = new ArrayList<>();
        for (int[] row : mat) {
            for (int v : row) arr.add(v);
        }
        Collections.sort(arr);
        int n = arr.size();
        if (n % 2 == 1) return arr.get(n/2);
        return (arr.get(n/2-1) + arr.get(n/2)) / 2;
    }
}
1
2
3
4
5
6
7
class Solution {
    fun matrixMedian(mat: List<List<Int>>): Int {
        val arr = mat.flatten().sorted()
        val n = arr.size
        return if (n % 2 == 1) arr[n/2] else (arr[n/2-1] + arr[n/2]) / 2
    }
}
1
2
3
4
5
6
7
def matrix_median(mat: list[list[int]]) -> int:
    arr = [v for row in mat for v in row]
    arr.sort()
    n = len(arr)
    if n % 2 == 1:
        return arr[n//2]
    return (arr[n//2-1] + arr[n//2]) // 2
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
impl Solution {
    pub fn matrix_median(mat: Vec<Vec<i32>>) -> i32 {
        let mut arr: Vec<i32> = mat.into_iter().flatten().collect();
        arr.sort();
        let n = arr.len();
        if n % 2 == 1 {
            arr[n/2]
        } else {
            (arr[n/2-1] + arr[n/2]) / 2
        }
    }
}
1
2
3
4
5
6
7
8
class Solution {
    matrixMedian(mat: number[][]): number {
        const arr = mat.flat().sort((a, b) => a - b);
        const n = arr.length;
        if (n % 2 === 1) return arr[Math.floor(n/2)];
        return Math.floor((arr[n/2-1] + arr[n/2]) / 2);
    }
}

Complexity

  • ⏰ Time complexity: O(mn log(mn)), sorting all elements.
  • 🧺 Space complexity: O(mn), storing all elements.

Method 2 – Binary Search on Value

Intuition

Since each row is sorted, we can use binary search on the value range to efficiently count how many elements are less than or equal to a candidate value, and thus find the median without flattening the matrix.

Approach

  1. Find the minimum and maximum values in the matrix (first and last elements of each row).
  2. Use binary search between min and max:
    • For each mid value, count how many elements in the matrix are ≤ mid (use upper_bound/bisect in each row).
    • If count < desired, move min up; else move max down.
  3. When min == max, that’s the median.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
class Solution {
public:
    int matrixMedian(vector<vector<int>>& mat) {
        int m = mat.size(), n = mat[0].size();
        int low = mat[0][0], high = mat[0][n-1];
        for (int i = 1; i < m; ++i) {
            low = min(low, mat[i][0]);
            high = max(high, mat[i][n-1]);
        }
        int k = (m * n + 1) / 2;
        while (low < high) {
            int mid = low + (high - low) / 2, cnt = 0;
            for (auto& row : mat) {
                cnt += upper_bound(row.begin(), row.end(), mid) - row.begin();
            }
            if (cnt < k) low = mid + 1;
            else high = mid;
        }
        return low;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
func matrixMedian(mat [][]int) int {
    m, n := len(mat), len(mat[0])
    low, high := mat[0][0], mat[0][n-1]
    for i := 1; i < m; i++ {
        if mat[i][0] < low { low = mat[i][0] }
        if mat[i][n-1] > high { high = mat[i][n-1] }
    }
    k := (m*n + 1) / 2
    for low < high {
        mid, cnt := (low+high)/2, 0
        for _, row := range mat {
            cnt += sort.SearchInts(row, mid+1)
        }
        if cnt < k {
            low = mid + 1
        } else {
            high = mid
        }
    }
    return low
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
class Solution {
    public int matrixMedian(int[][] mat) {
        int m = mat.length, n = mat[0].length;
        int low = mat[0][0], high = mat[0][n-1];
        for (int i = 1; i < m; ++i) {
            low = Math.min(low, mat[i][0]);
            high = Math.max(high, mat[i][n-1]);
        }
        int k = (m * n + 1) / 2;
        while (low < high) {
            int mid = low + (high - low) / 2, cnt = 0;
            for (int[] row : mat) {
                cnt += Arrays.binarySearch(row, mid+1);
                if (cnt < 0) cnt = -cnt - 1;
            }
            if (cnt < k) low = mid + 1;
            else high = mid;
        }
        return low;
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
class Solution {
    fun matrixMedian(mat: List<List<Int>>): Int {
        val m = mat.size; val n = mat[0].size
        var low = mat[0][0]; var high = mat[0][n-1]
        for (i in 1 until m) {
            low = minOf(low, mat[i][0])
            high = maxOf(high, mat[i][n-1])
        }
        val k = (m * n + 1) / 2
        while (low < high) {
            val mid = (low + high) / 2
            var cnt = 0
            for (row in mat) {
                cnt += row.count { it <= mid }
            }
            if (cnt < k) low = mid + 1 else high = mid
        }
        return low
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
def matrix_median(mat: list[list[int]]) -> int:
    m, n = len(mat), len(mat[0])
    low, high = min(row[0] for row in mat), max(row[-1] for row in mat)
    k = (m * n + 1) // 2
    while low < high:
        mid = (low + high) // 2
        cnt = sum(bisect_right(row, mid) for row in mat)
        if cnt < k:
            low = mid + 1
        else:
            high = mid
    return low
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
impl Solution {
    pub fn matrix_median(mat: Vec<Vec<i32>>) -> i32 {
        let m = mat.len(); let n = mat[0].len();
        let mut low = mat[0][0]; let mut high = mat[0][n-1];
        for i in 1..m {
            low = low.min(mat[i][0]);
            high = high.max(mat[i][n-1]);
        }
        let k = (m * n + 1) / 2;
        while low < high {
            let mid = (low + high) / 2;
            let mut cnt = 0;
            for row in &mat {
                cnt += row.iter().filter(|&&x| x <= mid).count();
            }
            if cnt < k {
                low = mid + 1;
            } else {
                high = mid;
            }
        }
        low
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
class Solution {
    matrixMedian(mat: number[][]): number {
        let m = mat.length, n = mat[0].length;
        let low = mat[0][0], high = mat[0][n-1];
        for (let i = 1; i < m; ++i) {
            low = Math.min(low, mat[i][0]);
            high = Math.max(high, mat[i][n-1]);
        }
        let k = Math.floor((m * n + 1) / 2);
        while (low < high) {
            let mid = Math.floor((low + high) / 2), cnt = 0;
            for (const row of mat) {
                cnt += row.filter(x => x <= mid).length;
            }
            if (cnt < k) low = mid + 1;
            else high = mid;
        }
        return low;
    }
}

Complexity

  • ⏰ Time complexity: O(m log(max-min) * n), binary search over value range, counting in each row.
  • 🧺 Space complexity: O(1), only counters and variables used.