PROBLEM LINK:
Author: Kamil Debowski
Primary Tester: Marek Sokołowski
Secondary Tester: Praveen Dhinwa
Editorialist: Praveen Dhinwa
DIFFICULTY:
medium
PREREQUISITES:
segment tree, understanding of bit wise xor operations
PROBLEM:
Given an array a containing n positive integers between 0 and 50, find the xor of sums of all nonempty segments of the array.
In other words, for each of \binom{n}{2} + n nonempty subarrays (consecutive subsequences) find its sum of elements, and print the xor of those sums.
Let’s start with the bruteforce solution that won’t pass.
Let a[i, j] denote the subarray a_i, a_{i+1}, \dots, a_j. We iterate over all subarrays a[i, j] and find their sum in constant time by precomputing prefix sums, and thus compute the xor of these sums. This approach has a time complexity of \mathcal{O}(n^2) which has no chance of passing as n can be as large as 3 \cdot 10^5.
An important observation
Claim: For a number x, its bth bit (zero based indexing from least significant digit to highest significant digit) will be set if and only if x \text{ mod } {2^{b+1}} \geq 2^b.
Proof: Taking mod 2^{b+1} for number x leaves only the last b+1 bits of x, i.e. the ith bit of the resulting number for i>b will be zero. Now, we want to check whether the most significant bit of the resulting b+1 bit number is 1 or not. Consider the number whose only set bit is bth bit. The number will be 2^b. This is the least b+1 bit number with its bth bit set. Thus, the result of x \text{ mod } {2^{b+1}} should be greater than or equal to 2^b in order to have its bth bit set.
Think in terms of finding the bits of the answer
When dealing with bitwise operations, it is generally a good idea to finding the bits of the answer. We want to check whether a particular bth bit of the answer is set (1) or not. If the number of subarrays with bth bit set in their sum is odd, then ith bit in the answer will be 1, 0 otherwise. This is due to properties of xor function. Xor of same bits is zero. If there are even number of 1’s, then the xor will be zero, 1 otherwise.
So, it means that if we can efficiently find the number of subarrays with bth bit set in their sum, then we can find the answer.
Solving the reduced version of the problem
Let s be the prefix sum array for array a, modulo 2^{b+1}. Assume a has 1based indexing and s 0based. We define s_0 = 0 and s_i = (a_1 + a_2 + \dots + a_i) \text{ mod } 2^{b + 1}. You can see that sum of subarray a[j, i] modulo 2^{b+1} will be (s_i  s_{j1}) \pmod {2^{b+1}}.
Finding total number of subarrays with bth bit set in their sum will be equivalent to finding number of pairs (j, i) where 1 \leq j \leq i \leq n, such that (s_i  s_{j1}) \pmod{2^{b+1}} \geq 2^b.
We need to take care of following two cases.

s_i \geq s_{j1}. We have (s_i  s_{j1}) \pmod{2^{b+1}} = s_i  s_{j1}. We want s_i  s_{j1} \geq 2^b. It means s_{j1} \leq s_i  2^b. We know that s_{j1} \geq 0. Combining these, we get 0 \leq s_{j1} \leq s_i  2^b, i.e. s_{j1} \, \epsilon \, [0, s_i  2^b].

s_i < s_{j1}. We have (s_i  s_{j1}) \pmod{2^{b+1}} = s_i  s_{j1} + 2^{b+1}.
We want s_i  s_{j1} + 2^{b+1} \geq 2^b. After rearranging the expression, we have s_{j1} \leq s_i + 2^b. We also know that s_i < s_{j1}. So, overall s_i < s_{j1} \leq s_i + 2^b, i.e. s_{j1} \, \epsilon \, [s_i + 1, s_i + 2^b]
Notice that the range [0, s_i  2^b] and [s_i + 1, s_i + 2^b] are mutually disjoint. This means that if we are able to find number of s_{j1}'s that belong to these ranges, we will be able to find the number of valid j's for the index i.
Segment tree to help answering the range queries fast
You can see that for finding number of valid j's for a given i, we will need to answer queries in which we have to tell the number of elements of the array s (only consider the numbers from the index 0 to i1) that lie in a particular range. We can answer such queries using segment tree.
The elements of s lie between 0 to max_{s_i}, where max_{s_i} can be at max a_1 + a_1 + \dots + a_n. Let us denote a_1 + a_1 + \dots + a_n by S. Assume we have an array cnt of size S+1, where cnt_x denote the number of times number x has appeared in the array s[0, i1]. Answering the query of count of numbers in the subarray s[0, i1] that have their values in the range [L, R], will be same as finding the sum cnt_L + cnt_{L+1} + \dots + cnt_R. When you go from i to i+1, you will to increment cnt_{s_{i}} by 1. Effectively, you need to maintain these two operations over an array of size S+1.
 Find sum of elements of the range [L, R] of the array, where L \leq R \leq S.
 Increment the ith element by 1.
Both of these operations can be perfomed in \mathcal{O}(\log{S}) time using segment tree, where a node of the segment tree will store the sum of elements for the range it is responsible.
Finally the solution with its time complexity
So, we can now count the number of valid j's for all the i's in \mathcal{O}(n \cdot \log{S}), i.e. we can find the number of the subarrays a[j, i] with bth bit set. We will do this for each bit in the answer (\log{S} bits). Hence, time complexity of this approach comes out to be \mathcal{O}(n \cdot {\log{S}}^2) time. Memory complexity of the solution will be \mathcal{O}(S) required for the segment tree to operate over S elements.
Pseudo Code
ans = 0;
Let tree be a segment tree data structure which maintains number of occurrences of elements of s[0..i1], when we are at ith number. The addition, removal and query operations will be done on it.
for b = 0; 2^b <= S; b++:
s[0] = 0;
for i = 1 to n:
s[i] = (s[i  1] + a[i]) mod 2^(b+1)
count = 0;
add(s[0]);
for i = 1 to n:
// s[j1] > s[i] case.
{
L = s[i] + 1;
R = s[i] + 2^b;
if R >= S:
R = S;
count += query(L, R);
}
// s[j1] <= s[i] case.
{
R = s[i]  2^b;
if R >= 0:
count += query(0, R);
}
add(s[i]);
// if count is odd, then bth bit is set.
if (count % 2 == 1):
ans = (1 << b);
// remove these elements that were added in the current iteration for processing bth bit, so that the segment tree is fresh for b+1th bit.
for i = 0 to n:
remove(s[i]);