Problem
You are given a 0-indexed integer array nums
. A pair of indices (i, j)
is a bad pair if i < j
and j - i != nums[j] - nums[i]
.
Return the total number of bad pairs in nums
.
Examples
Example 1:
Input: nums = [4,1,3,3]
Output: 5
Explanation: The pair (0, 1) is a bad pair since 1 - 0 != 1 - 4.
The pair (0, 2) is a bad pair since 2 - 0 != 3 - 4, 2 != -1.
The pair (0, 3) is a bad pair since 3 - 0 != 3 - 4, 3 != -1.
The pair (1, 2) is a bad pair since 2 - 1 != 3 - 1, 1 != 2.
The pair (2, 3) is a bad pair since 3 - 2 != 3 - 3, 1 != 0.
There are a total of 5 bad pairs, so we return 5.
Example 2:
Input: nums = [1,2,3,4,5]
Output: 0
Explanation: There are no bad pairs.
Constraints:
1 <= nums.length <= 105
1 <= nums[i] <= 109
Solution
Video explanation
Here is the video explaining below methods in detail. Please check it out:
Method 1 - Naive
The naive approach involves using two nested loops to check every possible pair (i, j)
and determine if it qualifies as a bad pair. Simply iterate through the array, and for each index i
, compare it with every other index j
where j > i
.
Code
Java
class Solution {
public long countBadPairs(int[] nums) {
long badPairs = 0;
int n = nums.length;
for (int i = 0; i < n; i++) {
for (int j = i + 1; j < n; j++) {
if (j - i != nums[j] - nums[i]) {
badPairs++;
}
}
}
return badPairs;
}
}
Python
class Solution:
def countBadPairs(self, nums: List[int]) -> int:
bad_pairs: int = 0
n: int = len(nums)
for i in range(n):
for j in range(i + 1, n):
if j - i != nums[j] - nums[i]:
bad_pairs += 1
return bad_pairs
Complexity
- ⏰ Time complexity:
O(n^2)
due to the nested loops. - 🧺 Space complexity:
O(1)
as we only need a few extra variables for counting.
Method 2 - Using Hashmap + Maths
To solve this problem, we need to count the number of pairs (i, j)
such that i < j
and j - i != nums[j] - nums[i]
. Rearranging the equation, we get:
$$ j - i = nums[j] - nums[i] $$ $$ j - nums[j] = i - nums[i] $$ Let’s call: $$ k_j = j - nums[j] $$ $$ k_i = i - nums[i] $$
We observe that for a pair (i, j)
to not be a “bad pair”:
$$
k_i = k_j
$$
So we need to count pairs where this condition does not hold.
Approach
- Use a hashmap to store the counts of each value of $k_i$ we encounter.
- Traverse the array and for every index
j
, compute $k_j$. Check how many $k_i$ values (from previously seen indices) equal $k_j$ and use this to determine how many pairs(i, j)
are not bad pairs. - The total pairs
(i, j)
fori < j
are $\frac{n(n-1)}{2}$. - Subtract the number of good pairs from the total pairs to get the number of bad pairs.
Code
Java
class Solution {
public long countBadPairs(int[] nums) {
Map<Integer, Integer> map = new HashMap<>();
long goodPairs = 0;
int n = nums.length;
for (int j = 0; j < n; j++) {
int key = j - nums[j];
if (map.containsKey(key)) {
goodPairs += map.get(key);
map.put(key, map.get(key) + 1);
} else {
map.put(key, 1);
}
}
long totalPairs = (long) n * (n - 1) / 2;
return totalPairs - goodPairs;
}
}
Python
class Solution:
def countBadPairs(self, nums: List[int]) -> int:
count_map: defaultdict[int, int] = defaultdict(int)
good_pairs: int = 0
n: int = len(nums)
for j in range(n):
key = j - nums[j]
good_pairs += count_map[key]
count_map[key] += 1
total_pairs: int = n * (n - 1) // 2
return total_pairs - good_pairs
Complexity
- ⏰ Time complexity:
O(n)
since we traverse the array once. - 🧺 Space complexity:
O(n)
due to the hashmap storage.