PROBLEM LINK:
Author: Full name
Tester: Full name
Editorialist: Oleksandr Kulkov
DIFFICULTY:
HARD
PREREQUISITES:
Heavy-Light Decomposition
PROBLEM:
You are given persistent tree. Each of its branches has limit of acorns on its subtree w_v. If limit is exceeded branch fall over with the whole subtree. You have to handle queries of two types.
- Add x acorns on branch u.
- Set number of acorns to 0 on the whole subtree of u.
QUICK EXPLANATION:
Use well-written persistent Hevy-Light Decomposition.
EXPLANATION:
Let’s find out how to deal with this queries in non-persistent variant. Let’s keep in each branch v its capacity subtracted by number of acorns on its subtree h_v. Then for query of first kind we should:
- Add x on the path from u to the root.
- Find first (if any) vertex v to has negative h_v on the path from u to the root.
- Subtract size of this vertex on the path from it to the root.
- Set number of acorns on its subtree to 0.
As we see, third and fourth part of this sequence is the same as in the second type of queries. If we use ideas of heavy-light decomposition thoughtfully, we can get as the output the array in which any path from vertex to the root can be decomposed in O(\log n) subarrays and any subtree correspond to some subarray. Thus to solve this problem using such array we should only solve addition, summing and finding the minimum query in persistent segment tree over such array. Let’s see how its done.
Key idea of heavy-light decomposition is as follows:
For each vertex v consider its child u with the largest subtree. Let’s call edge (u,v) heavy and all other edges will be called light. If u is subtree of v and their edge is light then sz_v \geq 2 \times sz_u. Thus there are at most O(\log n) light edges on any path from some vertex v to the root.
Consider ordering on the tree such that it’s obtained in dfs order when for each vertex v we firstly go into its largest child. Then any path of heavy edges will correspond to the contiguous subarray in order. On the other hand since it is still dfs order, any subtree will also form contiguous subarray. This will solve the whole problem on O(n \log^2 n) time and memory.
Now let’s consider some technical details. First of all, for each vertex calculate size of its subtree and make the largest child to be the first one.
void dfs_sz(int v = 0)
{
sz[v] = 1;
for(auto &u: g[v])
{
dfs_sz(u);
sz[v] += sz[u];
if(sz[u] > sz[g[v][0]])
swap(u, g[v][0]);
}
}
Now for each vertex we calculate some essential information. in_v corresponds to the time we entered the vertex, out_v corresponds to the time we quit from the vertex. Subtree of v is the segment [in_v;out_v). nxt_v stands for the highest vertex on the heavy path to the root from v. Thus this path is [in_{nxt_v};in_v].
void dfs_hld(int v = 0)
{
in[v] = t++;
rin[in[v]] = v;
for(auto u: g[v])
{
nxt[u] = (u == g[v][0] ? nxt[v] : u);
dfs_hld(u);
}
out[v] = t;
}
Queries of finding first negative vertex on the path or addition of the number on the path can be processed by the following code:
int get_path(int v)
{
int ans = 0;
for(; v; v = p[nxt[v]])
ans = max(ans, get(in[nxt[v]], in[v] + 1)); // get the lowest vertex on the path from nxt[v] to v
return ans;
}
void add_path(int v, int x)
{
for(; v; v = p[nxt[v]])
upd(in[nxt[v]], in[v] + 1, add(x)); // add x to the path from ntx[v] to v
}
upd and get are the functions of segment tree which will be considered later. Let’s now see how to process the queries. Firstly, the query of the second type:
int cut(int v)
{
int x = get_sum(in[v]); // calculates h_v
upd(in[v], out[v], nul); // return subtree of v to the initial state
add_path(p[v], -x);
return x;
}
Now we can process all queries with the following code:
void solve()
{
clean(); // remove data from previous testcase
int m;
cin >> n >> m;
for(int i = 1; i <= n; i++)
{
cin >> p[i] >> w[i];
g[p[i]].push_back(i);
}
dfs_sz();
dfs_hld();
build(); // build initial segment tree
while(m--)
{
int state;
cin >> state;
make_root(state); // create new version of the tree based on state version
int q, u;
cin >> q >> u;
if(q == 1)
{
int x;
cin >> x;
add_path(u, x);
int v = rin[get_path(u)];
cout << v << "\n";
if(v)
cut(v);
}
else
{
cout << cut(u) << "\n";
}
}
}
Finally consider the implementation of persistent segment tree:
const int maxn = 1e5 + 7, logn = 200;
int L[maxn * logn], R[maxn * logn], lvl[maxn * logn];
int ad[maxn * logn], mn[maxn * logn];
vector<pair<int, int>> ground[maxn];
int root[maxn], rt, st;
// copy vertex v to u. u is created if its actual level is not required one and changed otherwise
int copy(int &u, int v, int lv)
{
if(lvl[u] != lv)
u = st++;
lvl[u] = lv;
L[u] = L[v];
R[u] = R[v];
ad[u] = ad[v];
mn[u] = mn[v];
return u;
}
void make_root(int state)
{
copy(root[rt], root[state], rt);
rt++;
}
auto add = [](int x)
{
return [x](int v, ...)
{
ad[v] += x;
mn[v] -= x;
};
};
auto nul = [](int v, int l, int r, int st)
{
int u = lower_bound(begin(ground[r]), end(ground[r]), make_pair(l, 0))->second;
if(r - l > 1)
{
copy(L[v], L[u], lvl[v]);
copy(R[v], R[u], lvl[v]);
}
ad[v] = -st;
mn[v] = mn[u] + st;
};
template<class T>
void upd(int a, int b, T gen, int st = 0, int v = root[rt - 1], int l = 1, int r = n + 1)
{
if(a <= l && r <= b)
{
gen(v, l, r, st); // either add(x) or nul
return;
}
if(r <= a || b <= l)
return;
st += ad[v];
int m = (l + r) / 2;
upd(a, b, gen, st, copy(L[v], L[v], lvl[v]), l, m);
upd(a, b, gen, st, copy(R[v], R[v], lvl[v]), m, r);
mn[v] = min(mn[L[v]], mn[R[v]]) - ad[v];
}
int get_sum(int p, int v = root[rt - 1], int l = 1, int r = n + 1)
{
if(r - l == 1)
return ad[v];
int m = (l + r) / 2;
if(p < m)
return ad[v] + get_sum(p, L[v], l, m);
else
return ad[v] + get_sum(p, R[v], m, r);
}
int get(int a, int b, int st = 0, int v = root[rt - 1], int l = 1, int r = n + 1)
{
if(r <= a || b <= l)
return 0;
if(r - l == 1)
return l * (mn[v] - st < 0);
st += ad[v];
int m = (l + r) / 2;
if(a <= l && r <= b)
{
if(mn[R[v]] - st < 0)
return get(a, b, st, R[v], m, r);
else
return get(a, b, st, L[v], l, m);
}
int t = get(a, b, st, R[v], m, r);
if(t)
return t;
else
return get(a, b, st, L[v], l, m);
}
void build(int v = 0, int l = 1, int r = n + 1)
{
ground[r].push_back({l, v});
if(r - l == 1)
{
mn[v] = w[rin[l]];
return;
}
int m = (l + r) / 2;
build(L[v] = st++, l, m);
build(R[v] = st++, m, r);
mn[v] = min(mn[L[v]], mn[R[v]]);
}
You can see the entire code via link to editorialist’s solution below.
AUTHOR’S AND TESTER’S SOLUTIONS:
Author’s solution will be updated soon.
Tester’s solution will be updated soon.
Editorialist’s solution will be updated soon.