WRDSUM - Editorial

PROBLEM LINK:

Practice
Contest

Author: Alexey Zayakin
Tester: Istvan Nagy
Editorialist: Xellos

DIFFICULTY:

Hard

PREREQUISITES:

Multiplication of big integers, polynomial interpolation, modular inverses

PROBLEM:

The function F(x) returns the smallest integer root of x. You’re given a big integer N and are supposed to find the sum F(2)+F(3)+\dots+F(N).

SHORT EXPLANATION:

Compute all roots of N using fast exponentiation, bruteforce multiplication and binary search. The answer can be found using a DP over roots of N, where you often need to compute sums of the form \sum_{x=2}^a k^r. Use polynomial interpolation to get those sums.

EXPLANATION:

The first thing to notice is that F(x) can be described much more simply: it’s just the g-th root of x for the maximum possible g that gives a root that’s also an integer.

The other thing is that N is YUGE. That means we can’t find the roots using standard methods like factorisation or taking the result in doubles and correcting it if necessary. Now’s not the time for that, though - roots come later. Let’s just estimate how many of them need to be computed: if 2^g \le N\approx 10^n, then g \le n\frac{\log{10}}{\log{2}} \approx 6700.

Most integers aren’t powers of other integers. For such numbers, F(x)=x; if we used this as an “estimate” of the answer, we’d get N(N+1)/2-1. From that, we need to subtract x-F(x) for all relevant x. We subtract them for square x, i.e. for x=y^2 (and all y from the range [2,\mathrm{floor}(\sqrt{N})]), which means that we actually subtract \sum y^2 and add the answer for \mathrm{floor}(\sqrt{N}); that answer can be computed recursively. We do the same for x-s that are powers of 3 (perfect cubes?). Powers of 4… oh wait, a power of 4 is also a power of 2, so we can skip 4.

Here’s something that breaks the (very short) pattern. If we tried subtracting x-F(x) for all x \le N that are r-th powers and all r, then we can subtract it multiple times for the same x. The most obvious fix that the situation with r=4 suggests is to ignore all r which aren’t products of distinct primes, but that’s not enough.

Let’s take x-s that are sixth (but not any higher) powers, so g=6. We subtracted x-F(x) for them with r=2 and with r=3, but we should subtract just once, not twice. If we added for r=6 instead of subtracting, the contribution of numbers for which g=6 to the answer would be correct.

If we repeat the same idea for larger r and choose what to add/subtract so that the numbers with each g would finally contribute to the answer just the sum of their F() values, we arrive at the inclusion-exclusion formula: if r is a product of an odd number of distinct primes, subtract x-F(x) for all x=y^r and y \le \mathrm{floor}(\sqrt[r]{N}); if the number of primes is even, add x-F(x) instead. And if r is not a product of distinct primes, ignore it.

There are two questions arising here apart from the computation of roots: how costly is the recursion and how to compute \left(\sum y^r\right) \% mod for an arbitrary range of y and rather large r?

In the first question, we can once again use the fact that \sqrt[ab]{N}=\sqrt[a]{\sqrt[b]{N}} and so \mathrm{floor}(\sqrt[a]{\mathrm{floor}(\sqrt[b]{N})})=\mathrm{floor}(\sqrt[ab]{N}); that means we only need to compute the answers for all possible floor(root)s of N, even if we may need some of them multiple times. We can store them in an array and compute them just once when necessary, making this into a recursive DP that takes the current root of N as a parameter:

array answer[1..6700] of ?

array roots[2..6700] of N
# compute roots

compute_answer(R): # answer for roots[R]
	K = roots[R]

	if answer[R] != ?: return

	answer[R] = K*(K+1)/2-1 # modulo omitted
	for x (x <= 6700, product of distinct primes):
		dif = (sum of j**x for 2 <= j <= K)
		compute_answer(x*R)
		dif -= answer(x*R)
		if (x = product of odd number of primes): answer[R] -= dif
		else: answer[R] += dif

compute_answer(1)

Obvious optimisations such as "only try xR \le 6700" apply.

The time complexity of this is given by two things: the complexity of computing the sums of j^x and the number of x-s to try (the number of edges in a “recursion tree”). The number of products of distinct primes \le 6700 is about 4000, which gives an upper bound of 6700\cdot4000 \approx 27\cdot10^6, but if we only take x \le 6700/R for all R, it drops quickly to around 6700+4000+\frac{4000}{2}+\frac{4000}{3}+\dots \approx 6700+4000\cdot\ln{4000} \approx 40000 (actually, it’s asymptotically around O(n\log{n})). That’s much better.

