### PROBLEM LINK:

**Setter-** DOlaBMOon

**Tester-** Teja Vardhan Reddy

**Editorialist-** Abhishek Pandey

### DIFFICULTY:

MEDIUM-HARD

### PRE-REQUISITES:

Matrix Exponentiation, Recurrences, Finding Inverses. Some people solved it via simulation as well.

### PROBLEM:

Given there are M Selinas walking in a tree of N nodes, going either up or down until a collision with other selina/root/leaf, we need to find, for all nodes, sum of expected number of Selina which passed through that node so far.

### QUICK-EXPLANATION:

**Key to AC-** Collisions do not matter. We can find the answer for each Selina individually and add it to final answer.

Lets first see how we can find relations or recurrences. One of the ways to decompose it into a system of linear equations is, for every (internal) node, have 2 variables, one for when Selina at the node is going down to its children, and other when its going up. (Root and leaves will only have 1 variable). Now, we know that-

**"Expected" Selina coming down to this children from its parent** = \frac{1}{K} * **Expected Selina coming down at its parent.**

**"Expected" Selina going up to parent of current node= Expected Selina going up at current node.**

Use this to construct the matrix and we are done with the question.

### EXPLANATION:

The editorial will be divided roughly into 3 sections.

- One of discussing the variables, collisions and the recurrences
- Other for constructing the Matrix
- Concluding notes, alternate approaches etc.

**1.Collisions, Variables and Recurrences-**

First lets talk about collisions. Take these few examples.

- What happens if 1 Selina is going down and collides with 1 selina going up? Ans- Overall we see that still 1 Selina goes up and other goes down. Since Selina’s are indistinguishable, we dont care which goes up or which goes down.
- What if 2 Selinas are going down and collide with 1 Selina going up? Still 2 Selina’s will go down and 1 would go up.

Notice that, if we dont care about which Selina is going up or not, the final result is just as if the Selina’s pass through each other without collisions. This is the basic intuition that collisions are useless, we can calculate the answer for each selina individually. Hence, lets not try to find answer considering only 1 selina in the tree, and later on sum the

Now lets see what can be our recurrence. Suppose we know the expected number of selinas for all nodes at time t. How can we find for time t+1? Remember that we have separate states for direction. In formal words, knowing dp[i][direction][time] can we find dp[i][direction][time+1]? The answer is given in tab in case you want to work it out first. Dont worry about corner cases right now, find the general recurrence.

## Click to view

dp[i][down][time+1]=\frac{1}{K}dp[parent[i]][down][time] **where K is number of children of i's parent. To deal with this \frac{1}{K}, store K^{-1}\%MOD instead of floating value of \frac{1}{K}**

dp[i][up][time+1]=\sum dp[child[i]][up][time]

Now, from above we have found out the recurrence. The only case is that, root and leaves will have only 1 direction, either up or down.

Now, lets make it easier for us to frame a solution in Matrix exponentiation. It will be better if we define "Going up at node i" and "Going Down from node i" as separate variables.

Let me call D_i as "going down from node i" and define U_i on similar grounds. What many people did is, they mapped the variables to “numbers”, eg-Starting from 1 or 0, assign U_i=i and D_i=(i+1) to going up and down from a node, then move to next node and do the same. For leaves and root assign only a single variable, i.e. U_i=D_i=i. While not necessary, it makes understanding the next part easier. The reason was that, we can use this mapped number as index of variable in matrix. Lets get to this in next part!

**Construction of Matrix-**

Now, we have *roughly* \approx 2n variables. But wait, theres a problem! Matrix exponentiation usually gives us the N'th term of the series, but what we need is the *sum* upto N'th term! This will be solved later by adding extra N variables (X_i), 1 for each node, where these variables will store X_i=\sum U_i+\sum D_i. Basically X_i stores the sum of expected values till now for node i.

First, lets find how to construct our usual matrix to find N'th term, and then what changes we need to do in it to get the sum.

Refer to the recurrences we made till now. Define matrix A[i][j] as "coming at variable i from variable j" where variable i can be U_i or D_i and j can be downwards from parent (D_{par[i]}) or upwards from child (U_{child[i]}) etc. depending on what exactly is i.

Now, you might have to do a lot of visualization in below steps. But if we mapped variables, then that makes our life easier.

