KOL15E - Editorial

PROBLEM LINK:

Contest
Practice

Author: Devendra Aggarwal
Tester: Kevin Atienza
Editorialist: Kevin Atienza

DIFFICULTY:

Medium-hard

PREREQUISITES:

Rooted trees, dynamic programming

PROBLEM:

You are given a tree where nodes are assigned distinct values. A weird path is a path where the values increase and decrease alternately.

You are given Q queries. For each query (r,s) you need to find the sum of all weird paths in the subtree rooted at s with respect to r being the root of the original tree. Paths must contain at least two nodes.

Output all answers modulo M.

QUICK EXPLANATION:

Group the queries (r,s) according to r.

Root the tree at some node, say node 1. For each node i, compute the following values corresponding to the subtree rooted at i:

  • The sum of all weird paths whose endpoints belong to different subtrees.
  • The sum of all weird paths whose endpoints belong to the same subtree.
  • The number of weird paths where one of the endpoints is i, and the value at i is higher than the next one.
  • The sum of all weird paths where one of the endpoints is i, and the value at i is higher than the next one.
  • The number of weird paths where one of the endpoints is i, and the value at i is lower than the next one.
  • The sum of all weird paths where one of the endpoints is i, and the value at i is lower than the next one.

These values can be computed in linear time for all nodes with a single pass of the tree. Also, we can now answer queries (r,s) where r = 1.

Now, for the other queries. Perform another pass in the tree, so that each node becomes a root at some point. Note that whenever a new node becomes a root, the six values above must be updated in constant time. Also, when a particular node r becomes a root, immediately answer all queries (r,s) corresponding to that root.

EXPLANATION:

Let’s first answer a particular query (r,s). To do this, we must first root our tree at r.

Now, there are many kinds of weird paths. We can classify them into three types, for any given node i and the subtree rooted at i:

  • weird paths where one of the endpoints is the root i. Let’s call such paths root paths.
  • weird paths whose endpoints belong to different subtrees. Let’s call such paths through paths.
  • weird paths whose endpoints belong to the same subtree. Let’s call such paths under paths.

Let’s understand these different types of paths. Take a look at the following tree:

Image illustrating different types of weird paths.

Consider the node B.

  • Some examples of root paths at B are: B-E, B-J, B-K, and B-I. Note that B-H is not a root path because the values don’t alternately increase then decrease.
  • Some examples of through paths at B are: D-E, I-E, and J-G.
  • Some examples of under paths at B are: D-I and J-K. Note that all these paths (and the previous examples) are also under paths at A.

To answer the query (r,s), it’s natural to store the sums of these paths in each node after rooting the tree at r. But to do so, we must be able to compute them for every node i recursively after computing them for all subtrees.

Root paths

First, let’s compute the sums of all root paths. Such a path starts at i, and the remaining part of the path forms a root path at one of the subtrees of i. However, it isn’t just any root path from any subtree! This root path must still be a root path even when node i is appended to it. To compute the sum correctly, we must further classify root paths at i into two types (and we’ll sum these types separately):

  • root paths where the value at i is higher than the next one. Let’s call such paths high root paths.
  • root paths where the value at i is lower than the next one. Let’s call such paths low root paths.

Now, a high root path at i is simply i plus a low root path at a child whose value is lower than that of i. Similarly, a low root path at i is i plus a high root path at a child whose value is higher than that of i.

How do we calculate the sums of all high root paths at i? As described above, we simply get all low root paths at all children with values lower than that of i, and append i at each of them. But to do so, we must also know the counts of these paths!

More specifically, let’s denote by s_H(i) and s_L(i) the sum of all high and low root paths at i, and c_H(i) and c_L(i) be the counts. Then:

s_H(i) = \sum_{\substack{\text{$j$ child of $i$} \\ \text{val} _ {j} < \text{val} _ {i}}} \text{val} _ {i}\cdot (c_L(j)+1) + s_L(j) + \text{val} _ {j}

