hdu 5307 He is Flying(推公式+FFT)

题意:总区间中有n个数(n<=100000),求每个区间和所对应的区间长度(j-i+1)和;

思路:母函数求得多项式为(∑ix^si)(∑x−^si−1)−(∑x^si)(∑(i−1)x−^si−1);

        用FFT求出多项式,由于精度范围(s=50000),需要使用long double;

#include<cstdio>

#include<cstring>

#include<cmath>

#include<algorithm>

using namespace std;

#define ld long double

ld pi=acos(-1.0);   //注意π的精度

int num[500010],sum[500010];

long long snum[500010],ans[500010];

long long p,res;

struct Complex{

   ld r;

   ld i;

   Complex(){};

   Complex(ld a,ld b){

       r=a;i=b;

   }

   Complex operator +(const Complex &t)const{

      return Complex(r+t.r,i+t.i);

   }

   Complex operator -(const Complex &t)const{

      return Complex(r-t.r,i-t.i);

   }

   Complex operator *(const Complex &t)const{

      return Complex(r*t.r-i*t.i,r*t.i+i*t.r);

   }

}a1[500010],a2[500010];

void fft(Complex y[],int n,int rev){

    for(int i=1,j,k,t;i<n;i++){

        for(j=0,k=n>>1,t=i;k;k>>=1,t>>=1) j=j<<1|t&1;

        if(i<j) swap(y[i],y[j]);

    }

    for(int s=2,ds=1;s<=n;ds=s,s<<=1){

      Complex wn=Complex(cos(rev*2*pi/s),sin(rev*2*pi/s)),w=Complex(1,0),t;

      for(int k=0;k<ds;k++,w=w*wn){

        for(int i=k;i<n;i+=s){

            t=w*y[i+ds];

            y[i+ds]=y[i]-t;

            y[i]=y[i]+t;

        }

      }

    }

    if(rev==-1) for(int i=0;i<n;i++) y[i].r/=n;

}

int main()

{

    int i,j,k,t,n;

    snum[0]=0;

    for(long long i=1;i<=100000;i++)

        snum[i]=snum[i-1]+i*(i+1)/2;  //区间长度

    scanf("%d",&t);

    while(t--){

        memset(sum,0,sizeof(sum));

        scanf("%d",&n);

        for(i=0;i<n;i++){

            scanf("%d",&num[i]);

        }

        sum[0]=num[0];

        for(i=1;i<n;i++){

           sum[i]=sum[i-1]+num[i];  //区间和

        }

        p=0;res=0;

        for(i=0;i<n;i++){

          if(num[i]==0){

             p++;

          }

          else{

            res+=snum[p];

             p=0;

          }

        }

        res+=snum[p];

        printf("%lld\n",res);  //预处理区间和为0的情况

        int total=sum[n-1];    //多项式长度

        int total2=total*2;

        memset(a1,0,sizeof(a1));

        memset(a2,0,sizeof(a2));

        int len=1;

        while(len<=total2) len<<=1;

        for(i=0;i<n;i++){       //构造多项式

            a1[sum[i]].r+=i+1;

            if(i!=n-1)

                a2[total-sum[i]].r+=1;

        }

        a2[total].r+=1;

        fft(a1,len,1);

        fft(a2,len,1);

        for(i=0;i<=len;i++) a1[i]=a1[i]*a2[i];

        fft(a1,len,-1);

        for(i=0;i<=len;i++) ans[i]=(long long)(a1[i].r+0.5);



        memset(a1,0,sizeof(a1));

        memset(a2,0,sizeof(a2));

        for(i=0;i<n;i++){    //构造多项式

           a1[sum[i]].r+=1;

           if(i!=n-1){

             a2[total-sum[i]].r+=i+1;

           }

        }

        fft(a1,len,1);

        fft(a2,len,1);

        for(i=0;i<=len;i++) a1[i]=a1[i]*a2[i];

        fft(a1,len,-1);

        for(i=0;i<=len;i++) ans[i]-=(long long)(a1[i].r+0.5);

        for(i=total+1;i<=total2;i++)  //结果

            printf("%lld\n",ans[i]);

    }

    return 0;

}

 

你可能感兴趣的:(HDU)