大整数乘法

分治法的应用

【算法】

Mul(A[0…n-1], B[0…n-1], n)

//计算两个大整数A[], B[]的乘积

//输入:字符数组(或字串)表示的两个大整数

//输出:以字串形式输出的两个大整数的乘积

if (n == 1)

       return A[0] * B[0];

//高位补0,使n成为偶数(二分需要)

if (n%2 == 0)

{

       A[0…n] =   0’ + A[0…n-1] ;

       B[0…n] =   0’ + B[0…n-1] ;

       n++;

}

//进行二分

a1 = A[0, n/2];     //A的前半部分

a0 = A[n/2, n-1];   //A的后半部分

b1 = B[0, n/2];     //B的前半部分

b0 = B[n/2, n-1];   //B的后半部分

/*那么A = a1*10^n/2 + a0   B = b1*10^n/2 + b0

  利用与计算两位数相同的方法可以得到:

c = a*b = (a1*10^n/2 + a0)*(b1*10^n/2 + b0)

        = (a1*b1)10^n + (a1*b0 + a0*b1)10^n/2 + (a0*b0)

        =c2*10^n + c1*10^n/2 + c0

其中,c2 = a1 * b1, 是它们前半部分的积,c0 = a0 * b0,是它们后半部分的积

c1 = (a1*b0 + a0*b1) = (a1 + a0) * (b1 + b0) – (c2 + c0) ----利用已经算出来的积(c2, c0),减少乘法的数量(2 à 1 )

*/

c2 = Mul(a1, b1);

c0 = Mul(a0, b0);

c1 = Mul(a1+a0, b1+b0) – (c2 + c0);

return c2*10^n + c1*10^n/2 + c0;

 

【效率分析】

该算法会做多少次位乘呢?因为n位数的乘法需要对n/2位数做三次乘法运算,乘法次数M(n)的递推式将会是:

                                   n>1时,M(n) = 3M (n/2) M(1) = 1

