STFM - Editorial

Hi,

I have re-attempted this problem in C++ using the idea of precomputing all the factorial values until the value of M (maximum value of mod)…

However I’m getting WA on the last sub-task and also on 1st subtask of task 3:

#include <algorithm>
#include <iostream>
#include <cmath>
#include <vector>
using namespace std;

//We compute all the factorials up to the modulo max value
/* the product will contain M on it, which means its value will always be 0 mod M */
#define MAXFACT 19999999

long long int precalc[MAXFACT]={0LL};

void PreComputeModdedFactorials(int m)
{
	precalc[0]=1;
	for(int i = 1; i <= MAXFACT; i++)
		precalc[i] = (i%m*precalc[i-1]%m)%m;
}

long long int prodVal(unsigned long long int x)
{
	if(x<=MAXFACT)
	return precalc[x+1];
	else
	return 0;
}

long long int negval(int m)
{
	return m-1; // -1 mod M = M-1
}

long long int divparcel(unsigned long long int x, int m)
{
	if(x%2==0)
		return (x%m*(x/2)%m*(x+1)%m)%m;
	else
		return (x%m*x%m*((x+1)/2)%m)%m;
}

long long int f(unsigned long long int x, int m)
{
	return (((divparcel(x,m)%m+negval(m)%m)%m)+prodVal(x))%m;
}
	
int main()
{
	ios_base::sync_with_stdio(false);
	cin.tie(NULL);
	int n,m;
	cin >> n >> m;
	PreComputeModdedFactorials(m);
	vector<unsigned long long> v(n);
	for(int i = 0; i < n; i++)
		cin >> v[i];
	
	unsigned long long ans = 0ULL;
	for(int i = 0; i < n; i++)
	{
		//cout << prodVal(v[i]) << endl;
		ans = (ans%m + f(v[i], m)%m)%m;
	}
	cout << ans%m << endl;

	return 0;
}

In advance, sorry for the “mod spam”, but, besides that, is there anything wrong with the code? If yes, what? =))

Best,

Bruno

Just a guess,

     if(x%2==0)
        return (x%m*(x/2)%m*(x+1)%m)%m;
    else
        return (x%m*x%m*((x+1)/2)%m)%m;

It looks like these expression gets overflowed. I am not sure about though. I can give you the testfile on which you are getting wrong answer though.

1 Like

Thanks for the prompt reply but, that doesnt seem to be the problem… I’ve changed the return type to ULL and added yet another mod just to be sure and I only get WA on those 3 flies… :confused: weird

I have the same problem… did you find the error?

Modulo and multiplication have the same precedence, so that expression is evaluated left-to-right:

  • x=a
  • a%m=b
  • b*(x/2)=c
  • c%m=d

etc. Since b can be around 10^9 and x/2 around 10^18, it’s not hard to see that overflow occurs when computing c.

1 Like

Hey @xellos0, nice, thanks for the tip, so… how can I fix that above code? :confused:

"This means, we can reduce the task to multiply 3 numbers up to 10^7 (the modulo value) by doing modulo M for each of the numbers before multiplication. "

isnt this what I’m doing?

M <= 10^7… and if x < M then, yes, you do, as stated in the editorial. Let’s store the sum of 1∗1!+2∗2!+…x∗x! for each x≤M

No. We precalculate all the values upto M beforehand. So we can answer for each integer in O(1). That’s why O(M+N)

1 Like

Is this correct?

unsigned long long int divparcel(unsigned long long int x, int m)
{
if(x%2==0)
return (((x%m*(x/2)%m)%m)(x+1)%m)%m;
else
return ((x%m
x%m)%m*((x+1)/2)%m)%m;
}

Is this correct? To handle the divison parcel?

unsigned long long int divparcel(unsigned long long int x, int m)
{
	if(x%2==0)
		return (((x%m*(x/2)%m)%m)*(x+1)%m)%m;
	else
		return ((x%m*x%m)%m*((x+1)/2)%m)%m;
}

Just use parentheses, Luke. (And use them correctly. You added them everywhere except where you needed them.)

x/2 is not up to 1e7. x/2 is up to 5e18. (x/2)%M is up to 1e7. It’s b*(x/2)=c, not b*((x/2)%M)=c (note the parentheses).

((((x%m)((x/2)%m))%m)((x+1)%m))%m is the right way. Or just store x,x/2,x+1 in separate variables first, mod these variables and then multiply them while modding again, if you want a cleaner code.

@kuruma,@xellos0 :We can also use modular multiplication for finding (a*b)%m in O(log(n))(it can be optimised further to calculate in constant time).

You can check my solution for further details :-

http://www.codechef.com/viewsolution/6119518

@knb_dtu Occam’s razor: don’t use more than necessary. Anything more than ((a%m)*(b%m))%m is unnecessary. (How much time would you waste during a contest by overcomplicating things like this?)

@xellos0, Finally got AC… Such a stupid thing!! Thanks a ton!!

1 Like

WHAT IS WRONG IN MY CODE … I M NOT GETTING IT
#include <stdio.h>
#include <stdlib.h>
#include <math.h>