First, define what does your previous case (with which you multiple the matrix to get next term’s values) represent. Lets take it [U_1,D_1, U_2, D_2, U_3.....] for now.

Now, for case of going from parent to child-

dp[i][down][time+1]=\frac{1}{K}dp[parent[i]][down][time].

We have dp[parent[i]][down][time] in our base case, so we simply store K^{-1} at A[D_i][D_{parent[i]}] where K is number of children of parent[i]. This was all we needed to find next term of this recurrence. What about leaves? What direction will we assign to selina at leaf? Bouncing upwards or going downwards? Note that we assigned U_i=D_i=i for leaf and root, so we can use either of them as per our convenience.

For next case?

For this case, expand the summation-

dp[i][up][time+1]=\sum dp[child[i]][up][time]=1*dp[C_{i1}][up][time]+1*dp[C_{i2}][up][time]... where C_{ij} represents j'th child of node i. We have all of these in our base case, so we simply assign a 1 at all valid palces. These valid places are nothing but A[U_i][U_{C_{ij}}].

We made the matrix!

Now only 1 thing is left, incorporating sum. For this, we add N extra variables to base case X_1,X_2,...,X_n where X_i=\sum U_i+\sum D_i.

Now, if we know X_{i-1}, how will we find X_i? Simple!

X_i=X_{i-1}+U_i+D_i.

This means, we assign a 1 at places of X_{i-1}, A[i][U_i] and A[i][D_i]. My position for X_{i-1} doesnt seem very clear. Why? Because it depends on where you put X_{i-1} in your base case. The tester formed a base case of [X_1,X_2,....X_n,U_1,D_1,....,D_n], so for him index A[i][i] corresponded to the position of X_i. Note that when I say "position of X_i", I mean that position of i'th row which must be made 1 to make contribution of X_{i-1} non-zero.

You might have to make use of some paper, the part above is hard to visualize, but that is the only crux in this question, after this its all standard algorithms. Feel free to ask doubts.

**Concluding Notes-**

What is now left to do is, for each selina, find the contribution. We can say that the contributionis-

A^{time}*B where B is base case, (Only 1 selina at root at t=0, everything else is 0. Might vary depending on your definition of matrix.) time is nothing but the time for which this selina contributed, which is Q-t[i]+1

One optimization we can use in matrix exponentiation part is to, pre-calculate powers of A^{2^k} instead of calculating them again and again for each selina. Hence, tester used a 3-D matrix A[k][i][j] where k represents that this matrix is A^{2^k} of the matrix A we described above.

Once its done, we are all good to go with the question. All thats left is multiplying the matrix to get the answer. Dont forget to use long longs or take mods!

### SOLUTION

Setter

## Click to view

