Problem

Given a 0-indexed integer array nums of size n and two integers lower and upper, return the number of fair pairs.

A pair (i, j) is fair if:

  • 0 <= i < j < n, and
  • lower <= nums[i] + nums[j] <= upper

Examples

Example 1:

Input: nums = [0,1,7,4,4,5], lower = 3, upper = 6
Output: 6
Explanation: There are 6 fair pairs: (0,3), (0,4), (0,5), (1,3), (1,4), and (1,5).

Example 2:

Input: nums = [1,7,9,2,5], lower = 11, upper = 11
Output: 1
Explanation: There is a single fair pair: (2,3).

Solution

Feels like binary search because we need to figure out lower bound and upper bound. So, for every nums[i], find out the range as if we can add nums[i] into the range and the sum will be between [lower, upper]. Thus, for every nums[i], find (lower - nums[i]) and (upper - nums[i]).

Here is the approach:

  • Step 1: Sort the Array.
    • Why? Ignore i < j, we care that x + y should lie within the range. Here, the index does not play any role.
  • Step 2: For every element repeat Steps 3, 4, and 5.
  • Step 3: Find the point until we can add nums[i] and get the sum >= lower, say j.
  • Step 4: Find the point until we can add nums[i] and get the sum <= upper, say k.
  • Step 5: We can get the answer from k - j.

Code

Java
class Solution {
    public long countFairPairs(int[] nums, int lower, int upper) {
        Arrays.sort(nums);
        int n = nums.length;
        long ans = 0;

        for (int i = 0; i < n; i++) {
            int low = lowerBound(nums, i + 1, n - 1, lower - nums[i]);
            int high = upperBound(nums, i + 1, n - 1, upper - nums[i]);

            ans += (high - low + 1);
        }

        return ans;
    }

    private int lowerBound(int[] nums, int start, int end, int target) {
        while (start <= end) {
            int mid = start + (end - start) / 2;
            if (nums[mid] >= target) {
                end = mid - 1;
            } else {
                start = mid + 1;
            }
        }
        return start;
    }

    private int upperBound(int[] nums, int start, int end, int target) {
        while (start <= end) {
            int mid = start + (end - start) / 2;
            if (nums[mid] <= target) {
                start = mid + 1;
            } else {
                end = mid - 1;
            }
        }
        return end;
    }
}
Python
class Solution:
    def countFairPairs(self, nums: List[int], lower: int, upper: int) -> int:
        nums.sort()
        n = len(nums)
        ans = 0

        for i in range(n):
            lb = self.lowerBound(nums, i + 1, n - 1, lower - nums[i])
            ub = self.upperBound(nums, i + 1, n - 1, upper - nums[i])

            ans += (ub - lb + 1)

        return ans

    def lowerBound(self, nums: List[int], start: int, end: int, target: int) -> int:
        while start <= end:
            mid = start + (end - start) // 2
            if nums[mid] >= target:
                end = mid - 1
            else:
                start = mid + 1
        return start

    def upperBound(self, nums: List[int], start: int, end: int, target: int) -> int:
        while start <= end:
            mid = start + (end - start) // 2
            if nums[mid] <= target:
                start = mid + 1
            else:
                end = mid - 1
        return end

Complexity

  • ⏰ Time complexity: O(n log n). This is due to sorting the array and using binary search to find valid pairs.
  • 🧺 Space complexity: O(1), as no extra space is used other than a few variables.