Count Number of Bad Pairs
Problem
You are given a 0-indexed integer array nums. A pair of indices (i, j) is a bad pair if i < j and j - i != nums[j] - nums[i].
Return the total number of bad pairs in nums.
Examples
Example 1:
Input: nums = [4,1,3,3]
Output: 5
Explanation: The pair (0, 1) is a bad pair since 1 - 0 != 1 - 4.
The pair (0, 2) is a bad pair since 2 - 0 != 3 - 4, 2 != -1.
The pair (0, 3) is a bad pair since 3 - 0 != 3 - 4, 3 != -1.
The pair (1, 2) is a bad pair since 2 - 1 != 3 - 1, 1 != 2.
The pair (2, 3) is a bad pair since 3 - 2 != 3 - 3, 1 != 0.
There are a total of 5 bad pairs, so we return 5.
Example 2:
Input: nums = [1,2,3,4,5]
Output: 0
Explanation: There are no bad pairs.
Constraints:
1 <= nums.length <= 1051 <= nums[i] <= 109
Solution
Video explanation
Here is the video explaining below methods in detail. Please check it out:
<div class="youtube-embed"><iframe src="https://www.youtube.com/embed/dFo37xIW4FU" frameborder="0" allowfullscreen></iframe></div>
Method 1 - Naive
The naive approach involves using two nested loops to check every possible pair (i, j) and determine if it qualifies as a bad pair. Simply iterate through the array, and for each index i, compare it with every other index j where j > i.
Code
Java
class Solution {
public long countBadPairs(int[] nums) {
long badPairs = 0;
int n = nums.length;
for (int i = 0; i < n; i++) {
for (int j = i + 1; j < n; j++) {
if (j - i != nums[j] - nums[i]) {
badPairs++;
}
}
}
return badPairs;
}
}
Python
class Solution:
def countBadPairs(self, nums: List[int]) -> int:
bad_pairs: int = 0
n: int = len(nums)
for i in range(n):
for j in range(i + 1, n):
if j - i != nums[j] - nums[i]:
bad_pairs += 1
return bad_pairs
Complexity
- ⏰ Time complexity:
O(n^2)due to the nested loops. - 🧺 Space complexity:
O(1)as we only need a few extra variables for counting.
Method 2 - Using Hashmap + Maths
To solve this problem, we need to count the number of pairs (i, j) such that i < j and j - i != nums[j] - nums[i]. Rearranging the equation, we get:
Let's call:
We observe that for a pair (i, j) to not be a "bad pair":
So we need to count pairs where this condition does not hold.
Approach
- Use a hashmap to store the counts of each value of we encounter.
- Traverse the array and for every index
j, compute . Check how many values (from previously seen indices) equal and use this to determine how many pairs(i, j)are not bad pairs. - The total pairs
(i, j)fori < jare . - Subtract the number of good pairs from the total pairs to get the number of bad pairs.
Code
Java
class Solution {
public long countBadPairs(int[] nums) {
Map<Integer, Integer> map = new HashMap<>();
long goodPairs = 0;
int n = nums.length;
for (int j = 0; j < n; j++) {
int key = j - nums[j];
if (map.containsKey(key)) {
goodPairs += map.get(key);
map.put(key, map.get(key) + 1);
} else {
map.put(key, 1);
}
}
long totalPairs = (long) n * (n - 1) / 2;
return totalPairs - goodPairs;
}
}
Python
class Solution:
def countBadPairs(self, nums: List[int]) -> int:
count_map: defaultdict[int, int] = defaultdict(int)
good_pairs: int = 0
n: int = len(nums)
for j in range(n):
key = j - nums[j]
good_pairs += count_map[key]
count_map[key] += 1
total_pairs: int = n * (n - 1) // 2
return total_pairs - good_pairs
Complexity
- ⏰ Time complexity:
O(n)since we traverse the array once. - 🧺 Space complexity:
O(n)due to the hashmap storage.