全整数无浮点运算的 快速傅里叶变换FFT 加速 大整数乘法,整系数多项式乘法

全整数无浮点运算的 快速傅里叶变换FFT 加速 大整数乘法,整系数多项式乘法


我的模板,第一次实现,代码不够精简优化



  1  #include  < iostream >
  2  #include  < cstdio >
  3  #include  < cstring >
  4 
  5  using   namespace  std;
  6 
  7  template <   int  L,  class  T  =   int class  LT  =   long   long   >
  8  class   FFT
  9  {
 10  public  : 
 11          FFT() {
 12                  n  =   - 1 ;
 13          }
 14           void  fft( T e[],  int   & m,  int  minL ) {
 15                   in ( e, m, minL );
 16                  m  =  n;
 17                  fft();
 18                   out ( e );
 19          }
 20           void  ifft( T e[],  int   & m,  int  minL ) {
 21                   in ( e, m, minL );
 22                  m  =  n;
 23                  ifft();
 24                   out ( e );
 25          }
 26          T getP() {
 27                   return  p;
 28          }
 29  private  : 
 30           int  isPrime( T x ) {
 31                  T i;
 32                   if  ( x  <   2  ) {
 33                           return   0 ;
 34                  }
 35                   /*  overflow !!  */
 36                   for  ( i  =   2 ; (LT)i * <=  x;  ++ i ) {
 37                           if  ( x  %  i  ==   0  ) {
 38                                   return   0 ;
 39                          }
 40                  }
 41                   return   1 ;
 42          }
 43          T powMod( T a, T b, T c ) {
 44                  T ans  =   1 ;
 45                  a  %=  c;
 46                   while  ( b  >   0  ) {
 47                           if  ( b  &   1  ) {
 48                                  ans  =  ( (LT)ans  *  a )  %  c;
 49                          }
 50                          a  =  ( (LT)a  *  a )  %  c;
 51                          b  >>=   1 ;
 52                  }
 53                   return  ans;
 54          }
 55           /*  p is a prime number  */
 56           int  isG( T g, T p ) {
 57                  T p0  =  p  -   1 , i;
 58                   for  ( i  =   1 ; (LT)i * <=  p0;  ++ i ) {
 59                           if  ( p0  %  i  ==   0  ) {
 60                                   if  ( (powMod(g,i,p) == 1 &&  (i < p0) ) {
 61                                           return   0 ;
 62                                  }
 63                                   if  ( (powMod(g,p0 / i,p) == 1 &&  (p0 / i < p0) ) {
 64                                           return   0 ;
 65                                  }
 66                          }
 67                  }
 68                   return   1 ;
 69          }
 70           int  rev_bit(  int  i ) {
 71                   int  j  =   0 , k;
 72                   for  ( k  =   0 ; k  <  bit;  ++ k ) {
 73                          j  =  ( (j << 1 ) | (i & 1 ) );
 74                          i  >>=   1 ;
 75                  }
 76                   return  j;
 77          }
 78           void  reverse() {
 79                   int  i, j;
 80                  T t;
 81                   for  ( i  =   0 ; i  <  n;  ++ i ) {
 82                          j  =  rev_bit( i );
 83                           if  ( i  <  j ) {
 84                                  t  =  a[ i ];
 85                                  a[ i ]  =  a[ j ];
 86                                  a[ j ]  =  t;
 87                          }
 88                  }
 89          }
 90           void   in ( T e[],  int  m,  int  minL ) {
 91                   int  i, need  =   0 ;
 92                  bit  =   0 ;
 93                   while  ( ( 1 << ( ++ bit))  <  minL )
 94                          ;
 95                   if  ( n  !=  ( 1 << bit) ) {
 96                          need  =   1 ;
 97                          n  =  ( 1 << bit);
 98                  }
 99                   for  ( i  =   0 ; i  <  m;  ++ i ) {
100                          a[ i ]  =  e[ i ];
101                  }
102                   for  ( i  =  m; i  <  n;  ++ i ) {
103                          a[ i ]  =   0 ;
104                  }
105                   if  ( need ) {
106                          init(  21 10000000  );
107                  }
108          }
109           //  lim2 >= bit
110           void  init(  int  lim2, T minP ) {
111                  T k  =   2 , ig  =   2 ;
112                   int  i;
113                   do  {
114                           ++ k;
115                          p  =  ( (k << lim2)  |   1  );
116                  }  while  ( (p < minP)  ||  ( ! isPrime(p)) );
117                   while  (  ! isG(ig,p) ) {
118                           ++ ig;
119                  }
120                   for  ( i  =   0 ; i  <  bit;  ++ i ) {
121                          g[ i ]  =  powMod( ig, (k << (lim2 - bit + i)), p );
122                  }
123          }
124           void  fft() {
125                  T w, wm, u, t;
126                   int  s, m, m2, j, k;
127                  reverse();
128                   for  ( s  =  bit - 1 ; s  >=   0 -- s ) {
129                          m2  =  ( 1 << (bit - s));
130                          m  =  (m2 >> 1 );
131                          wm  =  g[ s ];
132                           for  ( k  =   0 ; k  <  n; k  +=  m2 ) {
133                                  w  =   1 ;
134                                   for  ( j  =   0 ; j  <  m;  ++ j ) {
135                                          t  =  ((LT)(w))  *  a[k + j + m]  %  p;
136                                          u  =  a[ k  +  j ];
137                                          a[ k  +  j ]  =  ( u  +  t )  %  p;
138                                          a[ k  +  j  +  m ]  =  ( u  +  p  -  t )  %  p;
139                                          w  =  ( ((LT)w)  *  wm )  %  p;
140                                  }
141                          }
142                  }
143          }
144           void  ifft() {
145                  T w, wm, u, t, inv;
146                   int  s, m, m2, j, k;
147                  reverse();
148                   for  ( s  =  bit - 1 ; s  >=   0 -- s ) {
149                          m2  =  ( 1 << (bit - s));
150                          m  =  (m2 >> 1 );
151                          wm  =  powMod( g[s], p - 2 , p );
152                           for  ( k  =   0 ; k  <  n; k  +=  m2 ) {
153                                  w  =   1 ;
154                                   for  ( j  =   0 ; j  <  m;  ++ j ) {
155                                          t  =  ((LT)(w))  *  a[k + j + m]  %  p;
156                                          u  =  a[ k  +  j ];
157                                          a[ k  +  j ]  =  ( u  +  t )  %  p;
158                                          a[ k  +  j  +  m ]  =  ( u  +  p  -  t )  %  p;
159                                          w  =  ( ((LT)w)  *  wm )  %  p;
160                                  }
161                          }
162                  }
163                  inv  =  powMod( n, p - 2 , p );
164                   for  ( k  =   0 ; k  <  n;  ++ k ) {
165                          a[ k ]  =  ( ((LT)inv)  *  a[ k ] )  %  p;
166                  }
167          }
168           void   out ( T e[] ) {
169                   int  i;
170                   for  ( i  =   0 ; i  <  n;  ++ i ) {
171                          e[ i ]  =  a[ i ];
172                  }
173          }
174 
175          T a[ L ], g[  100  ], p;
176           int  n, bit;
177  };
178 
179 
180 
181 
182 
183  #define   L  140000
184  typedef   long   long  Lint;
185 
186  FFT <  L,  int , Lint  >  fft;
187  char  s[ L ];
188 
189  void  bi_out(  int  a[] ) {
190           int  i, n;
191          n  =  a[  0  ];
192           for  ( i  =   0 ; i  <  n;  ++ i ) {
193                  s[ i ]  =   ' 0 '   +  a[ n  -  i ];
194          }
195          s[ n ]  =   0 ;
196          puts( s );
197  }
198 
199  int  bi_in(  int  a[] ) {
200           int  i, n;
201           if  ( scanf(  " %s " , s )  !=   1  ) {
202                   return   0 ;
203          }
204          a[  0  ]  =  n  =  strlen( s );
205           for  ( i  =   0 ; i  <  n;  ++ i ) {
206                  a[ n  -  i ]  =  s[ i ]  -   ' 0 ' ;
207          }
208           return   1 ;
209  }
210 
211  void  bi_mul(  int  c[],  int  a[],  int  b[] ) {
212           int  m, n, p, g, i;
213 
214          n  =  ( (a[ 0 ] > b[ 0 ])  ?  a[ 0 ] : b[ 0 ] );
215          n  <<=   1 ;
216 
217          m  =  a[  0  ];
218          fft.fft( a + 1 , m, n );
219 
220          m  =  b[  0  ];
221          fft.fft( b + 1 , m, n );
222 
223          p  =  fft.getP();
224 
225           for  ( i  =   1 ; i  <=  m;  ++ i ) {
226                  c[ i ]  =  (Lint)a[ i ]  *  b[ i ]  %  p;
227          }
228          fft.ifft( c + 1 , m, m );
229          g  =   0 ;
230           for  ( i  =   1 ; i  <=  m;  ++ i ) {
231                  g  +=  c[ i ];
232                  c[ i ]  =  g  %   10 ;
233                  g  /=   10 ;
234          }
235           for  ( i  =  a[ 0 ] + b[ 0 ]; (i > 1 ) && (c[i] == 0 );  -- i )
236                  ;
237          c[  0  ]  =  i;
238  }
239 
240  int  a[ L ], b[ L ], c[ L ];
241 
242  int  main() {
243           while  ( bi_in( a )  &&  bi_in( b ) ) {
244                  bi_mul( c, a, b );
245                  bi_out( c );
246          }
247           return   0 ;
248  }
249 


你可能感兴趣的:(全整数无浮点运算的 快速傅里叶变换FFT 加速 大整数乘法,整系数多项式乘法)