#define size6 10000003
#define size5 100001

#define s(n) n=Scan_f()
#define loop(i,N) for(i=0;i<N;i++)
#define pll(n) printf("%lld\n",n)

typedef long long int ulli;

inline unsigned long long int Scan_f()
{
int c;
do
c = fgetc(stdin);
while ( (c < ‘0’ || c > ‘9’) && c != EOF );

unsigned long long int a = 0;
while ( c >= '0' && c <= '9' )
{
    a = a*10 + (c - '0');
    c = fgetc(stdin);
}
return a;

}
ulli i,N,K,memofact[size6]= {0} ,memofinal[size6]= {0};
ulli fact(ulli n)
{
if(n==1) return 1;
{
ulli x;
if(memofact[n]!=0) return memofact[n];
{
x=(n*(fact(n-1)))%K;
memofact[n]=x;
return x;
}
}
}
ulli modulo(ulli n)
{
ulli y;
if(n%2==0)
{
y=n/2;
y%=K;
y*=((n+1)K); y=K;
y*=((n)K); y=K;

}
else
{
    {
        y=(n+1)/2;
        y%=K;
        y*=((n)%K);
        y%=K;
        y*=((n)%K);
        y%=K;
    }
}
return y;

}

ulli tag=10000;

ulli calc(ulli n)
{
if(n==1) return 2%K;
{
ulli x;
if(memofinal[n]!=0) return memofinal[n];
{
while(n>tag)
{
x=fact(tag);
tag+=1000;
}
ulli y;
x=(fact(n+1)-1)%K;
y=modulo(n);

        x=(x+y)%K;

        memofinal[n]=x;
        return x;
    }
}

}

int main()
{
memofact[1]=1;
s(N);
s(K);
// memofinal[0] fOR STOREING NO.
{
ulli ans=0;
{
loop(i,N)
{
s(memofinal[0]);
if(memofinal[0]>=(K-1))
ans=(ans+modulo(memofinal[0]))%K;
else
ans=(ans+calc(memofinal[0]))%K;
}
pll(ans);
}
}
return 0;
}

please Explain this (i!=0, for i>=M)

Good proof for 1∗1!+2∗2!+…+x∗x! = (x+1)!−1:

Let Sum = 1∗1!+2∗2!+…+x∗x!

Add (1! + 2! + … + x!) to both sides.

Sum + (1! + 2! + … + x!) = (1.1! + 1!) + (2.2! + 2!) + (3.3! + 3!) + … + (x.x! + x!)

Sum + (1! + 2! + … + x!) = (2!) + (3!) + (4!) + … + (x+1)!

Sum = (x+1)! - 1.

1 Like

please help me out. what is wrong in my code

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;

ll S1(ll x, ll m)
{
    if(x%2 == 0)
        return ((x/2)%m * x%m * (x+1)%m)%m;
    else
        return (x%m * x%m * ((x+1)/2)%m)%m;
}

ll S2(ll fact[], ll max, ll x, ll m)
{
    if(x <= max+1)
        return (fact[x+1]-1+m)%m;
    else
        return (m-1);
}

int main()
{
    ll n,m;
    scanf("%lld%lld",&n,&m);
    ll p[n];
    ll max = 1;
    for(ll i=0;i<n;i++)
    {
         scanf("%lld",&p[i]);
         max = (max < p[i] && p[i] < m)?p[i]:max;
    }
    ll fact[max+2];
    fact[0] = 1;
    for(ll i=1;i<=max+1;i++)
        fact[i] = (fact[i-1]%m * i%m)%m;
    ll sum = 0;
    for(ll i=0;i<n;i++)
        sum = (sum%m + (S1(p[i],m)%m + S2(fact,max,p[i],m)%m)%m)%m;
    printf("%lld\n",sum);
    return 0;
}

My solution is failing on test case 7. All other cases are passing. What might be the issue?
I am unable to figure this out.

#include<iostream>
#include<stdio.h>
#include<math.h>
#include<string.h>
#include<map>
#include<climits>
#include<vector>
#include<algorithm>
using namespace std;

long long precompute[10000001];

// calculating i+1 ! in precompute[i]
void pre_compute(int m) {
    precompute[0] = 0;
    precompute[1] = 2;
    for (long long i = 2; i<m; i++) {
        precompute[i] = precompute[i-1]*(i+1);
        precompute[i] %= m;
    }
}

long long fun(long long x, int m) {
    long long ans = ((x%m) * ((x+1)%m))/2;
    ans %= m;
    ans *= (x%m);
    ans %= m;
    if (x > m)
        x = m-1;
    ans += (precompute[x] - 1);
    ans %= m;
    return ans;
}

int main()
{
    int n,m;
    scanf("%d %d",&n,&m);
    pre_compute(m);
    long long temp;
    long long sum = 0;
    for(int i = 0; i < n; i++) {
        scanf("%lld",&temp);
        sum += fun(temp, m);
        sum %= m;
    }
    printf("%lld\n",sum);
    return 0;
}