PTREE - EDITORIAL

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2

Setter: Yash Chandnani
Tester: Encho Mishinev
Editorialist: Taranpreet Singh

DIFFICULTY:

Medium-Hard

PREREQUISITES:

Arithmetic-Geometric Progression, Observations and Preprocessing.

PROBLEM:

A perfect tree is the tree having all its leaves at equal distance from the root. Given a perfect tree with N vertices, the value of tree at time x defined as W(x) = max_p\left(\sum_{i = 1}^N dist(p_i)*x^{i-1} \right) where p is any permutation of all integers from 1 to N. For findnig value, we take maximum over all permutations.

We need to answer queries of form - Find the value of subtree rooted at a given node at a given time. Queries are in encoded form, thus, needs to be answered online.

QUICK EXPLANATION

  • For any query, the optimal permutation is the one where all the vertices are ordered by their distance from the root in ascending order.
  • We can have at most \sqrt N levels of depth di so that each depth di has at least one node with outdegree > 1. The number of nodes at depth i+1 is greater than or equal to the number of nodes at depth i due to the tree being perfect.
  • We preprocess to find all special vertices and for a given query, handle all vertices having depth in ranges [lo, hi] such that the number of vertices having depth x remain same for all x \in [lo, hi]. It is basically the sum of Arithmetic-Geometric Progression. It can be seen that the number of vertices at depth x+1 is more than x if and only if there is some special vertex at depth x.
  • To optimize the computation of AGP sum in O(1), we need to precompute powers for each query.

EXPLANATION

First of all, we need to find the best permutation, the permutation which maximizes the value of tree at time x. Since x \geq 1, x^i \propto i. dist(p_i) is the distance from root to p_i. To maximize the sum, it is optimal to have nodes with minimum distance at the beginning and nodes with larger distance at the end. So, optimal permutation is any permutation in which depths of all nodes are in ascending order.

Now, Let us observe the number of nodes at each depth. Let D[x] be number of nodes at depth x. D[x+1] can be smaller than D[x] if and only if there are leaf nodes at depth x. But the definition of perfect tree guarantees no leaves except at the deepest level. Also, if a node at depth x has two children and rest all nodes at depth x have one child, D[x+1] = D[x]+1. Let’s call a node special if it has outdegree greater than 1. We can see that D[x+1] > D[x] holds for for depth x if and only if there’s a special vertex at depth x.

Lemma: Number of levels having at least one special node in any perfect tree with N nodes, can’t exceed x where x is the largest integer such that x*(x+1)/2 \leq N.

Proof: Number of nodes at depth 1 is one. Assume at all depths x, all but one node have one child and one node has two children. Hence, we have one special node at each depth. So, We have one node at depth 1, two nodes at depth 2, three nodes at depth 3. So, the maximum number of levels is bounded by the fact that the total number of nodes is N, so maximum number of levels x is the largest x such that 1+2+3+\ldots +(x-1)+x \leq N. It can be easily seen that x \leq \sqrt(N).

Now, for each node u, we find all tuples (di, cnt) such that Number of nodes in subtree of node u at depth di+1 is cnt more than number of nodes in subtree of x at depth di. This way, we have for every node, at most \sqrt N tuples using which, we can determine the number of nodes at any depth in the subtree of any node. This shall come useful now.

Suppose, any node v has 1 node at distance zero (the node v itself), 2 nodes at distance in range [1, 3] and 3 nodes at distance in range [4,5]. For any given x, the answer for this query shall be

S = 0*x^0 + 1*x^1+ 1*x^2 + 2*x^3+2*x^4+ 3*x^5+3*x^6 +4*x^7 + 4*x^8+ 4*x^9 + 5*x^{10}+5*x^{11} + 5*x^{12}. Grouping terms by depth, we have

S = 1*(x^1+ x^2) + 2*(x^3+x^4)+ 3*(x^5+x^6) +4*(x^7 + x^8+ x^9) + 5*(x^{10}+x^{11} + x^{12}). Now, Taking out maximal power of x we can take the common out.