```
//teja349
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
//setbase - cout << setbase (16); cout << 100 << endl; Prints 64
//setfill - cout << setfill ('x') << setw (5); cout << 77 << endl; prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x) cout<<fixed<<val; // prints x digits after decimal in val
using namespace std;
#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define sz(a) a.size()
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (998244353)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
ll matr[25][2000][2000];
ll ans[2000],ans1[2000];
ll par[412],childs[412],inv[412];
ll iinf;
vector<vi> adj(412);
// dfs.
ll dfs(ll cur,ll previ){
par[cur]=previ;
ll i;
ll cnt=0;
rep(i,adj[cur].size()){
if(adj[cur][i]!=previ){
dfs(adj[cur][i],cur);
cnt++;
}
}
childs[cur]=cnt;
return 0;
}
ll t[412],dp[412];
int up[412],down[412];
int main(){
std::ios::sync_with_stdio(false);
iinf =mod;
iinf*=mod;
ll n;
cin>>n;
ll i;
ll u,v;
ll j,k;
// taking input
rep(i,n-1){
cin>>u>>v;
assert(u>=1 && u<=n);
assert(v>=1 && v<=n);
u--;
v--;
adj[u].pb(v);
adj[v].pb(u);
}
inv[0]=1;
inv[1]=1;
// finding inverse of numbers
f(i,2,n+10){
inv[i]=inv[mod%i]*(mod/i);
inv[i]%=mod;
inv[i]*=-1;
inv[i]%=mod;
inv[i]+=mod;
}
ll m;
cin>>m;
rep(i,m){
cin>>t[i];
}
ll q;
cin>>q;
ll sumi=0;
// special case.
if(n==1){
rep(i,m){
if(t[i]<=q)
sumi+=q-t[i]+1;
}
cout<<sumi<<endl;
return 0;
}
dfs(0,-1);
ll counter=n;
ll papa;
// determining which all vertices need two directions and which need only one direction(i.e root and leaves)
// constant optimisation
rep(i,n){
if(i==0 || childs[i]==0){
down[i]=counter++;
up[i]=down[i];
}
else{
down[i]=counter++;
up[i]=counter++;
}
}
// construction of matrix
rep(i,n){
if(i!=0){
papa=par[i];
matr[0][down[i]][down[papa]]=inv[childs[papa]];
if(!childs[i]){
matr[0][up[i]][down[papa]]=inv[childs[papa]];
}
else{
rep(j,adj[i].size()){
if(adj[i][j]!=par[i]){
papa=adj[i][j];
matr[0][up[i]][up[papa]]=1;
}
}
}
}
else{
rep(j,adj[i].size()){
if(adj[i][j]!=par[i]){
papa=adj[i][j];
matr[0][up[i]][up[papa]]=1;
matr[0][down[i]][up[papa]]=1;
//matr[0][2*i+counter][2*papa+counter+1]=1;
}
}
}
}
//construction for the sum part in matrix
rep(i,n){
matr[0][i][i]=1;
if(i==0 || childs[i]==0){
matr[0][i][up[i]]=1;
}
else{
matr[0][i][up[i]]=1;
matr[0][i][down[i]]=1;
}
}
// precomputing 2 powers of the matrix
ll mask=1;
while((1<<mask)<=q){
rep(i,counter){
rep(k,counter){
// optimisation
if(matr[mask-1][i][k]==0)
continue;
rep(j,counter){
matr[mask][i][j]+=matr[mask-1][i][k]*matr[mask-1][k][j];
if(matr[mask][i][j]>iinf)
matr[mask][i][j]-=iinf;
}
}
}
rep(i,counter){
rep(j,counter){
matr[mask][i][j]%=mod;
}
}
mask++;
}
ll num;
rep(i,m){
num=q-t[i]+1;
rep(j,counter){
ans[j]=0;
}
ans[up[0]]=1;
ans[down[0]]=1;
if(num<=0)
continue;
mask=0;
// finding matrix power * base case in K*K *logn ( assuming K is matrix size and n is power to which it is being raised to )
while(num){
if(num%2){
rep(k,counter){
// optimisation.
if(ans[k]==0)
continue;
rep(j,counter){
ans1[j]+=matr[mask][j][k]*ans[k];
if(ans1[j]>iinf)
ans1[j]-=iinf;
}
}
rep(j,counter){
ans[j]=ans1[j]%mod;
ans1[j]=0;
}
}
num/=2;
mask++;
}
rep(j,n){
dp[j]+=ans[j];
}
}
rep(i,n){
dp[i]%=mod;
cout<<dp[i]<<" ";
}
cout<<endl;
return 0;
}
```

Editorialist

Time Complexity=O(M\times N^2\times logQ)

Space Complexity=O(M+N^3LogN)

### CHEF VIJJU’S CORNER

**1. Go through our recurrence again. Why can we not use it to directly find the answer? Is it memory limit? Can you get a way to overcome that? (Hint: We only need last 1-2 rows to calculate next row!)**

**2. Hall of Fame for Noteworthy Solutions-**

- jtnydv25 - Solved the question in 0.5 seconds!!
- fjzzq2002 - Solved using Berlaykam Massy Algorithm
- ????? - Spot reserved for a well commented simulation solution.

**3. Setter’s Notes**

## Click to view

**First, from observation we know that:**

**Collision is of no use.**

**Then we can calculate the contribution separately for each ball. Firstly, we calculate the power of matrix of the power of 2, and consider the vector multiplicating matrix is O(n^2), so the total time complexity is O(m\times n^2\times logQ).**

**4. Tester’s notes on collision-**

## Click to view

**The main intent in question behind collision is momemtum should be conserved (which implies superposition principle(ie considering no collisions) should work). Also adding to your doubts, I think one more important doubt is what happen when two nodes moving in same direction in same node collide with some node in opposite direction. Will two nodes changes their direction or only one of them. Now since momentum is conserved , only one should change direction among two moving in same direction.**