PROBLEM LINK:
Author: Dmytro Berezin
Tester: Kevin Atienza
Translators: Sergey Kulik (Russian), Team VNOI (Vietnamese) and Hu Zecong (Mandarin)
Editorialist: Kevin Atienza
DIFFICULTY:
Easy
PREREQUISITES:
Dynamic programming
PROBLEM:
There are N dogs numbered from 1 to N in a line. A ball will be pass around. Initially, dog s has the ball.
A dog with the ball can pass it to another dog. With the pass strength of x, dog i can pass the ball to dog i - x or dog i + x (provided such dog/s exist).
There will be M passes of the ball, and the pass strength of the j th pass is A_j.
For each dog, how many ways are there for the ball to end up at that dog? Output your answers modulo 10^9 + 7.
QUICK EXPLANATION:
Use dynamic programming. Let f(i,j) be the number of ways for the ball to end up at dog i after the first j passes. Then we have the following recurrence:
where we define f(i,j) = 0 if i is not in the range [1,N]. Also, for base cases, we have f(s,0) = 1 and f(i,0) = 0 for i \not= s.
There are only N possible values for i and M+1 possible values for j, so we can store these values in an N\times (M+1) table and then fill the table up with this recurrence in increasing order of j. The answers that we want are f(i,M)\bmod (10^9 + 7) for 1 \le i \le N.
EXPLANATION:
Slow solution
The simplest solution to this problem that is guaranteed correct is to just enumerate all possible choices for passes. There are two choices for each pass, either left or right, and there are M passes, so there are 2^M possible choices. A simple recursion like the following will do: (C++)
#include <stdio.h>
#define mod 1000000007
int A[1111];
int count[1111];
int n, m, s;
void enumerate_choices(int curr, int i) {
if (!(0 <= curr && curr < n)) {
// out of bounds, so skip
return;
} else if (i < m) {
// do i'th pass
enumerate_choices(curr - A[i], i+1); // left
enumerate_choices(curr + A[i], i+1); // right
} else {
// all passes done.
// 'curr' is now the final destination.
// increment count at 'curr'.
count[curr]++;
// reduce modulo
if (count[curr] >= mod) count[curr] -= mod;
}
}
int main() {
int z;
scanf("%d", &z);
while (z--) {
scanf("%d%d%d", &n, &m, &s);
for (int i = 0; i < m; i++) scanf("%d", &A[i]);
for (int i = 0; i < n; i++) count[i] = 0;
enumerate_choices(s-1, 0); // zero-indexing
for (int i = 0; i < n; i++) {
if (i) printf(" ");
printf("%d", count[i]);
}
puts("");
}
}
Notice the following:
- We reduce the count modulo
mod
every time we increment it, because we want the answers modulomod
. However, to reduce modulo mod, we use the snippetif (count[curr] >= mod) count[curr] -= mod;
instead ofcount[curr] %= mod;
. This is because the former only requires subtraction and so is usually much faster. - Notice that we subtracted
1
froms
during the initial call, because we converted everything to zero-indexing. It’s usually beneficial to do that.
Now, this code tries all 2^M possibilities so it runs in O(2^M) time. This passes the first subtask, but is too slow for the second; Notice that if M = 1000 for instance, then 2^M = 2^{1000} is a very large number.
Fast solution
This means we need to find a better solution for the second subtask. To do so, we’ll first describe a different brute-force. Let’s define the function f(i,j) as the number of ways for the ball to end up at dog i using the first j passes. With this definition, we immediately notice two things:
- If j = 0, then it means there is no pass performed at all. Therefore, the ball is still at s, and we have f(s,0) = 1 and f(i,0) = 0 if i \not= s.
- If j = M, then it means all passes are performed. Therefore, the answers we are looking for are actually f(1,M), f(2,M), f(3,M), \ldots, f(N,M)!
The nice thing about f(i,j) is that we can define a particular recurrence with it. If we want the ball to end up at i, it means that before the j th pass, it must be exactly A_j units away from position i. But there are only two such positions: i - A_j and i + A_j. So actually, we have the following recurrence
where we define f(i,j) = 0 if i is not in the range [1,N] (which means the position doesn’t exist).
Now, with this definition, a straightforward recursive implementation can be done:
#include <stdio.h>
#define mod 1000000007
int A[1111];
int n, m, s;
int f(int i, int j) {
if(j == 0) {
return i == s;
} else {
int result = 0;
// in the following, weuse A[j-1] because zero-indexing
if (i + A[j-1] < n) result += f(i + A[j-1], j-1);
if (i - A[j-1] >= 0) result += f(i - A[j-1], j-1);
return result % mod;
}
}
int main() {
int z;
scanf("%d", &z);
while (z--) {
scanf("%d%d%d", &n, &m, &s);
s--; // zero-indexing
for (int i = 0; i < m; i++) scanf("%d", &A[i]);
for (int i = 0; i < n; i++) {
if (i) printf(" ");
printf("%d", f(i, m));
}
puts("");
}
}
Unfortunately, this is also slow. In fact, it’s even worse than the previous one; It runs in O(2^M\cdot N) in the worst case! So what now?
It turns out that there are two insights that will help us:
- If we have already computed f(i,j) once for a given (i,j) pair, then we don’t have to compute f(i,j) for the same (i,j) pair any more, because we already know the answer.
- The only arguments (i,j) that will ever be needed from f are 1 \le i \le N and 0 \le j \le M.
Thus, we can employ a different strategy to compute f(1,M), f(2,M), \ldots, f(N,M). Instead of creating a recursive function, we will just create a 2D array that will contain all values of f(i,j) that we need, namely for 1 \le i \le N and 0 \le j \le M, and then fill them up in increasing order of j. The following illustrates it:
#include <stdio.h>
#define mod 1000000007
int A[1111];
int n, m, s;
int f[1111][1111];
int main() {
int z;
scanf("%d", &z);
while (z--) {
scanf("%d%d%d", &n, &m, &s);
s--; // zero-indexing
for (int i = 0; i < m; i++) scanf("%d", &A[i]);
for(int j = 0; j <= m; j++) {
for (int i = 0; i < n; i++) {
if(j == 0) {
f[i][j] = i == s;
} else {
int result = 0;
// in the following, weuse A[j-1] because zero-indexing
if (i + A[j-1] < n) result += f[i + A[j-1]][j-1];
if (i - A[j-1] >= 0) result += f[i - A[j-1]][j-1];
f[i][j] = result % mod;
}
}
}
for (int i = 0; i < n; i++) {
if (i) printf(" ");
printf("%d", f[i][m]);
}
puts("");
}
}
This technique of tabulating values and then reusing previously-computed ones is called dynamic programming.
The advantage of this technique is that it’s much faster! Notice that it only takes O(1) to compute each entry of f. Since there are N\times (M+1) entries, the whole algorithm runs in O(NM) time. This passes the second subtask!
Minor improvements
There are a few minor improvements that can be done. The first is to reduce the memory requirements. Notice that the above requires O(NM) memory for the whole table. However,
- Computing the entries for a fixed j only requires values for j-1.
- We only need its entries for j = M.
Therefore, instead of storing the whole table, we only need to store the current and previous rows! This reduces the memory usage from O(NM) to O(2N) = O(N).
The following illustrates it:
#include <stdio.h>
#define mod 1000000007
int n, m, s;
int *f_curr = new int[1111];
int *f_prev = new int[1111];
int main() {
int z;
scanf("%d", &z);
while (z--) {
scanf("%d%d%d", &n, &m, &s);
s--; // zero-indexing
for (int i = 0; i < n; i++) f_curr[i] = i == s;
for(int j = 0; j < m; j++) {
// get next A
int A;
scanf("%d", &A);
// copy to f_prev
for (int i = 0; i < n; i++) f_prev[i] = f_curr[i];
for (int i = 0; i < n; i++) {
int result = 0;
if (i + A < n) result += f_prev[i + A];
if (i - A >= 0) result += f_prev[i - A];
f_curr[i] = result % mod;
}
}
for (int i = 0; i < n; i++) {
if (i) printf(" ");
printf("%d", f_curr[i]);
}
puts("");
}
}
Here are a couple of things to notice in this code:
- There is no f table any more. Instead, we have two arrays:
f_curr
andf_prev
. - We don’t have the
A
array any more. Instead, we just take each next A_j as we need it, and then throw the value away in the next iteration of the loop. - There’s a separate loop that computes the entries for j = 0. This allows us to remove an
if
-else
conditional inside the inner loop, potentially improving the running time by a constant.
There’s a further improvement that can be done, and the key insight is that for a fixed j, almost half of the entries are 0. This is because for a fixed j, the parity of the final destination is always the same as the parity of A_1 + A_2 + \ldots + A_j. It means that entries at positions with a differing parity will be 0, and there are approximately N/2 such positions. Thus, we can reduce the running time and required memory by half by only considering positions of the correct parity. This is slightly more complicated though, and there isn’t really a huge benefit, so this improvement is mainly for fun
Time Complexity:
O(NM)