PROBLEM LINK:
Author: Vlad Mosko
Tester: Hiroto Sekido
Editorialist: Kevin Atienza
DIFFICULTY:
Medium
PREREQUISITES:
Divisor lists, intervals, combinatorics, ad hoc
PROBLEM:
Given an array of integers [A_1, \ldots, A_N] and a number K, find the number of subarrays containing no bad pairs. A pair (i,j) is a bad pair if i < j and (A_i \bmod A_j) = K.
QUICK EXPLANATION:
- For each position j, compute the largest i_j < j such that (i_j,j) is a bad pair. Collect all such bad pairs gotten this way, sorted by j. (There are at most N such pairs, but may be less, because for some j s there could be no such $i_j$s).
- Remove a pair (i_j,j) if j > 1 and i_{j-1} \ge i_j.
- Add the pair (N, N+1) at the end of the list.
- Assume that the list of pairs acquired is [(i_1,j_1), (i_2,j_2), \ldots, (i_m,j_m)]. Then the answer is:
To compute all pairs (i_j,j), we need to store, for each value v > K, the maximum i < j such that A_i > K and (A_i \bmod v) = K (there are at most \max A_i such lists). We also store the maximum l < j such that A_l = K. As we increase j, we need update some of these lists: When going to a new value of j we only need to update at most O(\sqrt{\max A_j}) lists.
EXPLANATION:
Letās first assume that we have all the bad pairs, and we want to compute the number of subarrays containing no bad pairs. This is just equal to the number of subarrays (which is N(N+1)/2), minus the number of subarrays containing at least one bad pair. We will compute the latter instead, and weāll call a subarray bad if it contains at leaset one bad pair.
First, notice that all {N \choose 2} pairs can be bad pairs (for example, if A_i = 1 for all i and K = 0). Therefore we canāt even compute all the bad pairs, because that is too slow. However, notice that not all bad pairs are important. For example, if the pair p_1 completely contains the pair p_2, then any subarray containing p_1 also contains p_2. Thus, the pair p_1 is redundant, i.e. if we remove the pair p_1 from our list, then the set of bad subarrays stay the same.
Thus we can remove all bad pairs completely containing other bad pairs. Note that after doing so, no two bad pairs share the same left (or right) endpoint any more (otherwise one will contain the other). Since there are at most N endpoints, this reduces the number of bad pairs we have to at most N-1, which is more manageable
This also shows that for a given endpoint j, there is only at most one important bad pair (i,j) ending in j, and itās the one with the maximum i. Therefore, the first step is to compute the maximum i for each j such that (i,j) is bad (ignoring the j s for which there are no such i).
Bad pairs ending in every position
Assuming A_j > K, (A_i \bmod A_j) = K is equivalent to:
This means that A_j is a divisor of (A_i - K).
For every endpoint j, we need to compute the maximum i such that (i,j) is bad. Surely, there are no bad pairs such that A_i < K or A_j \le K, because (A_i \bmod A_j) < K in those cases.
Suppose that, for every value v, K < v \le \max A_i, we store the value i_v, which we define as the maximum i_v < j such that (A_{i_v} \bmod v) = K. Note that i_v is also dependent on j, so as we increase j, we will need to update some of the $i_v$s, but for now, assume we can do those updates. If so, the i we are looking for is simply i_{A_j}
Now for the updates. As we leave a particular value of j and go to the next one (j+1), we need to update some of the $i_v$s, specifically for those v s having (A_j \bmod v) = K. As shown above, this is equivalent to v being a divisor of A_j - K, so we only update for those v s that are divisors of the number A_j - K. This is good, because it has only at most 2\sqrt{A_j - K} such divisors
Unfortunately, the last statement is incorrect for the case A_j = K. If A_j = K, then A_j - K = 0. But every v is a divisor of zero! In this case, updating all the lists may be too slow, so we need to do something else. This is simple; we can simply maintain a separate value l, which is the maximum l < j such that A_l = K. i_v is now modified to exclude those i s such that A_i = K, and for an endpoint j, the i we are looking for is now \max(i_{A_j}, l) Thus, we maintain the fast running time!
The total running time of this step is O(N \sqrt{\max A_i}).
After getting the bad pairs (i_j,j) for every such j, we can simply remove those pairs for which i_{j-1} \ge i_j, because those pairs are not important. Now we have a list of bad pairs [(i_1,j_1), (i_2,j_2), \ldots, (i_m,j_m)] satisfying the following:
Subarrays containing bad pairs
We now want to compute the number of subarrays containing at least one bad pair.
Each such subarray contains a unique ārightmostā bad pair, so we can compute, for each 1 \le k \le m, the number of subarrays whose rightmost bad pair is (i_k,j_k). Letās denote this by f(k). Then the number of subarrays containing at least one bad pair is:
Letās first compute f(m) (the last one). For a subarray [i\ldots j] to contain (i_m,j_m), it must be true that 1 \le i \le i_m and j_m \le j \le N. There are i_m(N+1-j_m) such choices for such subarrays, therefore f(m) = i_m(N+1-j_m).
Now, letās compute f(k) for 1 \le k < m. The subarray [i\ldots j] should contain (i_k,j_k), but we need to ensure that this is the rightmost bad pair, i.e. the subarray doesnāt contain (i_{k+1},j_{k+1}). This means that 1 \le i \le i_k and j_k \le j. However, j < j_{k+1} should be true, otherwise it will contain the pair (i_{k+1},j_{k+1}). Therefore, there are i_k (j_{k+1} - j_k) choices for such subarrays, and thus we have f(k) = i_k(j_{k+1} - j_k)
This step runs in O(N) time
Hereās an implementation in C++:
#include <stdio.h>
#include <algorithm>
using namespace std;
#define ll long long
#define LIM 100111
int I[LIM];
int main() {
for (int j = 0; j < LIM; j++) I[j] = -1;
int n, k;
scanf("%d%d", &n, &k);
ll total = n*(n+1LL) >> 1;
int pi = 0, pj = 1;
#define add_pair(ci,cj) do {\
total += pi * (pj - (cj));\
pi = ci;\
pj = (cj);\
} while (0)
int l = 0;
for (int j = 1; j <= n; j++) {
int a;
scanf("%d", &a);
if (a == k) {
l = j;
} else if (a > k) {
int i = max(I[a], l);
if (pi < i) {
add_pair(i, j);
}
a -= k;
for (int v = 1; v * v <= a; v++) {
if (a % v == 0) I[v] = I[a / v] = j;
}
}
}
add_pair(n, n+1);
printf("%lld\n", total);
}
Some implementation details:
- We added a sentinel leftmost bad pair (0,1) to make the
add_pair
macro simpler. - There is also a sentinel rightmost bad pair (N,N+1) so that f(m) (the last pair) is not special
- There is no array A allocated; instead the individual values of A are inputted and processed on the fly.
Time Complexity:
O(N\sqrt{\max A_i})