PROBLEM LINK:
Author: Kamil Debowski
Primary Tester: Marek Sokołowski
Secondary Tester: Praveen Dhinwa
Editorialist: Praveen Dhinwa
DIFFICULTY:
medium
PREREQUISITES:
segment tree, understanding of bit wise xor operations
PROBLEM:
Given an array a containing n positive integers between 0 and 50, find the xor of sums of all non-empty segments of the array.
In other words, for each of \binom{n}{2} + n non-empty subarrays (consecutive subsequences) find its sum of elements, and print the xor of those sums.
Let’s start with the bruteforce solution that won’t pass.
Let a[i, j] denote the subarray a_i, a_{i+1}, \dots, a_j. We iterate over all subarrays a[i, j] and find their sum in constant time by precomputing prefix sums, and thus compute the xor of these sums. This approach has a time complexity of \mathcal{O}(n^2) which has no chance of passing as n can be as large as 3 \cdot 10^5.
An important observation
Claim: For a number x, its b-th bit (zero based indexing from least significant digit to highest significant digit) will be set if and only if x \text{ mod } {2^{b+1}} \geq 2^b.
Proof: Taking mod 2^{b+1} for number x leaves only the last b+1 bits of x, i.e. the i-th bit of the resulting number for i>b will be zero. Now, we want to check whether the most significant bit of the resulting b+1 bit number is 1 or not. Consider the number whose only set bit is b-th bit. The number will be 2^b. This is the least b+1 bit number with its b-th bit set. Thus, the result of x \text{ mod } {2^{b+1}} should be greater than or equal to 2^b in order to have its b-th bit set.
Think in terms of finding the bits of the answer
When dealing with bitwise operations, it is generally a good idea to finding the bits of the answer. We want to check whether a particular b-th bit of the answer is set (1) or not. If the number of subarrays with b-th bit set in their sum is odd, then i-th bit in the answer will be 1, 0 otherwise. This is due to properties of xor function. Xor of same bits is zero. If there are even number of 1’s, then the xor will be zero, 1 otherwise.
So, it means that if we can efficiently find the number of subarrays with b-th bit set in their sum, then we can find the answer.
Solving the reduced version of the problem
Let s be the prefix sum array for array a, modulo 2^{b+1}. Assume a has 1-based indexing and s 0-based. We define s_0 = 0 and s_i = (a_1 + a_2 + \dots + a_i) \text{ mod } 2^{b + 1}. You can see that sum of subarray a[j, i] modulo 2^{b+1} will be (s_i - s_{j-1}) \pmod {2^{b+1}}.
Finding total number of subarrays with b-th bit set in their sum will be equivalent to finding number of pairs (j, i) where 1 \leq j \leq i \leq n, such that (s_i - s_{j-1}) \pmod{2^{b+1}} \geq 2^b.
We need to take care of following two cases.
-
s_i \geq s_{j-1}. We have (s_i - s_{j-1}) \pmod{2^{b+1}} = s_i - s_{j-1}. We want s_i - s_{j-1} \geq 2^b. It means s_{j-1} \leq s_i - 2^b. We know that s_{j-1} \geq 0. Combining these, we get 0 \leq s_{j-1} \leq s_i - 2^b, i.e. s_{j-1} \, \epsilon \, [0, s_i - 2^b].
-
s_i < s_{j-1}. We have (s_i - s_{j-1}) \pmod{2^{b+1}} = s_i - s_{j-1} + 2^{b+1}.
We want s_i - s_{j-1} + 2^{b+1} \geq 2^b. After rearranging the expression, we have s_{j-1} \leq s_i + 2^b. We also know that s_i < s_{j-1}. So, overall s_i < s_{j-1} \leq s_i + 2^b, i.e. s_{j-1} \, \epsilon \, [s_i + 1, s_i + 2^b]
Notice that the range [0, s_i - 2^b] and [s_i + 1, s_i + 2^b] are mutually disjoint. This means that if we are able to find number of s_{j-1}'s that belong to these ranges, we will be able to find the number of valid j's for the index i.
Segment tree to help answering the range queries fast
You can see that for finding number of valid j's for a given i, we will need to answer queries in which we have to tell the number of elements of the array s (only consider the numbers from the index 0 to i-1) that lie in a particular range. We can answer such queries using segment tree.
The elements of s lie between 0 to max_{s_i}, where max_{s_i} can be at max a_1 + a_1 + \dots + a_n. Let us denote a_1 + a_1 + \dots + a_n by S. Assume we have an array cnt of size S+1, where cnt_x denote the number of times number x has appeared in the array s[0, i-1]. Answering the query of count of numbers in the subarray s[0, i-1] that have their values in the range [L, R], will be same as finding the sum cnt_L + cnt_{L+1} + \dots + cnt_R. When you go from i to i+1, you will to increment cnt_{s_{i}} by 1. Effectively, you need to maintain these two operations over an array of size S+1.
- Find sum of elements of the range [L, R] of the array, where L \leq R \leq S.
- Increment the i-th element by 1.
Both of these operations can be perfomed in \mathcal{O}(\log{S}) time using segment tree, where a node of the segment tree will store the sum of elements for the range it is responsible.
Finally the solution with its time complexity
So, we can now count the number of valid j's for all the i's in \mathcal{O}(n \cdot \log{S}), i.e. we can find the number of the subarrays a[j, i] with b-th bit set. We will do this for each bit in the answer (\log{S} bits). Hence, time complexity of this approach comes out to be \mathcal{O}(n \cdot {\log{S}}^2) time. Memory complexity of the solution will be \mathcal{O}(S) required for the segment tree to operate over S elements.
Pseudo Code
ans = 0;
Let tree be a segment tree data structure which maintains number of occurrences of elements of s[0..i-1], when we are at i-th number. The addition, removal and query operations will be done on it.
for b = 0; 2^b <= S; b++:
s[0] = 0;
for i = 1 to n:
s[i] = (s[i - 1] + a[i]) mod 2^(b+1)
count = 0;
add(s[0]);
for i = 1 to n:
// s[j-1] > s[i] case.
{
L = s[i] + 1;
R = s[i] + 2^b;
if R >= S:
R = S;
count += query(L, R);
}
// s[j-1] <= s[i] case.
{
R = s[i] - 2^b;
if R >= 0:
count += query(0, R);
}
add(s[i]);
// if count is odd, then b-th bit is set.
if (count % 2 == 1):
ans |= (1 << b);
// remove these elements that were added in the current iteration for processing b-th bit, so that the segment tree is fresh for b+1-th bit.
for i = 0 to n:
remove(s[i]);