n = 2^k时, 我们可以利用反向替换法对它求解:

                     M(2^k) = 3M ( 2^(k-1) ) = 3^ 2 M ( 2^(k-2)

                  = 3^k M(2^(k-k)) = 3^k

因为k = log2n

                            M(n) = 3log2n = n^1.585

 

C语言实现

#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <math.h>

//支持的大整数的位数
#define  N 1000

/*
	@brief reverse a string 
*/
void reverseStr(char *A, int n)
{
	int i = 0, j = n-1;
	char temp;

	while (i < j)
	{
		temp = A[i];
		A[i] = A[j];
		A[j] = temp;
		i++;
		j--;
	}
}
/*
	@brief return the maximum of three numbers 
*/
int max3(int a, int b, int c)
{
	if (a>b)
	{
		return a>c?a:c;
	}
	else 
		return b>c?b:c;
}

/*
	@brief caculate the difference of A, B, restore the result in C. 
	       the three numbers are represented in the form of string
*/
void substract2str(const char *A, const char *B, int n, char *C)
{
	int i, borrow = 0;
	
	for (i = 0; i<n; i++)
	{
		if (A[i] - borrow >= B[i])
		{
			C[i] = (A[i] - B[i] - borrow) + '0';
			borrow = 0;
		}
		else
		{
			C[i] = (A[i] +10 - B[i] - borrow) + '0';
			borrow = 1;
		}
	}
}
/*
	@brief caculate the sum of A, B, restore the result in C. 
	       the three numbers are represented in the form of string
*/
void add2str(const char *A, const char *B, int n, char *C)
{
	int i, sum, carry;

	carry = 0;

	for (i = 0; i<n; i++)
	{
		sum = (A[i] - '0') + (B[i] - '0')+carry;
		C[i] = sum%10 + '0'; //int --> char
		carry = sum/10;
	}
	if (carry)
		C[i] = '1';
}

/*
	@brief return the product of two big integers A, B, which are represented in
	 the form of string, also the product.
*/
char* Mul(char *A, char *B, size_t n)
{
	//一位数相乘,直接返回结果
	if (n == 1)
	{
		char *rst = (char *)malloc(3*sizeof(char));
		if (rst == NULL)
		{
			printf("Allocate memory failed!\n");
			exit(1);
		}
		rst[2] = 0;

		int temp = (A[0] - '0') * (B[0] - '0');
		rst[0] = temp%10 + '0';
		if (temp/10)
			rst[1] = temp/10 + '0';
		else
			rst[1] = 0;
		return rst;
	}
	
	//高位补0,使n为偶数
	if (n % 2 != 0)
	{
		*(A+n) = '0';
		*(B+n) = '0';
		n++;
	}
	char *a0 = A;     //A 的后半部分
	char *a1 = A + n/2; //A 的前半部分
	char *b0 = B;     //B 的后半部分
	char *b1 = B + n/2; //B 的前半部分
	size_t i;

	//多分配一位,因为下一层次的计算中有可能要补一位
	char *tp1 = (char *)calloc(n/2+2, sizeof(char));
	char *tp2 = (char *)calloc(n/2+2, sizeof(char));
	
	strncpy(tp1, a1, n/2);
	strncpy(tp2, b1, n/2);
	char *c2 = Mul(tp1, tp2, n/2);
	strncpy(tp1, a0, n/2);
	strncpy(tp2, b0, n/2);
	char *c0 = Mul(tp1, tp2, n/2);
	free(tp1); tp1 = NULL;
	free(tp2); tp2 = NULL;
	//两个乘积对齐
	if (strlen(c2) > strlen(c0))
	{
		for (i=strlen(c0); i<strlen(c2); i++)
		{
			c0[i] = '0';
		}
	}
	else if (strlen(c2) < strlen(c0))
	{
		for (i=strlen(c2); i<strlen(c0); i++)
		{
			c2[i] = '0';
		}
	}
	char *pa;
	//两个n/2位数的和,可能为n/2 + 1位
	size_t len = n/2+2;

	pa = (char *)calloc(len, sizeof(char));
	if (pa == NULL)
	{
		printf("Allocate memory failed!\n");
		exit(1);
	}	
	add2str(a1, a0, n/2, pa);

	char *pb = (char *)calloc(len, sizeof(char));
	if (pb == NULL)
	{
		printf("Allocate memory failed!\n");
		exit(1);
	}	
	add2str(b1, b0, n/2, pb);

	len = strlen(pa)>strlen(pb)?strlen(pa):strlen(pb);
	//对齐pa和pb
	if (len > strlen(pa))
	{
		pa[len-1] = '0';
	}
	else if (len > strlen(pb))
	{
		pb[len-1] = '0';
	}
	char *pd;
	pd = Mul(pa, pb, len);
	len = strlen(pd);
	
	free(pa); pa=NULL;
	free(pb); pb=NULL;
	//两个n位数的积的位数肯定大于两个n位数的和的位数
	char *pc = (char *)calloc((len+1), sizeof(char));
	if (pc == NULL)
	{
		printf("Allocate memory failed!\n");
		exit(1);
	}
	add2str(c2, c0, strlen(c0), pc);

	//pd  pc 对齐
	for (i=strlen(pc); i<len; i++)
		pc[i] = '0';
	//两个n位数的差最多为n位数
	char *c1 = (char *)calloc((len+1), sizeof(char));
	if (c1 == NULL)
	{
		printf("Allocate memory failed!\n");
		exit(1);
	}
	substract2str(pd, pc, len, c1);
	
	free(pd); pd = NULL;
	free(pc); pc = NULL;

	//两个n位数的乘积最多为2n位数
	char *tpc0 = (char *)calloc((2*n+1), sizeof(char));
	char *tpc1 = (char *)calloc((2*n+1), sizeof(char));
	char *tpc2 = (char *)calloc((2*n+1), sizeof(char));
	if (tpc0 == NULL || tpc1 == NULL || tpc2 == NULL)
	{
		printf("Allocate memory failed!\n");
		exit(1);
	}
	len = 2*n;
	
	//为了计算方便,都补齐到2n位
	strcpy(tpc0, c0);
	for (i=strlen(c0); i<len; i++)
		tpc0[i] = '0';
	tpc0[len] = 0;

	strncpy(tpc1+n/2, c1, len-n/2);
	for (i=0; i<n/2; i++)
		tpc1[i] = '0';
	for (i=n/2+strlen(c1); i<len; i++)
		tpc1[i] = '0';
	tpc1[len] = 0;
	//在此遇到内存错误,如果用strcpy(tpc+n, c2) -- C语言一定要注意内存的操作,确定不会越界吗?
	strncpy(tpc2+n, c2, len-n);
	for (i=0; i<n; i++)
		tpc2[i] = '0';
	for (i=n+strlen(c2); i<len; i++)
		tpc2[i] = '0';
	tpc2[len] = 0;
	
	free(c0); c0=NULL;
	free(c1); c1=NULL;
	free(c2); c2=NULL;
	char *rst = (char *)calloc((2*n+1), sizeof(char));
	char *tmp = (char *)calloc((2*n+1), sizeof(char));
	if (rst == NULL || tmp == NULL)
	{
		printf("Allocate memory failed!\n");
		exit(1);
	}
	add2str(tpc2, tpc0, len, tmp);
	add2str(tmp, tpc1, strlen(tmp), rst);
	free(tpc0); tpc0 = NULL;
	free(tpc1); tpc1 = NULL;
	free(tpc2); tpc2 = NULL;
	free(tmp);  tmp = NULL;

	return rst;
}

int main()
{
	char *A = (char *)calloc(N+1, sizeof(char));
	char *B = (char *)calloc(N+1, sizeof(char));
	size_t i;
	char *rst;
	do 
	{
		printf("input A: ");
		scanf("%s", A);
		printf("input B: ");
		scanf("%s", B);	
		/*
		  如果我们是按高位 --> 地位顺序输入, 则A的低地址存的是最高位的数。
		  比如输入95432,那么内存中会是这样的A[0] A[1] ... A[0] : 9 5 ... 2
		  而我们的算法是建立在低地址存最低位基础上的,所以需要一个reverse操作。
		*/
		reverseStr(A, strlen(A));
		reverseStr(B, strlen(B));
	
		//对齐A和B --> 通过高位补零
		if (strlen(A) > strlen(B))
		{
			for (i=strlen(B); i<strlen(A); i++)
				B[i] = '0';
		}
		else if (strlen(A) < strlen(B))
		{
			for (i=strlen(A); i<strlen(B); i++)
				A[i] = '0';
		}		
	
		rst = Mul(A, B, strlen(A));
		reverseStr(rst, strlen(rst));
		//remove the leading zero
		for (i=0; i<strlen(rst); i++)
			if (rst[i] != '0')
				break;
		if (i!=strlen(rst))
			puts(rst+i);
		else
			puts("0");
		free(rst);
	} while( !((A[0] == '0' && strlen(A) == 1) && (B[0] == '0' && strlen(B) == 1))); //输入零结束	
	free(A);
	free(B);

	return 0;
}

你可能感兴趣的:(大整数乘法)