Problem#
Given a (0-indexed) integer array nums
and two integers low
and
high
, return the number ofnice pairs .
A nice pair is a pair (i, j)
where 0 <= i < j < nums.length
and low <= (nums[i] XOR nums[j]) <= high
.
Examples#
Example 1#
1
2
3
4
5
6
7
8
9
Input: nums = [ 1 , 4 , 2 , 7 ], low = 2 , high = 6
Output: 6
Explanation: All nice pairs ( i, j) are as follows:
- ( 0 , 1 ): nums[ 0 ] XOR nums[ 1 ] = 5
- ( 0 , 2 ): nums[ 0 ] XOR nums[ 2 ] = 3
- ( 0 , 3 ): nums[ 0 ] XOR nums[ 3 ] = 6
- ( 1 , 2 ): nums[ 1 ] XOR nums[ 2 ] = 6
- ( 1 , 3 ): nums[ 1 ] XOR nums[ 3 ] = 3
- ( 2 , 3 ): nums[ 2 ] XOR nums[ 3 ] = 5
Example 2#
1
2
3
4
5
6
7
8
9
10
11
Input: nums = [ 9 , 8 , 4 , 2 , 1 ], low = 5 , high = 14
Output: 8
Explanation: All nice pairs ( i, j) are as follows:
- ( 0 , 2 ): nums[ 0 ] XOR nums[ 2 ] = 13
- ( 0 , 3 ): nums[ 0 ] XOR nums[ 3 ] = 11
- ( 0 , 4 ): nums[ 0 ] XOR nums[ 4 ] = 8
- ( 1 , 2 ): nums[ 1 ] XOR nums[ 2 ] = 12
- ( 1 , 3 ): nums[ 1 ] XOR nums[ 3 ] = 10
- ( 1 , 4 ): nums[ 1 ] XOR nums[ 4 ] = 9
- ( 2 , 3 ): nums[ 2 ] XOR nums[ 3 ] = 6
- ( 2 , 4 ): nums[ 2 ] XOR nums[ 4 ] = 5
Constraints#
1 <= nums.length <= 2 * 10^4
1 <= nums[i] <= 2 * 10^4
1 <= low <= high <= 2 * 10^4
Solution#
Method 1 – Trie (Bitwise Prefix Tree)#
Intuition#
To efficiently count pairs (i, j) with low <= (nums[i] XOR nums[j]) <= high, we use a Trie to store the binary representation of numbers seen so far. For each number, we count how many previous numbers have XOR with it in the given range by querying the Trie.
Approach#
Define a Trie node with two children (0 and 1) and a count of numbers passing through.
For each number in nums:
Count how many previous numbers have XOR with the current number <= high (using Trie query).
Subtract the count of pairs with XOR < low (using Trie query).
Insert the current number into the Trie.
The answer is the sum of counts for all numbers.
Code#
Cpp
Go
Java
Kotlin
Python
Rust
Typescript
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class TrieNode {
public :
TrieNode* child[2 ] = {nullptr , nullptr };
int cnt = 0 ;
};
class Solution {
public :
int countPairs(vector< int >& nums, int low, int high) {
return f (nums, high) - f(nums, low - 1 );
}
int f (vector< int >& nums, int x) {
TrieNode* root = new TrieNode();
int ans = 0 ;
for (int a : nums) {
TrieNode* node = root;
int cnt = 0 ;
for (int i = 14 ; i >= 0 ; -- i) {
int b = (a >> i) & 1 ;
int y = (x >> i) & 1 ;
if (y) {
if (node-> child[b]) cnt += node-> child[b]-> cnt;
node = node-> child[1 - b];
} else {
node = node-> child[b];
}
if (! node) break ;
}
ans += cnt;
node = root;
for (int i = 14 ; i >= 0 ; -- i) {
int b = (a >> i) & 1 ;
if (! node-> child[b]) node-> child[b] = new TrieNode();
node = node-> child[b];
node-> cnt++ ;
}
}
return ans;
}
};
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
type TrieNode struct {
child [2 ]* TrieNode
cnt int
}
func countPairs (nums []int , low , high int ) int {
return f (nums , high ) - f (nums , low - 1 )
}
func f (nums []int , x int ) int {
root := & TrieNode {}
ans := 0
for _ , a := range nums {
node := root
cnt := 0
for i := 14 ; i >= 0 ; i -- {
b := (a >> i ) & 1
y := (x >> i ) & 1
if y == 1 {
if node .child [b ] != nil {
cnt += node .child [b ].cnt
}
node = node .child [1 - b ]
} else {
node = node .child [b ]
}
if node == nil { break }
}
ans += cnt
node = root
for i := 14 ; i >= 0 ; i -- {
b := (a >> i ) & 1
if node .child [b ] == nil {
node .child [b ] = & TrieNode {}
}
node = node .child [b ]
node .cnt ++
}
}
return ans
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class TrieNode {
TrieNode[] child = new TrieNode[ 2] ;
int cnt = 0;
}
class Solution {
public int countPairs (int [] nums, int low, int high) {
return f(nums, high) - f(nums, low - 1);
}
private int f (int [] nums, int x) {
TrieNode root = new TrieNode();
int ans = 0;
for (int a : nums) {
TrieNode node = root;
int cnt = 0;
for (int i = 14; i >= 0; -- i) {
int b = (a >> i) & 1;
int y = (x >> i) & 1;
if (y == 1) {
if (node.child [ b] != null ) cnt += node.child [ b] .cnt ;
node = node.child [ 1- b] ;
} else {
node = node.child [ b] ;
}
if (node == null ) break ;
}
ans += cnt;
node = root;
for (int i = 14; i >= 0; -- i) {
int b = (a >> i) & 1;
if (node.child [ b] == null ) node.child [ b] = new TrieNode();
node = node.child [ b] ;
node.cnt ++ ;
}
}
return ans;
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class TrieNode {
val child = arrayOfNulls<TrieNode>(2 )
var cnt = 0
}
class Solution {
fun countPairs (nums: IntArray, low: Int, high: Int): Int {
fun f (nums: IntArray, x: Int): Int {
val root = TrieNode()
var ans = 0
for (a in nums) {
var node = root
var cnt = 0
for (i in 14 downTo 0 ) {
val b = (a shr i) and 1
val y = (x shr i) and 1
if (y == 1 ) {
if (node.child[b] != null ) cnt += node.child[b]!! .cnt
node = node.child[1 -b] ?: break
} else {
node = node.child[b] ?: break
}
}
ans += cnt
node = root
for (i in 14 downTo 0 ) {
val b = (a shr i) and 1
if (node.child[b] == null ) node.child[b] = TrieNode()
node = node.child[b]!!
node.cnt++
}
}
return ans
}
return f(nums, high) - f(nums, low - 1 )
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class TrieNode :
def __init__ (self):
self. child = [None , None ]
self. cnt = 0
class Solution :
def countPairs (self, nums: list[int], low: int, high: int) -> int:
def f (nums: list[int], x: int) -> int:
root = TrieNode()
ans = 0
for a in nums:
node = root
cnt = 0
for i in range(14 , - 1 , - 1 ):
b = (a >> i) & 1
y = (x >> i) & 1
if y:
if node. child[b]:
cnt += node. child[b]. cnt
node = node. child[1 - b]
else :
node = node. child[b]
if not node:
break
ans += cnt
node = root
for i in range(14 , - 1 , - 1 ):
b = (a >> i) & 1
if not node. child[b]:
node. child[b] = TrieNode()
node = node. child[b]
node. cnt += 1
return ans
return f(nums, high) - f(nums, low - 1 )
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
struct TrieNode {
child: [Option< Box< TrieNode>> ; 2 ],
cnt: i32 ,
}
impl TrieNode {
fn new () -> Self {
TrieNode { child: [None, None], cnt: 0 }
}
}
impl Solution {
pub fn count_pairs (nums: Vec< i32 > , low: i32 , high: i32 ) -> i32 {
fn f (nums: & Vec< i32 > , x: i32 ) -> i32 {
let mut root = Box::new(TrieNode::new());
let mut ans = 0 ;
for & a in nums {
let mut node = & mut root;
let mut cnt = 0 ;
for i in (0 ..= 14 ).rev() {
let b = (a >> i) & 1 ;
let y = (x >> i) & 1 ;
if y == 1 {
if let Some(ref n) = node.child[b as usize ] {
cnt += n.cnt;
}
if let Some(ref mut n) = node.child[(1 - b) as usize ] {
node = n;
} else { break ; }
} else {
if let Some(ref mut n) = node.child[b as usize ] {
node = n;
} else { break ; }
}
}
ans += cnt;
node = & mut root;
for i in (0 ..= 14 ).rev() {
let b = (a >> i) & 1 ;
if node.child[b as usize ].is_none() {
node.child[b as usize ] = Some(Box::new(TrieNode::new()));
}
node = node.child[b as usize ].as_mut().unwrap();
node.cnt += 1 ;
}
}
ans
}
f(& nums, high) - f(& nums, low - 1 )
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class TrieNode {
child : [TrieNode ? , TrieNode ? ] = [undefined , undefined ];
cnt = 0 ;
}
class Solution {
countPairs (nums : number [], low : number , high : number ): number {
function f (nums : number [], x : number ): number {
const root = new TrieNode ();
let ans = 0 ;
for (const a of nums ) {
let node : TrieNode | undefined = root ;
let cnt = 0 ;
for (let i = 14 ; i >= 0 ; -- i ) {
const b = (a >> i ) & 1 ;
const y = (x >> i ) & 1 ;
if (y ) {
if (node .child [b ]) cnt += node .child [b ]! .cnt ;
node = node .child [1 - b ];
} else {
node = node .child [b ];
}
if (! node ) break ;
}
ans += cnt ;
node = root ;
for (let i = 14 ; i >= 0 ; -- i ) {
const b = (a >> i ) & 1 ;
if (! node .child [b ]) node .child [b ] = new TrieNode ();
node = node .child [b ]! ;
node .cnt ++ ;
}
}
return ans ;
}
return f (nums , high ) - f (nums , low - 1 );
}
}
Complexity#
⏰ Time complexity: O(n * log M)
, where n is the number of elements and M is the max value (since we process each bit for each number).
🧺 Space complexity: O(n * log M)
, for the Trie storing all inserted numbers.