Fast Multiplication - explanation

I will try to explain the working of the fast multiplication function in brief.If we want to do an operation like 10^18 * 10^18 % MOD, where MOD is also of the order of 10^18, direct multiplication will result in overflow of even unsigned long long. This was the reason for many people getting WA for MTRICK in January 14 Long.

The best way to solve this is to split the multiplication operation into different steps so that the modulo operator can be applied in intermediate steps to avoid overflow.The idea for this is similar to the concept used in fast exponentiation. We can write a * b as 2 * ( a * (b/2) ). Again we write a * (b/2) as 2 * (a * (b/4)). This is the recursion we used to calculate the value of a * b. As we are dividing b by 2 at each step , we can get the answer in logarithmic time i.e. log(b).

Now if we want to multiply a by b and b is say 13, then a * b = a * 13 can be written

a * 13 = a + 2 * ( 13/2 * a)
       = a + 2 * ( 6 * a)
       = a + 2 * ( 2 * (6/2 * a))
       = a + 2 * (2 * (3 * a))
       = a + 2 * (2 * (2 * (a + (3/2 * a))))
       = a + 2 * (2 * (2 * (a + ( 1 * a))))

This splitting can easily be done by recursion.
The main advantage of splitting is that we can take modulo in the intermediate steps which is not possible in direct multiplication. Note: b>>1 gives the quotient obtained when b is divided by 2.

Now consider the code:

 long long multiple(long long a, long long b, long long c) // a * b % c

{

  if (b == 0) {  //Base case a * 0 =0 

      return 0 

  }

  long long ret = multiple(a, b >> 1, c)  //Multiply a by (b>>1).

  ret = (ret + ret) % c  //we double the value of ret i. 2 * (a * (b>>1)). Take MOD in this step

  if (b & 1) {  //implies b is ODD

      ret = (ret + a) % c  //if b is odd then we express it as a * b = a+ a * (b>>1). We have computed a*(b>>1) in the previous step by recursion i.e the value ret. We now add the extra a to it.

  }

  return ret

}

Hope the explanation is clear. Please feel free to ask anything if you have doubts.

(P.S. There is another shorter alternative but it is pretty difficult to understand. You can refer to the link given by @iscsi below in the comments)

15 Likes

please format the code properly and fix star issues, star character is used for italics, you can have spaces around * or use \ as escape character

1 Like

formatting Done. Thanks.

check yeputons comment.

1 Like

When using b >> 1 we can also use ret << 1 for ret + ret :wink:
Also conditional mod can be applied (have someone link?, do a * b % c only iff a * b >= c), but code will be longer, not needed to show how it works…

Just to add a little to @kcahdog explanation and show why this method works. Consider an example of multiplication by hand in base 10, in this case we’ll do 327 * 104 (it’s important to have a zero in the second number to understand this better):

   327
  *104
  ----
  1308
  000
+327
------
 34008

Considering that we can skip multiplication by zero we might as well write:

    327
   *104
    ---
   1308
 +327
 ------
  34008

or

    327
   *104
    ---
   1308
 +32700
 ------
  34008

So 327 * 104 = 327 * 4 + 327 * 0 + 327 * 100 = 327 * 4 + 327 * 100

When multiplication by hand is done we got through each digit of the second number, as we move along we add more trailing zeros to the multiplication and we only perform multiplications if the current digit in the second number is bigger than zero, using recursion we could do something like this:

int multiply(int a, int b) {
   if(b == 0) {
       return 0; // of course
   } else { // same as multiplication by hand

       // add trailing zero for the next multiplication, same as multiplication by hand
       int result = multiply(a, b / 10) * 10; 

       int lastDigit = b % 10; // get the last digit in b

       if(lastDigit > 0) {
           result = result + (a * lastDigit);
       } 

       return result;
   }
}

