Problem

Given a (0-indexed) integer array nums and two integers low and high, return the number ofnice pairs.

A nice pair is a pair (i, j) where 0 <= i < j < nums.length and low <= (nums[i] XOR nums[j]) <= high.

Examples

Example 1

1
2
3
4
5
6
7
8
9
Input: nums = [1,4,2,7], low = 2, high = 6
Output: 6
Explanation: All nice pairs (i, j) are as follows:
    - (0, 1): nums[0] XOR nums[1] = 5 
    - (0, 2): nums[0] XOR nums[2] = 3
    - (0, 3): nums[0] XOR nums[3] = 6
    - (1, 2): nums[1] XOR nums[2] = 6
    - (1, 3): nums[1] XOR nums[3] = 3
    - (2, 3): nums[2] XOR nums[3] = 5

Example 2

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
Input: nums = [9,8,4,2,1], low = 5, high = 14
Output: 8
Explanation: All nice pairs (i, j) are as follows:
​​​​​    - (0, 2): nums[0] XOR nums[2] = 13
    - (0, 3): nums[0] XOR nums[3] = 11
    - (0, 4): nums[0] XOR nums[4] = 8
    - (1, 2): nums[1] XOR nums[2] = 12
    - (1, 3): nums[1] XOR nums[3] = 10
    - (1, 4): nums[1] XOR nums[4] = 9
    - (2, 3): nums[2] XOR nums[3] = 6
    - (2, 4): nums[2] XOR nums[4] = 5

Constraints

  • 1 <= nums.length <= 2 * 10^4
  • 1 <= nums[i] <= 2 * 10^4
  • 1 <= low <= high <= 2 * 10^4

Solution

Method 1 – Trie (Bitwise Prefix Tree)

Intuition

To efficiently count pairs (i, j) with low <= (nums[i] XOR nums[j]) <= high, we use a Trie to store the binary representation of numbers seen so far. For each number, we count how many previous numbers have XOR with it in the given range by querying the Trie.

