Given three integer arrays a, b, and c, return the number of triplets
(a[i], b[j], c[k]), such that the bitwise XOR between the elements of each triplet has an even number of set bits.
For each triplet (a[i], b[j], c[k]), the XOR of the three numbers has an even number of set bits if the parity of set bits in a[i] ^ b[j] ^ c[k] is even. We can precompute the parity of set bits for all possible values and use prefix XOR and counting to efficiently count valid triplets.
classSolution {
public:longlong countTriplets(vector<int>& a, vector<int>& b, vector<int>& c) {
auto parity = [](int x) { return__builtin_popcount(x) %2; };
longlong cntA[2] = {}, cntB[2] = {}, cntC[2] = {};
for (int x : a) cntA[parity(x)]++;
for (int x : b) cntB[parity(x)]++;
for (int x : c) cntC[parity(x)]++;
longlong ans =0;
for (int pa =0; pa <2; ++pa)
for (int pb =0; pb <2; ++pb)
for (int pc =0; pc <2; ++pc)
if ((pa ^ pb ^ pc) ==0)
ans += cntA[pa] * cntB[pb] * cntC[pc];
return ans;
}
};
classSolution {
publiclongcountTriplets(int[] a, int[] b, int[] c) {
long[] cntA =newlong[2], cntB =newlong[2], cntC =newlong[2];
for (int x : a) cntA[Integer.bitCount(x) % 2]++;
for (int x : b) cntB[Integer.bitCount(x) % 2]++;
for (int x : c) cntC[Integer.bitCount(x) % 2]++;
long ans = 0;
for (int pa = 0; pa < 2; pa++)
for (int pb = 0; pb < 2; pb++)
for (int pc = 0; pc < 2; pc++)
if ((pa ^ pb ^ pc) == 0)
ans += cntA[pa]* cntB[pb]* cntC[pc];
return ans;
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
classSolution {
funcountTriplets(a: IntArray, b: IntArray, c: IntArray): Long {
val cntA = LongArray(2)
val cntB = LongArray(2)
val cntC = LongArray(2)
for (x in a) cntA[Integer.bitCount(x) % 2]++for (x in b) cntB[Integer.bitCount(x) % 2]++for (x in c) cntC[Integer.bitCount(x) % 2]++var ans = 0Lfor (pa in0..1)
for (pb in0..1)
for (pc in0..1)
if ((pa xor pb xor pc) ==0)
ans += cntA[pa] * cntB[pb] * cntC[pc]
return ans
}
}
classSolution:
defcountTriplets(self, a: list[int], b: list[int], c: list[int]) -> int:
defparity(x: int) -> int:
return bin(x).count('1') %2 cntA = [0, 0]
cntB = [0, 0]
cntC = [0, 0]
for x in a:
cntA[parity(x)] +=1for x in b:
cntB[parity(x)] +=1for x in c:
cntC[parity(x)] +=1 ans =0for pa in range(2):
for pb in range(2):
for pc in range(2):
if pa ^ pb ^ pc ==0:
ans += cntA[pa] * cntB[pb] * cntC[pc]
return ans