CHEFSOC2 - Editorial

PROBLEM LINK:

Contest
Practice

Author: Dmytro Berezin
Tester: Kevin Atienza
Translators: Sergey Kulik (Russian), Team VNOI (Vietnamese) and Hu Zecong (Mandarin)
Editorialist: Kevin Atienza

DIFFICULTY:

Easy

PREREQUISITES:

Dynamic programming

PROBLEM:

There are N dogs numbered from 1 to N in a line. A ball will be pass around. Initially, dog s has the ball.

A dog with the ball can pass it to another dog. With the pass strength of x, dog i can pass the ball to dog i - x or dog i + x (provided such dog/s exist).

There will be M passes of the ball, and the pass strength of the j th pass is A_j.

For each dog, how many ways are there for the ball to end up at that dog? Output your answers modulo 10^9 + 7.

QUICK EXPLANATION:

Use dynamic programming. Let f(i,j) be the number of ways for the ball to end up at dog i after the first j passes. Then we have the following recurrence:

f(i,j) = f(i - A_j, j-1) + f(i + A_j, j-1)

where we define f(i,j) = 0 if i is not in the range [1,N]. Also, for base cases, we have f(s,0) = 1 and f(i,0) = 0 for i \not= s.

There are only N possible values for i and M+1 possible values for j, so we can store these values in an N\times (M+1) table and then fill the table up with this recurrence in increasing order of j. The answers that we want are f(i,M)\bmod (10^9 + 7) for 1 \le i \le N.

EXPLANATION:

Slow solution

The simplest solution to this problem that is guaranteed correct is to just enumerate all possible choices for passes. There are two choices for each pass, either left or right, and there are M passes, so there are 2^M possible choices. A simple recursion like the following will do: (C++)

#include <stdio.h>
#define mod 1000000007

int A[1111];
int count[1111];
int n, m, s;
void enumerate_choices(int curr, int i) {
    if (!(0 <= curr && curr < n)) {
        // out of bounds, so skip
        return;
    } else if (i < m) {
        // do i'th pass
        enumerate_choices(curr - A[i], i+1); // left
        enumerate_choices(curr + A[i], i+1); // right
    } else {
        // all passes done.
        // 'curr' is now the final destination.
        // increment count at 'curr'.
        count[curr]++;
        // reduce modulo
        if (count[curr] >= mod) count[curr] -= mod;
    }
}

int main() {
    int z;
    scanf("%d", &z);
    while (z--) {
        scanf("%d%d%d", &n, &m, &s);
        for (int i = 0; i < m; i++) scanf("%d", &A[i]);
        for (int i = 0; i < n; i++) count[i] = 0;
        enumerate_choices(s-1, 0); // zero-indexing
        for (int i = 0; i < n; i++) {
            if (i) printf(" ");
            printf("%d", count[i]);
        }
        puts("");
    }
}

Notice the following:

  • We reduce the count modulo mod every time we increment it, because we want the answers modulo mod. However, to reduce modulo mod, we use the snippet if (count[curr] >= mod) count[curr] -= mod; instead of count[curr] %= mod;. This is because the former only requires subtraction and so is usually much faster.
  • Notice that we subtracted 1 from s during the initial call, because we converted everything to zero-indexing. It’s usually beneficial to do that.

Now, this code tries all 2^M possibilities so it runs in O(2^M) time. This passes the first subtask, but is too slow for the second; Notice that if M = 1000 for instance, then 2^M = 2^{1000} is a very large number.

Fast solution

This means we need to find a better solution for the second subtask. To do so, we’ll first describe a different brute-force. Let’s define the function f(i,j) as the number of ways for the ball to end up at dog i using the first j passes. With this definition, we immediately notice two things:

  • If j = 0, then it means there is no pass performed at all. Therefore, the ball is still at s, and we have f(s,0) = 1 and f(i,0) = 0 if i \not= s.
  • If j = M, then it means all passes are performed. Therefore, the answers we are looking for are actually f(1,M), f(2,M), f(3,M), \ldots, f(N,M)!

The nice thing about f(i,j) is that we can define a particular recurrence with it. If we want the ball to end up at i, it means that before the j th pass, it must be exactly A_j units away from position i. But there are only two such positions: i - A_j and i + A_j. So actually, we have the following recurrence

f(i,j) = f(i - A_j, j-1) + f(i + A_j, j-1)

where we define f(i,j) = 0 if i is not in the range [1,N] (which means the position doesn’t exist).

Now, with this definition, a straightforward recursive implementation can be done:

#include <stdio.h>
#define mod 1000000007

int A[1111];
int n, m, s;
int f(int i, int j) {
    if(j == 0) {
        return i == s;
    } else {
        int result = 0;
        // in the following, weuse A[j-1] because zero-indexing
        if (i + A[j-1] < n)  result += f(i + A[j-1], j-1);
        if (i - A[j-1] >= 0) result += f(i - A[j-1], j-1);
        return result % mod;
    }
}
int main() {
    int z;
    scanf("%d", &z);
    while (z--) {
        scanf("%d%d%d", &n, &m, &s);
        s--; // zero-indexing
        for (int i = 0; i < m; i++) scanf("%d", &A[i]);
        for (int i = 0; i < n; i++) {
            if (i) printf(" ");
            printf("%d", f(i, m));
        }
        puts("");
    }
}

Unfortunately, this is also slow. In fact, it’s even worse than the previous one; It runs in O(2^M\cdot N) in the worst case! So what now?