So far we’ve only done this in base-10 and it doesn’t make much of a difference in performance since we are using multiplication, addition and division (which is a more costly operation). However base-10 multiplication is very similar to base-2 multiplication, there are 2 difference between those 2 bases that we can exploit:

  • In base-10 we have 9 non-zero digits whereas in base-2 we only have 1 non-zero digit (1 itself). Since in base-2 we only have to multiply when the last digit is one, we can change this line result = result + (a * lastDigit) to result = result + a.
  • In base-10 when we had trailing zeros we are multiplying by powers of 10, in base-2 when we had trailing zeros we are multiplying by powers of 2. Same works for removing the last digit, in base 10 when we remove the last digit we are dividing the number by 10 whereas in base 2 when we remove the last digit we are dividing the number by 2 (integer divisions). Therefore b / 10 becomes b / 2 or even better b >> 1.

It doesn’t seem like much but thanks to those 2 differences we can use bitwise operations which are faster than taking remainder and performing division and even more changes than the ones demonstrated can be applied:

    int multiply(int a, int b) {
       if(b == 0) {
           return 0; // of course, same as before
       } else { // same as multiplication by hand

           // add trailing zero for the next multiplication, same as multiplication by hand    
           int result = multiply(a, b >> 1) << 1; 

          /* The line above is same as:
             int result = multiply(a, b >> 1);
             results = result + result
           */

           // get the last digit in b, this could also be written as b % 2 but bitwise operations are faster
           int lastDigit = b & 1; 

           // or b == 1, if written in C or C++ the if statement could be changed to if(b)
           // as in @kcahdog's code
           if(b > 0) { 
               result = result + a;
           } 

           return result;
       }
    }

However my favourite version of the code above is the non recursive one:

   int multiply(int a, int b) {
       int result = 0;

       while(b > 0) {
           int lastDigit = b & 1;

           if(lastDigit > 0) {
               result = result + a
           }

           a = a << 1;
           b = b >> 1;
       }

       return result;
   }

This method is only “useful” in problems like MTRICK where overflow is expected but we have to write the result modulo something. However I used int instead of long long and didn’t apply the modulo operation for the sake of simplicity. Just wanted to show why this works for anyone having trouble with it. We don’t have to use bitwise operations but they’re somehow faster than the usual methods and look cooler… Besides, they’re really useful for bit extraction and testing, it becomes really handy…

8 Likes

the one that MTRICK requires is not fast multiplication but modular multiplication :stuck_out_tongue:

1 Like

@chandan721 to perform modular multiplication efficiently without overflow you need fast multiplication…

@junior94 how is this fast? this is only helping the number to get himself modulo by c.

@hitesh_noty In standard multiplication like a * b you will have to multiply a by a b times in case there is overflow so you can apply Modulo in each intermediate step.This takes O(b) time. The method above divides b by 2 at each step and thus results in a runtime of log(b), hence Fast Multiplication

1 Like

Why this code not working and reporting as wrong answer…??
#include
using namespace std;

int main() {
int t,x,y;
int a[200];
cin>>t;

 while(t--)
  { int temp=0;
  int index=0;
  int i=0;
  	cin>>x>>y;
  	
  	a[index]=x%10;
  	x/=10;
  	
  	while(x>0)
  { 
    a[++index]=x%10;
     x/=10;  
   }
  
  
  for(;i<=index;)
  { int mul;
  
  	mul=y*a[i];
  	if(i==index)
  	{  
  		 mul=mul+temp;
  	
  	    a[i]=mul%10;
  	    mul/=10;
  	    while(mul)
  		{  a[++i]=mul%10;
  		    
  		    mul/=10;
  			
  	   }
  	break;	
  	}
  	mul+=temp;
  	a[i]=mul%10;
  	temp=mul/10;
     i++;
  }
  	
  while(i>=0)
  {
  	
  	 cout<<a[i];
  	 i--;
  }
  	cout<<endl;
  }
  
  
return 0;

}