Problem

Given an m x n matrix grid containing an odd number of integers where each row is sorted in non-decreasing order, return themedian of the matrix.

You must solve the problem in less than O(m * n) time complexity.

Examples

Example 1:

1
2
3
Input: grid = [[1,1,2],[2,3,3],[1,3,4]]
Output: 2
Explanation: The elements of the matrix in sorted order are 1,1,1,2,_2_ ,3,3,3,4. The median is 2.

Example 2:

1
2
3
Input: grid = [[1,1,3,3,4]]
Output: 3
Explanation: The elements of the matrix in sorted order are 1,1,_3_ ,3,4. The median is 3.

Constraints:

  • m == grid.length
  • n == grid[i].length
  • 1 <= m, n <= 500
  • m and n are both odd.
  • 1 <= grid[i][j] <= 10^6
  • grid[i] is sorted in non-decreasing order.

Solution

Method 1 – Binary Search on Value

Intuition

Since each row is sorted, we can use binary search on the possible value range to find the median. For each candidate value, count how many elements in the matrix are less than or equal to it using binary search in each row. The median is the value for which this count reaches the middle position.

Approach

  1. Set low to the minimum value in the matrix and high to the maximum value.
  2. While low < high, do:
    • Set mid = (low + high) // 2.
    • For each row, count elements <= mid using binary search.
    • If total count < median position, set low = mid + 1; else set high = mid.
  3. Return low as the median.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
class Solution {
public:
    int matrixMedian(vector<vector<int>>& grid) {
        int m = grid.size(), n = grid[0].size();
        int low = grid[0][0], high = grid[0][n-1];
        for (int i = 1; i < m; ++i) {
            low = min(low, grid[i][0]);
            high = max(high, grid[i][n-1]);
        }
        int k = (m * n) / 2;
        while (low < high) {
            int mid = (low + high) / 2, cnt = 0;
            for (auto& row : grid) {
                cnt += upper_bound(row.begin(), row.end(), mid) - row.begin();
            }
            if (cnt <= k) low = mid + 1;
            else high = mid;
        }
        return low;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
func matrixMedian(grid [][]int) int {
    m, n := len(grid), len(grid[0])
    low, high := grid[0][0], grid[0][n-1]
    for i := 1; i < m; i++ {
        if grid[i][0] < low { low = grid[i][0] }
        if grid[i][n-1] > high { high = grid[i][n-1] }
    }
    k := m * n / 2
    for low < high {
        mid := (low + high) / 2
        cnt := 0
        for _, row := range grid {
            l, r := 0, n
            for l < r {
                md := (l + r) / 2
                if row[md] <= mid {
                    l = md + 1
                } else {
                    r = md
                }
            }
            cnt += l
        }
        if cnt <= k {
            low = mid + 1
        } else {
            high = mid
        }
    }
    return low
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Solution {
    public int matrixMedian(int[][] grid) {
        int m = grid.length, n = grid[0].length;
        int low = grid[0][0], high = grid[0][n-1];
        for (int i = 1; i < m; i++) {
            low = Math.min(low, grid[i][0]);
            high = Math.max(high, grid[i][n-1]);
        }
        int k = m * n / 2;
        while (low < high) {
            int mid = (low + high) / 2, cnt = 0;
            for (int[] row : grid) {
                int l = 0, r = n;
                while (l < r) {
                    int md = (l + r) / 2;
                    if (row[md] <= mid) l = md + 1;
                    else r = md;
                }
                cnt += l;
            }
            if (cnt <= k) low = mid + 1;
            else high = mid;
        }
        return low;
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class Solution {
    fun matrixMedian(grid: Array<IntArray>): Int {
        val m = grid.size
        val n = grid[0].size
        var low = grid[0][0]
        var high = grid[0][n-1]
        for (i in 1 until m) {
            low = minOf(low, grid[i][0])
            high = maxOf(high, grid[i][n-1])
        }
        val k = m * n / 2
        while (low < high) {
            val mid = (low + high) / 2
            var cnt = 0
            for (row in grid) {
                var l = 0; var r = n
                while (l < r) {
                    val md = (l + r) / 2
                    if (row[md] <= mid) l = md + 1
                    else r = md
                }
                cnt += l
            }
            if (cnt <= k) low = mid + 1
            else high = mid
        }
        return low
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
def matrix_median(grid: list[list[int]]) -> int:
    m, n = len(grid), len(grid[0])
    low = min(row[0] for row in grid)
    high = max(row[-1] for row in grid)
    k = m * n // 2
    while low < high:
        mid = (low + high) // 2
        cnt = 0
        for row in grid:
            l, r = 0, n
            while l < r:
                md = (l + r) // 2
                if row[md] <= mid:
                    l = md + 1
                else:
                    r = md
            cnt += l
        if cnt <= k:
            low = mid + 1
        else:
            high = mid
    return low
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
impl Solution {
    pub fn matrix_median(grid: Vec<Vec<i32>>) -> i32 {
        let m = grid.len();
        let n = grid[0].len();
        let mut low = grid[0][0];
        let mut high = grid[0][n-1];
        for i in 1..m {
            low = low.min(grid[i][0]);
            high = high.max(grid[i][n-1]);
        }
        let k = m * n / 2;
        while low < high {
            let mid = (low + high) / 2;
            let mut cnt = 0;
            for row in &grid {
                let mut l = 0;
                let mut r = n;
                while l < r {
                    let md = (l + r) / 2;
                    if row[md] <= mid {
                        l = md + 1;
                    } else {
                        r = md;
                    }
                }
                cnt += l;
            }
            if cnt <= k {
                low = mid + 1;
            } else {
                high = mid;
            }
        }
        low
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class Solution {
    matrixMedian(grid: number[][]): number {
        const m = grid.length, n = grid[0].length;
        let low = grid[0][0], high = grid[0][n-1];
        for (let i = 1; i < m; ++i) {
            low = Math.min(low, grid[i][0]);
            high = Math.max(high, grid[i][n-1]);
        }
        const k = Math.floor(m * n / 2);
        while (low < high) {
            const mid = Math.floor((low + high) / 2);
            let cnt = 0;
            for (const row of grid) {
                let l = 0, r = n;
                while (l < r) {
                    const md = Math.floor((l + r) / 2);
                    if (row[md] <= mid) l = md + 1;
                    else r = md;
                }
                cnt += l;
            }
            if (cnt <= k) low = mid + 1;
            else high = mid;
        }
        return low;
    }
}

Complexity

  • ⏰ Time complexity: O(m * log(max-min) * log n), binary search on value range and binary search in each row.
  • 🧺 Space complexity: O(1), only variables for search and counting.