Approach

  1. Define a Trie node with two children (0 and 1) and a count of numbers passing through.
  2. For each number in nums:
    1. Count how many previous numbers have XOR with the current number <= high (using Trie query).
    2. Subtract the count of pairs with XOR < low (using Trie query).
    3. Insert the current number into the Trie.
  3. The answer is the sum of counts for all numbers.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class TrieNode {
public:
    TrieNode* child[2] = {nullptr, nullptr};
    int cnt = 0;
};
class Solution {
public:
    int countPairs(vector<int>& nums, int low, int high) {
        return f(nums, high) - f(nums, low - 1);
    }
    int f(vector<int>& nums, int x) {
        TrieNode* root = new TrieNode();
        int ans = 0;
        for (int a : nums) {
            TrieNode* node = root;
            int cnt = 0;
            for (int i = 14; i >= 0; --i) {
                int b = (a >> i) & 1;
                int y = (x >> i) & 1;
                if (y) {
                    if (node->child[b]) cnt += node->child[b]->cnt;
                    node = node->child[1-b];
                } else {
                    node = node->child[b];
                }
                if (!node) break;
            }
            ans += cnt;
            node = root;
            for (int i = 14; i >= 0; --i) {
                int b = (a >> i) & 1;
                if (!node->child[b]) node->child[b] = new TrieNode();
                node = node->child[b];
                node->cnt++;
            }
        }
        return ans;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
type TrieNode struct {
    child [2]*TrieNode
    cnt int
}
func countPairs(nums []int, low, high int) int {
    return f(nums, high) - f(nums, low-1)
}
func f(nums []int, x int) int {
    root := &TrieNode{}
    ans := 0
    for _, a := range nums {
        node := root
        cnt := 0
        for i := 14; i >= 0; i-- {
            b := (a >> i) & 1
            y := (x >> i) & 1
            if y == 1 {
                if node.child[b] != nil {
                    cnt += node.child[b].cnt
                }
                node = node.child[1-b]
            } else {
                node = node.child[b]
            }
            if node == nil { break }
        }
        ans += cnt
        node = root
        for i := 14; i >= 0; i-- {
            b := (a >> i) & 1
            if node.child[b] == nil {
                node.child[b] = &TrieNode{}
            }
            node = node.child[b]
            node.cnt++
        }
    }
    return ans
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class TrieNode {
    TrieNode[] child = new TrieNode[2];
    int cnt = 0;
}
class Solution {
    public int countPairs(int[] nums, int low, int high) {
        return f(nums, high) - f(nums, low - 1);
    }
    private int f(int[] nums, int x) {
        TrieNode root = new TrieNode();
        int ans = 0;
        for (int a : nums) {
            TrieNode node = root;
            int cnt = 0;
            for (int i = 14; i >= 0; --i) {
                int b = (a >> i) & 1;
                int y = (x >> i) & 1;
                if (y == 1) {
                    if (node.child[b] != null) cnt += node.child[b].cnt;
                    node = node.child[1-b];
                } else {
                    node = node.child[b];
                }
                if (node == null) break;
            }
            ans += cnt;
            node = root;
            for (int i = 14; i >= 0; --i) {
                int b = (a >> i) & 1;
                if (node.child[b] == null) node.child[b] = new TrieNode();
                node = node.child[b];
                node.cnt++;
            }
        }
        return ans;
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class TrieNode {
    val child = arrayOfNulls<TrieNode>(2)
    var cnt = 0
}
class Solution {
    fun countPairs(nums: IntArray, low: Int, high: Int): Int {
        fun f(nums: IntArray, x: Int): Int {
            val root = TrieNode()
            var ans = 0
            for (a in nums) {
                var node = root
                var cnt = 0
                for (i in 14 downTo 0) {
                    val b = (a shr i) and 1
                    val y = (x shr i) and 1
                    if (y == 1) {
                        if (node.child[b] != null) cnt += node.child[b]!!.cnt
                        node = node.child[1-b] ?: break
                    } else {
                        node = node.child[b] ?: break
                    }
                }
                ans += cnt
                node = root
                for (i in 14 downTo 0) {
                    val b = (a shr i) and 1
                    if (node.child[b] == null) node.child[b] = TrieNode()
                    node = node.child[b]!!
                    node.cnt++
                }
            }
            return ans
        }
        return f(nums, high) - f(nums, low - 1)
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class TrieNode:
    def __init__(self):
        self.child = [None, None]
        self.cnt = 0
class Solution:
    def countPairs(self, nums: list[int], low: int, high: int) -> int:
        def f(nums: list[int], x: int) -> int:
            root = TrieNode()
            ans = 0
            for a in nums:
                node = root
                cnt = 0
                for i in range(14, -1, -1):
                    b = (a >> i) & 1
                    y = (x >> i) & 1
                    if y:
                        if node.child[b]:
                            cnt += node.child[b].cnt
                        node = node.child[1-b]
                    else:
                        node = node.child[b]
                    if not node:
                        break
                ans += cnt
                node = root
                for i in range(14, -1, -1):
                    b = (a >> i) & 1
                    if not node.child[b]:
                        node.child[b] = TrieNode()
                    node = node.child[b]
                    node.cnt += 1
            return ans
        return f(nums, high) - f(nums, low - 1)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
struct TrieNode {
    child: [Option<Box<TrieNode>>; 2],
    cnt: i32,
}
impl TrieNode {
    fn new() -> Self {
        TrieNode { child: [None, None], cnt: 0 }
    }
}
impl Solution {
    pub fn count_pairs(nums: Vec<i32>, low: i32, high: i32) -> i32 {
        fn f(nums: &Vec<i32>, x: i32) -> i32 {
            let mut root = Box::new(TrieNode::new());
            let mut ans = 0;
            for &a in nums {
                let mut node = &mut root;
                let mut cnt = 0;
                for i in (0..=14).rev() {
                    let b = (a >> i) & 1;
                    let y = (x >> i) & 1;
                    if y == 1 {
                        if let Some(ref n) = node.child[b as usize] {
                            cnt += n.cnt;
                        }
                        if let Some(ref mut n) = node.child[(1-b) as usize] {
                            node = n;
                        } else { break; }
                    } else {
                        if let Some(ref mut n) = node.child[b as usize] {
                            node = n;
                        } else { break; }
                    }
                }
                ans += cnt;
                node = &mut root;
                for i in (0..=14).rev() {
                    let b = (a >> i) & 1;
                    if node.child[b as usize].is_none() {
                        node.child[b as usize] = Some(Box::new(TrieNode::new()));
                    }
                    node = node.child[b as usize].as_mut().unwrap();
                    node.cnt += 1;
                }
            }
            ans
        }
        f(&nums, high) - f(&nums, low - 1)
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class TrieNode {
    child: [TrieNode?, TrieNode?] = [undefined, undefined];
    cnt = 0;
}
class Solution {
    countPairs(nums: number[], low: number, high: number): number {
        function f(nums: number[], x: number): number {
            const root = new TrieNode();
            let ans = 0;
            for (const a of nums) {
                let node: TrieNode | undefined = root;
                let cnt = 0;
                for (let i = 14; i >= 0; --i) {
                    const b = (a >> i) & 1;
                    const y = (x >> i) & 1;
                    if (y) {
                        if (node.child[b]) cnt += node.child[b]!.cnt;
                        node = node.child[1-b];
                    } else {
                        node = node.child[b];
                    }
                    if (!node) break;
                }
                ans += cnt;
                node = root;
                for (let i = 14; i >= 0; --i) {
                    const b = (a >> i) & 1;
                    if (!node.child[b]) node.child[b] = new TrieNode();
                    node = node.child[b]!;
                    node.cnt++;
                }
            }
            return ans;
        }
        return f(nums, high) - f(nums, low - 1);
    }
}

Complexity

  • ⏰ Time complexity: O(n * log M), where n is the number of elements and M is the max value (since we process each bit for each number).
  • 🧺 Space complexity: O(n * log M), for the Trie storing all inserted numbers.