Problem

Let A be an N by M matrix in which every row and every column is sorted.

Given i1, j1, i2, and j2, compute the number of elements of M smaller than M[i1, j1] and larger than M[i2, j2].

Examples

Example 1

$$ \begin{array}{|c|c|c|c|c|c|} \hline 1 & 3 & 7 & 10 & 15 & 20 \\ \hline 2 & \colorbox{orange} 6 & 9 & 14 & 22 & 25 \\ \hline 3 & 8 & 10 & 15 & 25 & 30 \\ \hline 10 & 11 & 12 & \colorbox{green} {23} & 30 & 35 \\ \hline 20 & 25 & 30 & 35 & 40 & 45 \\ \hline \end{array} $$

Input: 
A = [[1, 3, 7, 10, 15, 20],
    [2, 6, 9, 14, 22, 25],
    [3, 8, 10, 15, 25, 30],
    [10, 11, 12, 23, 30, 35],
    [20, 25, 30, 35, 40, 45]]
i1 = 1, j1 = 1, i2 = 3, j2 = 3
Output: 15
Explanation: There are 15 numbers in the matrix smaller than 6 or greater than 23.

Solution

The matrix is sorted in both rows and columns. Given this, we can leverage binary search to count the number of elements smaller than A[i1][j1] and larger than A[i2][j2] efficiently. Binary search allows us to quickly determine the number of elements satisfying the given conditions in each row.

Approach

  1. Count Elements Smaller than A[i1][j1]: For each row, use binary search to find the count of elements smaller than A[i1][j1].
  2. Count Elements Larger than A[i2][j2]: For each row, use binary search to find the count of elements larger than A[i2][j2].
  3. Sum the Results: Sum the counts from both steps to get the total number of elements that are either smaller than A[i1][j1] or larger than A[i2][j2].

Code

Java
public class Solution {
    public int countElements(int[][] matrix, int i1, int j1, int i2, int j2) {
        int N = matrix.length;
        int M = matrix[0].length;
        
        int countLessThan(int target) {
            int count = 0;
            for (int[] row : matrix) {
                int left = 0, right = M - 1;
                while (left <= right) {
                    int mid = left + (right - left) / 2;
                    if (row[mid] < target) {
                        left = mid + 1;
                    } else {
                        right = mid - 1;
                    }
                }
                count += left;
            }
            return count;
        }
        
        int countGreaterThan(int target) {
            int count = 0;
            for (int[] row : matrix) {
                int left = 0, right = M - 1;
                while (left <= right) {
                    int mid = left + (right - left) / 2;
                    if (row[mid] > target) {
                        right = mid - 1;
                    } else {
                        left = mid + 1;
                    }
                }
                count += M - left;
            }
            return count;
        }
        
        int numLessThan = countLessThan(matrix[i1][j1]);
        int numGreaterThan = countGreaterThan(matrix[i2][j2]);

        return numLessThan + numGreaterThan;
    }
    // Example usage:
    public static void main(String[] args) {
        Solution solution = new Solution();
        int[][] matrix = {
            {1, 3, 7, 10, 15, 20},
            {2, 6, 9, 14, 22, 25},
            {3, 8, 10, 15, 25, 30},
            {10, 11, 12, 23, 30, 35},
            {20, 25, 30, 35, 40, 45}
        };
        System.out.println(solution.countElements(matrix, 1, 1, 3, 3));  // Output: 15
    }    
}
Python
class Solution:
    def countElements(self, matrix: List[List[int]], i1: int, j1: int, i2: int, j2: int) -> int:
        N = len(matrix)
        M = len(matrix[0])
        
        def count_less_than(target):
            count = 0
            for row in matrix:
                left, right = 0, M - 1
                while left <= right:
                    mid = (left + right) // 2
                    if row[mid] < target:
                        left = mid + 1
                    else:
                        right = mid - 1
                count += left
            return count
        
        def count_greater_than(target):
            count = 0
            for row in matrix:
                left, right = 0, M - 1
                while left <= right:
                    mid = (left + right) // 2
                    if row[mid] > target:
                        right = mid - 1
                    else:
                        left = mid + 1
                count += M - left
            return count
        
        num_less_than = count_less_than(matrix[i1][j1])
        num_greater_than = count_greater_than(matrix[i2][j2])
        
        return num_less_than + num_greater_than

# Example usage:
solution = Solution()
matrix = [[1, 3, 7, 10, 15, 20],
        [2, 6, 9, 14, 22, 25],
        [3, 8, 10, 15, 25, 30],
        [10, 11, 12, 23, 30, 35],
        [20, 25, 30, 35, 40, 45]]
print(solution.countElements(matrix, 1, 1, 3, 3))  # Output: 15

Complexity

  • ⏰ Time complexity: O(N * log M), where N is the number of rows and M is the number of columns. This is because binary search is applied to each row.
  • 🧺 Space complexity: O(1), as we only use a constant amount of extra space.