S = 1*x(1+x) + 2*x^3*(1+x)+ 3*x^5*(1+x) +4*x^7*(1+x+x^2) + 5*x^{10}*(1+x+x^2). Noticed anything? For the depths having same number of nodes, Similar expression of GP sum appears automatically.

S = (1+x)*(1*x + 2*x^3+ 3*x^5) + (1+x+x^2)*(4*x^7 + 5*x^{10}). Once again, Taking out maximal power of x we can take the common out.

S = (1+x)*x*(1 + 2*x^2+ 3*x^4) + (1+x+x^2)*x^7*(4 + 5*x^3). Now, things start to get interesting. For each range of depths (Range [1,3] and [4,5] for current example) having the same number of vertices at each depth, we have the value as product of exactly three terms, one is a geometric progression with 1 as the first term and x as a common ratio, Second term being p power of x where p is the number of vertices at depth less than vertices of current range (only one node for first range and 7 nodes for second range. The third term is a mess. :smiley:

The third term is actually an Arithmetico–geometric sequence, the term-by-term product of an AP with a GP. The common difference of AP is always 1, but the common ratio of GP term is q power of x where q is the number of nodes at each level in the current range.

For every range of depth having the same number, we can now calculate the answer in O(N*\sqrt N+Q*\sqrt N*log(mod)) by using the formula for GP and AGP summations and using power function in O(log(mod)) for finding the modular inverse.

Now, the above complexity is too much for the author, as he wants you to do much more work to achieve O((Q+N)*\sqrt N). He wants us to get rid of that log factor.

Let us consider the expression again.

If [l_i, h_i], n = h_i-l_i+1 denote ith range of depth for current node, cur denote the number of nodes at each level in current range and sum denoting the number of nodes before current range, then we can write sum as S = \sum x^{sum}*GP(x^{cur}, n)*AGP(l_i, x^{cur}, n)

AGP(a, r, n) denote the Arithmetico-Geometric Progression with a as the first term, r as the common ratio, 1 as common difference and n as the number of terms.

Now, The sum of AGP is given by \frac{a-(a+(n-1)*d)*r^n}{1-r}+\frac{d*r*(1-r^{n-1}}{(1-r)^2}. We have d = 1, r = x^cur here. So we have \frac{a-(a+(n-1))*x^{n*cur}}{1-x^{cur}}+\frac{x^{cur}*(1-x^{cur*(n-1)}}{(1-x^{cur})^2}.

Our sum becomes x^{sum}*GP(x, cur)*\left [ \frac{a-(a+(n-1))*x^{n*cur}}{1-x^{cur}}+\frac{x^{cur}*(1-x^{cur*(n-1)}}{(1-x^{cur})^2} \right ]

Expanding GP term, we have x^{sum}*\frac{1-x^{cur}}{1-x}*\left [ \frac{a-(a+(n-1))*x^{n*cur}}{1-x^{cur}}+\frac{x^{cur}*(1-x^{cur*(n-1)}}{(1-x^{cur})^2} \right ]

We can cancel (1-x^{cur}) from GP term numerator and denominator of both terms of AGP.

We are left with \frac{x^{sum}}{1-x}* \left [ a-(a+n-1)*x^{cur*n} + \frac{x^{cur}*(1-x^{cur*(n-1)})}{1-x^{cur}} \right ]

To get rid of the log factor, we need to avoid computing inverse, so we need to remove that denominator term too. We know that \frac{(1-x^{cur*(n-1)})}{1-x^{cur}} is just the GP sum with first term 1, common ratio x^{cur} and n-1 terms.

So we have S = \frac{x^{sum}}{1-x}* [a-(a+n-1)*x^{cur*n} + x^{cur}*GP(x^{cur}, n-1)].

Now, we can precompute powers of x. We can make two arrays, p1 and p2 both having lim elements. The idea is, that p1[i] array stores x^i while p2[i] stores x^{i*lim} so that for any value y, x^y = p2[\lfloor x/lim \rfloor]*p1[x \bmod lim]. So, by preprocessing in O(lim) time, we can now find any power of x up to lim*lim in O(1) time. We can choose lim = 450 since (450)^2 > 2*10^5.

Now, we can get any power of x in O(1) time. Inverse of (1-x) can be calculated in O(log(mod)), required only once for each query. The final bottleneck is the GP term.

Now, Let us calculate it in O(logN) using binary lifting type method. See, if n = 1, we have GP sum 1. If N is divisible by 2, then GP(r,n) is same as (1+r)*GP(r*r, N/2). If we have odd N, we have GP(r,n) = 1+r*GP(r,n-1). Using these relations, we can calculate GP Sum in O(log(N)) time.

But we actually started all this to remove the log factor. Turns out this expression is amortized O(1) over all ranges for each query. Following is the proof given by setter which can be found in the box below.

Click to view

Let li be the value of n in compute for ith range. GP function runs in log(l_i). Also, sum i*li = N. Under the constraint, we have max value \sum log(li) = \sqrt N. Note that the max value of this sum happens for max value of the product of li. In worst case, li=0 for i > \sqrt N as li is non zero for at most \sqrt N values. Now the max value of the product of li happens when the max value of the product of i*li happens. Hence i*li = \sqrt N for all i in worst case. Hence, worst case Complexity for all runs of cal function in a query is the sum of log( \sqrt N /i) for i < \sqrt N. This can be proved to be \leq 2*\sqrt N (just take \sqrt N /i to be smallest power of 2 > \sqrt N /i).

Also, the user mnbvmar (who submitted the fastest solution) has also explained his approach which you may refer here

Time Complexity

Time complexity is O((Q+N)*\sqrt N) per test case.

AUTHOR’S AND TESTER’S SOLUTIONS:

Setter’s solution

Click to view
	#pragma comment(linker, "/stack:200000000")
#pragma GCC optimize ("Ofast")
#pragma GCC target ("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#include <bits/stdc++.h>
using namespace std;
 
 
#define TRACE
 
#ifdef TRACE
#define trace(...) __f(#__VA_ARGS__, __VA_ARGS__)
template <typename Arg1>
void __f(const char* name, Arg1&& arg1){
	cerr << name << " : " << arg1 << std::endl;
}
template <typename Arg1, typename... Args>
void __f(const char* names, Arg1&& arg1, Args&&... args){
	const char* comma = strchr(names + 1, ',');cerr.write(names, comma - names) << " : " << arg1<<" | ";__f(comma+1, args...);
}
#else
#define trace(...)
#endif
 
#define rep(i, n)    for(int i = 0; i < (n); ++i)
#define repA(i, a, n)  for(int i = a; i <= (n); ++i)
#define repD(i, a, n)  for(int i = a; i >= (n); --i)
#define trav(a, x) for(auto& a : x)
#define all(x) x.begin(), x.end()
#define sz(x) (int)(x).size()
#define fill(a)  memset(a, 0, sizeof (a))
#define fst first
#define snd second
#define mp make_pair
#define pb push_back
typedef long double ld;
typedef long long ll;
typedef pair<int, int> pii;
typedef vector<int> vi;
 
ll ans = 0;
vector<vi> g;
const ll mod=1e9+7;
const int N = 4e5+9;
ll pw[40000],pw2[40000];
bool fg[N];
int h[N];
void pre(){
}
void dfs(int v,int p=-1){
	int cnt = 0;
	trav(i,g[v]){
		if(i!=p){
			h[i] = h[v]+1;
			cnt++,dfs(i,v);
		}
	}
	if(cnt!=1) fg[h[v]]=1;
}
vector<vector<pii>> dist;
bool chk[N];
ll mem=0;
void dfs2(int v,int p=-1){
	if(fg[h[v]]){
		dist[v].pb(mp(0,1));
	}
	chk[v]=0;
	trav(i,g[v]){
		if(i!=p){
			dfs2(i,v);
			if(!chk[v]) {
				dist[v].resize(sz(dist[v])+sz(dist[i]),mp(0,0));
				chk[v]=1;
			}
			rep(j,sz(dist[i])){
				dist[v][sz(dist[v])-1-j].fst = dist[i][sz(dist[i])-1-j].fst+1;
				dist[v][sz(dist[v])-1-j].snd += dist[i][sz(dist[i])-1-j].snd;
			}
		}
	}
}
int lim;
ll modpow(ll a, ll e) {
	if (e == 0) return 1;
	ll x = modpow(a * a % mod, e >> 1);
	return e & 1 ? x * a % mod : x;
}
ll pwf(int y){
	if(y<lim) return pw[y%lim];
	return pw[y%lim]*pw2[y/lim]%mod;
//	return modpow(pw[1],y);
//	return 0;
}
ll cal(ll r,int n){
	if(n==1) return 1;
	if(n%2==1) return (1+r*(cal(r,n-1)))%mod;
	else return (1+r)*cal(r*r%mod,n/2)%mod;
}
ll invx;
ll compute(ll& k,ll lst,ll h,int c){
	int l = h-lst;
	ll y = pwf(c);
	if(pw[1]==1){
		return k*c%mod*(h*(h+1)/2-lst*(lst+1)/2)%mod;
	}
	else if(y==1){
		return 0;
	}
	else {
		ll yl = pwf(c*l);
		return k*invx%mod*((yl*h%mod+(mod-lst)+(mod-cal(y,l))))%mod;
	}
}
void showtime(){
	trace(double(clock())/CLOCKS_PER_SEC);
}
void solve(){
	int n,q;cin>>n>>q;
	g.clear();g.resize(n+10);
	ans = 0;
	rep(i,n-1){
		int u,v;cin>>u>>v;
		u--,v--;
		g[u].pb(v);
		g[v].pb(u);
	}
	fill(fg);
	dfs(0);
	dist.clear();dist.resize(n+10);
	trace(n,q);
	dfs2(0);
	lim = sqrt(n)+10;
	showtime();
	rep(i,q){
		int v,x;cin>>v>>x;
		v^=ans,x^=ans;
		v--;
		if(x!=1) invx = modpow(x-1,mod-2);
		ans=0;
		ll tot =1,lst=-1;
		pw[0]=1,pw2[0]=1,pw[1]=x;
		repA(i,1,lim) pw[i] = pw[i-1]*x%mod;
		repA(i,1,lim) pw2[i] = pw2[i-1]*pw[lim]%mod;
		trav(j,dist[v]){
			ans+=compute(tot,lst,j.fst,j.snd);
			tot=tot*pwf((j.fst-lst)*j.snd)%mod;
			lst = j.fst;
		}
		ans=ans%mod;
		cout<<ans<<'\n';
	}
	showtime();
}
 
int main() {
	cin.sync_with_stdio(0); cin.tie(0);
	cin.exceptions(cin.failbit);
	pre();
	showtime();
	int n;cin>>n;
	rep(i,n) solve();	

	return 0;
}

Tester’s solution

Click to view
#include <iostream>
#include <stdio.h>
#include <algorithm>
#include <vector>
#include <queue>
using namespace std;
typedef long long llong;
 
const llong MOD = 1000000007LL;
const int THRESHOLD = 700;
 
int t;
int n,q;
vector<int> Graph[200111];
 
int rowPrec[200111];
 
int inVal[200111];
int inCtr = 0;
int realId[200111];
int lastInVal[200111];
 
int branchFactor[200111];
 
int MIN(int a,int b)
{
    if (a > b)
	return a;
    else
	return b;
}
 
llong FastPow(llong k,llong p)
{
    if (p == 0LL)
	return 1LL;
    else if (p == 1LL)
	return k;
 
    llong P = FastPow(k, p/2LL);
 
    P = (P * P) % MOD;
 
    if (p % 2LL == 1LL)
    {
	P = (P * k) % MOD;
    }
 
    return P;
}
 
inline llong Div(llong k)
{
    return FastPow(k, MOD - 2LL);
}
 
int bottom;
int Depth[200111];
vector<int> depthRows[200111];
int rowPos[200111];
 
int leftLeaf[200111];
int rightLeaf[200111];
 
int father[200111];
int nextBranching[200111];
bool isBranching[200111];
 
void DFS(int ver,int dad,int depth)
{
    int i;
 
    if (depth > bottom)
	bottom = depth;
 
    father[ver] = dad;
 
    inCtr++;
    inVal[ver] = inCtr;
    realId[inCtr] = ver;
 
    Depth[ver] = depth;
    rowPos[ver] = depthRows[depth].size();
    depthRows[depth].push_back(ver);
 
    if ( (ver == 1 && Graph[ver].size() > 1) || Graph[ver].size() > 2 )
	isBranching[ver] = true;
 
    if (isBranching[dad])
	nextBranching[ver] = dad;
    else
	nextBranching[ver] = nextBranching[dad];
 
    leftLeaf[ver] = -1;
    for (i=0;i<Graph[ver].size();i++)
    {
	if (Graph[ver][i] == dad)
	    continue;
 
	DFS(Graph[ver][i], ver, depth+1);
 
	branchFactor[ver]++;
 
	if (leftLeaf[ver] == -1)
	    leftLeaf[ver] = leftLeaf[ Graph[ver][i] ];
 
	rightLeaf[ver] = rightLeaf[ Graph[ver][i] ];
    }
 
    if (leftLeaf[ver] == -1)
    {
	leftLeaf[ver] = ver;
	rightLeaf[ver] = ver;
    }
 
    lastInVal[ver] = inCtr;
 
    return;
}
 
vector<int> intVers[200111];
vector<int> allInt;
 
void getInteresting(int ver,int dad)
{
    int i;
 
    if (isBranching[ver] && bottom - Depth[ver] > THRESHOLD)
    {
	allInt.push_back(ver);
    }
 
    for (i=0;i<Graph[ver].size();i++)
    {
	if (Graph[ver][i] == dad)
	    continue;
 
	getInteresting(Graph[ver][i], ver);
    }
}
 
bool SAI(int a,int b)
{
    return Depth[a] > Depth[b];
}
 
void precInteresting()
{
    int i;
 
    if (!allInt.empty())
	sort(allInt.begin(), allInt.end(), SAI);
 
    //fprintf(stderr,"Size = %d\n",allInt.size());
 
    for (i=0;i<allInt.size();i++)
    {
	int cur = allInt[i];
 
	while(cur > 0)
	{
	    intVers[cur].push_back(allInt[i]);
 
	    cur = father[cur];
	}
    }
 
    for (i=1;i<=n;i++)
    {
	if (intVers[i].empty() || intVers[i].back() != i)
	    intVers[i].push_back(i);
    }
}
 
int sq;
llong precSmall[1011];
llong precLarge[1011];
 
llong invPrecSmall[1011];
llong invPrecLarge[1011];
 
llong xmDiv;
 
void precPowers(llong x)
{
    if (x != 1)
	xmDiv = Div(x-1);
    else
	xmDiv = 1LL;
 
    precLarge[0] = 1LL;
    precSmall[0] = 1LL;
    precSmall[1] = x;
 
    invPrecLarge[0] = 1LL;
    invPrecSmall[0] = 1LL;
    invPrecSmall[1] = xmDiv;
 
    sq = 1;
    while(sq * sq <= n + 1)
    {
	sq++;
	precSmall[sq] = (precSmall[sq-1] * x) % MOD;
	invPrecSmall[sq] = (invPrecSmall[sq-1] * xmDiv) % MOD;
    }
 
    int i;
 
    for (i=1;i<=sq;i++)
    {
	precLarge[i] = (precLarge[i-1] * precSmall[sq]) % MOD;
	invPrecLarge[i] = (invPrecLarge[i-1] * invPrecSmall[sq]) % MOD;
    }
 
    return;
}
 
inline llong xp(llong p)
{
    return (precLarge[p/sq] * precSmall[p%sq]) % MOD;
}
 
inline llong invxp(llong p)
{
    return (invPrecLarge[p/sq] * invPrecSmall[p%sq]) % MOD;
}
 
llong divCache = -1LL;
 
inline llong qPowerSum(llong x, int xpow, llong R)
{
    if (x == 1)
    {
	return R + 1LL;
    }
 
    if (R < 0LL)
	return 0LL;
    else if (R == 0LL)
	return 1LL;
    else
    {
	//llong ans = FastPow(x, R+1) - 1;
 
	//fprintf(stderr,"%d\n",(R+1)*xpow);
	llong ans = xp( (R+1) * xpow ) - 1LL;
 
	if (ans < 0)
	    ans += MOD;
 
	//ans *= Div(x-1LL);
	//ans *= xmDiv;
	if (xpow == 1)
	{
	    divCache = xmDiv;
	    ans *= xmDiv;
	}
	else
	{
	    divCache = Div( xp(xpow) - 1LL + MOD );
	    ans *= divCache;
	}
 
	return ans % MOD;
    }
}
 
inline llong qPowerSum(llong x,int xpow,llong L,llong R)
{
    return (qPowerSum(x, xpow, R) - qPowerSum(x, xpow, L - 1) + MOD) % MOD;
}
 
//1+2x+3x^2...+(R+1)x^R
inline llong qArithPowerSum(llong x, int xpow,llong R)
{
    if (x == 1LL)
    {
	return (((R + 1LL) * (R + 2LL)) / 2LL) % MOD;
    }
 
    llong rx = xp( (R+1) * xpow );
    llong realX = xp(xpow);
    llong divVal = divCache;
 
    llong topVal = (rx * (R + 1)) % MOD;
    topVal *= (1LL - realX + MOD);
    topVal %= MOD;
 
    topVal = (1LL - rx + MOD) - topVal;
 
    if (topVal < 0)
	topVal += MOD;
 
    topVal *= divVal;
    topVal %= MOD;
 
    topVal *= divVal;
    topVal %= MOD;
 
    return topVal;
}
 
inline llong complexCalc(llong A,llong D1,llong D2,llong K,llong x)
{
    llong B = qPowerSum(x, 1, K, K + A - 1LL);
    llong shift = ( (D1 - 1LL) * qPowerSum(x, A, D2 - D1) ) % MOD;
    llong sum = B * (qArithPowerSum(x, A, D2 - D1) + shift);
 
    return sum % MOD;
}
 
int State[200111];
int Key = 1;
priority_queue< pair<int,int> > pq;
 
llong solveQuery(int gver,llong x)
{
    //cout<<"Solve "<<gver<<" ; "<<x<<endl;
    //fprintf(stderr,"%d %lld\n",gver,x);
 
    llong ans = 0;
    int leftL = leftLeaf[gver], rightL = rightLeaf[gver];
    int dep = Depth[leftL];
    int verCount = lastInVal[gver] - inVal[gver] + 1;
    int i;
 
    while(!pq.empty())
	pq.pop();
 
    //fprintf(stderr,"VC = %d\nWidth = %d",rowPos[rightL] - rowPos[leftL], verCount);
    //cout<<"vercount is "<<verCount<<" and depth is from "<<Depth[gver]<<" to "<<dep<<endl;
 
    //fprintf(stderr,"%d to %d\n",rowPos[leftL],rowPos[rightL]);
    int ops = 0;
 
    ///Heavy rows
    //while(dep >= Depth[gver] && rowPos[rightL] - rowPos[leftL] > THRESHOLD)
    while(bottom - dep < THRESHOLD && dep > Depth[gver])
    {
	ops++;
	//fprintf(stderr,"Heavy\n");
	//cout<<"Heavy row"<<endl;
	//cout<<dep<<endl;
 
	int dst = dep - Depth[gver];
	int firstPower = verCount - 1 - (rowPos[rightL] - rowPos[leftL]);
 
	ans += (llong)(dst) * qPowerSum(x, 1, firstPower, verCount - 1);
	ans %= MOD;
 
	dep--;
	verCount -= (rowPos[rightL] - rowPos[leftL] + 1);
	rightL = father[rightL];
	leftL = father[leftL];
    }
 
    if (dep <= Depth[gver])
	return ans;
 
    int rowCount = rowPos[rightL] - rowPos[leftL] + 1;
 
    for (i=0;i<intVers[gver].size();i++)
    {
	int lver = intVers[gver][i];
	int d = Depth[lver];
 
	if (d < Depth[gver])
	    d = Depth[gver];
 
	if (d != dep) //Move up
	{
	    llong val = complexCalc(rowCount, (d+1) - Depth[gver], dep - Depth[gver], verCount - (dep - d) * rowCount, x);
 
	    ans += val;
	    ans %= MOD;
 
	    verCount -= (dep - d) * rowCount;
	    dep = d;
	}
 
	rowCount -= (branchFactor[lver] - 1);
    }
 
    return ans;
}
 
int main()
{
    //freopen("4.in.txt","r",stdin);
    //freopen("ans.txt","w",stdout);
    //freopen("t.txt","r",stdin);
 
    int test;
    int i,j;
 
    scanf("%d",&t);
 
    for (test=1;test<=t;test++)
    {
	inCtr = 0;
 
	scanf("%d %d",&n,&q);
 
	allInt.clear();
	for (i=0;i<=n;i++)
	{
	    Graph[i].clear();
	    depthRows[i].clear();
	    intVers[i].clear();
	}
 
	for (i=1;i<n;i++)
	{
	    int a,b;
 
	    scanf("%d %d",&a,&b);
 
	    Graph[a].push_back(b);
	    Graph[b].push_back(a);
	}
 
	bottom = 0;
	DFS(1,0,1);
 
	getInteresting(1,0);
	precInteresting();
 
	llong ans = 0;
	for (i=1;i<=q;i++)
	{
	    /*if (i % 1000 == 0)
	    {
	        fprintf(stderr,"%d\n",i);
 
	        if (i == 20000)
	            return 0;
	    }*/
 
	    int a,b;
 
	    scanf("%d %d",&a,&b);
	    //a ^= ans;
	    //b ^= ans;
 
	    int v = ans ^ a;
	    int y = ans ^ b;
 
	    precPowers(y);
 
	    ans = solveQuery(v, y);
 
	    printf("%lld\n",ans);
	}
    }
 
    return 0;
}

Editorialist’s solution (20 points only due to slower implementation, but covers all ideas explained above)

Click to view
    import java.util.*;
import java.io.*;
import java.text.*;
//Solution Credits: Taranpreet Singh
public class Main{
    //SOLUTION BEGIN
    void pre() throws Exception{}
    void solve(int TC) throws Exception{
        int n = ni(), q = ni();
        int[][] ee = new int[n-1][];
        for(int i = 0; i< n-1; i++)ee[i] = new int[]{ni()-1, ni()-1};
        int[][] g = makeU(n,ee);
        int[] d = new int[n], sub = new int[n];int[][] ti = new int[n][2];
        TreeMap<Integer, Integer>[] map = new TreeMap[n];
        for(int i = 0; i< n; i++)map[i] = new TreeMap<>();
        dfs(g,map,ti,d,sub,0,-1);
        int mx=  0;
        for(int i = 0; i< n; i++)mx = Math.max(mx, d[i]);
        long ans = 0;
        for(int qq = 0; qq<q; qq++){
            long a = nl(), x = nl();
            int u = (int)(a^ans)-1;x^=ans;ans=0;
            int cur = 1;
            long cnt = 0;
            for(int di = d[u]; di<= mx; ){
                Integer xx = map[u].ceilingKey(di+1);
                int nxt = 0;
                if(xx==null)nxt = mx;
                else nxt = Math.max(di, xx-1);
                ans+= ((pow(x, cnt)*gp(x,cur))%mod*agp(di-d[u], pow(x, cur), 1, nxt-di+1))%mod;
                if(ans>=mod)ans-=mod;
                cnt+=(nxt-di+1)*cur;di = nxt+1;
                if(xx!=null)cur+=map[u].ceilingEntry(xx).getValue();
            }
            pn(ans);
        }
    }
    long gp(long r, long n){
        if(r==1)return n;
        return (((pow(r, n)+mod-1)%mod)*pow((r+mod-1)%mod, mod-2))%mod;
    }
    long mul(long a, long b){
        if(a>=mod)a%=mod;
        if(b>=mod)b%=mod;
        return (a*b)%mod;
    }
    long ap(long a, long d, long n){
        return (((2*a+d*(n-1))%mod*n)*pow(2, mod-2))%mod;
    }
    long agp(long a1,long r, long d, long n){
        if(r==1)
            return ap(a1,d,n);
        long x = (((pow(r,n-1)+mod-1)%mod)*pow(r+mod-1, mod-2))%mod;
        long ans = a1 + ((d*r)%mod*x)%mod;
        long y = ((a1+(n-1)*d)%mod*pow(r,n))%mod;
        ans = ((ans+mod-y)%mod);
        ans = ans*pow(1+mod-r, mod-2);
        return ans%mod;
    }
    long pow(long a, long p){
        long o = 1;a%=mod;
        while(p>0){
            if((p&1)==1)o = o*a%mod;
            a = a*a%mod;
            p>>=1;
        }
        return o;
    }
    int T = -1;
    void dfs(int[][] g,TreeMap<Integer, Integer>[] map, int[][] ti, int[] d,int[]sub, int u, int p){
        ti[u][0] = ++T;int ch = 0;sub[u]=1;
        for(int v:g[u]){
            if(v==p)continue;ch++;
            d[v] = d[u]+1;
            dfs(g,map,ti,d,sub,v,u);
            sub[u]+=sub[v];
        }
        ti[u][1] = T;
        if(ch>1)map[u].put(d[u]+1, ch-1);
        if(p!=-1)map[u].entrySet().forEach((e) -> {
            map[p].put(e.getKey(),map[p].getOrDefault(e.getKey(), 0)+e.getValue());
        });
    }
    int[][] makeU(int n, int[][] edge){
        int[][] g = new int[n][];int[] cnt = new int[n];
        for(int i = 0; i< edge.length; i++){cnt[edge[i][0]]++;cnt[edge[i][1]]++;}
        for(int i = 0; i< n; i++)g[i] = new int[cnt[i]];
        for(int i = 0; i< edge.length; i++){
            g[edge[i][0]][--cnt[edge[i][0]]] = edge[i][1];
            g[edge[i][1]][--cnt[edge[i][1]]] = edge[i][0];
        }
        return g;
    }
    //SOLUTION END
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    long mod = (long)1e9+7, IINF = (long)1e18;
    final int INF = (int)1e9, MX = (int)2e3+1;
    DecimalFormat df = new DecimalFormat("0.00000000000");
    double PI = 3.1415926535897932384626433832792884197169399375105820974944, eps = 1e-8;
    static boolean multipleTC = true, memory = false;
    FastReader in;PrintWriter out;
    void run() throws Exception{
        in = new FastReader();
        out = new PrintWriter(System.out);
        int T = (multipleTC)?ni():1;
        //Solution Credits: Taranpreet Singh
        pre();for(int t = 1; t<= T; t++)solve(t);
        out.flush();
        out.close();
    }
    public static void main(String[] args) throws Exception{
        if(memory)new Thread(null, new Runnable() {public void run(){try{new Main().run();}catch(Exception e){e.printStackTrace();}}}, "1", 1 << 28).start();
        else new Main().run();
    }
    long gcd(long a, long b){return (b==0)?a:gcd(b,a%b);}
    int gcd(int a, int b){return (b==0)?a:gcd(b,a%b);}
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n()throws Exception{return in.next();}
    String nln()throws Exception{return in.nextLine();}
    int ni()throws Exception{return Integer.parseInt(in.next());}
    long nl()throws Exception{return Long.parseLong(in.next());}
    double nd()throws Exception{return Double.parseDouble(in.next());}
 
    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(System.in));
        }
 
        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));
        }
 
        String next() throws Exception{
            while (st == null || !st.hasMoreElements()){
                try{
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    throw new Exception(e.toString());
                }
            }
            return st.nextToken();
        }
 
        String nextLine() throws Exception{
            String str = "";
            try{   
                str = br.readLine();
            }catch (IOException e){
                throw new Exception(e.toString());
            }  
            return str;
        }
    }
} 

Feel free to Share your approach, If it differs. Suggestions are always welcomed. :slight_smile: