Problem

You are given an integer array nums of length n, where nums is a permutation of the numbers in the range [1, n].

A XOR triplet is defined as the XOR of three elements nums[i] XOR nums[j] XOR nums[k] where i <= j <= k.

Return the number of unique XOR triplet values from all possible triplets (i, j, k).

Example 1

1
2
3
4
5
6
7
8
9
Input: nums = [1,2]
Output: 2
Explanation:
The possible XOR triplet values are:
* `(0, 0, 0) -> 1 XOR 1 XOR 1 = 1`
* `(0, 0, 1) -> 1 XOR 1 XOR 2 = 2`
* `(0, 1, 1) -> 1 XOR 2 XOR 2 = 1`
* `(1, 1, 1) -> 2 XOR 2 XOR 2 = 2`
The unique XOR values are `{1, 2}`, so the output is 2.

Example 2

1
2
3
4
5
6
7
8
9
Input: nums = [3,1,2]
Output: 4
Explanation:
The possible XOR triplet values include:
* `(0, 0, 0) -> 3 XOR 3 XOR 3 = 3`
* `(0, 0, 1) -> 3 XOR 3 XOR 1 = 1`
* `(0, 0, 2) -> 3 XOR 3 XOR 2 = 2`
* `(0, 1, 2) -> 3 XOR 1 XOR 2 = 0`
The unique XOR values are `{0, 1, 2, 3}`, so the output is 4.

Constraints

  • 1 <= n == nums.length <= 10^5
  • 1 <= nums[i] <= n
  • nums is a permutation of integers from 1 to n.

Examples

Solution

Method 1 – Prefix XOR and Set

Intuition

For a permutation, we can use prefix XOR to compute the XOR of any subarray in O(1) time. For all i ≤ j ≤ k, the XOR of nums[i] ^ … ^ nums[k] can be written as prefix[k+1] ^ prefix[i]. We can enumerate all (i, k) pairs and for each, enumerate all j in [i, k], but since XOR is associative, every (i, k) pair gives (k-i+1) triplets with the same value. We only need to collect all possible prefix[i] ^ prefix[k+1] values.

Approach

  1. Compute prefix XOR array: prefix[0] = 0, prefix[i+1] = prefix[i] ^ nums[i].
  2. For all i ≤ k, add prefix[i] ^ prefix[k+1] to a set.
  3. Return the size of the set.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
class Solution {
public:
    int countTriplets(vector<int>& nums) {
        int n = nums.size();
        vector<int> pre(n+1);
        for (int i = 0; i < n; ++i) pre[i+1] = pre[i] ^ nums[i];
        unordered_set<int> s;
        for (int i = 0; i < n; ++i) {
            for (int k = i; k < n; ++k) {
                s.insert(pre[i] ^ pre[k+1]);
            }
        }
        return s.size();
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
func countTriplets(nums []int) int {
    n := len(nums)
    pre := make([]int, n+1)
    for i := 0; i < n; i++ {
        pre[i+1] = pre[i] ^ nums[i]
    }
    s := map[int]struct{}{}
    for i := 0; i < n; i++ {
        for k := i; k < n; k++ {
            s[pre[i]^pre[k+1]] = struct{}{}
        }
    }
    return len(s)
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
class Solution {
    public int countTriplets(int[] nums) {
        int n = nums.length;
        int[] pre = new int[n+1];
        for (int i = 0; i < n; i++) pre[i+1] = pre[i] ^ nums[i];
        Set<Integer> s = new HashSet<>();
        for (int i = 0; i < n; i++) {
            for (int k = i; k < n; k++) {
                s.add(pre[i] ^ pre[k+1]);
            }
        }
        return s.size();
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
class Solution {
    fun countTriplets(nums: IntArray): Int {
        val n = nums.size
        val pre = IntArray(n+1)
        for (i in 0 until n) pre[i+1] = pre[i] xor nums[i]
        val s = mutableSetOf<Int>()
        for (i in 0 until n) {
            for (k in i until n) {
                s.add(pre[i] xor pre[k+1])
            }
        }
        return s.size
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
class Solution:
    def countTriplets(self, nums: list[int]) -> int:
        n = len(nums)
        pre = [0]*(n+1)
        for i in range(n):
            pre[i+1] = pre[i] ^ nums[i]
        s = set()
        for i in range(n):
            for k in range(i, n):
                s.add(pre[i] ^ pre[k+1])
        return len(s)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
use std::collections::HashSet;
impl Solution {
    pub fn count_triplets(nums: Vec<i32>) -> i32 {
        let n = nums.len();
        let mut pre = vec![0; n+1];
        for i in 0..n { pre[i+1] = pre[i] ^ nums[i]; }
        let mut s = HashSet::new();
        for i in 0..n {
            for k in i..n {
                s.insert(pre[i] ^ pre[k+1]);
            }
        }
        s.len() as i32
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
class Solution {
    countTriplets(nums: number[]): number {
        const n = nums.length
        const pre = new Array(n+1).fill(0)
        for (let i = 0; i < n; i++) pre[i+1] = pre[i] ^ nums[i]
        const s = new Set<number>()
        for (let i = 0; i < n; i++) {
            for (let k = i; k < n; k++) {
                s.add(pre[i] ^ pre[k+1])
            }
        }
        return s.size
    }
}

Complexity

  • ⏰ Time complexity: O(n^2)
  • 🧺 Space complexity: O(n^2) in the worst case (all triplets unique)