Let’s figure out what this means.

Let’s fix a child j with value less than that of i (i.e. \text{val} _ {j} < \text{val} _ {i}). Now, (c_L(j)+1) is the number of high root paths at i whose next node is j. The “+1” appears because we need to count the length-two path: i-j. Thus, \text{val} _ {i} must be added at least (c_L(j)+1) times in the sum, hence \text{val} _ {i}\cdot (c_L(j)+1). Now, we must also add the contributions of the remaining nodes, but this is simple: s_L(j) + \text{val} _ {j}. Note that “+\text{val} _ {j}” appears because of the contribution of j from the length-two path: i-j.

Similarly, for low root paths we have:

s_L(i) = \sum_{\substack{\text{$j$ child of $i$} \\ \text{val} _ {j} > \text{val} _ {i}}} \text{val} _ {i}\cdot (c_H(j)+1) + s_H(j) + \text{val} _ {j}

Since we also need c_H(i) and c_L(i), these must also be calculated. But that’s even simpler:

c_H(i) = \sum_{\substack{\text{$j$ child of $i$} \\ \text{val} _ {j} < \text{val} _ {i}}} (c_L(j)+1)
c_L(i) = \sum_{\substack{\text{$j$ child of $i$} \\ \text{val} _ {j} > \text{val} _ {i}}} (c_H(j)+1)

Through paths

Now, let’s focus on through paths. Let s_T(i) be the sum of the through paths at i. Now, a through path is simply two root paths at i that belong to different subtrees, merged together. However, again they aren’t simply any two root paths. For the merged path to be valid, the two root paths must be of the same type: i.e. either both high root paths or low root paths.

Using this characterization, we can now find an expression for s_T(i). Let j_1 and j_2 be two different children of i. Then the number of through paths whose endpoints are in the subtrees j_1 and j_2 are:

\begin{cases} \text{val} _ {i}\cdot (c_L(j_1)+1)(c_L(j_2)+1) + (c_L(j_1)+1)(s_L(j_2)+\text{val} _ {j_2}) + (c_L(j_2)+1)(s_L(j_1)+\text{val} _ {j_1}) & \text{if $\text{val} _ {j_1},\text{val} _ {j_2} < \text{val} _ {i}$} \\ \text{val} _ {i}\cdot (c_H(j_1)+1)(c_H(j_2)+1) + (c_H(j_1)+1)(s_H(j_2)+\text{val} _ {j_2}) + (c_H(j_2)+1)(s_H(j_1)+\text{val} _ {j_1}) & \text{if $\text{val} _ {j_1},\text{val} _ {j_2} > \text{val} _ {i}$} \\ 0 & \text{otherwise} \end{cases}

This formula is similar to the ones we have for root paths. For example, let’s look at the case “\text{val} _ {j_1},\text{val} _ {j_2} < \text{val} _ {i}”. The first term is the contribution of \text{val} _ {i} in the sum, the second term is the contribution of all nodes in the subtree j_2, and the third term is the contribution of all nodes in the subtree j_1.

Thus s_T(i) is simply the sum of this value for all (unordered) pairs of children (j_1,j_2).

Under paths

Finally, let’s focus on under paths. Let s_U(i) be the sum of the under paths at i. Consider the under paths belonging to the subtree j (a child of i). These under paths come in all four different types mentioned above: high root paths, low root paths, through paths and even under paths. The sum of all these paths is s_H(j) + s_L(j) + s_T(j) + s_U(j). Thus, the sum of all under paths at i is simply:

s_U(i) = \sum_{\text{$j$ child of $i$}} s_H(j) + s_L(j) + s_T(j) + s_U(j)

Computing s_T(i) quickly

Notice that the formula for s_T(i) involves a double sum on (j_1,j_2). This can make computing s_T(i) very slow, especially if some node contains a lot of children. In the worst case, this runs in O(N^2).

