A Simple Sum - Editorial

Problem link : contest
practice

Difficulty : Medium

Pre-requisites : recursion, repeated squaring

Solution :

Solution 1

Explantion of Amit Saharana
As 1 + x + x^2 + x^3 + … . + x^m = (x^m - 1) / (x - 1).

Now, we can’t directly divide first by calculating numerator modulo n. Also, inverse of x-1 modulo n can’t be calculated as it is not given that gcd(x-1,n)=1. Hence we just take a newmod=n(x-1). calculate numerator modulo newmod and then divide by (x-1) to get final result.

This is similar to the problem of dividing by 30 in FLOORI4 question of September Long Challenge, 2014.

Why this is correct: Suppose you have to calculate (a/b) mod c and you encounter similar problem because a is too large directly, and gcd(b,c)!=1.
Let (a/b) mod c = y then (a/b) = xc+y (for some x)
=> a = xcb + yb => a = x(bc) + yb => a (mod (bc)) = yb so to calculate y, you can calculate a mod bc to get yb and divide by b to get y.
In the above problem, a = x^(m+1)-1, b = x-1, c = n.

Note that (x - 1) * n won’t fit in long long. So for this, you should use either use manually written big integer library or use BigInteger library in Java. Many people also used python for precisely this reason.

Solution 2

To find 1 + x + x^2 + … + x^m.
Assume m is even, group the expression like this
1 + x^2 + … + x^m + x + x^3 + x^5 + x^(m - 1)
= 1 + x^2 + … + x^m + x (1 + x^2 + x^4 + x^(m - 2))

So you can notice that we can write a recursive function. Note that in the two expressions, first expression has x^m extra.
So we need to calculate x^m extra. For that we use the normal repeated squaring idea.

For calculating (x * y) % mod, if we use log n steps (repeated squaring idea, see tester’s code), then overall time complexity of above recursive way could be
found by following equation
T(n) = T(n / 2) + log^2n = log^3n which won’t pass.

So calculate (x * y) % mod faster by a trick which you can find on the internet (basically storing the data in long double and avoiding overflow). I will
add the link soon.

Solution 3

This time instead of using normal grouping, we do different grouping.
Group two consecutive terms together.
Write 1 + x + x^2 + … x^m (Assume m is even).
= (1 + x) + (x^2 + x^3) + (x^4 + x^5) + … + (x^(m - 1) + x^m).
as (1 + x) ( 1 + x^2 + … )

Now your recurrence relation for time complexity will be T(n) = T(n / 2) + O(log n).
So this will take log^2(n) steps only. See tester’s code for more details.

Tester’s solution: link

5 Likes

I coded the question in python 2.7 and 3.1. Same solution. By one is AC, other is WA. Any idea why?

Solution link of AC in python 2.7 http://www.codechef.com/viewsolution/4942649

Solution link of WA in python 3.1 http://www.codechef.com/viewsolution/4942676

#ans in python3

import math
t=int(input())
sum=0
ans=[]
while(t):
	sum=0
	l=input()
	x=int(l[0])
	m=int(l[2])
	n=int(l[4])
	for i in range(0,m+1):
		sum=sum+x**i
	sum=sum%n
	ans.append(sum)
	t=t-1
for i in ans:
	print(i)

clicking the practice link redirecting to the first problem(IITKP105)…pls fix it…!!!

In solution 1 it is proposed that newmod will be n*(x-1) but as 1 <= n , x <= 10^16 so not n*(x-1) overflow the long long range in c++ ?

Yes, it will. So for this, we should use either big integer library or use Java or python.

Added following lines in editorial:
Note that (x - 1) * n won’t fit in long long. So for this, you should use either use manually written big integer library or use BigInteger library in Java. Many people also used python for precisely this reason.

my point is… what is wrong with my solution in python 3 when the same code gets AC in python 2.7

I used the following
f(m)=1+x+x^2+…+x^m
=1+x+x^2+…x^(m-1)/2+x^(m+1)/2*(1+x+…+x^(m-1)/2)
=f((m-1)/2)*(1+x^m)
But I am getting wrong answer Please help me.
http://www.codechef.com/viewsolution/4937762

Here is a comment of yeputons on QPOLYSUM codechef editorial to get a O(1) solution for (x * y) % m where x and y are long numbers.
http://discuss.codechef.com/questions/4505/qpolysum-editorial?page=1#4511

I submitted this code, it still gives me wrong answer

t=int(input())
while(t):
    t=t-1
    x,m,n=map(int,input().split())
    if x==1:
        print((m+1)%n)
    else:
        n=n*(x-1)
        ans=pow(x,m+1,n)-1
        ans=(ans+n)%n
        print(ans/(x-1))

python 3.4