The next step is figuring out how to compute \sum_{j=2}^a j^r. Good news: this sum is the same for a and a \% mod, since j^r \% mod = (j\% mod)^r and \left(\sum_{j=0}^{mod-1} j^r \right) \% mod = 0; the second identity follows from the fact that for a prime modulo, all values of j^r in the sum will be distinct, so we can rewrite it as \sum_{j=0}^{mod-1} j = mod\frac{(mod-1)}{2}; since mod is odd, this is divisible by it. This also means that we only need to compute the floor(root)s of N modulo mod.

Anyway, we know that for r=1, the sum is a(a+1)/2-1 and we can also find a formula for r=2: \frac{a(a+1)(2a+1)}{6}-1. For bigger r, we can use WolframAlpha; in general, the sum for r turns out to be a polynomial S_r(a) of degree r+1. And for polynomials with sufficiently small degrees, we can use Lagrangian interpolation: if the first r+1 values of S_r(N) are S_r(1)=y_1, \dots, S_r(r+1)=y_{r+1}, then we get the formula

S_r(a) = \sum_{j=1}^{r+1}\frac{\prod_{k \neq j} (a-k)}{\prod_{k \neq j} (j-k)}y_j = \prod_k (a-k)\sum_{j=1}^{r+1}\frac{(-1)^{r+1-j}}{(j-1)!(r+1-j)!(a-j)}y_j

where k runs from 1 to r+1 except j when necessary). The second expression is only valid for a > r+1, but that’s not a problem - we’ve already computed the values for a \le r+1, so we can use them directly. Otherwise, we can precompute factorials and use modular inverses to compute S_r(a) from the second expression in O(r\log{mod}) time.

Now, let’s finally get to computing floor(root)s. We’ll do that recursively (starting from N as the “1-st root”), multiplying by prime numbers in non-decreasing order - if p is the greatest prime divisor of r with r=pq, then \mathrm{floor}(\sqrt[r]{N}) is computed as the p-th root from \mathrm{floor}(\sqrt[q]{N}). This way, most roots will be computed from numbers with fewer digits, which speeds up the computations.

Multiplying two d-digit numbers together is possible in O(d^2); it’s also possible faster using fast Fourier / Number Theoretic Transform or Karatsuba’s algorithm, but that’s not necessary here. Using fast exponentiation, which computes x^e by multiplying x^{e/2}\cdot x^{e/2} for even e or x^{e-1}\cdot x for odd e, x^e can be found in time O((ed)^2\log{e}) (since x^{e/2} has O(ed) digits).

We can now compute roots using binary search. In big integers, a good implementation is by adding 1-s in binary representation to the current value of the computed root x of y; if x^e > y, we keep that 1 in x and otherwise, we replace it by 0 again. If y has d digits, we only need to try the last O(d/e) digits in the binary representation of x, so the time complexity is O((d/e)d^2\log{e})=O(d^3\log{e}/e); the sum over all prime e \le n is O(d^3\log^2{n}) and the sum over all relevant d (which are n,\frac{n}{2},\frac{n}{3},\dots) is O(n^3\log^2{n}), since \sum \frac{1}{k^3} is convergent.

For d=2016, that’s way too much. We need to reduce the value of d somehow. The key is that O(d^2) multiplication works as long as all numbers in intermediate computation don’t exceed the range of integers, which lets us use any base b such that b^2d \le 2^{64} (approximately). For convenience in binary search, let’s look for a large base that’s a power of 2. We also know that if a number in base b_1 has d_1 digits, then it has approximately d_2=\frac{\log{b_1}}{\log{b_2}}d_1 digits in base b_2; with b_1=10 and d_1=2016, we can estimate that b=2^{28} gives up to d=250 digits and satisfies b^2d \le 2^{64}. Therefore, we can try b=2^{28} or a slighly smaller power of 2.

If we convert N into base b and take N\approx b^n, then all square roots can be computed in time O(n^3\log^2{n}); we only need their remainders modulo mod, with which we can do the recursion with computing the sums S() in time O(n\log{n}\cdot n\log{mod})=O(n^2\log{n}\log{mod}). Precomputation of small values of those sums runs in O(n^2\log^2{n}) as well, so the total time complexity is O(n^2\log{n}(n\log{n}+\log{mod})). With n\approx 250 and since the O(n^3\log^2{n}) factor is only caused by simple multiplication/addition, this passes quite comfortably (actually, it’s hard to estimate the actual running time in any way other than by submitting).

A final note: probably everyone who solved this noticed that the modulo is a prime that’s commonly used in Number Theoretic Transform. That, however, is not necessary for multiplying 250-digit numbers (and in fact, we need smaller b if we want to use it), so it’s just there for misdirecting people. Gr8 b8 m8.

Sample Solutions

Author
Tester

1 Like