Problem

You are given a 0-indexed positive integer array nums and a positive integer k.

A pair of numbers (num1, num2) is called excellent if the following conditions are satisfied:

  • Both the numbers num1 and num2 exist in the array nums.
  • The sum of the number of set bits in num1 OR num2 and num1 AND num2 is greater than or equal to k, where OR is the bitwise OR operation and AND is the bitwise AND operation.

Return the number of distinct excellent pairs.

Two pairs (a, b) and (c, d) are considered distinct if either a != c or b != d. For example, (1, 2) and (2, 1) are distinct.

Note that a pair (num1, num2) such that num1 == num2 can also be excellent if you have at least one occurrence of num1 in the array.

Examples

Example 1:

1
2
3
4
5
6
7
8
9
Input:
nums = [1,2,3,1], k = 3
Output:
 5
Explanation: The excellent pairs are the following:
- (3, 3). (3 AND 3) and (3 OR 3) are both equal to (11) in binary. The total number of set bits is 2 + 2 = 4, which is greater than or equal to k = 3.
- (2, 3) and (3, 2). (2 AND 3) is equal to (10) in binary, and (2 OR 3) is equal to (11) in binary. The total number of set bits is 1 + 2 = 3.
- (1, 3) and (3, 1). (1 AND 3) is equal to (01) in binary, and (1 OR 3) is equal to (11) in binary. The total number of set bits is 1 + 2 = 3.
So the number of excellent pairs is 5.

Example 2:

1
2
3
4
5
Input:
nums = [5,1,1], k = 10
Output:
 0
Explanation: There are no excellent pairs for this array.

Solution

Intuition

The key idea is to reduce the problem to counting set bits. For any pair (num1, num2), the sum of set bits in num1 | num2 and num1 & num2 equals the sum of set bits in num1 and num2. This is because every bit in num1 or num2 is counted once in the OR and once in the AND, so their total is just the sum of set bits in both numbers. We can count the set bits for each unique number, sort them, and use binary search to efficiently count pairs whose sum is at least k.

Approach

  1. Remove duplicates from nums to get unique numbers.
  2. For each unique number, count its set bits and store in a list.
  3. Sort the list of set bit counts.
  4. For each count, use binary search to find how many counts in the list satisfy count + other_count >= k.
  5. Sum up the results for all counts to get the total number of excellent pairs.

Complexity

  • ⏰ Time complexity: O(n log n), where n is the number of unique numbers, due to sorting and binary search for each count.
  • 🧺 Space complexity: O(n), for storing unique numbers and their bit counts.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Solution {
public:
    int countBits(int x) {
        int cnt = 0;
        while (x) {
            cnt += x & 1;
            x >>= 1;
        }
        return cnt;
    }
    long long excellentPairs(vector<int>& nums, int k) {
        unordered_set<int> s(nums.begin(), nums.end());
        vector<int> bits;
        for (int x : s) bits.push_back(countBits(x));
        sort(bits.begin(), bits.end());
        long long ans = 0;
        int n = bits.size();
        for (int i = 0; i < n; ++i) {
            int need = k - bits[i];
            int idx = lower_bound(bits.begin(), bits.end(), need) - bits.begin();
            ans += n - idx;
        }
        return ans;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
class Solution {
    public long excellentPairs(int[] nums, int k) {
        Set<Integer> set = new HashSet<>();
        for (int num : nums) set.add(num);
        int[] bits = new int[set.size()];
        int idx = 0;
        for (int num : set) bits[idx++] = Integer.bitCount(num);
        Arrays.sort(bits);
        long ans = 0;
        int n = bits.length;
        for (int i = 0; i < n; i++) {
            int need = k - bits[i];
            int left = Arrays.binarySearch(bits, need);
            if (left < 0) left = -left - 1;
            ans += n - left;
        }
        return ans;
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
class Solution:
    def excellentPairs(self, nums: list[int], k: int) -> int:
        s = set(nums)
        bits = [bin(x).count('1') for x in s]
        bits.sort()
        n = len(bits)
        ans = 0
        import bisect
        for b in bits:
            idx = bisect.bisect_left(bits, k - b)
            ans += n - idx
        return ans