PROBLEM LINK:
Author: Lalit Kundu
Tester: Hiroto Sekido
Editorialist: Kevin Atienza
DIFFICULTY:
SIMPLE
PREREQUISITES:
Dynamic programming, preprocessing, binary search, cumulative sums
PROBLEM:
You are given a string S of length N consisting only of 0 s and 1 s. You are also given an integer K.
You have to answer Q queries. In the i th query, two integers L and R are given. Then you should print the number of substrings of S[L, R] which contain at most K 0 s and at most K 1 s where S[L, R] denotes the substring from L th to R th characters of the string S.
In other words, you have to count number of pairs (i, j) of integers such that L \le i \le j \le R such that no character in substring S[i, j] occurs more than K times.
QUICK EXPLANATION:
Let \text{far}[i] be the last index j such that S[i, j] has at most K 0 s and at most K 1 s. Then for a given query (L, R), the number of valid strings starting at index i is \min(R,\text{far}[i])-i+1 (for L \le i \le R). Therefore, the answer for the query (L, R) is the following:
Note that \text{far}[i] never decreases, so we can use binary search to find the maximum index k such that \text{far}[i] \le R. The sum then becomes:
Each of these is solvable in closed form, except possibly for range sums of \text{far}[i]. But for that, we can simply compute cumulative sums of \text{far}[i].
The algorithm runs in O(N + Q \log N), but can be sped up to O(N + Q) by getting rid of binary search. Weâll describe this below.
EXPLANATION:
We will begin by describing a slow solution, and then work on improving it incrementally. We will call a string valid if there are at most K 0 s and at most K 1 s.
Slow solution
First, we can simply test each substring and count all those that are valid. This can be accomplished, for example, with the following:
def answer_query(L, R):
answer = 0
for i in L...R
for j in i...R
# try the substring S[i, j]
count0 = count1 = 0
for k in i...j
if S[k] == '1'
count1 += 1
else
count0 += 1
if count0 <= K and count1 <= K
answer += 1
return answer
How fast does this go? This runs proportionally to the sum of the length of all substrings of S[L, R], and one can easily compute it to be approximately \frac{(R-L+1)^3}{6}. Since R-L+1 is at most N, in the worst case the running time is approximately \frac{N^3}{6} steps. This wonât pass any of the subtasks!
Breaking early
The above brute-force solution can be optimized to pass subtask 2, by the observation that valid strings are at most 2K in length. Therefore, we can simply check all substrings that are at most 2K in length, and there are O(NK) of them:
def answer_query(L, R):
answer = 0
for i in L...R
# process only strings up to min(2K-1,R)
for j in i...min(i+2*K-1, R)
# try the substring S[i, j]
count0 = count1 = 0
for k in i...j
if S[k] == '1'
count1 += 1
else
count0 += 1
if count0 > K or count1 > K:
break
if count0 <= K and count1 <= K
answer += 1
return answer
Notice that aside from reducing the limit of j to \min(2K+1,R), we have also added a break
statement, because when either count0
or count1
exceeds K, then there is no more point in proceeding to inspect the rest of the letters, since we already know that the string wouldnât be valid.
Thus the algorithm now runs in O(NK\min(N,K)) time per query and O(QNK\min(N,K)) time overall, and passes subtask 2!
Dynamic programming
In fact, we can extend the argument we used about the break
statement! Notice that the substring of a valid string is also valid, and a superstring of any invalid string is also invalid! Therefore, we can break out of the j loop once we encounter an invalid string! However, this alone wonât improve the code in the worst case, because it might be the case that we donât encounter an invalid string, or we encounter such a string late enough.
However, we can exploit even more properties of the substrings we are inspecting! Notice that to check whether a substring S[i, j] is valid or not, we only need to count the 0 s and 1 s in it. However, this can be computed easily once we know these counts for the substring S[i, j-1]! Specifically, as we process the substrings S[i, j] for a fixed i and increasing j, we only need to increment count0
or count1
depending on S[i]:
def answer_query(L, R):
answer = 0
for i in L...R
count0 = count1 = 0
for j in i...min(i+2*K-1, R)
# try the substring S[i, j]
if S[j] == '1'
count1 += 1
else
count0 += 1
if count0 <= K and count1 <= K
answer += 1
else
break
return answer
Now, this runs in O(N\min(N,K)) time per query, and O(QN\min(N,K)) time overall!
However, using this might still not pass subtask 1, because there are up to 10^5 queries. Thankfully, in subtask 1, N is at most 100, so there are only at most N(N-1)/2 substrings S[L, R]. Thus, we can simply precompute the answer for all those substrings, and then answer the Q queries with a simple lookup. This approach should run in O(N^3\min(N,K)+Q) and should be able to pass subtask 1!
Walking algorithm
There is still a way to optimize the above! Remember what we said above, that the substring of a valid string is also valid, and a superstring of any invalid string is also invalid. Therefore, if S[i, j] is valid, then S[i+1, j] is also valid!. Therefore, when we process the next i, we donât have to start j from i any more, because we already know many strings are valid from the previous i. Specifically, if j' is the last j such that S[i, j] is valid, then S[i+1, j'] is also valid, so we can simply start iterating j from j'+1 onwards. Whatâs more, we can compute count0
and count1
of S[i+1, j'+1] from S[i, j'+1] by simply decrementing one of them depending on S[i]! This is illustrated in the following code:
def answer_query(L, R):
answer = 0
j = L
count0 = count1 = 0
if S[L] == '1'
count1 += 1
else
count0 += 1
for i in L...R
while j <= R and count0 <= K and count1 <= K:
j += 1
if j > R
break
if S[j] == '1'
count1 += 1
else
count0 += 1
# at this point, we know S[i, j-1] is valid but S[i, j] is invalid
answer += j - i
# decrement
if S[i] == '1'
count1 -= 1
else
count0 -= 1
return answer
Now, how fast does this go? Notice that there are still nested loops. However, every time the inner loop iterates, j increases by 1. Therefore, the inner loop runs in at most R-L steps, or O(N). Therefore, the whole algorithm runs in O(N) time per query, and O(QN) time overall! This should be able to pass subtask 1, 2 and 3.
Preprocessing and binary search
The previous algorithm is too slow for subtask 4, because it still takes O(N) time per query. In fact, this is probably the fastest we can do without some sort of preprocessing, because at the very least we have to read the string S[L, R] to compute the answer, and this already takes O(N) time. Thus, we will try to speed up the algorithm by preprocessing.
First, note that the crucial part of the previous solution is finding, for each i, the first j such that S[i, j] is invalid, or j > R. However, by ignoring first the "or j > R" part, we see that the first such j for each i only depends on the string S! For a given i, letâs denote that j by \text{far}[i] (if you read the quick explanation above, note that this \text{far} is different from the \text{far} there. Specifically, this one is larger by exactly 1.). Thus, we can try to precompute \text{far} at the beginning:
far[1...N]
def precompute():
count0 = count1 = 0
j = 1
if S[1] == '1'
count1 += 1
else
count0 += 1
for i in 1...N
while j <= N and count0 <= K and count1 <= K:
j += 1
if j > N
break
if S[j] == '1'
count1 += 1
else
count0 += 1
# at this point, we know S[i, j-1] is valid but S[i, j] is invalid
far[i] = j
# decrement
if S[i] == '1'
count1 -= 1
else
count0 -= 1
and use that for our queries:
def answer_query(L, R):
answer = 0
for i in L...R
j = min(far[i], R+1) # R+1 is the first j such that j > R
answer += j - i
return answer
The precomputation works similarly to the previous code and runs in O(N), but the queries still takes O(N) time each. But now, the answer_query
code is much simpler, and in fact can be expressed by the following mathematical expression
We can now compute such a sum using some simple manipulations:
However, we still got this nasty \min term which we need to take care of. Thankfully, we can use the fact that \text{far}[i] never decreases, to know that \text{far}[i] will be \le R at the beginning, and then as i increases it will eventually exceed R, and once it does, it stays greater than R. Thus, it makes sense to find the last k such that \text{far}[k] \le R, so the above expression becomes
Now, we are almost down to O(1) computation, aside from two things: finding the k and a range sum for \text{far}[i]. But these are simple. First, k can be computed with binary search, as the index such that \text{far}[k] \le R < \text{far}[k+1], because \text{far}[i] is monotonically nondecreasing. Also, to compute range sums for \text{far}[i], one can simply use cumulative sums or prefix sums: Let \text{sumfar}[i] be the sum of the $\text{far}$s until the i th index. Then the sum \text{far}[i] + \text{far}[i+1] + \cdots + \text{far}[j] is simply \text{sumfar}[j] - \text{sumfar}[i-1]!
These are illustrated in the following code:
far[1...N]
sumfar[0...N]
def precompute():
# precompute far
count0 = count1 = 0
j = 1
if S[1] == '1'
count1 += 1
else
count0 += 1
for i in 1...N
while j <= N and count0 <= K and count1 <= K:
j += 1
if j > N
break
if S[j] == '1'
count1 += 1
else
count0 += 1
far[i] = j
# decrement
if S[i] == '1'
count1 -= 1
else
count0 -= 1
# precompute sumfar
sumfar[0] = 0
for i in 1...N
sumfar[i] = sumfar[i-1] + far[i]
def answer_query(L, R):
# binary search to find k such that far[k] <= R < far[k+1]
# we maintain the invariant far[k1] <= R < far[k2]
k1 = L-1
k2 = R+1
while k2 - k1 > 1:
km = (k1 + k2) / 2 # here, "/" floor division
if far[km] <= R
k1 = km
else
k2 = km
k = k1 # k is now equal to k1 because k2 - k1 = 1 and far[k1] <= R < far[k2]
answer = sumfar[k] - sumfar[L-1] + (R-k)*(R+1) - (R*(R+1)/2 - L*(L-1)/2)
return answer
Using this, one can see that precomputation still runs in O(N) time, and queries now run in O(\log N) time each, due to the binary search. Thus, the overall algorithm runs in O(N + Q \log N) time, which comfortably passes all the subtasks!
Be careful with overflows! Use the right data type for this.
Bonus: \text{far}[i] and \text{raf}[i]
The above solution already works, but we will introduce a final optimization here. Specifically, we will try to improve our algorithm to compute the k in each query.
Note that k is the largest index such that \text{far}[k] \le R or k < L. The key idea is that, by ignoring first the "or k < L" part, we see that k is only dependent on R! Thus it would be nice if we are able to precompute the k s for all possible R s, and in fact it is easy to do so.
Letâs define a similar array \text{raf}, where \text{raf}[R] is the smallest index i such that \text{far}[i] > R. Using this array, one can compute k simply as \max(L,\text{raf}[R])-1 (weâll leave this to the reader to see why). Now, how do we compute all the $\text{raf}s? The idea is that \text{raf} is essentially the *reverse* of \text{far}$, only that the direction is to the left rather than to the right. Thus, a similar O(N) time walking algorithm can be used to compute it. We will leave this as an exercise for the reader, however, because we will show here a different way to compute it in O(N) time, by exploiting the relationships between \text{far} and \text{raf}. See the following pseudocode for details:
far[1...N]
raf[1...N]
sumfar[0...N]
def precompute():
# precompute far
.....
# precompute sumfar
....
# precompute raf
# initialize
for i in 1...N
raf[i] = -1
# we know that far[i]-1 < far[i], so if j = far[i]-1,
# then raf[j] must be <= i, (because raf[j] is the least such i)
# we process each i in decreasing order to guarantee that we assign the least such i
for i in N...1 by -1
raf[far[i]-1] = i
# for all the raf[i]'s that we didn't encounter, set its value to raf[i+1]
# because raf[i] <= raf[i+1]
for i in N...1 by -1
if raf[i] == -1
raf[i] = raf[i+1]
def answer_query(L, R):
k = max(L, raf[R])-1
answer = sumfar[k] - sumfar[L-1] + (R-k)*(R+1) - (R*(R+1)/2 - L*(L-1)/2)
return answer
Finally, one can now see that it runs in O(1) time per query, and O(N + Q) time overall!
Time Complexity:
O(N + Q)