It turns out that there are two insights that will help us:

  1. If we have already computed f(i,j) once for a given (i,j) pair, then we don’t have to compute f(i,j) for the same (i,j) pair any more, because we already know the answer.
  2. The only arguments (i,j) that will ever be needed from f are 1 \le i \le N and 0 \le j \le M.

Thus, we can employ a different strategy to compute f(1,M), f(2,M), \ldots, f(N,M). Instead of creating a recursive function, we will just create a 2D array that will contain all values of f(i,j) that we need, namely for 1 \le i \le N and 0 \le j \le M, and then fill them up in increasing order of j. The following illustrates it:

#include <stdio.h>
#define mod 1000000007

int A[1111];
int n, m, s;
int f[1111][1111];
int main() {
    int z;
    scanf("%d", &z);
    while (z--) {
        scanf("%d%d%d", &n, &m, &s);
        s--; // zero-indexing
        for (int i = 0; i < m; i++) scanf("%d", &A[i]);
        for(int j = 0; j <= m; j++) {
            for (int i = 0; i < n; i++) {
                if(j == 0) {
                    f[i][j] = i == s;
                } else {
                    int result = 0;
                    // in the following, weuse A[j-1] because zero-indexing
                    if (i + A[j-1] < n)  result += f[i + A[j-1]][j-1];
                    if (i - A[j-1] >= 0) result += f[i - A[j-1]][j-1];
                    f[i][j] = result % mod;
                }
            }
        }
        for (int i = 0; i < n; i++) {
            if (i) printf(" ");
            printf("%d", f[i][m]);
        }
        puts("");
    }
}

This technique of tabulating values and then reusing previously-computed ones is called dynamic programming.

The advantage of this technique is that it’s much faster! Notice that it only takes O(1) to compute each entry of f. Since there are N\times (M+1) entries, the whole algorithm runs in O(NM) time. This passes the second subtask!

Minor improvements

There are a few minor improvements that can be done. The first is to reduce the memory requirements. Notice that the above requires O(NM) memory for the whole table. However,

  • Computing the entries for a fixed j only requires values for j-1.
  • We only need its entries for j = M.

Therefore, instead of storing the whole table, we only need to store the current and previous rows! This reduces the memory usage from O(NM) to O(2N) = O(N).

The following illustrates it:

#include <stdio.h>
#define mod 1000000007

int n, m, s;
int *f_curr = new int[1111];
int *f_prev = new int[1111];
int main() {
    int z;
    scanf("%d", &z);
    while (z--) {
        scanf("%d%d%d", &n, &m, &s);
        s--; // zero-indexing
        for (int i = 0; i < n; i++) f_curr[i] = i == s;
        for(int j = 0; j < m; j++) {
            // get next A
            int A;
            scanf("%d", &A);
            // copy to f_prev
            for (int i = 0; i < n; i++) f_prev[i] = f_curr[i];
            for (int i = 0; i < n; i++) {
                int result = 0;
                if (i + A < n)  result += f_prev[i + A];
                if (i - A >= 0) result += f_prev[i - A];
                f_curr[i] = result % mod;
            }
        }
        for (int i = 0; i < n; i++) {
            if (i) printf(" ");
            printf("%d", f_curr[i]);
        }
        puts("");
    }
}

Here are a couple of things to notice in this code:

  • There is no f table any more. Instead, we have two arrays: f_curr and f_prev.
  • We don’t have the A array any more. Instead, we just take each next A_j as we need it, and then throw the value away in the next iteration of the loop.
  • There’s a separate loop that computes the entries for j = 0. This allows us to remove an if-else conditional inside the inner loop, potentially improving the running time by a constant.

There’s a further improvement that can be done, and the key insight is that for a fixed j, almost half of the entries are 0. This is because for a fixed j, the parity of the final destination is always the same as the parity of A_1 + A_2 + \ldots + A_j. It means that entries at positions with a differing parity will be 0, and there are approximately N/2 such positions. Thus, we can reduce the running time and required memory by half by only considering positions of the correct parity. This is slightly more complicated though, and there isn’t really a huge benefit, so this improvement is mainly for fun :slight_smile:

Time Complexity:

O(NM)

AUTHOR’S AND TESTER’S SOLUTIONS:

Setter
Tester

1 Like

#doubt

input ->

1

5 3 2

1 1 2

ways to end at position 4(1-based indexing) ->

2->3->2->4

2->1->2->4

2->4->5->4

2->4->3->4

but when i run an AC code it show 2 ways to end at position 4. Can someone explain how only 2 ?

Initially the ball is with 2.

1] Pass strength 1.

2->3 &
2->1

2] Pass strength 1 again.

2->3->2 &

2->3->4 .

2->1->2

3] Pass strength 2.

2->3->2->4

2->3->4->2

2->1->2->4 .

So by doing the above passes the ball would end up at position 4(2 times) & at position 2(one time).

In the last two, why 2->4? You can only move 1 place to the left or right

1 Like

@glow
There are only two ways to end at position 4,

2->3->2->4

2->1->2->4

For the first pass dog 2 has pass strength 1.

So,it can only move,

2->1 or,

2->3

So, it is not possible for 2nd dog to pass to 4th dog in the first pass. Therefore,(2->4) pass is wrong.

Hence the passes 2->4->5->4 and 2->4->3->4 are invalid.

1 Like

thanks got it. :slight_smile:

here’s my algorithm…
compared every move to bit string. if ith bit in the bit string is 0 then move right with power a[i] else if it is 1 move left with power a[i]. I got sub task #1 correct but TLE for the second one.
Can anyone please suggest some optimization for this code ??

[link text][1]

thanx in advanced…
[1]: https://www.codechef.com/viewsolution/10059777