In order to compute s_T(i) quickly, we can imagine adding each subtree one by one, and update s_T(i) by adding the number of new through paths after adding each subtree. Actually, we not only do this for s_T(i), but also for all other values: c_H(i),s_H(i),c_L(i),s_L(i) and s_U(i).

Initially, we assume that i has no subtrees, so we have s_T(i) = c_H(i) = s_H(i) = c_L(i) = s_L(i) = s_U(i) = 0. We add the subtrees one by one, and update these values along the way.

Suppose we have a new subtree j to add, and that these six values are correct prior to adding j. We now want to update them. We first update s_T(i). There are two possibilities: either \text{val} _ {j} < \text{val} _ {i} or \text{val} _ {j} > \text{val} _ {i}.

  • If \text{val} _ {j} < \text{val} _ {i}, we set increment s_T(i) by s_H(i)(c_L(j)+1) + c_H(i)(s_L(j)+\text{val} _ {j}).
  • If \text{val} _ {j} > \text{val} _ {i}, we set increment s_T(i) by s_L(i)(c_H(j)+1) + c_L(i)(s_H(j)+\text{val} _ {j}).

(we invite you to check that the added value is indeed the sum of the new through paths)

Finally, the values c_H(i), s_H(i), c_L(i), s_L(i) and s_U(i) can be easily updated from the formulas above we have for them.

By combining all these techniques, we can compute all the values we want with a single pass of the whole tree!

Answering queries

The above solution is only good for queries (r,s) whose root r is the root we chose in the beginning. What about the other queries?

Suppose the tree is rooted at r_1, and we want the tree to be rooted at a new root, say r_2. Suppose for simplicity that r_2 is a child of r_1. Rooting the tree at r_2 can be done by first detaching the subtree r_2 from r_1, and then attaching the tree r_1 to r_2. However, while doing so, we must keep the six values above updated. But we already know how to update the values when attaching a new subtree! What about detaching? Well, we can simply reverse the operations above. For example the following are possible implementations of attach and detach:

def attach(i,j):
    if val[j] < val[i]:
        sT[i] += sH[i] * (cL[j] + 1) + cH[i] * (sL[j] + val[j])
        sH[i] += val[i] * (cL[j] + 1) + (sL[j] + val[j])
        cH[i] += cL[j] + 1
    else:
        sT[i] += sL[i] * (cH[j] + 1) + cL[i] * (sH[j] + val[j])
        sL[i] += val[i] * (cH[j] + 1) + (sH[j] + val[j])
        cL[i] += cH[j] + 1

    sU[i] += sL[j] + sH[j] + sT[j] + sU[j]


def detach(i,j):
    sU[i] -= sL[j] + sH[j] + sT[j] + sU[j]

    if val[j] < val[i]:
        cH[i] -= cL[j] + 1
        sH[i] -= val[i] * (cL[j] + 1) + (sL[j] + val[j])
        sT[i] -= sH[i] * (cL[j] + 1) + cH[i] * (sL[j] + val[j])
    else:
        cL[i] -= cH[j] + 1
        sL[i] -= val[i] * (cH[j] + 1) + (sH[j] + val[j])
        sT[i] -= sL[i] * (cH[j] + 1) + cL[i] * (sH[j] + val[j])

(Don’t forget to reduce the results modulo M)

Using this, we now have a way to switch roots, as long as the new root is a child of the current root. But even with just this operation, we can now answer all queries with the following method:

  • Perform a special pass in the tree, where each node we visit always becomes the new root. The values must be updated every time. Note that each change of root we do is always to a child of the current root, so the method above can always be used.
  • When some node r becomes the root, answer all queries (r,s) corresponding to that root during that time.

When doing a single pass of the whole tree this way, each node becomes a root at some point, which means that all queries will be answered! This also means that our solution is offline, because it requires getting all queries before processing any of them.

Time Complexity:

O(N + Q)

AUTHOR’S AND TESTER’S SOLUTIONS:

setter
tester