Problem

Given an array of integers arr.

We want to select three indices ij and k where (0 <= i < j <= k < arr.length).

Let’s define a and b as follows:

  • a = arr[i] ^ arr[i + 1] ^ ... ^ arr[j - 1]
  • b = arr[j] ^ arr[j + 1] ^ ... ^ arr[k]

Note that ^ denotes the bitwise-xor operation.

Return the number of triplets (ij and k) Where a == b.

Examples

Example 1:

Input: arr = [2,3,1,6,7]
Output: 4
Explanation: The triplets are (0,1,2), (0,2,2), (2,3,4) and (2,4,4)

Example 2:

Input: arr = [1,1,1,1,1]
Output: 10

Solution

So, the idea is xor of elements from index i to j-1 should be equal to xor of elements from index j to k.

Method 1 - Brute Force

We run 4 nested loops - 3 inner ones for i, j, k and calculate the triplets, such that a = b.

Complexity

  • ⏰ Time complexity: O(n^4)

Method 2 - Improve brute force with Prefix XOR

Now, we run 3 nested loops, trying to find ``

Code

Java
class Solution {

	public int countTriplets(int[] nums) {
		int ans = 0;

		for (int i = 0; i < nums.length - 1; i++) {
			int a = 0; // XOR for every subarray starting from i (inclusive) to j-1
			for (int j = i + 1; j < nums.length; j++) {
				a ^= nums[j - 1];
				int b = 0; // XOR for every subarray starting from j (inclusive) to k
				for (int k = j; k < nums.length; k++) {
					b ^= nums[k];
					if (a == b) {
						ans++;
					}
				}
			}
		}

		return ans;
	}
}

Complexity

  • ⏰ Time complexity: O(n^3)
  • 🧺 Space complexity: O(1)

Method 3 - Optimized Prefix Xor

Now, improve prefix xor even more. We know a = xor of arr[i:j-1] and b = xor of arr[j:k]:

a = arr[i] ^ arr[i + 1] ^ ... ^ arr[j - 1]
b = arr[j] ^ arr[j + 1] ^ ... ^ arr[k]

Lets say, we found out the point such that a == b.

Now, if we xor both sides with a:

a = b
a ^ a = b ^ a
0 = b ^ a
arr[i] ^ arr[i + 1] ^ ... ^ arr[j - 1] ^ arr[j] ^ arr[j + 1] ^ ... ^ arr[k] = 0

Hence, xor of arr[i:k] is 0. We only need to find out how many pair (i, k) of prefix value are equal.

So we can calculate the prefix array first, then brute force count the pair.

Because we once we determine the pair (i,k), j can be any number that i < j <= k, so we need to plus k - i - 1 to the result ans.

Code

Java
class Solution {

	public int countTriplets(int[] arr) {
		int n = arr.length + 1, ans = 0;
		int[] prefix = new int[n];

		for (int i = 1; i < n; ++i) {
			prefix[i] = arr[i - 1] ^ prefix[i - 1];
		}

		for (int i = 0; i < n; ++i) {
			for (int j = i + 1; j < n; ++j) {
				if (prefix[i] == prefix[j]) {
					ans += j - i - 1;
				}
			}
		}

		return ans;
	}
}