这是一道模板题。
给你两个多项式,请输出乘起来后的多项式。
输入格式
第一行两个整数 nn 和 mm,分别表示两个多项式的次数。
第二行 n+1n+1 个整数,分别表示第一个多项式的 00 到 nn 次项前的系数。
第三行 m+1m+1 个整数,分别表示第一个多项式的 00 到 mm 次项前的系数。
输出格式
一行 n+m+1n+m+1 个整数,分别表示乘起来后的多项式的 00 到 n+mn+m 次项前的系数。
样例一
input
1 2
1 2
1 2 1
output
1 4 5 2
explanation
(1+2x)⋅(1+2x+x2)=1+4x+5x2+2x3 。
限制与约定
0≤n,m≤1050≤n,m≤105 ,保证输入中的系数大于等于 00 且小于等于 99。
时间限制:1s1s
空间限制:256MB
如题目,这道题就是快速傅里叶变换的模板题,我在此粘模板,感谢Quack_quack大神的讲解。
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#define MAXN 300000
using namespace std;
int n,m,N;
struct cpx{
double r,i;
cpx(){}
cpx(double rr,double ii){
r=rr,i=ii;
}
inline cpx operator+(const cpx &x){
return cpx(r+x.r,i+x.i);
}
inline cpx operator-(const cpx&x){
return cpx(r-x.r,i-x.i);
}
inline cpx operator*(const cpx&x){
return cpx(r*x.r-i*x.i,r*x.i+i*x.r);
}
inline void operator*=(const cpx&x){
*this=*this*x;
}
}a[MAXN+10],b[MAXN+10],c[MAXN+10],d[MAXN+10];
template<class T>
void Read(T &x){
char c;
while(c=getchar(),c!=EOF)
if(c>='0'&&c<='9'){
x=c-'0';
while(c=getchar(),c>='0'&&c<='9')
x=x*10+c-'0';
ungetc(c,stdin);
return;
}
}
void read(){
Read(n),Read(m);
int i;
for(i=0;i<=n;i++)
Read(a[i].r);
for(i=0;i<=m;i++)
Read(b[i].r);
}
inline cpx cpx_pow(double n){
return cpx(cos(n),sin(n));
}
void fft(cpx *in,cpx *out,int step,int size,int dir){
if(size==1){
out[0]=in[0];
return;
}
fft(in,out,(step<<1),(size>>1),dir);
fft(in+step,out+(size>>1),(step<<1),(size>>1),dir);
int t=size>>1;
cpx w(1,0),w1(cos(dir*2*M_PI/size),sin(dir*2*M_PI/size)),tt;
for(int i=0;i<t;i++,w*=w1){
cpx even=out[i],odd=out[i+(size>>1)];
tt=w*odd;
out[i]=even+tt;
out[i+t]=even-tt;
}
}
void solve(){
int i,L=n+m+1;
for(N=1;N<L;N<<=1);
fft(a,c,1,N,1);
fft(b,d,1,N,1);
for(i=0;i<N;i++)
c[i]*=d[i];
fft(c,a,1,N,-1);
}
void print(){
int i,t=n+m;
for(i=0;i<=t;i++)
printf("%d ",int((a[i].r/N)+0.1));
}
int main()
{
read();
solve();
print();
}
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#define MAXN 300000
using namespace std;
template<class T>
void Read(T &x){
char c;
while(c=getchar(),c!=EOF)
if(c>='0'&&c<='9'){
x=c-'0';
while(c=getchar(),c>='0'&&c<='9')
x=x*10+c-'0';
ungetc(c,stdin);
return;
}
}
struct cpx{
double r,i;
cpx(){
}
cpx(double rr,double ii){
r=rr,i=ii;
}
inline cpx operator+(const cpx &x)const{
return cpx(r+x.r,i+x.i);
}
inline cpx operator-(const cpx &x)const{
return cpx(r-x.r,i-x.i);
}
inline cpx operator*(const cpx &x)const{
return cpx(r*x.r-i*x.i,r*x.i+i*x.r);
}
inline void operator*=(const cpx &x){
*this=*this*x;
}
}a[MAXN+10],b[MAXN+10];
int n,m,r[MAXN+10],Log,Len,N;
void read(){
Read(n),Read(m);
int i;
for(i=0;i<=n;i++)
Read(a[i].r);
for(i=0;i<=m;i++)
Read(b[i].r);
}
void fft(cpx *a,int f){
int i,j,k;
for(i=0;i<N;i++)
if(i<r[i])
swap(a[i],a[r[i]]);
for(i=1;i<N;i<<=1){
cpx wn(cos(M_PI/i),f*sin(M_PI/i)); //i是size/2,所以不用乘2
int t=i<<1;
for(j=0;j<N;j+=t){
cpx w(1,0);
for(k=0;k<i;k++,w*=wn){
cpx x=a[j+k],y=w*a[j+k+i];
a[j+k]=x+y,a[j+k+i]=x-y;
}
}
}
if(f==-1)
for(i=0;i<N;i++)
a[i].r/=N;
}
void solve(){
int i;
Len=n+m;
for(N=1;N<=Len;N<<=1)
Log++;
for(i=0;i<N;i++)
r[i]=(r[i>>1]>>1)|((i&1)<<(Log-1));
fft(a,1);
fft(b,1);
for(i=0;i<N;i++)
a[i]*=b[i];
fft(a,-1);
}
void print(){
int i;
for(i=0;i<=Len;i++)
printf("%d ",int(a[i].r+0.1));
}
int main()
{
read();
solve();
print();
}
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#define MAXN 300000
using namespace std;
template<class T>
void Read(T &x){
char c;
while(c=getchar(),c!=EOF)
if(c>='0'&&c<='9'){
x=c-'0';
while(c=getchar(),c>='0'&&c<='9')
x=x*10+c-'0';
ungetc(c,stdin);
return;
}
}
struct cpx{
double r,i;
cpx(){
}
cpx(double rr,double ii){
r=rr,i=ii;
}
inline cpx operator+(const cpx &x)const{
return cpx(r+x.r,i+x.i);
}
inline cpx operator-(const cpx &x)const{
return cpx(r-x.r,i-x.i);
}
inline cpx operator*(const cpx &x)const{
return cpx(r*x.r-i*x.i,r*x.i+i*x.r);
}
inline void operator*=(const cpx &x){
*this=*this*x;
}
}a[MAXN+10],b[MAXN+10];
int n,m,Len,N;
void read(){
Read(n),Read(m);
int i;
for(i=0;i<=n;i++)
Read(a[i].r);
for(i=0;i<=m;i++)
Read(b[i].r);
}
void fft(cpx *a,int f){
int i,j,k;
for(i=1,j=0;i<N-1;++i)
{
for(int d=N;j^=d>>=1,~j&d;);
if(i<j)swap(a[i],a[j]);
}
for(i=1;i<N;i<<=1){
cpx wn(cos(M_PI/i),f*sin(M_PI/i)); //i是size/2,所以不用乘2
for(j=0;j<N;j+=i<<1){
cpx w(1,0);
for(k=0;k<i;k++,w*=wn){
cpx x=a[j+k],y=w*a[j+k+i];
a[j+k]=x+y,a[j+k+i]=x-y;
}
}
}
if(f==-1)
for(i=0;i<N;i++)
a[i].r/=N;
}
void solve(){
int i;
Len=n+m;
for(N=1;N<=Len;N<<=1);
fft(a,1);
fft(b,1);
for(i=0;i<N;i++)
a[i]*=b[i];
fft(a,-1);
}
void print(){
for(int i=0;i<=Len;i++)
printf("%d ",int(a[i].r+0.1));
}
int main()
{
read();
solve();
print();
}